In [None]:
%pip install torch torchvision transformers datasets bitsandbytes accelerate peft tensorboard

In [None]:
import os
from datetime import datetime

from datasets import load_dataset
from torch.utils.data import Dataset

In [None]:
class DocciDataset(Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("google/docci", split=split, trust_remote_code=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"],
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["description"],
                }
            ],
        }


datasets = {
    "train": DocciDataset("train"),
    "test": DocciDataset("test"),
}

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "qresearch/doubutsu-2b-pt-378",
)
model = AutoModelForCausalLM.from_pretrained(
    "qresearch/doubutsu-2b-pt-378",
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to("cuda")

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=128,
    lora_alpha=256,
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    task_type="CAUSAL_LM",
)
model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
EPOCHS = 3
BATCH_SIZE = 2
GRAD_ACCUM_STEPS = 16
LR = 2e-4
WEIGHT_DECAY = 0.0

IMG_TOKENS = 729

In [None]:
def collate_fn(batch):
    images = [sample["image"] for sample in batch]

    all_input_ids = []
    max_length = 0

    for sample in batch:
        messages = [
            {"role": "human", "content": sample["qa"][0]["question"]},
            {"role": "assistant", "content": sample["qa"][0]["answer"]},
        ]

        input_ids = tokenizer.apply_chat_template(
            messages, return_tensors="pt", add_generation_prompt=False
        )
        all_input_ids.append(input_ids.squeeze(0))  # Remove batch dimension
        max_length = max(max_length, input_ids.size(1))

    padded_input_ids = []
    attention_masks = []
    labels = []

    for input_ids in all_input_ids:
        padding_length = max_length - input_ids.size(0)

        padded_input_ids.append(
            torch.cat(
                [
                    input_ids,
                    torch.full(
                        (padding_length,), tokenizer.pad_token_id, dtype=torch.long
                    ),
                ]
            )
        )

        attention_mask = torch.ones(input_ids.size(0) + IMG_TOKENS, dtype=torch.long)
        attention_mask = torch.cat(
            [attention_mask, torch.zeros(padding_length, dtype=torch.long)]
        )
        attention_masks.append(attention_mask)

        label = padded_input_ids[-1].clone()
        label[label == tokenizer.pad_token_id] = -100

        # Find the start of the assistant's response
        assistant_starts = (padded_input_ids[-1] == tokenizer.eos_token_id).nonzero()
        if len(assistant_starts) > 1:
            assistant_start = assistant_starts[1].item()
            label[:assistant_start] = -100
        else:
            # If no assistant start found, mask everything
            label[:] = -100

        # Adjust labels to account for image tokens after BOS token or at the beginning
        if tokenizer.bos_token_id is not None and label[0] == tokenizer.bos_token_id:
            label = torch.cat(
                [
                    label[:1],
                    torch.full((IMG_TOKENS,), -100, dtype=torch.long),
                    label[1:],
                ]
            )

        else:
            label = torch.cat(
                [torch.full((IMG_TOKENS,), -100, dtype=torch.long), label]
            )

        labels.append(label)

    return (
        images,
        torch.stack(padded_input_ids),
        torch.stack(attention_masks),
        torch.stack(labels),
    )

In [None]:
def compute_loss(batch):
    images, input_ids, attention_masks, labels = batch

    input_ids = input_ids.to("cuda")
    attention_masks = attention_masks.to("cuda")
    labels = labels.to("cuda")

    with torch.no_grad():
        image_features = torch.stack([model.encode_image(img) for img in images])

    img_embs = model.mm_projector(image_features)
    img_embs = img_embs.squeeze(1)

    tok_embs = model.text_model.get_input_embeddings()(input_ids)

    has_bos = tokenizer.bos_token_id is not None
    first_token_is_bos = (
        input_ids[0, 0] == model.text_model.config.bos_token_id if has_bos else False
    )

    if has_bos and first_token_is_bos:
        inputs_embeds = torch.cat(
            [tok_embs[:, :1, :], img_embs, tok_embs[:, 1:, :]], dim=1
        )
    else:
        inputs_embeds = torch.cat([img_embs, tok_embs], dim=1)

    outputs = model.text_model(
        inputs_embeds=inputs_embeds,
        labels=labels,
        attention_mask=attention_masks,
    )
    return outputs.loss

In [None]:
import math

from bitsandbytes.optim import Adam8bit
from torch.utils.data import DataLoader
from tqdm import tqdm


def lr_schedule(step, max_steps):
    warmup_steps = int(0.1 * max_steps)

    if step < warmup_steps:
        return LR * step / warmup_steps
    else:
        progress = (step - warmup_steps) / (max_steps - warmup_steps)
        return LR * 0.5 * (1 + math.cos(math.pi * progress))


dataloaders = {
    "train": DataLoader(
        datasets["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
    )
}

run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

log_dir = os.path.join("./logs", run_name)
os.makedirs(log_dir, exist_ok=True)

# TEST THIS MORE. SIGLIP TRAINING SEEMS TO BE FINE THOUGH
# model.train()
model.text_model.train()
model.text_model.gradient_checkpointing_enable()
model.vision_model.gradient_checkpointing_enable()

total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
optimizer = Adam8bit(model.parameters(), lr=LR)

global_step = 0
for epoch in range(EPOCHS):
    for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
        global_step += 1

        loss = compute_loss(batch)
        loss.backward()

        if global_step % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

            lr = lr_schedule(global_step // GRAD_ACCUM_STEPS, total_steps)
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr


model.save_pretrained("qwenlora")

# Testing

In [None]:
from datasets import load_dataset  # noqa: F811
from torch.utils.data import Dataset


class CaptchaDataset(Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("google/docci", trust_remote_code=True)[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"],
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["description"],
                }
            ],
        }


datasets = {
    "train": CaptchaDataset("train"),
    "test": CaptchaDataset("test"),
}

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("qresearch/doubutsu-2b-pt-378")
model = AutoModelForCausalLM.from_pretrained(
    "qresearch/doubutsu-2b-pt-378",
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to("cuda")

In [None]:
model.load_adapter("qwenlora")

In [None]:
from IPython.display import display

sample = datasets["test"][0]
display(sample["image"])

for qa in sample["qa"]:
    print("Question:", qa["question"])
    print("Ground Truth:", qa["answer"])
    print(
        "qwenvision:",
        model.answer_question(
            sample["image"],
            qa["question"],
            tokenizer=tokenizer,
        ),
    )