## EPAlign Prompt and Vision Finetune

### Config

In [None]:
import os
import numpy as np
import torch
from PIL import Image
import torch
from torch import nn, optim
import pandas as pd
import clip
from torch.utils.data import Dataset, DataLoader, BatchSampler
from tqdm import tqdm
import logging

# DATASET is the dataset name model trained on, e.g. RAF, MELD
DATASET = "RAF" # MELD

# BATCH_SIZE should smaller/equal to the category of the emotion, e.g. for RAF-DB, the category is 7
BATCH_SIZE = 7
EPOCH = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

PROJECT_PATH = os.path.join('/', *os.getcwd().split(os.sep)[:-2])
# PRETRAIN_MODEL is the pretrained model name, e.g. ViT-B/32
PRETRAIN_MODEL = "ViT-B/32"
# PRETRAIN_MODEL_PATH is the pretrained model path, e.g. EPAlign/ckpt/base
PRETRAIN_MODEL_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/base"
# RAF_DATA_PATH is the RAF-DB dataset path, e.g. data/RAF/compound/Image/original
RAF_DATA_PATH = f"{PROJECT_PATH}/data/{DATASET}/compound/Image/original"
# RAF_LABEL_PATH is the RAF-DB label path, e.g. data/RAF/compound/EmoLabel/list_patition_label.txt
RAF_LABEL_PATH = f"{PROJECT_PATH}/data/{DATASET}/compound/EmoLabel/list_patition_label.txt"
# LOG_PATH is the log path, e.g. EPAlign/log
LOG_PATH = f"{PROJECT_PATH}/EPAlign/log"
# CKPT_PATH is the path to save checkpoint, e.g. EPAlign/ckpt/RAF
CKPT_PATH = f"{PROJECT_PATH}/EPAlign/ckpt/{DATASET}"


### Use Preprocess

In [None]:
model, preprocess = clip.load(PRETRAIN_MODEL, device=device, jit=False, download_root=PRETRAIN_MODEL_PATH)

### Define Dataset

In [None]:
class RAFDataset(Dataset):
    def __init__(self, 
                 data_path="path/to/data", 
                 mode="train", 
                 datalist="path/to/label.txt", 
                 preprocess=None):
        self.data_path = data_path
        self.mode = mode
        self.datalist = datalist
        self.preprocess = preprocess
        self.data = self.load_data()
        if self.datalist.find('compound') != -1:
            self.label2text = { 1: "Happily Surprised", 2: "Happily Disgusted", 3: "Sadly Fearful", 4: "Sadly Angry", 5: "Sadly Surprised", 6: "Sadly Disgusted", 7: "Fearfully Angry", 8: "Fearfully Surprised", 9: "Angrily Surprised", 10: "Angrily Disgusted", 11: "Disgustedly Surprised"}
        else:
            self.label2text = { 1: "Surprise", 2: "Fear", 3: "Disgust", 4: "Happiness", 5: "Sadness", 6: "Anger", 7: "Neutral"}
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img = Image.open(self.data_path + '/' + self.data.iloc[idx]["path"])
        if self.preprocess:
            img = self.preprocess(img)
        label = self.data.iloc[idx]["label"]
        text = self.label2text[int(label)]
        return img, text, label
    
    def load_data(self):
        data = pd.read_csv(self.datalist, sep=" ", header=None)
        data.columns = ["path", "label"]
        if self.mode == "train":
            data = data[data["path"].str.contains("train")]
        elif self.mode == "test":
            data = data[data["path"].str.contains("test")]
        else:
            data = data[data["path"].str.contains("test")]
        return data

if DATASET == "RAF":
    train_dataset = RAFDataset(mode='train', data_path=RAF_DATA_PATH, datalist=RAF_LABEL_PATH, preprocess=preprocess)
    test_dataset = RAFDataset(mode='test', data_path=RAF_DATA_PATH, datalist=RAF_LABEL_PATH, preprocess=preprocess)
len(train_dataset), len(test_dataset)

### Define Batch Sample (ensures no same class per batch)

In [None]:
class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size
    
train_labels = torch.tensor([item[2] for item in train_dataset])
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)

test_labels = torch.tensor([item[2] for item in test_dataset])
test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler)

### Train Config

In [None]:
#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

# import itertools
# parameters = itertools.chain(model.visual.parameters(), [model.logit_scale])
parameters = model.parameters()
lr = 1e-5
# betas = (0.9,0.98)
# eps = 1e-6
# weight_decay = 0.2

# optimizer = optim.Adam(parameters, lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
optimizer = optim.Adam(parameters, lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader) * EPOCH)


### Train Log

In [None]:
os.makedirs(LOG_PATH, exist_ok=True)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

file_handler = logging.FileHandler(f"{LOG_PATH}/log_prompt_vision_{DATASET}.txt")

file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))

log = logging.getLogger('')
log.addHandler(file_handler)
log.info('finetune start...')

### Train

In [None]:
best_te_loss = 1e5
best_ep = -1
os.makedirs(CKPT_PATH, exist_ok=True)
for epoch in range(EPOCH):
    log.info(f"running epoch {epoch}, best test loss {best_te_loss} after epoch {best_ep}")
    step = 0
    tr_loss = 0
    model.train()
    pbar = tqdm(train_dataloader, leave=False)
    for batch in pbar:
        step += 1
        optimizer.zero_grad()

        images, texts, _ = batch
        images = images.to(device)
        texts = clip.tokenize(texts).to(device)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(BATCH_SIZE).to(device)

        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        total_loss.backward()
        tr_loss += total_loss.item()
        if device == "cpu":
            optimizer.step()
            scheduler.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            scheduler.step()
            clip.model.convert_weights(model)
        pbar.set_description(f"train batchCE: {total_loss.item()}", refresh=True)
    tr_loss /= step
    
    step = 0
    te_loss = 0
    with torch.no_grad():
        model.eval()
        test_pbar = tqdm(test_dataloader, leave=False)
        for batch in test_pbar:
            step += 1
            images, texts, _ = batch
            images = images.to(device)
            texts = clip.tokenize(texts).to(device)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            te_loss += total_loss.item()
            test_pbar.set_description(f"test batchCE: {total_loss.item()}", refresh=True)
        te_loss /= step
        
    if te_loss < best_te_loss:
        best_te_loss = te_loss
        best_ep = epoch
        torch.save(model.state_dict(), f"{CKPT_PATH}/best_model.pt")
    log.info(f"epoch {epoch}, tr_loss {tr_loss}, te_loss {te_loss}")
    # torch.save(model.state_dict(), f'{CKPT_PATH}/model_{epoch}.pt')
