## Finetune CLIP for fashion dataset

In [None]:
import random
import torch
from PIL import Image
from matplotlib import pyplot as plt
from transformers import CLIPProcessor, CLIPModel

import model as fashion_clip

In [None]:
# define training arguments
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCH = 4
LR = 1e-6
WD = 1e-4
patience = 2

adjust_data_size = False

save_dir = "./results/"
data_dir = "./data/"

## Load data and model

In [None]:
# sample smaller datasets for quick test, with a balanced distribution for each class label
if adjust_data_size:
    fashion_clip.adjust_dataset_size("./data/train_data.json", "./data/small_train.json", 200)
    fashion_clip.adjust_dataset_size("./data/val_data.json", "./data/small_val.json", 40)
    fashion_clip.adjust_dataset_size("./data/test_data.json", "./data/small_test.json", 40)

# load data
train_data = fashion_clip.load_data("./data/train_data.json")
val_data = fashion_clip.load_data("./data/val_data.json")
test_data = fashion_clip.load_data("./data/test_data.json")

labels = list(set([data["class_label"] for data in train_data]))

In [None]:
# load model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# create dataset
train_dataset = list(fashion_clip.get_dataset(train_data, data_dir, processor, DEVICE))
val_dataset = list(fashion_clip.get_dataset(val_data, data_dir, processor, DEVICE))
test_dataset = list(fashion_clip.get_dataset(test_data, data_dir, processor, DEVICE))

## Evaluation before training

In [None]:
# evaluation before training as a baseline
# evaluation task: prediction on class label

texts = []
for cl in labels:
    texts.append(f"a photo of{cl}")

# top3 prediction for a single image
index = random.randint(0, len(test_data))           # randomly choose an example of the testset
image = Image.open(data_dir+test_data[index]['image_path'])
gold_label = test_data[index]["class_label"]        # gold label of the chosen example

text_features, image_feature = fashion_clip.get_features(texts, image, model, processor, DEVICE)    # get text and image features
fashion_clip.make_single_prediction(text_features, image_feature, 3, labels)
print(f"Correct label: {gold_label}")

# top1 precision for all images in test data
images = [Image.open(data_dir+data['image_path']) for data in test_data]        # preprocess all images in the testset
all_gold_labels = [labels.index(data["class_label"]) for data in test_data]     # gold labels of all images in the testset

image_features = fashion_clip.image_features(images, model, processor, DEVICE)  # we're using the same texts so no need to recalculate text features here
fashion_clip.make_full_prediction(text_features, image_features, all_gold_labels, 1)

## Train

In [None]:
# use Huggingface's Trainer to finetune the model, the best model will be saved in save_dir
clip_trainer = fashion_clip.FashionCLIPTrainer(model, train_dataset, val_dataset, save_dir, LR, WD, patience, BATCH_SIZE, EPOCH)
clip_trainer.trainer.train()

In [None]:
clip_trainer.trainer.evaluate()

In [None]:
# plot train and val loss
log_history = clip_trainer.trainer.state.log_history
train_losses = []
eval_losses = []
for log in log_history[:-1]:
    if "eval_loss" in log:
        eval_losses.append(log["eval_loss"])
    if "loss" in log:
        train_losses.append(log["loss"])

plt.plot(train_losses, label="train loss")
plt.plot(eval_losses, label="val loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend()
plt.savefig(save_dir+"loss.png")

## Evaluation after training

In [None]:
# top 3 prediction of a single image
text_features, image_feature = fashion_clip.get_features(texts, image, model, processor, DEVICE)
fashion_clip.make_single_prediction(text_features, image_feature,3,labels)
print(f"Correct label: {gold_label}")

# top 1 accuracy of all images
image_features = fashion_clip.image_features(images, model, processor, DEVICE)
fashion_clip.make_full_prediction(text_features, image_features, all_gold_labels, 1)