# HW2-3 Image captioning

### Library

In [None]:
!pip install unsloth pandas pyarrow torch torchvision datasets transformer accelerate bitsandbytes

In [None]:
import pandas as pd
from PIL import Image
from datasets import Dataset
import torch
from torchvision import transforms
from unsloth import FastLanguageModel
import json

### Dataset

In [None]:
train_df = pd.read_parquet('train_data.parquet')
valid_df = pd.read_parquet('valid_data.parquet')
test_df = pd.read_parquet('test_data.parquet')

print(train_df.head())
print(valid_df.head())
print(test_df.head())

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def preprocess(sample):
    image = Image.open(sample["image"]).convert("RGB")
    image_tensor = transform(image)
    return {
        "image": image_tensor,
        "text": f"<|image|>\n{sample['caption']}\n"
    }

train_ds = Dataset.from_pandas(train_df)
train_ds = train_ds.map(preprocess)

valid_ds = Dataset.from_pandas(valid_df)
valid_ds = valid_ds.map(preprocess)

### Pretrained model

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-vision-11b",
    dtype = torch.float16,
    load_in_4bit = True,
)

### Finetuning

In [None]:
FastLanguageModel.for_instructions(model)

model.fit(
    train_ds, 
    valid_ds,
    tokenizer = tokenizer,
    epochs=2,
    batch_size = 4,
    lr = 2e-5,
    lora_r = 64,
    lora_alpha = 16,
    lora_dropout=0.05
)

### Test & Generate Caption

In [None]:
results = []

device = "cuda" if torch.cuda.is_available() else "cpu"

for row in test_df.itertuples():
    img = Image.open(row.image).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)

    prompt = "<|image|>\nDescribe this image.\n"

    with torch.no_grad():
        output = model.generate(
            inputs = tokenizer(prompt, return_tensors='pt').to(device),
            images = img_tensor,
            max_new_tokens = 50,
            do_sample = True,
            temperature = 0.5
        )

    caption = tokenizer.decode(output[0], skip_special_tokens=True)
    results.append({
        "idx": row.idx,
        "output": caption
    })

with open("submission.json", 'w') as f:
    json.dump(results, f, indent=2)