In [1]:
# !git clone https://github.com/zengqunzhao/Exp-CLIP.git
!pip install ftfy salesforce-lavis

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting salesforce-lavis
  Downloading salesforce_lavis-1.0.2-py3-none-any.whl.metadata (18 kB)
Collecting contexttimer (from salesforce-lavis)
  Downloading contexttimer-0.3.3.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting decord (from salesforce-lavis)
  Downloading decord-0.6.0-py3-none-manylinux2010_x86_64.whl.metadata (422 bytes)
Collecting fairscale==0.4.4 (from salesforce-lavis)
  Downloading fairscale-0.4.4.tar.gz (235 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.4/235.4 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting iopath (from salesforce-lavis)
  Downloading iopath-0.1.10.tar.gz (42 

In [2]:
import sys
sys.path.insert(0, '/kaggle/working/Exp-CLIP')

import torch
from torch import nn
from models.clip import clip
from models.BLIP2_T5 import *
import torch.nn.functional as F

import argparse
import time
import os
import random
import numpy as np
import datetime
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

import pickle
from tqdm import tqdm

  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


# Start

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=512)
parser.add_argument('--batch-size-test-image', type=int, default=512)
parser.add_argument('--batch-size-test-video', type=int, default=64)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--print-freq', type=int, default=100)
parser.add_argument('--milestones', nargs='+', type=int, default=30)
parser.add_argument('--seed', type=int, default = 1)
parser.add_argument('--job-id', type=str, default=str(int(time.time())))
parser.add_argument('--instruction', type=str, default='Please play the role of a facial action describer. Objectively describe the detailed facial actions of the person in the image.')
parser.add_argument('--load-model', type=str, default='CLIP_L14')

# Use parse_known_args() to handle extra arguments
args, unknown = parser.parse_known_args()

random.seed(args.seed)  
np.random.seed(args.seed) 
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

now = datetime.datetime.now()
train_time = now.strftime("%y-%m-%d %H:%M")
print("Training date: ", train_time)
job_id = args.job_id

print('************************')
for k, v in vars(args).items():
    print(k,'=',v)
print('************************')
# Handle unknown arguments if needed
if unknown:
    print('Ignored Arguments:', unknown)

Training date:  25-02-17 14:37
************************
workers = 8
epochs = 5
batch_size = 512
batch_size_test_image = 512
batch_size_test_video = 64
lr = 0.001
weight_decay = 0.0001
momentum = 0.9
print_freq = 100
milestones = 30
seed = 1
job_id = 1739803077
instruction = Please play the role of a facial action describer. Objectively describe the detailed facial actions of the person in the image.
load_model = CLIP_L14
************************
Ignored Arguments: ['-f', '/root/.local/share/jupyter/runtime/kernel-ba63bfcc-aaa2-4c37-b72c-0a507a6644f8.json']


# RAF-DB Dataloader

In [18]:
# Define the dataset class
class CustomImageDataset(Dataset):
    def __init__(self, txt_file, img_dir, transform=None):
        self.img_labels = []
        self.img_dir = img_dir
        self.transform = transform

        with open(txt_file, 'r') as file:
            for line in file.readlines():
                filename, label = line.strip().split()
                name, ext = filename.strip().split('.')
                filename = name + '_aligned.' + ext
                self.img_labels.append((filename, int(label)))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [19]:

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Dataset and DataLoader setup
txt_file = '/kaggle/input/rafdb-dg/EmoLabel/test_label.txt'
img_dir = '/kaggle/input/rafdb-dg/aligned/aligned'

dataset = CustomImageDataset(txt_file=txt_file, img_dir=img_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)

# CAER-S Dataloader

In [9]:
from PIL import Image, UnidentifiedImageError

prompt = ['Surprise',
'Fear',
'Disgust',
'Happy',
'Sad',
'Anger',
'Neutral']

class FolderBasedDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        """
        Args:
            root_dirs (list): List of root directories containing class folders.
            transform (callable, optional): A function/transform to apply to the images.
        """
        self.img_labels = []  # To store (image_path, class_label)
        self.transform = transform
        self.labels = []
        self.img_names = []
        self.dummy = None
        self.dummy_label = None

        # Loop through each root directory
        for root_dir in root_dirs:
            # Iterate over the class folders
            for class_folder in sorted(os.listdir(root_dir)):  # Alphabetically sorted for consistent label assignment
                self.labels.append(class_folder)
                class_path = os.path.join((os.path.join(root_dir, class_folder)), class_folder)
                if os.path.isdir(class_path):  # Ensure it's a folder
                    # Assign the folder name as the class label
                    class_label = prompt.index(class_folder)
                    # Collect all image paths in this folder
                    for img_name in os.listdir(class_path):
                        img_path = os.path.join(class_path, img_name)
                        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):  # Supported image formats
                            self.img_labels.append((img_path, class_label))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path, class_label = self.img_labels[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if not self.dummy: 
                self.dummy = image
                self.dummy_label = class_label
        except (UnidentifiedImageError, SyntaxError, TypeError):
            # print(f"Error: Unable to identify the image at {img_path}")
            image = self.dummy  # or handle it accordingly
            class_label = self.dummy_label
        if self.transform:
            image = self.transform(image)

        return image, class_label, img_path

def collate_fn(batch):
    return [b for b in batch if b is not None]  # Remove None entries


In [10]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# List of root directories for CAERS datasets
# root_dirs = [
#     '/kaggle/input/caer-s-train-1',
#     '/kaggle/input/caer-s-train-2',
#     '/kaggle/input/caer-s-train-3'
# ]
root_dirs = [
    '/kaggle/input/caer-s-test-1',
    '/kaggle/input/caer-s-test-2',
]

# Initialize dataset
dataset = FolderBasedDataset(root_dirs=root_dirs, transform=transform)

# Initialize DataLoader
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

# Inspect some samples
for images, labels, img_paths in data_loader:
    print("Batch size:", len(images))
    print("Labels:", labels)  # Print first 3 labels in the batch
    print("Image paths:", img_paths[:3])  # Print first 3 image paths in the batch
    break

Batch size: 1
Labels: tensor([0])
Image paths: ['/kaggle/input/caer-s-test-2/Surprise/Surprise/0678.png']


# CLIP Extraction

In [4]:
# Load CLIP model
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-L/14", device)


100%|███████████████████████████████████████| 890M/890M [00:12<00:00, 72.2MiB/s]


In [37]:
# Output directory
output_dir = "/kaggle/working/clip_encodings_caer_test"
os.makedirs(output_dir, exist_ok=True)

# Save a chunk of data
def save_chunk(data, chunk_index):
    file_path = os.path.join(output_dir, f"encodings_chunk_{chunk_index}.pkl")
    with open(file_path, "wb") as f:
        pickle.dump(data, f)

# Save label mapping
def save_label_mapping(label_mapping):
    mapping_path = os.path.join(output_dir, "label_mapping.pkl")
    with open(mapping_path, "wb") as f:
        pickle.dump(label_mapping, f)

# Generate a label mapping
all_labels = sorted(set(dataset.labels))  # Assuming `dataset.labels` contains all possible labels
label_mapping = {label: idx for idx, label in enumerate(all_labels)}

# Save the label mapping
save_label_mapping(label_mapping)

# Main loop to process data
chunk_size = 20000  # Number of entries per pickle file
chunk_data = {}     # Dictionary to store data indexed by filename
chunk_index = 0

# Check for progress file
progress_file = os.path.join(output_dir, "progress.txt")
start_index = 0
if os.path.exists(progress_file):
    with open(progress_file, "r") as f:
        start_index = int(f.read().strip())

print("Starting from batch index:", start_index)

Starting from batch index: 0


In [41]:
clip_model.eval()
with torch.no_grad():
    for batch_idx, (images, labels, img_names) in enumerate(tqdm(data_loader, desc="Processing batches")):
        if batch_idx < start_index:
            continue

        images = images.to(device)

        # Generate image encodings
        features = clip_model.encode_image(images)

        # Assume logit scale is part of the model (adjust if needed)
        logit_scale = clip_model.logit_scale.exp().item()

        for i, feature in enumerate(features):
            img_name = img_names[i]
            label = labels[i]  # String label, e.g., "Angry", "Sad"
            label_idx = label_mapping[label]  # Map to integer index

            # Add data to the dictionary, indexed by filename
            chunk_data[img_name] = {
                "encoding": feature.cpu().numpy(),
                "logit_scale": logit_scale,
                "label": label,
                "label_idx": label_idx,
            }

            # Save the chunk when the buffer is full
            if len(chunk_data) >= chunk_size:
                save_chunk(chunk_data, chunk_index)
                chunk_data = {}
                chunk_index += 1

        # Save progress to resume in case of interruption
        with open(progress_file, "w") as f:
            f.write(str(batch_idx))

# Save any remaining data
if chunk_data:
    save_chunk(chunk_data, chunk_index)

# Clean up progress file
if os.path.exists(progress_file):
    os.remove(progress_file)

print("Encoding complete and saved.")

Processing batches:   2%|▏         | 54/2624 [00:12<09:35,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0501.png


Processing batches:   4%|▎         | 96/2624 [00:21<09:27,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2741.png


Processing batches:   7%|▋         | 191/2624 [00:43<09:07,  4.44it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0273.png


Processing batches:  11%|█         | 290/2624 [01:05<08:45,  4.44it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/1654.png


Processing batches:  13%|█▎        | 344/2624 [01:17<08:34,  4.43it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0515.png


Processing batches:  15%|█▌        | 400/2624 [01:30<08:18,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/2836.png


Processing batches:  17%|█▋        | 445/2624 [01:40<08:07,  4.47it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2464.png


Processing batches:  18%|█▊        | 471/2624 [01:46<08:03,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2302.png


Processing batches:  18%|█▊        | 484/2624 [01:49<08:00,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/1132.png


Processing batches:  30%|███       | 789/2624 [02:57<06:52,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2470.png


Processing batches:  34%|███▎      | 885/2624 [03:19<06:29,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0630.png


Processing batches:  34%|███▍      | 891/2624 [03:20<06:28,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1838.png


Processing batches:  36%|███▌      | 941/2624 [03:31<06:17,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/0947.png


Processing batches:  38%|███▊      | 1010/2624 [03:47<06:01,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0142.png


Processing batches:  39%|███▉      | 1017/2624 [03:48<06:00,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/1495.png


Processing batches:  45%|████▌     | 1181/2624 [04:25<05:23,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0267.png


Processing batches:  48%|████▊     | 1250/2624 [04:41<05:08,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1512.png


Processing batches:  48%|████▊     | 1255/2624 [04:42<05:07,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1248.png


Processing batches:  49%|████▊     | 1276/2624 [04:46<05:02,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0618.png


Processing batches:  49%|████▉     | 1291/2624 [04:50<04:58,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2027.png


Processing batches:  50%|████▉     | 1301/2624 [04:52<04:57,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0181.png


Processing batches:  54%|█████▍    | 1425/2624 [05:20<04:29,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1804.png


Processing batches:  54%|█████▍    | 1429/2624 [05:21<04:27,  4.47it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1810.png


Processing batches:  56%|█████▌    | 1462/2624 [06:08<04:22,  4.42it/s]  

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/2188.png


Processing batches:  56%|█████▌    | 1465/2624 [06:09<04:21,  4.44it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0529.png


Processing batches:  56%|█████▌    | 1475/2624 [06:11<04:17,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1274.png


Processing batches:  58%|█████▊    | 1534/2624 [06:25<04:04,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1145.png


Processing batches:  59%|█████▉    | 1561/2624 [06:31<03:58,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1506.png


Processing batches:  60%|█████▉    | 1570/2624 [06:33<03:56,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1192.png


Processing batches:  61%|██████    | 1607/2624 [06:41<03:49,  4.42it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2796.png


Processing batches:  62%|██████▏   | 1630/2624 [06:46<03:43,  4.44it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1179.png


Processing batches:  70%|███████   | 1840/2624 [07:33<02:56,  4.44it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0195.png


Processing batches:  70%|███████   | 1846/2624 [07:35<02:54,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2316.png


Processing batches:  70%|███████   | 1849/2624 [07:35<02:53,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/1481.png


Processing batches:  71%|███████   | 1853/2624 [07:36<02:52,  4.47it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1623.png


Processing batches:  75%|███████▌  | 1974/2624 [08:03<02:26,  4.44it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/0006.png


Processing batches:  81%|████████▏ | 2136/2624 [08:40<01:49,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/2822.png


Processing batches:  82%|████████▏ | 2155/2624 [08:44<01:45,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0298.png


Processing batches:  84%|████████▍ | 2212/2624 [08:57<01:32,  4.45it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0803.png


Processing batches:  86%|████████▋ | 2264/2624 [09:09<01:20,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1260.png


Processing batches:  87%|████████▋ | 2291/2624 [09:15<01:14,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0156.png


Processing batches:  88%|████████▊ | 2316/2624 [09:20<01:09,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2782.png


Processing batches:  93%|█████████▎| 2449/2624 [09:50<00:39,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2769.png


Processing batches:  94%|█████████▍| 2474/2624 [09:56<00:33,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1186.png


Processing batches:  98%|█████████▊| 2564/2624 [10:16<00:13,  4.46it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/0953.png


Processing batches: 100%|██████████| 2624/2624 [10:30<00:00,  4.16it/s]

Encoding complete and saved.





In [42]:
output_dir = "/kaggle/working/clip_encodings_caer_test"

# Function to load all chunks of encodings
def load_encodings(output_dir):
    encodings = {}
    for file_name in os.listdir(output_dir):
        if file_name.startswith("encodings_chunk_") and file_name.endswith(".pkl"):
            file_path = os.path.join(output_dir, file_name)
            with open(file_path, "rb") as f:
                chunk = pickle.load(f)
                encodings.update(chunk)
    return encodings

# Function to load the label mapping
def load_label_mapping(output_dir):
    mapping_path = os.path.join(output_dir, "label_mapping.pkl")
    with open(mapping_path, "rb") as f:
        label_mapping = pickle.load(f)
    return label_mapping

# Load encodings and label mapping
encodings = load_encodings(output_dir)
label_mapping = load_label_mapping(output_dir)

# Print basic info about the loaded data
print(f"Total images encoded: {len(encodings)}")
print(f"Label mapping: {label_mapping}")

Total images encoded: 20992
Label mapping: {'Anger': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Neutral': 4, 'Sad': 5, 'Surprise': 6}


In [1]:
import pickle
import os

def load_extracted_features(output_dir):
    features_dict = {}
    
    # Get all chunk files in the directory
    chunk_files = sorted([f for f in os.listdir(output_dir) if f.startswith("encodings_chunk_") and f.endswith(".pkl")])
    # print(chunk_files)
    if not chunk_files:
        print("No saved feature chunks found.")
        return features_dict  # Return empty if no data is found

    print(f"Loading {len(chunk_files)} chunks...")

    # Load each chunk and merge into the main dictionary
    for chunk_file in chunk_files:
        file_path = os.path.join(output_dir, chunk_file)
        with open(file_path, "rb") as f:
            chunk_data = pickle.load(f)
            print(len(chunk_data))
            features_dict.update(chunk_data)  # Merge chunk data into the main dictionary
    
    print(f"Loaded {len(features_dict)} image features from {len(chunk_files)} chunks.")
    
    return features_dict


In [2]:
features = load_extracted_features('/kaggle/working/clip_encodings_caer')
len(features)

Loading 3 chunks...
20000
20000
9007
Loaded 49007 image features from 3 chunks.


49007

# BLIP_2 Extraction

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

def get_blip2t5_model(device):
    model, vis_processor, txt_processor = load_model_and_preprocess(
        name="blip2_t5",
        model_type="pretrain_flant5xl",
        is_eval=True,
        device=device,
    )
    model.generate = MethodType(generate, model)
    return model, vis_processor, txt_processor

blip_instruction = args.instruction

blip2, vis_processor, txt_processor = get_blip2t5_model(device)



vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

100%|██████████| 1.89G/1.89G [00:08<00:00, 226MB/s]
  state_dict = torch.load(cached_file, map_location="cpu")


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.45G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

In [47]:
# Save a chunk of data indexed by image name
def save_chunk(data, chunk_index, output_dir):
    file_path = os.path.join(output_dir, f"encodings_chunk_{chunk_index}.pkl")
    with open(file_path, "wb") as f:
        pickle.dump(data, f)

# Function to save text features in chunks, indexed by image name
def save_text_features_in_chunks(dataloader, blip2, tokens, output_dir, chunk_size=1000, batch_size=8):
    os.makedirs(output_dir, exist_ok=True)

    chunk_data = {}
    chunk_index = 0

    # Check for progress file
    progress_file = os.path.join(output_dir, "progress.txt")
    start_index = 0
    if os.path.exists(progress_file):
        with open(progress_file, "r") as f:
            start_index = int(f.read().strip())

    print("Starting from batch index:", start_index)

    for batch_idx, (images, labels, img_names) in enumerate(tqdm(dataloader, desc="Processing batches")):
        if batch_idx < start_index:
            continue

        images = images.to(device)

        # Regenerate tokens dynamically if this is the last batch and its size is smaller
        if len(images) < batch_size:
            tokens = get_tokens(blip_instruction, blip2.t5_tokenizer, len(images), device)

        # Generate text features
        _, text_features = blip2.generate({"image": images, "tokens": tokens})
        text_features = torch.mean(text_features, dim=1)
        text_features = text_features.float()
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        for i, feature in enumerate(text_features):
            img_name = img_names[i]
            label = labels[i]
            chunk_data[img_name] = {
                "encoding": feature.cpu().numpy(),
                "label": label,
            }

            if len(chunk_data) >= chunk_size:
                save_chunk(chunk_data, chunk_index, output_dir)
                chunk_data = {}
                chunk_index += 1

        # Save progress
        with open(progress_file, "w") as f:
            f.write(str(batch_idx))

    # Save any remaining data
    if chunk_data:
        save_chunk(chunk_data, chunk_index, output_dir)

    # Clean up progress file
    if os.path.exists(progress_file):
        os.remove(progress_file)

    print("Encoding complete and saved.")

In [48]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
batch_size = 8

tokens = get_tokens(blip_instruction, blip2.t5_tokenizer, batch_size, device)

# Define transformations
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])

# # Dataset and DataLoader setup
# txt_file = '/kaggle/input/rafdb-dg/EmoLabel/train_label.txt'
# img_dir = '/kaggle/input/rafdb-dg/aligned/aligned'

# dataset = CustomImageDataset(txt_file=txt_file, img_dir=img_dir, transform=transform)
# data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

output_dir = "/kaggle/working/blip2_encodings_caer_test"
chunk_size = 20000  # Number of encodings per pickle file

save_text_features_in_chunks(data_loader, blip2, tokens, output_dir, chunk_size=chunk_size, batch_size=batch_size)

Starting from batch index: 0


  return torch.cuda.amp.autocast(dtype=dtype)
Processing batches:   1%|          | 15/2624 [00:15<40:22,  1.08it/s] 

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/1495.png


Processing batches:   2%|▏         | 61/2624 [00:57<39:43,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1512.png


Processing batches:   3%|▎         | 75/2624 [01:10<39:34,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/2836.png


Processing batches:  13%|█▎        | 334/2624 [05:12<35:32,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2464.png


Processing batches:  16%|█▋        | 431/2624 [06:42<34:01,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1804.png


Processing batches:  17%|█▋        | 441/2624 [06:51<33:51,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1506.png


Processing batches:  23%|██▎       | 601/2624 [09:20<31:23,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/2188.png


Processing batches:  23%|██▎       | 609/2624 [09:28<31:17,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0298.png


Processing batches:  25%|██▌       | 662/2624 [10:17<30:26,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0195.png


Processing batches:  27%|██▋       | 704/2624 [10:56<29:45,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/0006.png


Processing batches:  30%|██▉       | 778/2624 [12:05<28:40,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/1654.png


Processing batches:  34%|███▍      | 891/2624 [13:50<26:53,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0181.png


Processing batches:  37%|███▋      | 973/2624 [15:06<25:37,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1192.png


Processing batches:  38%|███▊      | 1006/2624 [15:37<25:06,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2769.png


Processing batches:  47%|████▋     | 1235/2624 [19:10<21:33,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0273.png


Processing batches:  47%|████▋     | 1244/2624 [19:19<21:26,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2027.png


Processing batches:  48%|████▊     | 1248/2624 [19:22<21:20,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/0953.png


Processing batches:  50%|████▉     | 1299/2624 [20:10<20:34,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0142.png


Processing batches:  55%|█████▍    | 1432/2624 [22:14<18:30,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0515.png


Processing batches:  55%|█████▌    | 1454/2624 [22:34<18:09,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0156.png


Processing batches:  57%|█████▋    | 1484/2624 [23:02<17:41,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0529.png


Processing batches:  63%|██████▎   | 1659/2624 [25:45<14:58,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1179.png


Processing batches:  64%|██████▍   | 1681/2624 [26:05<14:36,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/1481.png


Processing batches:  65%|██████▍   | 1701/2624 [26:24<14:18,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2470.png


Processing batches:  66%|██████▌   | 1728/2624 [26:49<13:53,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0630.png


Processing batches:  67%|██████▋   | 1749/2624 [27:09<13:33,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2796.png


Processing batches:  67%|██████▋   | 1770/2624 [27:28<13:13,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1274.png


Processing batches:  68%|██████▊   | 1784/2624 [27:41<13:00,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2782.png


Processing batches:  69%|██████▉   | 1820/2624 [28:15<12:27,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1623.png


Processing batches:  70%|███████   | 1839/2624 [28:32<12:10,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/0947.png


Processing batches:  73%|███████▎  | 1926/2624 [29:53<10:49,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1145.png


Processing batches:  73%|███████▎  | 1928/2624 [29:55<10:47,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1838.png


Processing batches:  79%|███████▉  | 2080/2624 [32:17<08:26,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/0618.png


Processing batches:  80%|███████▉  | 2092/2624 [32:28<08:14,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1186.png


Processing batches:  80%|████████  | 2107/2624 [32:42<08:00,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2316.png


Processing batches:  84%|████████▎ | 2197/2624 [34:05<06:37,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0501.png


Processing batches:  88%|████████▊ | 2311/2624 [35:51<04:51,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1248.png


Processing batches:  88%|████████▊ | 2316/2624 [35:56<04:46,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-1/Fear/Fear/2822.png


Processing batches:  89%|████████▉ | 2341/2624 [36:19<04:23,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/1260.png


Processing batches:  89%|████████▉ | 2345/2624 [36:23<04:19,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/2302.png


Processing batches:  92%|█████████▏| 2407/2624 [37:21<03:21,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/1810.png


Processing batches:  92%|█████████▏| 2423/2624 [37:36<03:06,  1.08it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Happy/Happy/0267.png


Processing batches:  95%|█████████▍| 2488/2624 [38:36<02:06,  1.07it/s]

Error: Unable to identify the image at /kaggle/input/caer-s-test-2/Surprise/Surprise/2741.png


Processing batches: 100%|██████████| 2624/2624 [40:43<00:00,  1.07it/s]

Encoding complete and saved.





In [45]:
label_map = {'Surprise': 1,
'Fear' : 2,
'Disgust': 3,
'Happiness': 4,
'Sadness': 5,
'Anger': 6,
'Neutral': 7}

correct = 0
total = 0
missed = 0
miss_list = [] 

for img, label in data_loader:
    try:
        results = libreface.get_facial_attributes(img[0])
        if results:
            exp = label_map[results['facial_expression']]
            total += 1
            if exp == label.item():
                correct += 1
    except TypeError:
        missed += 1
        miss_list.append(img)

Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Processing landmarks did not result on anything...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Using device: cpu for inference...
Usin

AttributeError: 'numpy.ndarray' object has no attribute 'append'

In [46]:
print(f'Total processed: {total}, Accuracy = {correct / total}, missed = {missed}, corrected total = {correct/(total + missed)}')

Total processed: 1995, Accuracy = 0.6807017543859649, missed = 243, corrected total = 0.6067917783735478


# Consolidating Data

In [40]:
class ProcessedDataset(Dataset):
    def __init__(self, clip_encodings, blip_encodings):
        self.encodings = []
        self.labels = []
        
        for key in clip_encodings.keys():
            if blip_encodings[key] is None:
                blip_enc = np.zeros(2048)
            else:
                blip_enc = blip_encodings[key]#['encoding']
            self.encodings.append((
                clip_encodings[key]['encoding'],
                blip_enc,
                clip_encodings[key]['logit_scale']
            ))
            self.labels.append(clip_encodings[key]['label'])
            
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_features, text_features, logit_scale = self.encodings[idx]
        label = self.labels[idx]
        return image_features, text_features, logit_scale, label


In [43]:
# Function to load all chunks of encodings
def load_clip_encodings(output_dir):
    encodings = {}
    for file_name in os.listdir(output_dir):
        if file_name.startswith("encodings_chunk_") and file_name.endswith(".pkl"):
            file_path = os.path.join(output_dir, file_name)
            with open(file_path, "rb") as f:
                chunk = pickle.load(f)
                encodings.update(chunk)
    return encodings

def load_blip_encodings(output_dir):
    encodings = {}
    for file_name in sorted(os.listdir(output_dir)):
        if file_name.startswith("encodings_chunk_") and file_name.endswith(".pkl"):
            print(f'processing {file_name}')
            file_path = os.path.join(output_dir, file_name)
            with open(file_path, "rb") as f:
                chunk_data = pickle.load(f)
                encodings.update(chunk_data)
    return encodings

# Function to load the label mapping
def load_label_mapping(output_dir):
    mapping_path = os.path.join(output_dir, "label_mapping.pkl")
    with open(mapping_path, "rb") as f:
        label_mapping = pickle.load(f)
    return label_mapping

clip_encodings = load_clip_encodings('/kaggle/working/clip_encodings_caer')
# blip_encodings = load_blip_encodings('/kaggle/working/blip2_encodings_caer')
blip_pickle = '/kaggle/input/caer-extracted-features-clip-blip/blip2_AU_CAER_train.pkl'
with open(blip_pickle, 'rb') as f:
    blip_encodings = pickle.load(f)

label_mapping = load_label_mapping('/kaggle/working/clip_encodings_caer')

# Print basic info about the loaded data
print(f"Total images encoded: {len(blip_encodings)}")
print(f"Label mapping: {label_mapping}")

Total images encoded: 49007
Label mapping: {'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Neutral': 4, 'Sad': 5, 'Surprise': 6}


In [44]:
dataset = ProcessedDataset(clip_encodings, blip_encodings)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [45]:
class ExpCLIP_Train(nn.Module):
    def __init__(self, args):
        super().__init__()

        device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    
        self.projection_head = Linear_Matrix_L14()

    def forward(self, image_features):
        image_features = image_features.float()
        image_features = self.projection_head(image_features)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        
        return image_features

class Linear_Matrix_L14(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(768, 2048, bias=False)
    def forward(self, x):
        return self.mlp(x)


In [46]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix="", log_txt_path=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix
        self.log_txt_path = log_txt_path

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print_txt = '\t'.join(entries)
        print(print_txt)
        with open(self.log_txt_path, 'a') as f:
            f.write(print_txt + '\n')

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class RecorderMeter(object):
    """Computes and stores the minimum loss value and its epoch index"""
    def __init__(self, total_epoch):
        self.reset(total_epoch)

    def reset(self, total_epoch):
        self.total_epoch = total_epoch
        self.current_epoch = 0
        self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32)    # [epoch, train/val]
        self.epoch_accuracy = np.zeros((self.total_epoch, 2), dtype=np.float32)  # [epoch, train/val]

    def update(self, idx, train_loss, train_acc):
        self.epoch_losses[idx, 0] = train_loss * 50
        self.epoch_accuracy[idx, 0] = train_acc
        self.current_epoch = idx + 1

    def plot_curve(self, save_path):

        title = 'the accuracy/loss curve of train/val'
        dpi = 80
        width, height = 1600, 800
        legend_fontsize = 10
        figsize = width / float(dpi), height / float(dpi)

        fig = plt.figure(figsize=figsize)
        x_axis = np.array([i for i in range(self.total_epoch)])  # epochs
        y_axis = np.zeros(self.total_epoch)

        plt.xlim(0, self.total_epoch)
        plt.ylim(0, 100)
        interval_y = 5
        interval_x = 1
        plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
        plt.yticks(np.arange(0, 100 + interval_y, interval_y))
        plt.grid()
        plt.title(title, fontsize=20)
        plt.xlabel('the training epoch', fontsize=16)
        plt.ylabel('accuracy', fontsize=16)

        y_axis[:] = self.epoch_accuracy[:, 0]
        plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_accuracy[:, 1]
        plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 0]
        plt.plot(x_axis, y_axis, color='g', linestyle=':', label='train-loss-x50', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        y_axis[:] = self.epoch_losses[:, 1]
        plt.plot(x_axis, y_axis, color='y', linestyle=':', label='valid-loss-x50', lw=2)
        plt.legend(loc=4, fontsize=legend_fontsize)

        if save_path is not None:
            fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
        plt.close(fig)

In [47]:
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def train(train_loader, model, criterion, optimizer, epoch, args, log_txt_path):
    losses = AverageMeter('Loss', ':.4f')
    top1 = AverageMeter('Accuracy', ':6.3f')
    progress = ProgressMeter(len(train_loader),
                             [losses, top1],
                             prefix="Epoch: [{}]".format(epoch),
                             log_txt_path=log_txt_path)

    # switch to train mode
    model.train()

    # Use tqdm to show progress
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch}", unit="batch") as pbar:
        for i, (image_features, text_features, logit_scale, labels) in enumerate(train_loader):
            image_features = image_features.to(device)
            text_features = text_features.to(device)
            logit_scale = logit_scale[0]
            logit_scale = logit_scale.to(device)
            
            n, _ = image_features.shape
            target = torch.arange(n)
            target = target.to(device)

            # compute output
            image_features = model(image_features)
            
            logits_per_image = logit_scale * image_features @ text_features.T
            logits_per_text = logit_scale * text_features @ image_features.T

            loss_vision = criterion(logits_per_image, target)
            loss_text = criterion(logits_per_text, target)
            
            loss = 0.5 * loss_vision + 0.5 * loss_text
            
            # measure accuracy and record loss
            acc1, _ = accuracy(logits_per_image, target, topk=(1, 3))
            losses.update(loss.item(), image_features.size(0))
            top1.update(acc1[0], image_features.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update tqdm progress bar
            pbar.set_postfix({"Loss": losses.avg, "Accuracy": top1.avg.item()})
            pbar.update(1)
            
            # Optionally display progress at intervals
            # if i % args.print_freq == 0:
            #     progress.display(i)
            
    return top1.avg, losses.avg


In [49]:
log_txt_path = './log/' + '-log.txt'
log_curve_path = './log/' + '-log.png'
checkpoint_path = './checkpoint/'

os.makedirs('log', exist_ok = True)
os.makedirs('checkpoint', exist_ok = True)
epochs = 50
recorder = RecorderMeter(epochs)

In [50]:
import torch.backends.cudnn as cudnn

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

model = ExpCLIP_Train(args)
# define loss function (criterion)
criterion = nn.CrossEntropyLoss().cuda()

# only open learnable part
for name, param in model.named_parameters():
    param.requires_grad = False
for name, param in model.named_parameters():
    if "projection_head" in name:
        param.requires_grad = True 

pretrain_model_path = '/kaggle/working/checkpoint/1738165527-model_20.pth'

state_dict = model.state_dict()
pre_train_model = torch.load(pretrain_model_path, map_location=torch.device('cpu') )
# pre_train_model = torch.load(pretrain_model_path)
# for name, param in pre_train_model.items():
#     if "mlp.weight" in name:
#         state_dict["projection_head.mlp.weight"].copy_(param)

model = model.to(device)
        
# define optimizer
optimizer = torch.optim.SGD([{"params": model.projection_head.parameters(), "lr": args.lr}],
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

# define scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.milestones], gamma=0.1)
cudnn.benchmark = True

  pre_train_model = torch.load(pretrain_model_path, map_location=torch.device('cpu') )


In [51]:
# e_count = 5
e_count = 50

In [52]:
for epoch in range(0, epochs):
    inf = '********************' + str(epoch) + '********************'
    start_time = time.time()
    current_learning_rate = optimizer.state_dict()['param_groups'][0]['lr']
    with open(log_txt_path, 'a') as f:
        f.write(inf + '\n')
        print(inf)
        f.write('Current learning rate: ' + str(current_learning_rate) + '\n')      
        
    # train for one epoch
    train_acc, train_los = train(dataloader, model, criterion, optimizer, epoch, args, log_txt_path)
    scheduler.step()

    # print and save log
    epoch_time = time.time() - start_time
    recorder.update(epoch, train_los, train_acc)
    recorder.plot_curve(log_curve_path)
    print('The train accuracy: {:.3f}'.format(train_acc.item()))
    print('An epoch time: {:.2f}s'.format(epoch_time))
    with open(log_txt_path, 'a') as f:
        f.write('The best accuracy: ' + str(train_acc.item()) + '\n')
        f.write('An epoch time: ' + str(epoch_time) + 's' + '\n')

#  save model and conduct zero-shot prediction
e_count += 5
checkpoint_name = checkpoint_path + f'-model_{e_count}.pth'
torch.save(model.projection_head.state_dict(), checkpoint_name)

********************0********************


Epoch 0: 100%|██████████| 96/96 [00:09<00:00,  9.76batch/s, Loss=6.25, Accuracy=0.488]


The train accuracy: 0.488
An epoch time: 9.85s
********************1********************


Epoch 1: 100%|██████████| 96/96 [00:12<00:00,  7.80batch/s, Loss=6.06, Accuracy=1.52]


The train accuracy: 1.518
An epoch time: 12.32s
********************2********************


Epoch 2: 100%|██████████| 96/96 [00:09<00:00,  9.87batch/s, Loss=5.9, Accuracy=2.82] 


The train accuracy: 2.820
An epoch time: 9.73s
********************3********************


Epoch 3: 100%|██████████| 96/96 [00:09<00:00,  9.70batch/s, Loss=5.69, Accuracy=4.47]


The train accuracy: 4.473
An epoch time: 9.90s
********************4********************


Epoch 4: 100%|██████████| 96/96 [00:10<00:00,  9.42batch/s, Loss=5.43, Accuracy=6.45]


The train accuracy: 6.450
An epoch time: 10.20s
********************5********************


Epoch 5: 100%|██████████| 96/96 [00:09<00:00,  9.78batch/s, Loss=5.25, Accuracy=8.75]


The train accuracy: 8.752
An epoch time: 9.83s
********************6********************


Epoch 6: 100%|██████████| 96/96 [00:09<00:00,  9.70batch/s, Loss=5.12, Accuracy=10.8]


The train accuracy: 10.760
An epoch time: 9.91s
********************7********************


Epoch 7: 100%|██████████| 96/96 [00:10<00:00,  9.11batch/s, Loss=5.01, Accuracy=12.1]


The train accuracy: 12.094
An epoch time: 10.54s
********************8********************


Epoch 8: 100%|██████████| 96/96 [00:10<00:00,  9.21batch/s, Loss=4.92, Accuracy=14.1]


The train accuracy: 14.108
An epoch time: 10.43s
********************9********************


Epoch 9: 100%|██████████| 96/96 [00:09<00:00,  9.64batch/s, Loss=4.84, Accuracy=15.4]


The train accuracy: 15.365
An epoch time: 9.96s
********************10********************


Epoch 10: 100%|██████████| 96/96 [00:10<00:00,  9.27batch/s, Loss=4.76, Accuracy=16.7]


The train accuracy: 16.726
An epoch time: 10.37s
********************11********************


Epoch 11: 100%|██████████| 96/96 [00:09<00:00,  9.82batch/s, Loss=4.69, Accuracy=18.2]


The train accuracy: 18.214
An epoch time: 9.79s
********************12********************


Epoch 12: 100%|██████████| 96/96 [00:09<00:00,  9.98batch/s, Loss=4.63, Accuracy=19.3]


The train accuracy: 19.340
An epoch time: 9.63s
********************13********************


Epoch 13: 100%|██████████| 96/96 [00:10<00:00,  9.30batch/s, Loss=4.58, Accuracy=20.6]


The train accuracy: 20.550
An epoch time: 10.33s
********************14********************


Epoch 14: 100%|██████████| 96/96 [00:09<00:00,  9.83batch/s, Loss=4.53, Accuracy=21.8]


The train accuracy: 21.793
An epoch time: 9.78s
********************15********************


Epoch 15: 100%|██████████| 96/96 [00:09<00:00,  9.63batch/s, Loss=4.48, Accuracy=22.9]


The train accuracy: 22.938
An epoch time: 9.98s
********************16********************


Epoch 16: 100%|██████████| 96/96 [00:10<00:00,  9.21batch/s, Loss=4.44, Accuracy=24]  


The train accuracy: 23.956
An epoch time: 10.43s
********************17********************


Epoch 17: 100%|██████████| 96/96 [00:09<00:00,  9.72batch/s, Loss=4.4, Accuracy=25]   


The train accuracy: 24.966
An epoch time: 9.89s
********************18********************


Epoch 18: 100%|██████████| 96/96 [00:13<00:00,  7.07batch/s, Loss=4.36, Accuracy=25.9]


The train accuracy: 25.878
An epoch time: 13.58s
********************19********************


Epoch 19: 100%|██████████| 96/96 [00:11<00:00,  8.32batch/s, Loss=4.33, Accuracy=26.9]


The train accuracy: 26.890
An epoch time: 11.55s
********************20********************


Epoch 20: 100%|██████████| 96/96 [00:15<00:00,  6.32batch/s, Loss=4.3, Accuracy=27.9] 


The train accuracy: 27.941
An epoch time: 15.20s
********************21********************


Epoch 21: 100%|██████████| 96/96 [00:10<00:00,  8.76batch/s, Loss=4.27, Accuracy=28.7]


The train accuracy: 28.727
An epoch time: 10.96s
********************22********************


Epoch 22: 100%|██████████| 96/96 [00:09<00:00,  9.71batch/s, Loss=4.24, Accuracy=29.8]


The train accuracy: 29.843
An epoch time: 9.90s
********************23********************


Epoch 23: 100%|██████████| 96/96 [00:10<00:00,  9.06batch/s, Loss=4.21, Accuracy=30.5]


The train accuracy: 30.504
An epoch time: 10.60s
********************24********************


Epoch 24: 100%|██████████| 96/96 [00:10<00:00,  9.30batch/s, Loss=4.18, Accuracy=31.3]


The train accuracy: 31.289
An epoch time: 10.32s
********************25********************


Epoch 25: 100%|██████████| 96/96 [00:10<00:00,  9.23batch/s, Loss=4.16, Accuracy=32.1]


The train accuracy: 32.091
An epoch time: 10.41s
********************26********************


Epoch 26: 100%|██████████| 96/96 [00:10<00:00,  8.94batch/s, Loss=4.14, Accuracy=33]  


The train accuracy: 33.036
An epoch time: 10.74s
********************27********************


Epoch 27: 100%|██████████| 96/96 [00:09<00:00,  9.73batch/s, Loss=4.12, Accuracy=33.9]


The train accuracy: 33.889
An epoch time: 9.88s
********************28********************


Epoch 28: 100%|██████████| 96/96 [00:10<00:00,  9.55batch/s, Loss=4.1, Accuracy=34.5] 


The train accuracy: 34.483
An epoch time: 10.06s
********************29********************


Epoch 29: 100%|██████████| 96/96 [00:10<00:00,  9.10batch/s, Loss=4.07, Accuracy=35.4]


The train accuracy: 35.360
An epoch time: 10.56s
********************30********************


Epoch 30: 100%|██████████| 96/96 [00:11<00:00,  8.45batch/s, Loss=4.06, Accuracy=35.3]


The train accuracy: 35.334
An epoch time: 11.36s
********************31********************


Epoch 31: 100%|██████████| 96/96 [00:09<00:00,  9.82batch/s, Loss=4.06, Accuracy=35.4]


The train accuracy: 35.430
An epoch time: 9.78s
********************32********************


Epoch 32: 100%|██████████| 96/96 [00:10<00:00,  9.14batch/s, Loss=4.06, Accuracy=35.5]


The train accuracy: 35.468
An epoch time: 10.51s
********************33********************


Epoch 33: 100%|██████████| 96/96 [00:09<00:00,  9.89batch/s, Loss=4.06, Accuracy=35.6]


The train accuracy: 35.619
An epoch time: 9.71s
********************34********************


Epoch 34: 100%|██████████| 96/96 [00:09<00:00,  9.95batch/s, Loss=4.06, Accuracy=35.9]


The train accuracy: 35.950
An epoch time: 9.65s
********************35********************


Epoch 35: 100%|██████████| 96/96 [00:10<00:00,  9.25batch/s, Loss=4.05, Accuracy=35.9]


The train accuracy: 35.903
An epoch time: 10.39s
********************36********************


Epoch 36: 100%|██████████| 96/96 [00:09<00:00,  9.72batch/s, Loss=4.05, Accuracy=35.9]


The train accuracy: 35.856
An epoch time: 9.88s
********************37********************


Epoch 37: 100%|██████████| 96/96 [00:09<00:00,  9.81batch/s, Loss=4.05, Accuracy=36]  


The train accuracy: 35.958
An epoch time: 9.79s
********************38********************


Epoch 38: 100%|██████████| 96/96 [00:10<00:00,  9.25batch/s, Loss=4.05, Accuracy=35.8]


The train accuracy: 35.764
An epoch time: 10.39s
********************39********************


Epoch 39: 100%|██████████| 96/96 [00:09<00:00,  9.82batch/s, Loss=4.05, Accuracy=36.3]


The train accuracy: 36.256
An epoch time: 9.78s
********************40********************


Epoch 40: 100%|██████████| 96/96 [00:09<00:00,  9.77batch/s, Loss=4.05, Accuracy=36.1]


The train accuracy: 36.144
An epoch time: 9.84s
********************41********************


Epoch 41: 100%|██████████| 96/96 [00:10<00:00,  9.20batch/s, Loss=4.04, Accuracy=36.3]


The train accuracy: 36.260
An epoch time: 10.44s
********************42********************


Epoch 42: 100%|██████████| 96/96 [00:09<00:00,  9.77batch/s, Loss=4.04, Accuracy=36.4]


The train accuracy: 36.389
An epoch time: 9.83s
********************43********************


Epoch 43: 100%|██████████| 96/96 [00:09<00:00,  9.79batch/s, Loss=4.04, Accuracy=36.4]


The train accuracy: 36.356
An epoch time: 9.82s
********************44********************


Epoch 44: 100%|██████████| 96/96 [00:10<00:00,  9.17batch/s, Loss=4.04, Accuracy=36.3]


The train accuracy: 36.350
An epoch time: 10.48s
********************45********************


Epoch 45: 100%|██████████| 96/96 [00:09<00:00,  9.92batch/s, Loss=4.04, Accuracy=36.7]


The train accuracy: 36.711
An epoch time: 9.68s
********************46********************


Epoch 46: 100%|██████████| 96/96 [00:09<00:00,  9.90batch/s, Loss=4.03, Accuracy=36.5]


The train accuracy: 36.474
An epoch time: 9.70s
********************47********************


Epoch 47: 100%|██████████| 96/96 [00:10<00:00,  9.20batch/s, Loss=4.03, Accuracy=36.6]


The train accuracy: 36.629
An epoch time: 10.44s
********************48********************


Epoch 48: 100%|██████████| 96/96 [00:09<00:00,  9.78batch/s, Loss=4.03, Accuracy=36.5]


The train accuracy: 36.481
An epoch time: 9.83s
********************49********************


Epoch 49: 100%|██████████| 96/96 [00:09<00:00,  9.90batch/s, Loss=4.03, Accuracy=36.6]


The train accuracy: 36.576
An epoch time: 9.70s


# Testing

In [11]:
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import matplotlib
matplotlib.use('Agg')
import numpy as np
from data_loader.video_dataloader import test_data_loader
from sklearn.metrics import confusion_matrix
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from models.Text import *
from models.Exp_CLIP import ExpCLIP_Test
import argparse
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import itertools
from collections import Counter

In [12]:
import time

parser = argparse.ArgumentParser()
parser.add_argument('--load-model', type=str, default='CLIP_L14')
parser.add_argument('--job-id', type=str, default=str(int(time.time())))  # Default job ID is a timestamp

# Use parse_known_args() to handle extra arguments
args, unknown = parser.parse_known_args()

print('************************')
for k, v in vars(args).items():
    print(k, '=', v)
print('************************')

************************
load_model = CLIP_L14
job_id = 1738225677
************************


In [27]:
class ExpCLIP_Test(nn.Module):
    def __init__(self, args):
        super().__init__()

        device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    
        if args.load_model == 'CLIP_B32':
            self.clip_model, _ = clip.load("ViT-B/32", device)
        elif args.load_model == 'CLIP_B16':
            self.clip_model, _ = clip.load("ViT-B/16", device)
        elif args.load_model == 'CLIP_L14':
            self.clip_model, _ = clip.load("ViT-L/14", device)

        if args.load_model == 'CLIP_L14':
            self.projection_head = Linear_Matrix_L14()
        else:
            self.projection_head = Linear_Matrix()

    def forward(self, image, text=None, mode_task=None):
        image_features = self.clip_model.encode_image(image)
        image_features = image_features.float()
        image_features = self.projection_head(image_features)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        text_tokenized = clip.tokenize(text, context_length=77, truncate=True).to('cuda')
        text_features = self.clip_model.encode_text(text_tokenized)
        text_features = text_features.float()
        text_features = self.projection_head(text_features)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        logit_scale = self.clip_model.logit_scale.exp()

        return logit_scale, image_features, text_features
        
    def get_text_features(self, text):
        text_tokenized = clip.tokenize(text, context_length=77, truncate=True).to('cuda')
        text_features = self.clip_model.encode_text(text_tokenized)
        text_features = text_features.float()
        text_features = self.projection_head(text_features)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        logit_scale = self.clip_model.logit_scale.exp()

        return logit_scale, text_features


class Linear_Matrix_L14(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(768, 2048, bias=False)
        # self.mlp = nn.Linear(768, 4096, bias=False)
    def forward(self, x):
        return self.mlp(x)



class ExpCLIP_PostProcess(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.final = nn.Linear(42,7,bias = False)

    def forward(self, out_features):
        logits = self.final(out_features)
        return logits

In [14]:
# create model and load pre_trained parameters
model = ExpCLIP_Test(args)
post_processor = ExpCLIP_PostProcess(args)

# pretrain_model_path = './checkpoint/' + args.job_id + "-model.pth"
pretrain_model_path = '/kaggle/working/checkpoint/1738165742-model_25.pth'
# pretrain_model_path  = '/kaggle/working/Exp-CLIP/checkpoint/ExpCLIP_L14_model.pth'

state_dict = model.state_dict()
pre_train_model = torch.load(pretrain_model_path, map_location=torch.device('cpu') )
# pre_train_model = torch.load(pretrain_model_path)

for name, param in pre_train_model.items():
    if "mlp.weight" in name:
        state_dict["projection_head.mlp.weight"].copy_(param)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(post_processor.parameters(), lr=1e-3, weight_decay=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
post_processor = post_processor.to(device)
model.eval()
post_processor.train()

100%|███████████████████████████████████████| 890M/890M [00:09<00:00, 93.6MiB/s]
  pre_train_model = torch.load(pretrain_model_path, map_location=torch.device('cpu') )


ExpCLIP_PostProcess(
  (final): Linear(in_features=42, out_features=7, bias=False)
)

In [28]:
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, txt_file, img_dir, transform=None):
        """
        Args:
            txt_file (str): Path to the .txt file with image filenames and labels.
            img_dir (str): Path to the directory containing images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.img_labels = []
        self.img_dir = img_dir
        self.transform = transform
        self.samples = []  # Add an attribute to store (image_path, label)

        # Read the txt file to get the image paths and labels
        with open(txt_file, 'r') as file:
            for line in file.readlines():
                # Assuming format: image_filename label
                filename, label = line.strip().split()
                name, ext = filename.strip().split('.')  # imagename is name.jpg  --->  name_aligned.jpg
                filename = name + '_aligned.' + ext
                full_path = os.path.join(self.img_dir, filename)  # Construct the full image path
                self.img_labels.append((filename, int(label) - 1))  # Store as a tuple (image_filename, label)
                self.samples.append((full_path, int(label) - 1))  # Store full path and label for caching

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)  # Construct the full image path
        image = Image.open(img_path).convert('RGB')  # Open image

        if self.transform:
            image = self.transform(image)  # Apply the transformations

        return image, label


# Define the transformations (similar to the ones in ImageFolder)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Example usage
test_data_path = '/kaggle/input/rafdb-dg/aligned/aligned'  # Path to the folder containing images
test_txt_file = '/kaggle/input/rafdb-dg/EmoLabel/train_label.txt'  # Path to the text file listing image names and labels

test_dataset = CustomImageDataset(txt_file=test_txt_file, img_dir=test_data_path, transform=transform)

# Create DataLoader for batching and shuffling
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


In [29]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# List of root directories for CAERS datasets

root_dirs = [
    '/kaggle/input/caer-s-test-1',
    '/kaggle/input/caer-s-test-2',
]

# Initialize dataset
dataset = FolderBasedDataset(root_dirs=root_dirs, transform=transform)

# Initialize DataLoader
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

# Inspect some samples
for images, labels, img_paths in data_loader:
    print("Batch size:", len(images))
    print("Labels:", labels)  # Print first 3 labels in the batch
    print("Image paths:", img_paths[:3])  # Print first 3 image paths in the batch
    break

Batch size: 32
Labels: tensor([0, 0, 4, 1, 0, 3, 6, 4, 5, 3, 5, 0, 1, 0, 1, 5, 4, 4, 3, 3, 4, 4, 0, 3,
        4, 5, 2, 4, 2, 5, 6, 5])
Image paths: ['/kaggle/input/caer-s-test-2/Surprise/Surprise/1452.png', '/kaggle/input/caer-s-test-2/Surprise/Surprise/0907.png', '/kaggle/input/caer-s-test-2/Sad/Sad/0338.png']


In [34]:
def post_train(data_loader, model, post_processor, optimizer, criterion, prompt, dataset_, batch_size_ = 32):
    correct = 0
    all_predicted = []
    all_targets = []

    cache_file = f"./cache/{dataset_}_image_features.pkl"
    os.makedirs(os.path.dirname(cache_file), exist_ok=True)

    if os.path.exists(cache_file):
        with open(cache_file, "rb") as f:
            image_feature_cache = pickle.load(f)
    else:
        image_feature_cache = {}

    # Add a progress bar using tqdm
    with tqdm(total=len(data_loader), desc="Post Training Progress", unit="batch") as pbar:
        for i, (images, target, _) in enumerate(data_loader):
            images = images.to(device)
            target = target.to(device)

            n, _, _, _ = images.shape

            # Check cache for image features
            image_ids = test_loader.dataset.samples[i * batch_size_:(i + 1) * batch_size_]
            uncached_ids = [idx for idx, img_id in enumerate(image_ids) if img_id not in image_feature_cache]

            if uncached_ids:
                uncached_images = images[uncached_ids]
                logit_scale, uncached_features, text_features = model(image=uncached_images, text=prompt, mode_task="Static_FER")

                for idx, img_id in zip(uncached_ids, image_ids):
                    image_feature_cache[img_id] = uncached_features[idx].cpu()

            # Retrieve features from the cache
            image_features = torch.stack([image_feature_cache[img_id] for img_id in image_ids]).to(device)
            logit_scale, text_features = model.get_text_features(prompt)


            output = logit_scale * image_features @ text_features.t()
            output = output.view(n, -1, 1)
            output = torch.mean(output, dim=-1)

            optimizer.zero_grad()  # Zero gradients
            final_logits = post_processor(output)
            loss = criterion(final_logits, targets)  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()

            predicted = final_logits.argmax(dim=1, keepdim=True)
            correct += predicted.eq(target.view_as(predicted)).sum().item()

            all_predicted.extend(predicted.cpu().numpy().flatten())
            all_targets.extend(target.cpu().numpy().flatten())

            pbar.update(1)

    # Save the cache
    with open(cache_file, "wb") as f:
        pickle.dump(image_feature_cache, f)

    # Calculate accuracy
    accuracy = 100. * correct / len(test_loader.dataset)
    
    # Calculate F1 score, precision, and recall
    f1 = f1_score(all_targets, all_predicted, average='weighted')
    precision = precision_score(all_targets, all_predicted, average='weighted')
    recall = recall_score(all_targets, all_predicted, average='weighted')

    # Print metrics
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"UAR: {uar:.4f}")
    
    return uar, accuracy



In [5]:
emotion_words = [
    ["Surprise", "Amazement", "Astonishment", "Wonder", "Shock", "Bewilderment"],
    ["Fear", "Terror", "Anxiety", "Dread", "Panic", "Apprehension"],
    ["Disgust", "Revulsion", "Contempt", "Loathing", "Repulsion", "Aversion"],
    ["Happiness", "Joy", "Delight", "Bliss", "Contentment", "Elation"],
    ["Sadness", "Sorrow", "Grief", "Melancholy", "Despair", "Heartache"],
    ["Anger", "Rage", "Fury", "Wrath", "Annoyance", "Resentment"],
    ["Neutral", "Indifference", "Apathy", "Calmness", "Detachment", "Equanimity"]
]

In [6]:
prompt = []
for emotions in emotion_words:
    prompt += emotions

# prompt = []
# for emotions in emotion_words:
#     definer = ''
#     for emotion in emotions:
#         definer += emotion + ', '
#     definer = definer[:len(definer) - 2]
#     prompt.append(definer)

prompt_list = [f'an expression of {p}' for p in prompt]

In [8]:
for prompt in prompt_list:
    encoding = clip.encode_text(prompt)
    break

AttributeError: module 'models.clip.clip' has no attribute 'encode_text'

In [37]:
post_train(data_loader, model, post_processor, optimizer, criterion, prompt_list, 'CAER-S')

Post Training Progress:   0%|          | 0/656 [00:01<?, ?batch/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 66.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 51.12 MiB is free. Process 3675 has 15.84 GiB memory in use. Of the allocated memory 14.95 GiB is allocated by PyTorch, and 615.34 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [136]:
test_data = CustomImageDataset(txt_file=test_txt_file, img_dir=test_data_path,
                              transform=transforms.Compose([transforms.Resize((224, 224)),
                                                 transforms.ToTensor()]))

test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=4,
                                              shuffle=False,
                                              num_workers=8,
                                              pin_memory=True)

In [110]:
with torch.no_grad():
    for i, (images, target) in enumerate(test_loader):
    
        images = images.to(device)
        target = target.to(device)
        break

FER_prompt_ = {'RAFDB':prompt}
zero_shot_prompt = FER_prompt_['RAFDB']
n,_,_,_ = images.shape
logit_scale, image_features, text_features = model(image=images,text=zero_shot_prompt, mode_task="Static_FER") 

In [111]:
prompt_number = int(len(zero_shot_prompt) / 7)

output = logit_scale * image_features @ text_features.t()
output = output.view(n, -1, 1)
# output = torch.mean(output, dim=-1)
# predicted = output.argmax(dim=1, keepdim=True)
top_n_predictions = torch.topk(output, 5, dim=1).indices
# top_n_predictions
output.shape

torch.Size([4, 42, 1])

In [112]:
batch_predicted = []
for preds in top_n_predictions:
    # Map the top n classes to their corresponding root class
    root_classes = [p.item() // 6 for p in preds]
    print(root_classes)
    # Count the occurrences of each root class
    root_vote = Counter(root_classes)
    # Get the most common root class as the final prediction
    voted_root_class = root_vote.most_common(1)[0][0]
    batch_predicted.append(voted_root_class)
batch_predicted

[4, 5, 4, 4, 3]
[4, 4, 4, 3, 5]
[3, 0, 3, 4, 4]
[6, 4, 4, 5, 6]


[4, 4, 3, 6]

In [137]:
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
import itertools
import matplotlib.pyplot as plt
import numpy as np

from collections import Counter

def zero_shot_test(set=0, dataset_=None, mode_task=None, FER_prompt_=None, prompt_type=None):

    DATASET_PATH_MAPPING = {
        "RAFDB": "/kaggle/input/rafdb-dg/aligned/aligned",
        "AffectNet7": "/data/EECS-IoannisLab/datasets/Static_FER_Datasets/AffectNet7_Face/test/",
        "AffectNet8": "/data/EECS-IoannisLab/datasets/Static_FER_Datasets/AffectNet8_Face/test/",
        "FERPlus": "/data/EECS-IoannisLab/datasets/Static_FER_Datasets/FERPlus_Face/test/",
        "DFEW": "./annotation/DFEW_set_"+str(set+1)+"_test.txt",
        "FERV39k": "./annotation/FERV39k_test.txt",
        "MAFW": "./annotation/MAFW_set_"+str(set+1)+"_test.txt",
        "AFEW": "./annotation/AFEW_validation.txt",
    }
    test_data_path = DATASET_PATH_MAPPING[dataset_]
    zero_shot_prompt = FER_prompt_[dataset_]
    
    if dataset_ in ["RAFDB", "AffectNet7", "DFEW", "FERV39k", "AFEW"]:
        prompt_number = int(len(zero_shot_prompt) / 7)
    elif dataset_ in ["AffectNet8", "FERPlus"]:
        prompt_number = int(len(zero_shot_prompt) / 8)
    elif dataset_ in ["MAFW"]:
        prompt_number = int(len(zero_shot_prompt) / 11)

    if mode_task == "Static_FER":
        batch_size_ = 256
        test_data = CustomImageDataset(txt_file=test_txt_file, img_dir=test_data_path,
                                       transform=transforms.Compose([transforms.Resize((224, 224)),
                                                                     transforms.ToTensor()]))
        confusion_matrix_path = "./confusion_matrix/"+args.load_model+"-"+dataset_+'-'+prompt_type+'.pdf'
    elif mode_task == "Dynamic_FER":
        batch_size_ = 64
        test_data = test_data_loader(list_file=test_data_path,
                                     num_segments=16,
                                     duration=1,
                                     image_size=224)
        confusion_matrix_path = "./confusion_matrix/"+args.load_model+"-"+dataset_+ '-' + str(set)+'-'+prompt_type+'.pdf'

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size_,
                                              shuffle=False,
                                              num_workers=8,
                                              pin_memory=True)
    correct = 0

    # Initialize variables for metrics
    all_predicted = []
    all_targets = []

    # Define the root emotion mapping
    root_emotion_mapping = {i: idx for idx, emotions in enumerate(emotion_words) for i in range(len(emotions))}
    # print(root_emotion_mapping)
    
    # Number of top classes to consider for voting
    top_n = 5

    # Caching mechanism
    cache_file = f"./cache/{dataset_}_image_features.pkl"
    os.makedirs(os.path.dirname(cache_file), exist_ok=True)

    if os.path.exists(cache_file):
        with open(cache_file, "rb") as f:
            image_feature_cache = pickle.load(f)
    else:
        image_feature_cache = {}

    # Add a progress bar using tqdm
    with tqdm(total=len(test_loader), desc="Testing Progress", unit="batch") as pbar:
        with torch.no_grad():
            for i, (images, target) in enumerate(test_loader):
                images = images.to(device)
                target = target.to(device)

                n, _, _, _ = images.shape

                # Check cache for image features
                image_ids = test_loader.dataset.samples[i * batch_size_:(i + 1) * batch_size_]
                uncached_ids = [idx for idx, img_id in enumerate(image_ids) if img_id not in image_feature_cache]

                if uncached_ids:
                    uncached_images = images[uncached_ids]
                    logit_scale, uncached_features, text_features = model(image=uncached_images, text=zero_shot_prompt, mode_task="Static_FER")

                    for idx, img_id in zip(uncached_ids, image_ids):
                        image_feature_cache[img_id] = uncached_features[idx].cpu()

                # Retrieve features from the cache

                image_features = torch.stack([image_feature_cache[img_id] for img_id in image_ids]).to(device)
                logit_scale, text_features = model.get_text_features(zero_shot_prompt)

                # print(zero_shot_prompt, text_features.shape)

                output = logit_scale * image_features @ text_features.t()
                # print(output.shape)
                output = output.view(n, -1, 1)
                # print(output.shape)
                output = torch.mean(output, dim=-1)
                # print(output.shape)

                # # Get the top n predictions for each input
                # top_n_predictions = torch.topk(output, top_n, dim=1).indices
                # # print(top_n_predictions)

                # # Perform voting for root classes
                # batch_predicted = []
                # for preds in top_n_predictions:
                #     # Map the top n classes to their corresponding root class
                #     root_classes = [p.item() // 6 for p in preds]
                #     # Count the occurrences of each root class
                #     root_vote = Counter(root_classes)
                #     # Get the most common root class as the final prediction
                #     voted_root_class = root_vote.most_common(1)[0][0]
                #     batch_predicted.append(voted_root_class)

                # # Add the batch predictions and targets to the respective lists
                # all_predicted.extend(batch_predicted)
                # all_targets.extend(target.cpu().numpy().flatten())

                # correct += sum(p == t for p, t in zip(batch_predicted, target.cpu().numpy()))

                # pbar.update(1)

                predicted = output.argmax(dim=1, keepdim=True) // 6
                correct += predicted.eq(target.view_as(predicted)).sum().item()

                all_predicted.extend(predicted.cpu().numpy().flatten())
                all_targets.extend(target.cpu().numpy().flatten())

                pbar.update(1)

    # Save the cache
    with open(cache_file, "wb") as f:
        pickle.dump(image_feature_cache, f)

    # Calculate accuracy
    accuracy = 100. * correct / len(test_loader.dataset)
    
    # Calculate F1 score, precision, and recall
    f1 = f1_score(all_targets, all_predicted, average='weighted')
    precision = precision_score(all_targets, all_predicted, average='weighted')
    recall = recall_score(all_targets, all_predicted, average='weighted')

    # Compute confusion matrix
    _confusion_matrix = confusion_matrix(all_targets, all_predicted)
    np.set_printoptions(precision=4)
    normalized_cm = _confusion_matrix.astype('float') / _confusion_matrix.sum(axis=1)[:, np.newaxis]
    normalized_cm = normalized_cm * 100
    list_diag = np.diag(normalized_cm)
    uar = list_diag.mean()

    # Print metrics
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"UAR: {uar:.4f}")

    # Plot normalized confusion matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(normalized_cm, interpolation='nearest', cmap=plt.cm.Reds)
    plt.colorbar()
    tick_marks = np.arange(len(Emotion_name_dic[dataset_]))
    plt.xticks(tick_marks, Emotion_name_dic[dataset_], rotation=45)
    plt.yticks(tick_marks, Emotion_name_dic[dataset_])

    fmt = '.2f'
    thresh = normalized_cm.max() / 2.
    for i, j in itertools.product(range(normalized_cm.shape[0]), range(normalized_cm.shape[1])):
        plt.text(j, i, format(normalized_cm[i, j], fmt), fontsize=12,
                 horizontalalignment="center",
                 color="white" if normalized_cm[i, j] > thresh else "black")

    plt.ylabel('True label', fontsize=18)
    plt.xlabel('Predicted label', fontsize=18)
    plt.tight_layout()
    plt.close()
    
    return uar, accuracy



In [138]:
def zero_shot_test_FER(FER_prompt_,type_):

    datasets_ = ["RAFDB",]
    
    for dataset in datasets_:
        uar, war = zero_shot_test(dataset_=dataset, mode_task="Static_FER", FER_prompt_=FER_prompt_, prompt_type=type_)
        # print(f'************************* {dataset}')
        # print(f"UAR/WAR: {uar:.2f}/{war:.2f}")

In [139]:
class RecorderMeter(object):
    pass    

In [6]:
# prompt = ['Surprise',
# 'Fear',
# 'Disgust',
# 'Happiness',
# 'Sadness',
# 'Anger',
# 'Neutral']
# # for key in label_mapping.keys():
# #     prompt.append(key)

# prompt

emotion_words = [
    ["Surprise", "Amazement", "Astonishment", "Wonder", "Shock", "Bewilderment"],
    ["Fear", "Terror", "Anxiety", "Dread", "Panic", "Apprehension"],
    ["Disgust", "Revulsion", "Contempt", "Loathing", "Repulsion", "Aversion"],
    ["Happiness", "Joy", "Delight", "Bliss", "Contentment", "Elation"],
    ["Sadness", "Sorrow", "Grief", "Melancholy", "Despair", "Heartache"],
    ["Anger", "Rage", "Fury", "Wrath", "Annoyance", "Resentment"],
    ["Neutral", "Indifference", "Apathy", "Calmness", "Detachment", "Equanimity"]
]

In [7]:
prompt = []
for emotions in emotion_words:
    prompt += emotions

# prompt = []
# for emotions in emotion_words:
#     definer = ''
#     for emotion in emotions:
#         definer += emotion + ', '
#     definer = definer[:len(definer) - 2]
#     prompt.append(definer)

prompt_list = [prompt,]

prompt_list.append([f'an expression of {p}' for p in prompt])
prompt_list.append([f'a photo of a face with an expression of {p}' for p in prompt])

In [143]:
print(f'testing for e_count = {e_count}')

for p in prompt_list:
    zero_shot_test_FER({'RAFDB':p}, 'type1')

testing for e_count = 25


Testing Progress: 100%|██████████| 48/48 [00:14<00:00,  3.29batch/s]


Accuracy: 37.14%
F1 Score: 0.3514
Precision: 0.5301
Recall: 0.3714
UAR: 30.1760


Testing Progress: 100%|██████████| 48/48 [00:14<00:00,  3.30batch/s]


Accuracy: 50.51%
F1 Score: 0.4534
Precision: 0.5007
Recall: 0.5051
UAR: 33.7579


Testing Progress: 100%|██████████| 48/48 [00:14<00:00,  3.40batch/s]


Accuracy: 30.01%
F1 Score: 0.2972
Precision: 0.4623
Recall: 0.3001
UAR: 27.3783


In [21]:
for i, FER_prompt in enumerate(FER_prompt_list):
    print(f'************************************************************************** Zero-shot Prompt Type: ', FER_prompt_type_list[i])
    type_ = "type"+str(i+1)
    print(FER_prompt['RAFDB'][0])
    # zero_shot_test_FER(FER_prompt, type_)



************************************************************************** Zero-shot Prompt Type:  Class Name
happiness.
************************************************************************** Zero-shot Prompt Type:  An Expression of Name
an expression of happiness.
************************************************************************** Zero-shot Prompt Type:  A Photo of A Face with An Expression of Name
a photo of a face with an expression of happiness.
************************************************************************** Zero-shot Prompt Type:  Expression Ensemble Five
an expression of happiness.
************************************************************************** Zero-shot Prompt Type:  Expression Ensemble Ten
an expression of happiness.


# Working out text encoder

In [5]:
clip_model, _ = clip.load("ViT-L/14", 'cpu')

100%|████████████████████████████████████████| 890M/890M [00:08<00:00, 106MiB/s]


In [None]:
text_tokenized = clip.tokenize(text, context_length=77, truncate=True).to('cuda')
text_features = self.clip_model.encode_text(text_tokenized)
text_features = text_features.float()
text_features = self.projection_head(text_features)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

In [18]:
text_feature_dict = {}

for prompt in prompt_list[1]:
    text_tokenized = clip.tokenize(prompt, context_length=77, truncate=True)
    text_features = clip_model.encode_text(text_tokenized)
    text_features = text_features.float()
    text_feature_dict[prompt] = text_features

In [22]:
import pickle

filepath = 'prompt_features_clip.pkl'

with open(filepath, 'rb') as f:
    file = pickle.load(f)

In [25]:
file['an expression of Surprise']

tensor([[-4.0530e-01, -3.2315e-01, -1.1393e-01,  4.7825e-01,  9.2232e-02,
          4.6974e-01,  1.2963e+00,  5.7603e-01,  7.7125e-01,  5.3472e-01,
          3.1385e-01,  2.9515e-01, -4.0732e-01,  4.0871e-01, -6.7008e-01,
         -4.1172e-01,  3.3998e-01,  5.2299e-01, -2.5251e-01, -1.6546e-01,
         -6.7288e-02,  4.5203e-01,  9.4692e-01,  2.7368e-01,  9.2738e-03,
         -3.6453e-01,  9.0409e-01,  4.1976e-01,  1.1359e-01,  1.5770e-01,
          7.9366e-02, -8.0827e-04,  3.3597e-01, -3.2699e-01,  5.5383e-01,
         -3.1630e-02,  3.6935e-01,  5.6386e-01, -1.6073e-01,  2.1650e-01,
         -1.4766e-01, -1.5909e-01,  7.3177e-01, -9.3351e-02, -1.4118e-01,
          5.0999e-02,  5.9460e-01, -4.3160e-01,  3.2536e-01, -2.3839e-01,
          5.9213e-01,  9.5787e-02,  2.2013e-01,  3.6423e-01,  6.5811e-01,
          5.0414e-01,  3.8529e-01,  2.2709e-01,  1.3345e-01, -1.4895e-01,
         -2.5076e-01,  5.2613e-01, -4.7673e-01, -1.3499e-01, -1.8288e-01,
          2.8535e-01,  2.3071e-01, -5.