# Install dependencies

In [None]:
!pip install "einops==0.7.0" "bitsandbytes==0.45.2" "accelerate>=0.20.1" "ipywidgets>=8.1" "jupyterlab-widgets>=3"

In [None]:
import ipywidgets, sys
import jupyterlab_widgets
print("ipywidgets:", ipywidgets.__version__)         # Expect 8.x
print("jupyterlab-widgets:", jupyterlab_widgets.__version__)  # Expect 3.x on Lab 4
print(sys.version)


# Import libraries

In [None]:
from torch.utils.data import Dataset
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
from bitsandbytes.optim import Adam8bit
import math
from einops import rearrange
from tqdm import tqdm
from huggingface_hub import notebook_login
from PIL import Image

# Logging to Hugging Face

In [None]:
notebook_login()

# Preprocessing Datasets

In [None]:
class DentalDataset(Dataset):
    def __init__(self, split='train'):
        self.data = load_dataset("imagefolder", data_dir="/kaggle/input/biology-dataset/Biology-Dataset", split=split)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "image": sample["image"], # Should be a PIL image
            "qa": [
                {
                    "question": "Describe this image.",
                    "answer": sample["text"],
                }
            ]
        }

datasets = {
    "train": DentalDataset("train"),
    "test": DentalDataset("test"),
    "validation": DentalDataset("validation")
}

# Load base model

In [None]:
DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-05-20"

tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision=MD_REVISION)
moondream = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream2", revision=MD_REVISION, trust_remote_code=True,
    torch_dtype=DTYPE, device_map={"": DEVICE}
)

# Evaluate model before fine-tuning

In [None]:
sample = datasets['train'][0]
display(sample['image'])

for qa in sample['qa']:
    print('Question:\n', qa['question'])
    print('Ground Truth:\n', qa['answer'])
    print('Moondream:\n', moondream.answer_question(
        moondream.encode_image(sample['image']),
        qa['question'],
        tokenizer=tokenizer,
    ))

# Setting up hyperparameters for fine-tuning

In [None]:
EPOCHS = 1
BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 2
LR = 1e-5
USE_WANDB = False

In [None]:
ANSWER_EOS = "<|endoftext|>"

# Number of tokens used to represent each image.
IMG_TOKENS = 729

def collate_fn(batch):
    images = [sample['image'] for sample in batch]
    images = [img.convert("RGB") if hasattr(img, "mode") and img.mode != "RGB" else img for img in images]
    images = [moondream.vision_encoder.preprocess(image) for image in images]

    labels_acc = []
    tokens_acc = []

    for sample in batch:
        toks = [tokenizer.bos_token_id]
        labs = [-100] * (IMG_TOKENS + 1)

        for qa in sample['qa']:
            q_t = tokenizer(
                f"\n\nQuestion: {qa['question']}\n\nAnswer:",
                add_special_tokens=False
            ).input_ids
            toks.extend(q_t)
            labs.extend([-100] * len(q_t))

            a_t = tokenizer(
                f" {qa['answer']}{ANSWER_EOS}",
                add_special_tokens=False
            ).input_ids
            toks.extend(a_t)
            labs.extend(a_t)

        tokens_acc.append(toks)
        labels_acc.append(labs)

    max_len = -1
    for labels in labels_acc:
        max_len = max(max_len, len(labels))

    attn_mask_acc = []

    for i in range(len(batch)):
        len_i = len(labels_acc[i])
        pad_i = max_len - len_i

        labels_acc[i].extend([-100] * pad_i)
        tokens_acc[i].extend([tokenizer.eos_token_id] * pad_i)
        attn_mask_acc.append([1] * len_i + [0] * pad_i)

    return (
        images,
        torch.stack([torch.tensor(t, dtype=torch.long) for t in tokens_acc]),
        torch.stack([torch.tensor(l, dtype=torch.long) for l in labels_acc]),
        torch.stack([torch.tensor(a, dtype=torch.bool) for a in attn_mask_acc]),
    )

def compute_loss(batch):
    images, tokens, labels, attn_mask = batch

    tokens = tokens.to(DEVICE)
    labels = labels.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)

    with torch.no_grad():
        img_embs = moondream.vision_encoder(images)

    tok_embs = moondream.text_model.get_input_embeddings()(tokens)
    inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)

    outputs = moondream.text_model(
        inputs_embeds=inputs_embeds,
        labels=labels,
        attention_mask=attn_mask,
    )

    return outputs.loss

def lr_schedule(step, max_steps):
    x = step / max_steps
    if x < 0.1:
        return 0.1 * LR + 0.9 * LR * x / 0.1
    else:
        return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2

def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    count = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            loss = compute_loss(batch)
            total_loss += loss.item()
            count += 1

    model.train()
    return total_loss / count

dataloaders = {
    "train": DataLoader(
        datasets["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
    ),
    "validation": DataLoader(
        datasets["validation"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
    )
}

moondream.text_model.train()
moondream.text_model.transformer.gradient_checkpointing_enable()

total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
optimizer = Adam8bit(
    [
        {"params": moondream.text_model.parameters()},
    ],
    lr=LR * 0.1,
    betas=(0.9, 0.95),
    eps=1e-6
)

if USE_WANDB:
    import wandb
    wandb.init(
        project="moondream-ft",
        config={
            "EPOCHS": EPOCHS,
            "BATCH_SIZE": BATCH_SIZE,
            "GRAD_ACCUM_STEPS": GRAD_ACCUM_STEPS,
            "LR": LR,
        }
    )

i = 0
for epoch in range(EPOCHS):
    for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
        i += 1

        loss = compute_loss(batch)
        loss.backward()

        if i % GRAD_ACCUM_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()

            lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if USE_WANDB:
            wandb.log({
                "loss/train": loss.item(),
                "lr": optimizer.param_groups[0]['lr']
            })
        if i % 250 == 0:
            val_loss = evaluate(moondream.text_model, dataloaders["validation"])
            print(f"Validation Loss (step {i+1}): {val_loss:.4f}")

            if USE_WANDB:
                wandb.log({"loss/val": val_loss})

if USE_WANDB:
    wandb.finish()

# Save the fine-tuned model

In [None]:
moondream.save_pretrained("checkpoints/moondream-ft")

# Push to Hugging Face Hub

In [None]:
moondream.push_to_hub("moondream-ft")

# Reload the fine-tuned model for evaluation


In [None]:
# moondream = AutoModelForCausalLM.from_pretrained("./checkpoints/moondream-ft", trust_remote_code=True)
# moondream.eval()
from safetensors.torch import load_file

DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 # CPU doesn't support float16
MD_REVISION = "2024-05-20"

tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision=MD_REVISION)
base_model = AutoModelForCausalLM.from_pretrained(
    "vikhyatk/moondream2", revision=MD_REVISION, trust_remote_code=True,
    torch_dtype=DTYPE, device_map={"": DEVICE}
)

state_dict = load_file("checkpoints/moondream-ft/model.safetensors", device="cpu")

# 3. Apply weights to the model
missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
print("✅ Loaded fine-tuned weights.")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

moondream2 = base_model.eval().to(DEVICE)

# Evaluation and Inference


In [None]:
# for i, sample in enumerate(datasets['test']):
#     md_answer = moondream.answer_question(
#         moondream.encode_image(sample['image']),
#         sample['qa'][0]['question'],
#         tokenizer=tokenizer,
#         num_beams=4,
#         no_repeat_ngram_size=5,
#         early_stopping=True
#     )

#     if i < 8:
#         display(sample['image'])
#         print('Question:', sample['qa'][0]['question'])
#         print('Ground Truth:', sample['qa'][0]['answer'])
#         print('Moondream:', md_answer)
#     else:
#         break

import json

results = []

for i, sample in tqdm(enumerate(datasets['test'])):
    img = sample["image"]
    img = img.convert("RGB") if hasattr(img, "mode") and img.mode != "RGB" else img
    md_answer = moondream2.answer_question(
        moondream2.encode_image(img),
        sample['qa'][0]['question'],
        tokenizer=tokenizer,
        num_beams=4,
        no_repeat_ngram_size=5,
        early_stopping=True
    )

    results.append({
        "id": i,
        "question": sample['qa'][0]['question'],
        "ground_truth": sample['qa'][0]['answer'],
        "prediction": md_answer
    })

    if i < 8:
        display(sample['image'])
        print('Question:', sample['qa'][0]['question'])
        print('Ground Truth:', sample['qa'][0]['answer'])
        print('Moondream:', md_answer)
    # optional limit
    # else:
    #     break

# save to disk
with open("predictions.json", "w") as f:
    json.dump(results, f, indent=2)

print("✅ Saved predictions to predictions.json")