In [1]:
# !git clone https://github.com/zengqunzhao/Exp-CLIP.git
!pip install ftfy salesforce-lavis

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting salesforce-lavis
  Downloading salesforce_lavis-1.0.2-py3-none-any.whl.metadata (18 kB)
Collecting contexttimer (from salesforce-lavis)
  Downloading contexttimer-0.3.3.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting decord (from salesforce-lavis)
  Downloading decord-0.6.0-py3-none-manylinux2010_x86_64.whl.metadata (422 bytes)
Collecting fairscale==0.4.4 (from salesforce-lavis)
  Downloading fairscale-0.4.4.tar.gz (235 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.4/235.4 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting iopath (from salesforce-lavis)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m

In [2]:
import sys
sys.path.insert(0, '/kaggle/working/Exp-CLIP')

import torch
from torch import nn
from models.clip import clip
# from models.BLIP2_T5 import *
import torch.nn.functional as F

import argparse
import time
import os
import random
import numpy as np
import datetime
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

import pickle
from tqdm import tqdm


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=512)
parser.add_argument('--batch-size-test-image', type=int, default=512)
parser.add_argument('--batch-size-test-video', type=int, default=64)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--print-freq', type=int, default=100)
parser.add_argument('--milestones', nargs='+', type=int, default=30)
parser.add_argument('--seed', type=int, default = 1)
parser.add_argument('--job-id', type=str, default=str(int(time.time())))
parser.add_argument('--instruction', type=str, default='Please play the role of a facial action describer. Objectively describe the detailed facial actions of the person in the image.')
parser.add_argument('--load-model', type=str, default='CLIP_L14')

# Use parse_known_args() to handle extra arguments
args, unknown = parser.parse_known_args()

random.seed(args.seed)  
np.random.seed(args.seed) 
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

now = datetime.datetime.now()
train_time = now.strftime("%y-%m-%d %H:%M")
print("Training date: ", train_time)
job_id = args.job_id

print('************************')
for k, v in vars(args).items():
    print(k,'=',v)
print('************************')
# Handle unknown arguments if needed
if unknown:
    print('Ignored Arguments:', unknown)

Training date:  25-01-24 10:32
************************
workers = 8
epochs = 5
batch_size = 512
batch_size_test_image = 512
batch_size_test_video = 64
lr = 0.001
weight_decay = 0.0001
momentum = 0.9
print_freq = 100
milestones = 30
seed = 1
job_id = 1737714723
instruction = Please play the role of a facial action describer. Objectively describe the detailed facial actions of the person in the image.
load_model = CLIP_L14
************************
Ignored Arguments: ['-f', '/root/.local/share/jupyter/runtime/kernel-84438eb5-1862-46c3-b91c-e6c57092cf47.json']


# RAF-DB Dataloader

In [16]:
# Define the dataset class
class CustomImageDataset(Dataset):
    def __init__(self, txt_file, img_dir, transform=None):
        self.img_labels = []
        self.img_dir = img_dir
        self.transform = transform

        with open(txt_file, 'r') as file:
            for line in file.readlines():
                filename, label = line.strip().split()
                name, ext = filename.strip().split('.')
                filename = name + '_aligned.' + ext
                self.img_labels.append((filename, int(label)))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [None]:

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Dataset and DataLoader setup
txt_file = '/kaggle/input/rafdb-dg/EmoLabel/train_label.txt'
img_dir = '/kaggle/input/rafdb-dg/aligned/aligned'

dataset = CustomImageDataset(txt_file=txt_file, img_dir=img_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

# CAER-S Dataloader

In [12]:
class FolderBasedDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        """
        Args:
            root_dirs (list): List of root directories containing class folders.
            transform (callable, optional): A function/transform to apply to the images.
        """
        self.img_labels = []  # To store (image_path, class_label)
        self.transform = transform
        self.labels = []
        self.img_names = []

        # Loop through each root directory
        for root_dir in root_dirs:
            # Iterate over the class folders
            for class_folder in sorted(os.listdir(root_dir)):  # Alphabetically sorted for consistent label assignment
                self.labels.append(class_folder)
                class_path = os.path.join((os.path.join(root_dir, class_folder)), class_folder)
                if os.path.isdir(class_path):  # Ensure it's a folder
                    # Assign the folder name as the class label
                    class_label = class_folder
                    # Collect all image paths in this folder
                    for img_name in os.listdir(class_path):
                        img_path = os.path.join(class_path, img_name)
                        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):  # Supported image formats
                            self.img_labels.append((img_path, class_label))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path, class_label = self.img_labels[idx]
        image = Image.open(img_path).convert('RGB')  # Convert to RGB

        if self.transform:
            image = self.transform(image)

        return image, class_label, img_path


In [13]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# List of root directories for CAERS datasets
root_dirs = [
    '/kaggle/input/caer-s-train-1',
    '/kaggle/input/caer-s-train-2',
    '/kaggle/input/caer-s-train-3'
]

# Initialize dataset
dataset = FolderBasedDataset(root_dirs=root_dirs, transform=transform)

# Initialize DataLoader
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)

# Inspect some samples
for images, labels, img_paths in data_loader:
    print("Batch size:", len(images))
    print("Labels:", labels[:3])  # Print first 3 labels in the batch
    print("Image paths:", img_paths[:3])  # Print first 3 image paths in the batch
    break


Batch size: 8
Labels: ('Neutral', 'Neutral', 'Fear')
Image paths: ('/kaggle/input/caer-s-train-3/Neutral/Neutral/4567.png', '/kaggle/input/caer-s-train-3/Neutral/Neutral/0897.png', '/kaggle/input/caer-s-train-2/Fear/Fear/6343.png')


# CLIP Extraction

In [6]:
# Load CLIP model
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-L/14", device)


100%|████████████████████████████████████████| 890M/890M [00:05<00:00, 172MiB/s]


In [10]:
# Output directory
output_dir = "/kaggle/working/clip_encodings_caer"
os.makedirs(output_dir, exist_ok=True)

# Save a chunk of data
def save_chunk(data, chunk_index):
    file_path = os.path.join(output_dir, f"encodings_chunk_{chunk_index}.pkl")
    with open(file_path, "wb") as f:
        pickle.dump(data, f)

# Save label mapping
def save_label_mapping(label_mapping):
    mapping_path = os.path.join(output_dir, "label_mapping.pkl")
    with open(mapping_path, "wb") as f:
        pickle.dump(label_mapping, f)

# Generate a label mapping
all_labels = sorted(set(dataset.labels))  # Assuming `dataset.labels` contains all possible labels
label_mapping = {label: idx for idx, label in enumerate(all_labels)}

# Save the label mapping
save_label_mapping(label_mapping)

# Main loop to process data
chunk_size = 20000  # Number of entries per pickle file
chunk_data = {}     # Dictionary to store data indexed by filename
chunk_index = 0

# Check for progress file
progress_file = os.path.join(output_dir, "progress.txt")
start_index = 0
if os.path.exists(progress_file):
    with open(progress_file, "r") as f:
        start_index = int(f.read().strip())

print("Starting from batch index:", start_index)

Starting from batch index: 0


In [11]:
clip_model.eval()
with torch.no_grad():
    for batch_idx, (images, labels, img_names) in enumerate(tqdm(data_loader, desc="Processing batches")):
        if batch_idx < start_index:
            continue

        images = images.to(device)

        # Generate image encodings
        features = clip_model.encode_image(images)

        # Assume logit scale is part of the model (adjust if needed)
        logit_scale = clip_model.logit_scale.exp().item()

        for i, feature in enumerate(features):
            img_name = img_names[i]
            label = labels[i]  # String label, e.g., "Angry", "Sad"
            label_idx = label_mapping[label]  # Map to integer index

            # Add data to the dictionary, indexed by filename
            chunk_data[img_name] = {
                "encoding": feature.cpu().numpy(),
                "logit_scale": logit_scale,
                "label": label,
                "label_idx": label_idx,
            }

            # Save the chunk when the buffer is full
            if len(chunk_data) >= chunk_size:
                save_chunk(chunk_data, chunk_index)
                chunk_data = {}
                chunk_index += 1

        # Save progress to resume in case of interruption
        with open(progress_file, "w") as f:
            f.write(str(batch_idx))

# Save any remaining data
if chunk_data:
    save_chunk(chunk_data, chunk_index)

# Clean up progress file
if os.path.exists(progress_file):
    os.remove(progress_file)

print("Encoding complete and saved.")

Processing batches: 100%|██████████| 6126/6126 [22:57<00:00,  4.45it/s]

Encoding complete and saved.





In [36]:
output_dir = "/kaggle/working/clip_encodings_caer"

# Function to load all chunks of encodings
def load_encodings(output_dir):
    encodings = {}
    for file_name in os.listdir(output_dir):
        if file_name.startswith("encodings_chunk_") and file_name.endswith(".pkl"):
            file_path = os.path.join(output_dir, file_name)
            with open(file_path, "rb") as f:
                chunk = pickle.load(f)
                encodings.update(chunk)
    return encodings

# Function to load the label mapping
def load_label_mapping(output_dir):
    mapping_path = os.path.join(output_dir, "label_mapping.pkl")
    with open(mapping_path, "rb") as f:
        label_mapping = pickle.load(f)
    return label_mapping

# Load encodings and label mapping
encodings = load_encodings(output_dir)
label_mapping = load_label_mapping(output_dir)

# Print basic info about the loaded data
print(f"Total images encoded: {len(encodings)}")
print(f"Label mapping: {label_mapping}")

Total images encoded: 49007
Label mapping: {'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Neutral': 4, 'Sad': 5, 'Surprise': 6}


# BLIP_2 Extraction

In [13]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

def get_blip2t5_model(device):
    model, _, _ = load_model_and_preprocess(
        name="blip2_t5",
        model_type="pretrain_flant5xl",
        is_eval=True,
        device=device,
    )
    model.generate = MethodType(generate, model)
    return model

blip_instruction = args.instruction

blip2 = get_blip2t5_model(device)



vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

100%|██████████| 1.89G/1.89G [00:08<00:00, 230MB/s] 
  state_dict = torch.load(cached_file, map_location="cpu")


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.45G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

100%|██████████| 407M/407M [00:01<00:00, 227MB/s]  
  checkpoint = torch.load(cached_file, map_location="cpu")


In [22]:
# Save a chunk of data indexed by image name
def save_chunk(data, chunk_index, output_dir):
    file_path = os.path.join(output_dir, f"encodings_chunk_{chunk_index}.pkl")
    with open(file_path, "wb") as f:
        pickle.dump(data, f)

# Function to save text features in chunks, indexed by image name
def save_text_features_in_chunks(dataloader, blip2, tokens, output_dir, chunk_size=1000, batch_size=8):
    os.makedirs(output_dir, exist_ok=True)

    chunk_data = {}
    chunk_index = 0

    # Check for progress file
    progress_file = os.path.join(output_dir, "progress.txt")
    start_index = 0
    if os.path.exists(progress_file):
        with open(progress_file, "r") as f:
            start_index = int(f.read().strip())

    print("Starting from batch index:", start_index)

    for batch_idx, (images, labels, img_names) in enumerate(tqdm(dataloader, desc="Processing batches")):
        if batch_idx < start_index:
            continue

        images = images.to(device)

        # Regenerate tokens dynamically if this is the last batch and its size is smaller
        if len(images) < batch_size:
            tokens = get_tokens(blip_instruction, blip2.t5_tokenizer, len(images), device)

        # Generate text features
        _, text_features = blip2.generate({"image": images, "tokens": tokens})
        text_features = torch.mean(text_features, dim=1)
        text_features = text_features.float()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        for i, feature in enumerate(text_features):
            img_name = img_names[i]
            label = labels[i]
            chunk_data[img_name] = {
                "encoding": feature.cpu().numpy(),
                "label": label,
            }

            if len(chunk_data) >= chunk_size:
                save_chunk(chunk_data, chunk_index, output_dir)
                chunk_data = {}
                chunk_index += 1

        # Save progress
        with open(progress_file, "w") as f:
            f.write(str(batch_idx))

    # Save any remaining data
    if chunk_data:
        save_chunk(chunk_data, chunk_index, output_dir)

    # Clean up progress file
    if os.path.exists(progress_file):
        os.remove(progress_file)

    print("Encoding complete and saved.")

In [23]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
batch_size = 8

tokens = get_tokens(blip_instruction, blip2.t5_tokenizer, batch_size, device)

# Define transformations
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])

# # Dataset and DataLoader setup
# txt_file = '/kaggle/input/rafdb-dg/EmoLabel/train_label.txt'
# img_dir = '/kaggle/input/rafdb-dg/aligned/aligned'

# dataset = CustomImageDataset(txt_file=txt_file, img_dir=img_dir, transform=transform)
# data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

output_dir = "/kaggle/working/blip2_encodings_caer"
chunk_size = 20000  # Number of encodings per pickle file

save_text_features_in_chunks(data_loader, blip2, tokens, output_dir, chunk_size=chunk_size, batch_size=batch_size)

Starting from batch index: 0


  return torch.cuda.amp.autocast(dtype=dtype)
Processing batches: 100%|██████████| 6126/6126 [1:34:38<00:00,  1.08it/s]

Encoding complete and saved.





# Consolidating Data

In [4]:
class ProcessedDataset(Dataset):
    def __init__(self, clip_encodings, blip_encodings):
        self.encodings = []
        self.labels = []
        
        for key in clip_encodings.keys():
            self.encodings.append((
                clip_encodings[key]['encoding'],
                blip_encodings[key]['encoding'],
                clip_encodings[key]['logit_scale']
            ))
            self.labels.append(clip_encodings[key]['label'])
            
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_features, text_features, logit_scale = self.encodings[idx]
        label = self.labels[idx]
        return image_features, text_features, logit_scale, label


In [5]:
# Function to load all chunks of encodings
def load_clip_encodings(output_dir):
    encodings = {}
    for file_name in os.listdir(output_dir):
        if file_name.startswith("encodings_chunk_") and file_name.endswith(".pkl"):
            file_path = os.path.join(output_dir, file_name)
            with open(file_path, "rb") as f:
                chunk = pickle.load(f)
                encodings.update(chunk)
    return encodings

def load_blip_encodings(output_dir):
    encodings = {}
    for file_name in sorted(os.listdir(output_dir)):
        if file_name.startswith("encodings_chunk_") and file_name.endswith(".pkl"):
            print(f'processing {file_name}')
            file_path = os.path.join(output_dir, file_name)
            with open(file_path, "rb") as f:
                chunk_data = pickle.load(f)
                encodings.update(chunk_data)
    return encodings

# Function to load the label mapping
def load_label_mapping(output_dir):
    mapping_path = os.path.join(output_dir, "label_mapping.pkl")
    with open(mapping_path, "rb") as f:
        label_mapping = pickle.load(f)
    return label_mapping

clip_encodings = load_clip_encodings('/kaggle/working/clip_encodings_caer')
blip_encodings = load_blip_encodings('/kaggle/working/blip2_encodings_caer')

label_mapping = load_label_mapping('/kaggle/working/clip_encodings_caer')

# Print basic info about the loaded data
print(f"Total images encoded: {len(clip_encodings)}")
print(f"Label mapping: {label_mapping}")

processing encodings_chunk_0.pkl
processing encodings_chunk_1.pkl
processing encodings_chunk_2.pkl
Total images encoded: 49007
Label mapping: {'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Neutral': 4, 'Sad': 5, 'Surprise': 6}


In [72]:
dataset = ProcessedDataset(clip_encodings, blip_encodings)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [74]:
class ExpCLIP_Train(nn.Module):
    def __init__(self, args):
        super().__init__()

        device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    
        self.projection_head = Linear_Matrix_L14()

    def forward(self, image_features):
        image_features = image_features.float()
        image_features = self.projection_head(image_features)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        return image_features

class Linear_Matrix_L14(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(768, 2048, bias=False)
    def forward(self, x):
        return self.mlp(x)


In [75]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix="", log_txt_path=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix
        self.log_txt_path = log_txt_path

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print_txt = '\t'.join(entries)
        print(print_txt)
        with open(self.log_txt_path, 'a') as f:
            f.write(print_txt + '\n')

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class RecorderMeter(object):
    """Computes and stores the minimum loss value and its epoch index"""
    def __init__(self, total_epoch):
        self.reset(total_epoch)

    def reset(self, total_epoch):
        self.total_epoch = total_epoch
        self.current_epoch = 0
        self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32)    # [epoch, train/val]
        self.epoch_accuracy = np.zeros((self.total_epoch, 2), dtype=np.float32)  # [epoch, train/val]

    def update(self, idx, train_loss, train_acc):
        self.epoch_losses[idx, 0] = train_loss * 50
        self.epoch_accuracy[idx, 0] = train_acc
        self.current_epoch = idx + 1

    def plot_curve(self, save_path):

        title = 'the accuracy/loss curve of train/val'
        dpi = 80
        width, height = 1600, 800
        legend_fontsize = 10
        figsize = width / float(dpi), height / float(dpi)

        fig = plt.figure(figsize=figsize)
        x_axis = np.array([i for i in range(self.total_epoch)])  # epochs
        y_axis = np.zeros(self.total_epoch)

        plt.xlim(0, self.total_epoch)
        plt.ylim(0, 100)
        interval_y = 5
        interval_x = 1
        plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
        plt.yticks(np.arange(0, 100 + interval_y, interval_y))
        plt.grid()
        plt.title(title, fontsize=20)
        plt.xlabel('the training epoch', fontsize=16)
        plt.ylabel('accuracy', fontsize=16)

        y_axis[:] = self.epoch_accuracy[:, 0]
        plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_accuracy[:, 1]
        plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 0]
        plt.plot(x_axis, y_axis, color='g', linestyle=':', label='train-loss-x50', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 1]
        plt.plot(x_axis, y_axis, color='y', linestyle=':', label='valid-loss-x50', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        if save_path is not None:
            fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
        plt.close(fig)

In [80]:
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def train(train_loader, model, criterion, optimizer, epoch, args, log_txt_path):
    losses = AverageMeter('Loss', ':.4f')
    top1 = AverageMeter('Accuracy', ':6.3f')
    progress = ProgressMeter(len(train_loader),
                             [losses, top1],
                             prefix="Epoch: [{}]".format(epoch),
                             log_txt_path=log_txt_path)

    # switch to train mode
    model.train()

    # Use tqdm to show progress
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch}", unit="batch") as pbar:
        for i, (image_features, text_features, logit_scale, labels) in enumerate(train_loader):
            image_features = image_features.to(device)
            text_features = text_features.to(device)
            logit_scale = logit_scale[0]
            logit_scale = logit_scale.to(device)
            
            n, _ = image_features.shape
            target = torch.arange(n)
            target = target.to(device)

            # compute output
            image_features = model(image_features)
            
            logits_per_image = logit_scale * image_features @ text_features.T
            logits_per_text = logit_scale * text_features @ image_features.T

            loss_vision = criterion(logits_per_image, target)
            loss_text = criterion(logits_per_text, target)
            
            loss = 0.5 * loss_vision + 0.5 * loss_text
            
            # measure accuracy and record loss
            acc1, _ = accuracy(logits_per_image, target, topk=(1, 3))
            losses.update(loss.item(), image_features.size(0))
            top1.update(acc1[0], image_features.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update tqdm progress bar
            pbar.set_postfix({"Loss": losses.avg, "Accuracy": top1.avg.item()})
            pbar.update(1)
            
            # Optionally display progress at intervals
            # if i % args.print_freq == 0:
            #     progress.display(i)
            
    return top1.avg, losses.avg


In [85]:
log_txt_path = './log/' + job_id + '-log.txt'
log_curve_path = './log/' + job_id + '-log.png'
checkpoint_path = './checkpoint/' + job_id

os.makedirs('log', exist_ok = True)
os.makedirs('checkpoint', exist_ok = True)

recorder = RecorderMeter(args.epochs)

In [86]:
import torch.backends.cudnn as cudnn

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

model = ExpCLIP_Train(args)
# define loss function (criterion)
criterion = nn.CrossEntropyLoss().cuda()

# only open learnable part
for name, param in model.named_parameters():
    param.requires_grad = False
for name, param in model.named_parameters():
    if "projection_head" in name:
        param.requires_grad = True 

model = model.to(device)
        
# define optimizer
optimizer = torch.optim.SGD([{"params": model.projection_head.parameters(), "lr": args.lr}],
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

# define scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.milestones], gamma=0.1)
cudnn.benchmark = True

In [94]:
for epoch in range(0, args.epochs):

    inf = '********************' + str(epoch) + '********************'
    start_time = time.time()
    current_learning_rate = optimizer.state_dict()['param_groups'][0]['lr']
    with open(log_txt_path, 'a') as f:
        f.write(inf + '\n')
        print(inf)
        f.write('Current learning rate: ' + str(current_learning_rate) + '\n')      
        
    # train for one epoch
    train_acc, train_los = train(dataloader, model, criterion, optimizer, epoch, args, log_txt_path)
    scheduler.step()

    # print and save log
    epoch_time = time.time() - start_time
    recorder.update(epoch, train_los, train_acc)
    recorder.plot_curve(log_curve_path)
    print('The train accuracy: {:.3f}'.format(train_acc.item()))
    print('An epoch time: {:.2f}s'.format(epoch_time))
    with open(log_txt_path, 'a') as f:
        f.write('The best accuracy: ' + str(train_acc.item()) + '\n')
        f.write('An epoch time: ' + str(epoch_time) + 's' + '\n')

#  save model and conduct zero-shot prediction
checkpoint_name = checkpoint_path + '-model.pth'
torch.save(model.projection_head.state_dict(), checkpoint_name)

********************0********************


Epoch 0: 100%|██████████| 96/96 [00:09<00:00, 10.66batch/s, Loss=0.705, Accuracy=89.7]


The train accuracy: 89.722
An epoch time: 9.02s
********************1********************


Epoch 1: 100%|██████████| 96/96 [00:09<00:00, 10.58batch/s, Loss=0.702, Accuracy=89.7]


The train accuracy: 89.701
An epoch time: 9.08s
********************2********************


Epoch 2: 100%|██████████| 96/96 [00:09<00:00,  9.79batch/s, Loss=0.701, Accuracy=89.9]


The train accuracy: 89.887
An epoch time: 9.81s
********************3********************


Epoch 3: 100%|██████████| 96/96 [00:09<00:00, 10.54batch/s, Loss=0.699, Accuracy=89.8]


The train accuracy: 89.824
An epoch time: 9.12s
********************4********************


Epoch 4: 100%|██████████| 96/96 [00:09<00:00, 10.66batch/s, Loss=0.7, Accuracy=89.8]  


The train accuracy: 89.757
An epoch time: 9.01s


# Testing

In [3]:
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import matplotlib
matplotlib.use('Agg')
import numpy as np
from data_loader.video_dataloader import test_data_loader
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from models.Text import *
from models.Exp_CLIP import ExpCLIP_Test
import argparse
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import itertools

  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


In [4]:
import time

parser = argparse.ArgumentParser()
parser.add_argument('--load-model', type=str, default='CLIP_L14')
parser.add_argument('--job-id', type=str, default=str(int(time.time())))  # Default job ID is a timestamp

# Use parse_known_args() to handle extra arguments
args, unknown = parser.parse_known_args()

# pretrain_model_path = './checkpoint/' + args.job_id + "-model.pth"
pretrain_model_path = '/kaggle/working/checkpoint/1737699389-model.pth'
# pretrain_model_path  = '/kaggle/working/Exp-CLIP/checkpoint/ExpCLIP_L14_model.pth'

print('************************')
for k, v in vars(args).items():
    print(k, '=', v)
print('************************')

************************
load_model = CLIP_L14
job_id = 1737950964
************************


In [5]:
class ExpCLIP_Test(nn.Module):
    def __init__(self, args):
        super().__init__()

        device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    
        if args.load_model == 'CLIP_B32':
            self.clip_model, _ = clip.load("ViT-B/32", device)
        elif args.load_model == 'CLIP_B16':
            self.clip_model, _ = clip.load("ViT-B/16", device)
        elif args.load_model == 'CLIP_L14':
            self.clip_model, _ = clip.load("ViT-L/14", device)

        if args.load_model == 'CLIP_L14':
            self.projection_head = Linear_Matrix_L14()
        else:
            self.projection_head = Linear_Matrix()

    def forward(self, image, text=None, mode_task=None):
        
        if mode_task=='Static_FER':
            image_features = self.clip_model.encode_image(image)
            image_features = image_features.float()
            image_features = self.projection_head(image_features)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        elif mode_task=='Dynamic_FER':
            n, t, c, h, w = image.shape
            image = image.contiguous().view(-1, c, h, w)
            image_features = self.clip_model.encode_image(image)
            image_features = image_features.float()
            image_features = self.projection_head(image_features)
            image_features = image_features.reshape(n, t, -1)
            image_features = torch.mean(image_features, dim=1)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        text_tokenized = clip.tokenize(text, context_length=77, truncate=True).to('cuda')
        text_features = self.clip_model.encode_text(text_tokenized)
        text_features = text_features.float()
        text_features = self.projection_head(text_features)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        logit_scale = self.clip_model.logit_scale.exp()

        return logit_scale, image_features, text_features


class Linear_Matrix_L14(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(768, 2048, bias=False)
        # self.mlp = nn.Linear(768, 4096, bias=False)
    def forward(self, x):
        return self.mlp(x)

In [6]:
# create model and load pre_trained parameters
model = ExpCLIP_Test(args)
# model = torch.nn.DataParallel(model).cuda() 
state_dict = model.state_dict()
pre_train_model = torch.load(pretrain_model_path, map_location=torch.device('cpu') )
# pre_train_model = torch.load(pretrain_model_path)
for name, param in pre_train_model.items():
    if "mlp.weight" in name:
        state_dict["projection_head.mlp.weight"].copy_(param)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

100%|████████████████████████████████████████| 890M/890M [00:09<00:00, 101MiB/s]
  pre_train_model = torch.load(pretrain_model_path, map_location=torch.device('cpu') )


ExpCLIP_Test(
  (clip_model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
            )
            (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamic

In [7]:
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, txt_file, img_dir, transform=None):
        """
        Args:
            txt_file (str): Path to the .txt file with image filenames and labels.
            img_dir (str): Path to the directory containing images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.img_labels = []
        self.img_dir = img_dir
        self.transform = transform

        # Read the txt file to get the image paths and labels
        with open(txt_file, 'r') as file:
            for line in file.readlines():
                # Assuming format: image_filename label
                filename, label = line.strip().split()
                name,ext = filename.strip().split('.')                #imagename is name.jpg  --->  name_algined.jpg
                filename  = name + '_aligned.' + ext
                self.img_labels.append((filename, int(label)-1))  # Store as a tuple (image_filename, label)

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)  # Construct the full image path
        image = Image.open(img_path).convert('RGB')  # Open image

        if self.transform:
            image = self.transform(image)  # Apply the transformations

        return image, label

# Define the transformations (similar to the ones in ImageFolder)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Example usage
test_data_path = '/kaggle/input/rafdb-dg/aligned/aligned'  # Path to the folder containing images
test_txt_file = '/kaggle/input/rafdb-dg/EmoLabel/train_label.txt'  # Path to the text file listing image names and labels

test_dataset = CustomImageDataset(txt_file=test_txt_file, img_dir=test_data_path, transform=transform)

# Create DataLoader for batching and shuffling
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


In [8]:
test_data = CustomImageDataset(txt_file=test_txt_file, img_dir=test_data_path,
                              transform=transforms.Compose([transforms.Resize((224, 224)),
                                                 transforms.ToTensor()]))

test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=4,
                                              shuffle=False,
                                              num_workers=8,
                                              pin_memory=True)

In [44]:
with torch.no_grad():
    for i, (images, target) in enumerate(test_loader):
    
        images = images.to(device)
        target = target.to(device)
        break

FER_prompt_ = {'RAFDB':prompt}
zero_shot_prompt = FER_prompt_['RAFDB']
n,_,_,_ = images.shape
logit_scale, image_features, text_features = model(image=images,text=zero_shot_prompt, mode_task="Static_FER") 

In [45]:
prompt_number = int(len(zero_shot_prompt) / 7)

output = logit_scale * image_features @ text_features.t()
output = output.view(n, -1, prompt_number)
output = torch.mean(output, dim=-1)
predicted = output.argmax(dim=1, keepdim=True)

predicted

tensor([[5],
        [4],
        [3],
        [6]], device='cuda:0')

In [46]:
target

tensor([4, 4, 3, 3], device='cuda:0')

In [22]:
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
import itertools
import matplotlib.pyplot as plt
import numpy as np

def zero_shot_test(set=0, dataset_=None, mode_task=None, FER_prompt_=None, prompt_type=None):

    DATASET_PATH_MAPPING = {
        "RAFDB": "/kaggle/input/rafdb-dg/aligned/aligned",
        "AffectNet7": "/data/EECS-IoannisLab/datasets/Static_FER_Datasets/AffectNet7_Face/test/",
        "AffectNet8": "/data/EECS-IoannisLab/datasets/Static_FER_Datasets/AffectNet8_Face/test/",
        "FERPlus": "/data/EECS-IoannisLab/datasets/Static_FER_Datasets/FERPlus_Face/test/",
        "DFEW": "./annotation/DFEW_set_"+str(set+1)+"_test.txt",
        "FERV39k": "./annotation/FERV39k_test.txt",
        "MAFW": "./annotation/MAFW_set_"+str(set+1)+"_test.txt",
        "AFEW": "./annotation/AFEW_validation.txt",
    }
    test_data_path = DATASET_PATH_MAPPING[dataset_]
    zero_shot_prompt = FER_prompt_[dataset_]
    
    if dataset_ in ["RAFDB", "AffectNet7", "DFEW", "FERV39k", "AFEW"]:
        prompt_number = int(len(zero_shot_prompt) / 7)
    elif dataset_ in ["AffectNet8", "FERPlus"]:
        prompt_number = int(len(zero_shot_prompt) / 8)
    elif dataset_ in ["MAFW"]:
        prompt_number = int(len(zero_shot_prompt) / 11)

    if mode_task == "Static_FER":
        batch_size_ = 512
        test_data = CustomImageDataset(txt_file=test_txt_file, img_dir=test_data_path,
                                       transform=transforms.Compose([transforms.Resize((224, 224)),
                                                                     transforms.ToTensor()]))
        confusion_matrix_path = "./confusion_matrix/"+args.load_model+"-"+dataset_+'-'+prompt_type+'.pdf'
    elif mode_task == "Dynamic_FER":
        batch_size_ = 64
        test_data = test_data_loader(list_file=test_data_path,
                                     num_segments=16,
                                     duration=1,
                                     image_size=224)
        confusion_matrix_path = "./confusion_matrix/"+args.load_model+"-"+dataset_+ '-' + str(set)+'-'+prompt_type+'.pdf'

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size_,
                                              shuffle=False,
                                              num_workers=8,
                                              pin_memory=True)
    correct = 0

    # Initialize variables for metrics
    all_predicted = []
    all_targets = []

    # Add a progress bar using tqdm
    with tqdm(total=len(test_loader), desc="Testing Progress", unit="batch") as pbar:
        with torch.no_grad():
            for i, (images, target) in enumerate(test_loader):
                images = images.to(device)
                target = target.to(device)
                
                n, _, _, _ = images.shape
                logit_scale, image_features, text_features = model(image=images, text=zero_shot_prompt, mode_task="Static_FER")

                output = logit_scale * image_features @ text_features.t()
                output = output.view(n, -1, prompt_number)
                output = torch.mean(output, dim=-1)

                predicted = output.argmax(dim=1, keepdim=True) % 6
                correct += predicted.eq(target.view_as(predicted)).sum().item()

                all_predicted.extend(predicted.cpu().numpy().flatten())
                all_targets.extend(target.cpu().numpy().flatten())

                pbar.update(1)

    # Calculate accuracy
    accuracy = 100. * correct / len(test_loader.dataset)
    
    # Calculate F1 score, precision, and recall
    f1 = f1_score(all_targets, all_predicted, average='weighted')
    precision = precision_score(all_targets, all_predicted, average='weighted')
    recall = recall_score(all_targets, all_predicted, average='weighted')

    # Compute confusion matrix
    _confusion_matrix = confusion_matrix(all_targets, all_predicted)
    np.set_printoptions(precision=4)
    normalized_cm = _confusion_matrix.astype('float') / _confusion_matrix.sum(axis=1)[:, np.newaxis]
    normalized_cm = normalized_cm * 100
    list_diag = np.diag(normalized_cm)
    uar = list_diag.mean()

    # Print metrics
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"UAR: {uar:.4f}")

    # Plot normalized confusion matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(normalized_cm, interpolation='nearest', cmap=plt.cm.Reds)
    plt.colorbar()
    tick_marks = np.arange(len(Emotion_name_dic[dataset_]))
    plt.xticks(tick_marks, Emotion_name_dic[dataset_], rotation=45)
    plt.yticks(tick_marks, Emotion_name_dic[dataset_])

    fmt = '.2f'
    thresh = normalized_cm.max() / 2.
    for i, j in itertools.product(range(normalized_cm.shape[0]), range(normalized_cm.shape[1])):
        plt.text(j, i, format(normalized_cm[i, j], fmt), fontsize=12,
                 horizontalalignment="center",
                 color="white" if normalized_cm[i, j] > thresh else "black")

    plt.ylabel('True label', fontsize=18)
    plt.xlabel('Predicted label', fontsize=18)
    plt.tight_layout()
    plt.close()
    
    return uar, accuracy


In [23]:
def zero_shot_test_FER(FER_prompt_,type_):

    datasets_ = ["RAFDB",]
    
    for dataset in datasets_:
        uar, war = zero_shot_test(dataset_=dataset, mode_task="Static_FER", FER_prompt_=FER_prompt_, prompt_type=type_)
        # print(f'************************* {dataset}')
        # print(f"UAR/WAR: {uar:.2f}/{war:.2f}")

In [24]:
class RecorderMeter(object):
    pass    

In [25]:
# prompt = ['Surprise',
# 'Fear',
# 'Disgust',
# 'Happiness',
# 'Sadness',
# 'Anger',
# 'Neutral']
# # for key in label_mapping.keys():
# #     prompt.append(key)

# prompt

emotion_words = [
    ["Surprise", "Amazement", "Astonishment", "Wonder", "Shock", "Bewilderment"],
    ["Fear", "Terror", "Anxiety", "Dread", "Panic", "Apprehension"],
    ["Disgust", "Revulsion", "Contempt", "Loathing", "Repulsion", "Aversion"],
    ["Happiness", "Joy", "Delight", "Bliss", "Contentment", "Elation"],
    ["Sadness", "Sorrow", "Grief", "Melancholy", "Despair", "Heartache"],
    ["Anger", "Rage", "Fury", "Wrath", "Annoyance", "Resentment"],
    ["Neutral", "Indifference", "Apathy", "Calmness", "Detachment", "Equanimity"]
]

In [26]:
prompt = []
for emotions in emotion_words:
    prompt += emotions

prompt_list = [prompt,]

prompt_list.append([f'an expression of {p}' for p in prompt])
prompt_list.append([f'a photo of a face with an expression of {p}' for p in prompt])

In [27]:
for p in prompt_list:
    zero_shot_test_FER({'RAFDB':p}, 'type1')

Testing Progress: 100%|██████████| 24/24 [05:18<00:00, 13.27s/batch]


Accuracy: 57.40%
F1 Score: 0.5128
Precision: 0.5274
Recall: 0.5740
UAR: 47.4929


Testing Progress: 100%|██████████| 24/24 [05:19<00:00, 13.33s/batch]


Accuracy: 61.35%
F1 Score: 0.5296
Precision: 0.4848
Recall: 0.6135
UAR: 47.7489


Testing Progress: 100%|██████████| 24/24 [05:19<00:00, 13.32s/batch]

Accuracy: 47.87%
F1 Score: 0.4588
Precision: 0.5156
Recall: 0.4787
UAR: 40.9927





In [21]:
for i, FER_prompt in enumerate(FER_prompt_list):
    print(f'************************************************************************** Zero-shot Prompt Type: ', FER_prompt_type_list[i])
    type_ = "type"+str(i+1)
    print(FER_prompt['RAFDB'][0])
    # zero_shot_test_FER(FER_prompt, type_)



************************************************************************** Zero-shot Prompt Type:  Class Name
happiness.
************************************************************************** Zero-shot Prompt Type:  An Expression of Name
an expression of happiness.
************************************************************************** Zero-shot Prompt Type:  A Photo of A Face with An Expression of Name
a photo of a face with an expression of happiness.
************************************************************************** Zero-shot Prompt Type:  Expression Ensemble Five
an expression of happiness.
************************************************************************** Zero-shot Prompt Type:  Expression Ensemble Ten
an expression of happiness.
