# Fine-Tuning

In [1]:
import sys, os

# sys path hack to allow importing the encoding functions and other modules
sys.path.insert(0, os.path.abspath('../src'))
sys.path.insert(0, os.path.abspath('../externals'))

In [2]:
import torch 

if torch.cuda.is_available():
    gpu = torch.device("cuda")
elif torch.backends.mps.is_available():
    gpu = torch.device("mps")
else:
    print("Warning: no GPU detected, falling back to CPU")
    gpu = torch.device("cpu")

In [3]:
%load_ext tensorboard

## CLIP

In [4]:
from transformers import CLIPModel, CLIPImageProcessor, CLIPTokenizer

clip_version = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(clip_version).to(gpu)
image_processor = CLIPImageProcessor.from_pretrained(clip_version)
tokenizer = CLIPTokenizer.from_pretrained(clip_version)

In [5]:
import reload_recursive

%reload prompt_dataset
from prompt_dataset import PromptDataset
import pandas as pd
from torch.utils.data import DataLoader

training_rel_samples = pd.read_pickle("../data/finetuning/train_rel_samples_all.pkl")
training_attr_samples = pd.read_pickle("../data/finetuning/train_attr_samples_all.pkl")
training_samples = pd.concat([training_rel_samples, training_attr_samples])

def compute_prompt(sample):
    if not pd.isna(sample["object_name"]):
        # sample is for an attribute
        return f"{sample['attr_value']} {sample['object_name']}"
    else:
        # sample is for a relation
        return f"{sample['object0_name']} {sample['rel']} {sample['object1_name']}"
    
training_set = PromptDataset(training_samples, prompt_transform=compute_prompt, img_size=224, mode="scale")
training_loader = DataLoader(training_set, batch_size=32, shuffle=True)

In [6]:
validation_rel_samples = pd.read_pickle("../data/finetuning/val_rel_samples_10k.pkl")
validation_attr_samples = pd.read_pickle("../data/finetuning/val_attr_samples_10k.pkl")
validation_samples = pd.concat([validation_rel_samples, validation_attr_samples])

validation_set = PromptDataset(validation_samples, prompt_transform=compute_prompt, img_size=224, mode="scale")
validation_loader = DataLoader(validation_set, batch_size=32, shuffle=True)

In [7]:
training_samples.shape

(708469, 13)

In [8]:
validation_samples.shape

(75064, 13)

In [9]:
from transformers import DataCollator
from typing import List, Dict, Any
import numpy as np

class CLIPCollator:
    def __call__(self, features) -> Dict[str, Any]:
        images = [f[0] for f in features]
        texts = [f[1] for f in features]

        image_inputs = image_processor(images, return_tensors="pt", do_resize=False, do_center_crop=False)
        text_inputs = tokenizer(texts, return_tensors="pt", padding=True)

        return {**image_inputs, **text_inputs, "return_loss": True}

In [10]:
from transformers import TrainingArguments
from datetime import datetime

now = datetime.now()

training_args = TrainingArguments(
    output_dir="model_snapshots/clip_finetune2",
    learning_rate=2e-7,
    warmup_ratio=0.25,
    report_to="tensorboard",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=2000,
    remove_unused_columns=False,
    save_steps=2000,
    save_total_limit=5,
    logging_steps=500
)

In [11]:
from torch.utils.tensorboard import SummaryWriter
from transformers.integrations import TensorBoardCallback

layout = {
    "combined": {
        "loss": ["Multiline", ["train/loss", "eval/loss"]]
    },
}

writer = SummaryWriter(log_dir="runs/clip_finetune2")
writer.add_custom_scalars(layout)

from transformers import Trainer
trainer = Trainer(model=model, args=training_args, train_dataset=training_set, eval_dataset=validation_set, data_collator=CLIPCollator())

tb_callback = trainer.pop_callback(TensorBoardCallback)
tb_callback.tb_writer = writer
trainer.add_callback(tb_callback)
result = trainer.train()



Step,Training Loss,Validation Loss
2000,1.3581,1.470445
4000,1.0587,1.188135
6000,0.8474,1.013854
8000,0.72,0.914984
10000,0.6526,0.857171
12000,0.6257,0.819098
14000,0.5904,0.792808
16000,0.5449,0.765773
18000,0.5201,0.747992
20000,0.4976,0.737006


KeyboardInterrupt: 

## OWL-ViT

In [23]:
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor

owl_model = AutoModelForZeroShotObjectDetection.from_pretrained("google/owlvit-base-patch32")
owl_processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")

In [None]:
import torch
from torchvision.io import read_image, ImageReadMode

class PromptDataset(torch.utils.data.Dataset):
    def __init__(self, scene_graphs):
        self.scene_graphs = scene_graphs.items()

    def __len__(self):
        return len(self.scene_graphs)
    
    def __getitem__(self, idx):
        entry = self.scene_graphs[idx]
        image = read_image(f"../data/images/{entry['image_id']}.jpg", ImageReadMode.RGB)

        # crop bounding box
        y,x,h,w = get_scaled_bbox(entry, image.shape[1], image.shape[2])
        image = crop(image, y, x, h, w)

        if self.mode == "pad":
            # resize and scale (maintain aspect ratio)
            if entry["bbox_h"] > entry["bbox_w"]:
                resize_dimensions = (self.img_size, 2*round((self.img_size*entry["bbox_w"]/entry["bbox_h"])/2)) 
            else:
                resize_dimensions = (2*round((self.img_size*entry["bbox_h"]/entry["bbox_w"])/2), self.img_size)
            image = resize(image, resize_dimensions, antialias=True)

            # pad the image to square dimensions
            image = pad(image, ((self.img_size - resize_dimensions[1])//2, (self.img_size - resize_dimensions[0])//2))

        elif self.mode == "scale":
            # resize and scale the image to the target dimensions
            image = resize(image, (self.img_size, self.img_size), antialias=True)

        else: 
            raise RuntimeError("Unsupported image processing mode!")

        return (image, self.prompt_transform(entry), entry['y'])

In [27]:

training_args = TrainingArguments(
    output_dir="owl_finetune",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    learning_rate=1e-7,
    warmup_ratio=0.2,
    save_total_limit=5,
    evaluation_strategy="epoch",
    remove_unused_columns=False,
    logging_first_step=True,
    save_steps=2000,
)
