In [224]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
from peft import get_peft_model, LoraConfig, TaskType
from PIL import Image
import random
import pandas as pd
import os
import ast

In [226]:
class TripletDataset(Dataset):
    def __init__(self, path):
        self.anchor_positive_negative_triplets = []
        
        df = pd.read_csv(path, sep='\t')
        for index, row in df.iterrows():
            expected_order = ast.literal_eval(row["expected_order"])
            for i in range(0, 5):
                for j in range(i+1, 5):
                    self.anchor_positive_negative_triplets.append((row["sentence"], 
                                                                   os.path.join("train", row["compound"].replace("'s", "_s"), expected_order[i]), 
                                                                   os.path.join("train", row["compound"].replace("'s", "_s"), expected_order[j])))
        

    def __len__(self):
        return len(self.anchor_positive_negative_triplets)

    def __getitem__(self, idx):
        anchor_text, pos_img_path, neg_img_path = self.anchor_positive_negative_triplets[idx]
        pos_img = Image.open(pos_img_path).convert('RGB')
        neg_img = Image.open(neg_img_path).convert('RGB')

        return (anchor_text, pos_img, neg_img)

In [228]:
def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margin=0.3):
    pos_sim = torch.nn.functional.cosine_similarity(anchor_embedding, positive_embedding)
    neg_sim = torch.nn.functional.cosine_similarity(anchor_embedding, negative_embedding)
    loss = torch.relu(margin + neg_sim - pos_sim).mean()
    return loss

In [230]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [231]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    #task_type=TaskType.FEATURE_EXTRACTION
)
model = get_peft_model(model, lora_config)
print(model.__class__)

<class 'peft.peft_model.PeftModel'>


In [232]:
def collate_fn(batch):
    texts = [item[0] for item in batch]
    pos_images = [item[1] for item in batch]
    neg_images = [item[2] for item in batch]

    inputs_pos = processor(text=texts, images=pos_images, return_tensors='pt', padding=True, truncation=True)
    inputs_neg = processor(text=texts, images=neg_images, return_tensors='pt', padding=True, truncation=True)
    return inputs_pos, inputs_neg

In [238]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
dataset = TripletDataset("train/subtask_a_train.tsv")
loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
model.train()
for epoch in range(15):
    for inputs_pos, inputs_neg in loader:
        outputs_pos = model(**inputs_pos)
        outputs_neg = model(**inputs_neg)

        anchor_emb = outputs_pos.text_embeds
        pos_emb = outputs_pos.image_embeds
        neg_emb = outputs_neg.image_embeds

        loss = triplet_loss(anchor_emb, pos_emb, neg_emb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} - Loss: {loss.item():.4f}")

Epoch 1 - Loss: 0.2813
Epoch 2 - Loss: 0.2383
Epoch 3 - Loss: 0.2174
Epoch 4 - Loss: 0.1551
Epoch 5 - Loss: 0.2620
Epoch 6 - Loss: 0.1489
Epoch 7 - Loss: 0.1330
Epoch 8 - Loss: 0.0616
Epoch 9 - Loss: 0.0859
Epoch 10 - Loss: 0.0889
Epoch 11 - Loss: 0.1457
Epoch 12 - Loss: 0.0298
Epoch 13 - Loss: 0.0781
Epoch 14 - Loss: 0.0878
Epoch 15 - Loss: 0.0169
