In [None]:
%pip install torch torchvision transformers datasets bitsandbytes accelerate peft tensorboard

In [None]:
import os
from datetime import datetime

from datasets import load_dataset
from torch.utils.data import Dataset

In [None]:
import math
from typing import List, Tuple

from PIL import Image


def generate_grid_configurations(size: int) -> List[Tuple[int, int]]:
    grid_configs = [
        (2 * size, 2 * size),
        (1 * size, 2 * size),
        (1 * size, 3 * size),
        (1 * size, 4 * size),
        (4 * size, 1 * size),
        (3 * size, 1 * size),
        (2 * size, 1 * size),
    ]
    return grid_configs


def select_best_resolution(original_size, possible_resolutions):
    """
    Selects the best resolution from a list of possible resolutions based on the original size.

    Args:
        original_size (tuple): The original size of the image in the format (width, height).
        possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].

    Returns:
        tuple: The best fit resolution in the format (width, height).
    """
    original_width, original_height = original_size
    best_fit = None
    max_effective_resolution = 0
    min_wasted_resolution = float("inf")

    for width, height in possible_resolutions:
        scale = min(width / original_width, height / original_height)
        downscaled_width, downscaled_height = (
            int(original_width * scale),
            int(original_height * scale),
        )
        effective_resolution = min(
            downscaled_width * downscaled_height, original_width * original_height
        )
        wasted_resolution = (width * height) - effective_resolution

        if effective_resolution > max_effective_resolution or (
            effective_resolution == max_effective_resolution
            and wasted_resolution < min_wasted_resolution
        ):
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (width, height)

    return best_fit


def resize_and_pad_image(image, target_resolution):
    """
    Resize and pad an image to a target resolution while maintaining aspect ratio.

    Args:
        image (PIL.Image.Image): The input image.
        target_resolution (tuple): The target resolution (width, height) of the image.

    Returns:
        PIL.Image.Image: The resized and padded image.
    """
    original_width, original_height = image.size
    target_width, target_height = target_resolution

    scale_w = target_width / original_width
    scale_h = target_height / original_height

    if scale_w < scale_h:
        new_width = target_width
        new_height = min(math.ceil(original_height * scale_w), target_height)
    else:
        new_height = target_height
        new_width = min(math.ceil(original_width * scale_h), target_width)

    # Resize the image
    resized_image = image.resize((new_width, new_height))

    new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
    paste_x = (target_width - new_width) // 2
    paste_y = (target_height - new_height) // 2
    new_image.paste(resized_image, (paste_x, paste_y))

    return new_image


def divide_to_patches(image, patch_size):
    """
    Divides an image into patches of a specified size.

    Args:
        image (PIL.Image.Image): The input image.
        patch_size (int): The size of each patch.

    Returns:
        list: A list of PIL.Image.Image objects representing the patches.
    """
    patches = []
    width, height = image.size
    for i in range(0, height, patch_size):
        for j in range(0, width, patch_size):
            box = (j, i, j + patch_size, i + patch_size)
            patch = image.crop(box)
            patches.append(patch)

    return patches


def slice_anyres_image(image, patch_size=378):
    grid_pinpoints = generate_grid_configurations(patch_size)

    best_resolution = select_best_resolution(image.size, grid_pinpoints)
    image_padded = resize_and_pad_image(image, best_resolution)

    patches = divide_to_patches(image_padded, patch_size)

    size = {"shortest_edge": patch_size}
    image_original_resize = image.resize((size["shortest_edge"], size["shortest_edge"]))

    image_patches = [image_original_resize] + patches

    return image_patches

In [None]:
class DocciDataset(Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("google/docci", trust_remote_code=True)[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"],
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["description"],
                }
            ],
        }


datasets = {
    "train": DocciDataset("train"),
    "test": DocciDataset("test"),
}

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("qresearch/doubutsu-2b-pt-378")
model = AutoModelForCausalLM.from_pretrained(
    "qresearch/doubutsu-2b-pt-378",
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to("cuda")

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=128,
    lora_alpha=256,
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    task_type="CAUSAL_LM",
)

model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
EPOCHS = 3
BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 32
LR = 2e-4
WEIGHT_DECAY = 0.0

IMG_TOKENS = 729

In [None]:
def collate_fn(batch):
    all_image_patches = []
    all_input_ids = []
    max_length = 0
    max_image_tokens = 0

    for sample in batch:
        image_patches = slice_anyres_image(sample["image"])
        all_image_patches.append(image_patches)
        max_image_tokens = max(max_image_tokens, len(image_patches) * IMG_TOKENS)

        messages = [
            {"role": "human", "content": sample["qa"][0]["question"]},
            {"role": "assistant", "content": sample["qa"][0]["answer"]},
        ]

        input_ids = tokenizer.apply_chat_template(
            messages, return_tensors="pt", add_generation_prompt=False
        )
        all_input_ids.append(input_ids.squeeze(0))
        max_length = max(max_length, input_ids.size(1))

    padded_input_ids = []
    attention_masks = []
    labels = []

    for input_ids, image_patches in zip(all_input_ids, all_image_patches):
        image_tokens = len(image_patches) * IMG_TOKENS
        padding_length = max_length - input_ids.size(0)

        padded_input_id = torch.cat(
            [
                input_ids,
                torch.full((padding_length,), tokenizer.pad_token_id, dtype=torch.long),
            ]
        )
        padded_input_ids.append(padded_input_id)

        attention_mask = torch.ones(input_ids.size(0) + image_tokens, dtype=torch.long)
        attention_mask = torch.cat(
            [
                attention_mask,
                torch.zeros(
                    max_image_tokens - image_tokens + padding_length, dtype=torch.long
                ),
            ]
        )
        attention_masks.append(attention_mask)

        label = padded_input_id.clone()
        label[label == tokenizer.pad_token_id] = -100

        assistant_starts = (padded_input_id == tokenizer.eos_token_id).nonzero()
        if len(assistant_starts) > 1:
            assistant_start = assistant_starts[1].item()
            label[:assistant_start] = -100
        else:
            label[:] = -100

        if tokenizer.bos_token_id is not None and label[0] == tokenizer.bos_token_id:
            label = torch.cat(
                [
                    label[:1],
                    torch.full((image_tokens,), -100, dtype=torch.long),
                    label[1:],
                ]
            )
        else:
            label = torch.cat(
                [torch.full((image_tokens,), -100, dtype=torch.long), label]
            )

        labels.append(label)

    return (
        all_image_patches,
        torch.stack(padded_input_ids),
        torch.stack(attention_masks),
        torch.stack(labels),
    )

In [None]:
def compute_loss(batch):
    image_patches, input_ids, attention_masks, labels = batch

    input_ids = input_ids.to("cuda")
    attention_masks = attention_masks.to("cuda")
    labels = labels.to("cuda")

    all_img_embs = []
    for patches in image_patches:
        with torch.no_grad():
            image_features = torch.stack([model.encode_image(img) for img in patches])

        img_embs = model.mm_projector(image_features)

        img_embs = img_embs.view(
            -1, img_embs.size(-1)
        )  # Flatten to [num_patches * IMG_TOKENS, dim_llm]
        img_embs = img_embs.to(model.dtype)

        all_img_embs.append(img_embs)

    max_image_tokens = max(emb.size(0) for emb in all_img_embs)

    padded_img_embs = []
    for img_emb in all_img_embs:
        padding = torch.zeros(
            max_image_tokens - img_emb.size(0),
            img_emb.size(1),
            device=img_emb.device,
            dtype=model.dtype,
        )
        padded_img_emb = torch.cat([img_emb, padding], dim=0)
        padded_img_embs.append(padded_img_emb)

    img_embs = torch.stack(padded_img_embs)

    tok_embs = model.text_model.get_input_embeddings()(input_ids)

    if tokenizer.bos_token_id is not None:
        inputs_embeds = torch.cat(
            [tok_embs[:, :1, :], img_embs, tok_embs[:, 1:, :]], dim=1
        )
    else:
        inputs_embeds = torch.cat([img_embs, tok_embs], dim=1)

    outputs = model.text_model(
        inputs_embeds=inputs_embeds,
        labels=labels,
        attention_mask=attention_masks,
    )

    return outputs.loss

In [None]:
from bitsandbytes.optim import Adam8bit
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


def lr_schedule(step, max_steps):
    warmup_steps = int(0.1 * max_steps)

    if step < warmup_steps:
        return LR * step / warmup_steps
    else:
        progress = (step - warmup_steps) / (max_steps - warmup_steps)
        return LR * 0.5 * (1 + math.cos(math.pi * progress))


dataloaders = {
    "train": DataLoader(
        datasets["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
    )
}

run_name = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

log_dir = os.path.join("./logs", run_name)
os.makedirs(log_dir, exist_ok=True)

writer = SummaryWriter(log_dir=log_dir)


# TEST THIS MORE. SIGLIP TRAINING SEEMS TO BE FINE THOUGH
model.train()
#model.text_model.train()
model.text_model.gradient_checkpointing_enable()
model.vision_model.gradient_checkpointing_enable()

total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
optimizer = Adam8bit(model.parameters(), lr=LR)

global_step = 0
for epoch in range(EPOCHS):
    for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
        global_step += 1

        loss = compute_loss(batch)
        loss.backward()

        if global_step % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

            lr = lr_schedule(global_step // GRAD_ACCUM_STEPS, total_steps)
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

        writer.add_scalar("Loss/train", loss.item(), global_step)
        writer.add_scalar("LR", optimizer.param_groups[0]["lr"], global_step)

writer.close()

model.save_pretrained("qwenlora")

# Testing


In [None]:
from datasets import load_dataset  # noqa: F811
from torch.utils.data import Dataset


class CaptchaDataset(Dataset):
    def __init__(self, split="train"):
        self.data = load_dataset("google/docci", trust_remote_code=True)[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"],
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["description"],
                }
            ],
        }


datasets = {
    "train": CaptchaDataset("train"),
    "test": CaptchaDataset("test"),
}

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("qresearch/doubutsu-2b-pt-756")
model = AutoModelForCausalLM.from_pretrained(
    "qresearch/doubutsu-2b-pt-756",
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to("cuda")

In [None]:
model.load_adapter("qwenlora")

In [None]:
from IPython.display import display

sample = datasets["test"][0]
display(sample["image"])

for qa in sample["qa"]:
    print("Question:", qa["question"])
    print("Ground Truth:", qa["answer"])
    print(
        "qwenvision:",
        model.answer_question(
            sample["image"],
            qa["question"],
            tokenizer=tokenizer,
        ),
    )