In [None]:
%pip install torch torchvision transformers datasets bitsandbytes accelerate peft tensorboard

In [None]:
import os
from datetime import datetime

from datasets import load_dataset
from torch.utils.data import Dataset

In [None]:
class DocciDataset(Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("google/docci", trust_remote_code=True)[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"],
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["description"],
                }
            ],
        }


datasets = {
    "train": DocciDataset("train"),
    "test": DocciDataset("test"),
}

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("qresearch/doubutsu-2b-pt-756")
model = AutoModelForCausalLM.from_pretrained(
    "qresearch/doubutsu-2b-pt-756",
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to("cuda")

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=128,
    lora_alpha=256,
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "mm_projector.model.0",
        "mm_projector.model.2"
    ],
    task_type="CAUSAL_LM",
)

model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
EPOCHS = 3
BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 32
LR = 2e-4
WEIGHT_DECAY = 0.0

USE_TENSORBOARD = True
TRAIN_ENCODER = True

In [None]:
def collate_fn(batch):
    all_images = []
    all_input_ids = []
    max_length = 0

    for sample in batch:
        all_images.append(sample["image"])

        messages = [
            {"role": "human", "content": sample["qa"][0]["question"]},
            {"role": "assistant", "content": sample["qa"][0]["answer"]},
        ]

        input_ids = tokenizer.apply_chat_template(
            messages, return_tensors="pt", add_generation_prompt=False
        )
        all_input_ids.append(input_ids.squeeze(0))
        max_length = max(max_length, input_ids.size(1))

    padded_input_ids = []
    attention_masks = []
    labels = []

    for input_ids in all_input_ids:
        padding_length = max_length - input_ids.size(0)

        padded_input_id = torch.cat(
            [
                input_ids,
                torch.full((padding_length,), tokenizer.pad_token_id, dtype=torch.long),
            ]
        )
        padded_input_ids.append(padded_input_id)

        attention_mask = torch.ones(input_ids.size(0), dtype=torch.long)
        attention_mask = torch.cat(
            [
                attention_mask,
                torch.zeros(padding_length, dtype=torch.long),
            ]
        )
        attention_masks.append(attention_mask)

        label = padded_input_id.clone()
        label[label == tokenizer.pad_token_id] = -100

        assistant_starts = (padded_input_id == tokenizer.eos_token_id).nonzero()
        if len(assistant_starts) > 1:
            assistant_start = assistant_starts[1].item()
            label[:assistant_start] = -100
        else:
            label[:] = -100

        labels.append(label)

    return (
        all_images,
        torch.stack(padded_input_ids),
        torch.stack(attention_masks),
        torch.stack(labels),
    )

In [None]:
def compute_loss(batch):
    images, input_ids, attention_masks, labels = batch

    input_ids = input_ids.to("cuda")
    attention_masks = attention_masks.to("cuda")
    labels = labels.to("cuda")

    all_img_embs = []
    img_lengths = []
    for i, image in enumerate(images):
        image_embeds = model.encode_image(image)
        img_embs = model.mm_projector(image_embeds)
        all_img_embs.append(img_embs)
        img_lengths.append(img_embs.size(1))

    max_img_length = max(img_lengths)
    
    padded_img_embs = []
    img_attention_masks = []
    for img_emb, length in zip(all_img_embs, img_lengths):
        padding_length = max_img_length - length
        padded_img_emb = torch.cat([
            img_emb,
            torch.zeros(1, padding_length, img_emb.size(2), device=img_emb.device, dtype=model.dtype)
        ], dim=1)
        padded_img_embs.append(padded_img_emb)
        
        img_attention_mask = torch.cat([
            torch.ones(length, device=img_emb.device, dtype=torch.long),
            torch.zeros(padding_length, device=img_emb.device, dtype=torch.long)
        ])
        img_attention_masks.append(img_attention_mask)
        
    img_embs = torch.cat(padded_img_embs, dim=0)
    img_attention_masks = torch.stack(img_attention_masks)

    tok_embs = model.text_model.get_input_embeddings()(input_ids)

    inputs_embeds = torch.cat([img_embs, tok_embs], dim=1)

    full_attention_masks = torch.cat([img_attention_masks, attention_masks], dim=1)

    padded_labels = torch.full((labels.shape[0], max_img_length), -100, device=labels.device, dtype=labels.dtype)
    padded_labels = torch.cat([padded_labels, labels], dim=1)

    outputs = model.text_model(
        inputs_embeds=inputs_embeds,
        labels=padded_labels,
        attention_mask=full_attention_masks,
    )
    return outputs.loss

In [None]:
from bitsandbytes.optim import Adam8bit
from torch.utils.data import DataLoader
from tqdm import tqdm
import math

if USE_TENSORBOARD:
    from torch.utils.tensorboard import SummaryWriter

def text_schedule(step, max_steps, lr):
    warmup_steps = int(0.1 * max_steps)
    if step < warmup_steps:
        return lr * step / warmup_steps
    else:
        progress = (step - warmup_steps) / (max_steps - warmup_steps)
        return lr * 0.5 * (1 + math.cos(math.pi * progress))

def vision_schedule(step, max_steps, lr):
    warmup_steps = int(0.5 * max_steps)
    target_lr = text_schedule(warmup_steps, max_steps, lr)

    if step < warmup_steps:
        return target_lr * step / warmup_steps
    else:
        return text_schedule(step, max_steps, lr)
dataloaders = {
    "train": DataLoader(
        datasets["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
    )
}

run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

if USE_TENSORBOARD:
    log_dir = os.path.join("./logs", run_name)
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)

model.train()
model.text_model.gradient_checkpointing_enable()

if TRAIN_ENCODER:
    model.vision_model.train()
    model.vision_model.gradient_checkpointing_enable()
else:
    model.vision_model.eval()
    for param in model.vision_model.parameters():
        param.requires_grad = False
    for param in model.mm_projector.parameters():
        param.requires_grad = False


total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
text_optimizer = Adam8bit(model.text_model.parameters(), lr=LR)

if TRAIN_ENCODER:
    vision_params = list(model.vision_model.parameters()) + list(model.mm_projector.parameters())
    vision_optimizer = Adam8bit(vision_params, lr=LR)

global_step = 0
lr_text = 0
lr_vision = 0

for epoch in range(EPOCHS):
    for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
        global_step += 1
        loss = compute_loss(batch)
        loss.backward()
        
        # Update learning rates in every iteration
        current_step = global_step // GRAD_ACCUM_STEPS
        lr_text = text_schedule(current_step, total_steps, LR)
        if TRAIN_ENCODER:
            lr_vision = vision_schedule(current_step, total_steps, LR)
        
        if global_step % GRAD_ACCUM_STEPS == 0:
            for param_group in text_optimizer.param_groups:
                param_group["lr"] = lr_text
            text_optimizer.step()
            text_optimizer.zero_grad()
            
            if TRAIN_ENCODER:
                for param_group in vision_optimizer.param_groups:
                    param_group["lr"] = lr_vision
                vision_optimizer.step()
                vision_optimizer.zero_grad()
        
        if USE_TENSORBOARD:
            writer.add_scalar("Loss/train", loss.item(), global_step)
            writer.add_scalar("LR/text", lr_text, global_step)
            if TRAIN_ENCODER:
                writer.add_scalar("LR/vision", lr_vision, global_step)

if USE_TENSORBOARD:
    writer.close()

model.save_pretrained("qwenlora")



# Testing


In [None]:
from datasets import load_dataset  # noqa: F811
from torch.utils.data import Dataset


class CaptchaDataset(Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("google/docci", trust_remote_code=True)[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"],
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["description"],
                }
            ],
        }


datasets = {
    "train": CaptchaDataset("train"),
    "test": CaptchaDataset("test"),
}

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("qresearch/doubutsu-2b-pt-756")
model = AutoModelForCausalLM.from_pretrained(
    "qresearch/doubutsu-2b-pt-756",
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to("cuda")

In [None]:
model.load_adapter("qwenlora")

In [None]:
from IPython.display import display

sample = datasets["test"][0]
display(sample["image"])

for qa in sample["qa"]:
    print("Question:", qa["question"])
    print("Ground Truth:", qa["answer"])
    print(
        "qwenvision:",
        model.answer_question(
            sample["image"],
            qa["question"],
            tokenizer=tokenizer,
        ),
    )