## data prep

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from torchvision.models import vit_b_16

from transformers import CLIPProcessor, CLIPModel

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

# For visualization
import torchvision.transforms.functional as TF
class EmoSet(Dataset):
    ATTRIBUTES_MULTI_CLASS = [
        'scene', 'facial_expression', 'human_action', 'brightness', 'colorfulness',
    ]
    ATTRIBUTES_MULTI_LABEL = [
        'object'
    ]
    NUM_CLASSES = {
        'brightness': 11,
        'colorfulness': 11,
        'scene': 254,
        'object': 409,
        'facial_expression': 6,
        'human_action': 264,
    }

    def __init__(self,
                 data_root,
                 num_emotion_classes,
                 phase,
                 ):
        assert num_emotion_classes in (8, 2), "num_emotion_classes must be either 8 or 2"
        assert phase in ('train', 'val', 'test'), "phase must be 'train', 'val', or 'test'"
        self.transforms_dict = self.get_data_transforms()

        self.info = self.get_info(data_root, num_emotion_classes)

        if phase == 'train':
            self.transform = self.transforms_dict['train']
        elif phase == 'val':
            self.transform = self.transforms_dict['val']
        elif phase == 'test':
            self.transform = self.transforms_dict['test']
        else:
            raise NotImplementedError

        split_file = os.path.join(data_root, f'{phase}.json')
        with open(split_file, 'r') as f:
            data_store = json.load(f)

        self.data_store = [
            [
                self.info['emotion']['label2idx'][item[0]],
                int(item[1].split('.')[0].split('_')[-1]),
                os.path.join(data_root, item[1]),
                os.path.join(data_root, item[2])
            ]
            for item in data_store
        ]

    @classmethod
    def get_data_transforms(cls):
        transforms_dict = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'test': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }
        return transforms_dict

    def get_info(self, data_root, num_emotion_classes):
        assert num_emotion_classes in (8, 2), "num_emotion_classes must be either 8 or 2"
        info_path = os.path.join(data_root, 'info.json')
        with open(info_path, 'r') as f:
            info = json.load(f)

        if num_emotion_classes == 8:
            # Ensure 'emotion' key exists with label2idx and idx2label
            if 'emotion' not in info:
                info['emotion'] = {
                    'label2idx': info.get('label2idx', {}),
                    'idx2label': info.get('idx2label', {})
                }
        elif num_emotion_classes == 2:
            # Map emotions to 'positive' and 'negative'
            emotion_info = {
                'label2idx': {
                    'amusement': 0,
                    'awe': 0,
                    'contentment': 0,
                    'excitement': 0,
                    'anger': 1,
                    'disgust': 1,
                    'fear': 1,
                    'sadness': 1,
                },
                'idx2label': {
                    '0': 'positive',
                    '1': 'negative',
                }
            }
            info['emotion'] = emotion_info
        else:
            raise NotImplementedError

        return info

    def load_image_by_path(self, path):
        image = Image.open(path).convert('RGB')
        image = self.transform(image)
        return image

    def load_annotation_by_path(self, path):
        with open(path, 'r') as f:
            json_data = json.load(f)
        return json_data

    def __getitem__(self, idx):
        emotion_label_idx, image_id, image_path, annotation_path = self.data_store[idx]
        image = self.load_image_by_path(image_path)
        annotation_data = self.load_annotation_by_path(annotation_path)
        data = {'image_id': image_id, 'image': image, 'emotion_label_idx': emotion_label_idx}

        # Process multi-class attributes
        for attribute in self.ATTRIBUTES_MULTI_CLASS:
            attribute_label_idx = -1  # Default to -1 if not present
            if attribute in annotation_data:
                label = str(annotation_data[attribute])  # Ensure label is string for dict keys
                
                attribute_label_idx = self.info['label2idx'].get(label, -1)
            data[f'{attribute}_label_idx'] = attribute_label_idx

        # Process multi-label attributes
        for attribute in self.ATTRIBUTES_MULTI_LABEL:
            assert attribute == 'object', "Currently only 'object' attribute is supported as multi-label"
            num_classes = self.NUM_CLASSES[attribute]
            attribute_label_idx = torch.zeros(num_classes, dtype=torch.float)
            if attribute in annotation_data:
                for label in annotation_data[attribute]:
                    idx_label = self.info['label2idx'].get(label, None)
                    if idx_label is not None and 0 <= idx_label < num_classes:
                        attribute_label_idx[idx_label] = 1.0
            data[f'{attribute}_label_idx'] = attribute_label_idx

        return data

    def __len__(self):
        return len(self.data_store)


def get_transforms(phase='train'):
    if phase == 'train':
        return transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])


def create_dataloaders(data_root, batch_size=32, num_workers=4, num_emotion_classes=8):
    """
    Creates train, validation, and test dataloaders.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int): Batch size.
        num_workers (int): Number of subprocesses for data loading.
        num_emotion_classes (int): Number of emotion classes (2 or 8).

    Returns:
        tuple: (train_loader, val_loader, test_loader)
    """
    # Create datasets
    train_dataset = EmoSet(data_root=data_root, num_emotion_classes=num_emotion_classes, phase='train')
    val_dataset = EmoSet(data_root=data_root, num_emotion_classes=num_emotion_classes, phase='val')
    test_dataset = EmoSet(data_root=data_root, num_emotion_classes=num_emotion_classes, phase='test')

    # Handle class imbalance with weighted sampler
    # print(train_dataset.data_store)
    # print(len(train_dataset.data_store))
    # print((train_dataset.data_store[0]))
    labels = [sample[0] for sample in train_dataset.data_store]
    if num_emotion_classes == 2:
        label_indices = labels
    else:
        label_indices = labels

    class_sample_count = np.array([labels.count(t) for t in range(num_emotion_classes)])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[label] for label in label_indices])
    samples_weight = torch.from_numpy(samples_weight).double()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader


In [None]:
import random
from torch.utils.data import Subset

def create_sub_dataset(dataset, fraction=0.1):
    size = len(dataset)
    subset_size = int(size * fraction)
    indices = list(range(size))
    random.shuffle(indices)
    subset_indices = indices[:subset_size]
    return Subset(dataset, subset_indices)


def create_dataloaders(data_root, batch_size=32, num_workers=4, num_emotion_classes=8, fraction=0.1):
    """
    Creates train, validation, and test dataloaders and uses only a fraction of each dataset.

    Args:
        data_root (str): Root directory of the dataset.
        batch_size (int): Batch size.
        num_workers (int): Number of subprocesses for data loading.
        num_emotion_classes (int): Number of emotion classes (2 or 8).
        fraction (float): Fraction of dataset to use, e.g., 0.1 for 1/10.

    Returns:
        tuple: (train_loader, val_loader, test_loader)
    """
    # Create datasets
    train_dataset = EmoSet(data_root=data_root, num_emotion_classes=num_emotion_classes, phase='train')
    val_dataset = EmoSet(data_root=data_root, num_emotion_classes=num_emotion_classes, phase='val')
    test_dataset = EmoSet(data_root=data_root, num_emotion_classes=num_emotion_classes, phase='test')

    # 10% data
    train_dataset = create_sub_dataset(train_dataset, fraction)
    val_dataset = create_sub_dataset(val_dataset, fraction)
    test_dataset = create_sub_dataset(test_dataset, fraction)

    # Handle class imbalance with weighted sampler
    labels = [sample['emotion_label_idx'] for sample in train_dataset]
    class_sample_count = np.array([labels.count(t) for t in range(num_emotion_classes)])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[label] for label in labels])
    samples_weight = torch.from_numpy(samples_weight).double()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [None]:
data_root = './'

# Create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(data_root=data_root, batch_size=32, num_workers=4, num_emotion_classes=8)

## Model Setup

In [None]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16

class ViTSentimentClassifier(nn.Module):
    def __init__(self, num_classes=8, pretrained=True):
        super(ViTSentimentClassifier, self).__init__()
        self.vit = vit_b_16(pretrained=pretrained)
        self.vit.heads = nn.Identity()  
        self.classifier = nn.Linear(self.vit.hidden_dim, num_classes)

    def forward(self, x):
        features = self.vit(x)
        out = self.classifier(features)
        return out
from transformers import CLIPProcessor, CLIPModel

class CLIPSentimentClassifier(nn.Module):
    def __init__(self, num_classes=8, pretrained=True):
        super(CLIPSentimentClassifier, self).__init__()
        self.clip_model = CLIPModel.from_pretrained("clip-vit-base-patch16")
        self.classifier = nn.Linear(self.clip_model.config.hidden_size, num_classes)

    def forward(self, images):
        outputs = self.clip_model.get_image_features(images)
        out = self.classifier(outputs)
        return out


In [None]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16

class ViTSentimentClassifier(nn.Module):
    def __init__(self, num_classes=8, pretrained=True):
        super(ViTSentimentClassifier, self).__init__()
        self.vit = vit_b_16(pretrained=pretrained)
        self.vit.heads = nn.Identity()  
        self.classifier = nn.Linear(self.vit.hidden_dim, num_classes)

    def forward(self, x):
        features = self.vit(x)
        out = self.classifier(features)
        return out
from transformers import CLIPProcessor, CLIPModel

class CLIPSentimentClassifier(nn.Module):
    def __init__(self, num_classes=8, pretrained=True):
        super(CLIPSentimentClassifier, self).__init__()
        self.clip_model = CLIPModel.from_pretrained("./clip/")
        print(self.clip_model.config)
        self.classifier = nn.Linear(self.clip_model.config.hidden_size, num_classes)

    def forward(self, images):
        outputs = self.clip_model.get_image_features(images)
        out = self.classifier(outputs)
        return out

## Training

In [None]:
from tqdm import tqdm
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=20, scheduler=None):
    """
    Trains the given model.

    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): Training data loader.
        val_loader (DataLoader): Validation data loader.
        criterion: Loss function.
        optimizer: Optimizer.
        device: Device to train on ('cuda:2' or 'cpu').
        num_epochs (int): Number of epochs.
        scheduler: Learning rate scheduler.

    Returns:
        model: Trained model.
        history: Training history containing loss and accuracy.
    """
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            inputs = batch['image'].to(device)
            labels = batch['emotion_label_idx'].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            _, preds = torch.max(outputs, 1)
            correct_preds += torch.sum(preds == labels).item()
            total_preds += labels.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct_preds / total_preds
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)

        # Validation Phase
        model.eval()
        val_running_loss = 0.0
        val_correct_preds = 0
        val_total_preds = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                inputs = batch['image'].to(device)
                labels = batch['emotion_label_idx'].to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * inputs.size(0)

                _, preds = torch.max(outputs, 1)
                val_correct_preds += torch.sum(preds == labels).item()
                val_total_preds += labels.size(0)

        val_epoch_loss = val_running_loss / len(val_loader.dataset)
        val_epoch_acc = val_correct_preds / val_total_preds
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_epoch_acc)

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f} | "
              f"Val Loss: {val_epoch_loss:.4f} | Val Acc: {val_epoch_acc:.4f}")

        if scheduler:
            scheduler.step()

    return model, history
# Device configuration
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize the model
num_emotion_classes = 8  # or 2 for binary classification
vit_model = ViTSentimentClassifier(num_classes=num_emotion_classes, pretrained=True)
vit_model = vit_model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(vit_model.parameters(), lr=1e-4)

# Optionally, define a learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Train the model
num_epochs = 20
trained_vit_model, vit_history = train_model(
    model=vit_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=num_epochs,
    scheduler=scheduler
)
# Initialize the CLIP-based model
clip_model = CLIPSentimentClassifier(num_classes=num_emotion_classes, pretrained=True)
clip_model = clip_model.to(device)

# Define loss function and optimizer
criterion_clip = nn.CrossEntropyLoss()
optimizer_clip = torch.optim.AdamW(clip_model.parameters(), lr=1e-4)

scheduler_clip = torch.optim.lr_scheduler.StepLR(optimizer_clip, step_size=5, gamma=0.1)

# Train the CLIP model
trained_clip_model, clip_history = train_model(
    model=clip_model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion_clip,
    optimizer=optimizer_clip,
    device=device,
    num_epochs=num_epochs,
    scheduler=scheduler_clip
)

## Evaluation

In [None]:
def evaluate_model(model, test_loader, device, num_classes=8):
    """
    Evaluates the model on the test set.

    Args:
        model (nn.Module): The trained model.
        test_loader (DataLoader): Test data loader.
        device: Device to perform evaluation on.
        num_classes (int): Number of classes.

    Returns:
        metrics: Dictionary containing accuracy, precision, recall, and F1-score.
        conf_matrix: Confusion matrix.
    """
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            inputs = batch['image'].to(device)
            labels = batch['emotion_label_idx'].to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    conf_matrix = confusion_matrix(all_labels, all_preds)

    metrics = {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1_score': f1
    }

    return metrics, conf_matrix

In [None]:
vit_metrics, vit_conf_matrix = evaluate_model(trained_vit_model, test_loader, device, num_classes=num_emotion_classes)
print("ViT Model Evaluation Metrics:")
print(vit_metrics)

In [None]:
clip_metrics, clip_conf_matrix = evaluate_model(trained_clip_model, test_loader, device, num_classes=num_emotion_classes)
print("CLIP Model Evaluation Metrics:")
print(clip_metrics)

In [None]:
def plot_confusion_matrix(conf_matrix, classes, title='Confusion Matrix'):
    """
    Plots the confusion matrix.

    Args:
        conf_matrix (np.array): Confusion matrix.
        classes (list): List of class names.
        title (str): Title of the plot.
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title(title)
    plt.show()

# Load idx2label from info.json
def load_idx2label(data_root, num_emotion_classes):
    info_path = os.path.join(data_root, 'info.json')
    with open(info_path, 'r') as f:
        info = json.load(f)
    if num_emotion_classes == 8:
        idx2label = {int(k): v for k, v in info['idx2label'].items()}
    elif num_emotion_classes == 2:
        idx2label = {0: 'positive', 1: 'negative'}
    return idx2label

idx2label = load_idx2label(data_root, num_emotion_classes=8)

# Plot for ViT
plot_confusion_matrix(vit_conf_matrix, classes=[idx2label[i] for i in range(num_emotion_classes)],
                      title='ViT Model Confusion Matrix')

# Plot for CLIP
plot_confusion_matrix(clip_conf_matrix, classes=[idx2label[i] for i in range(num_emotion_classes)],
                      title='CLIP Model Confusion Matrix')

## Visualization with Saliency Maps

In [None]:
def generate_saliency_map_vit(model, image, label, device):
    """
    Generates a saliency map for the given image using ViT.

    Args:
        model (nn.Module): Trained ViT model.
        image (Tensor): Input image tensor.
        label (int): True label.
        device: Device.

    Returns:
        saliency: Saliency map.
    """
    model.eval()
    image = image.unsqueeze(0).to(device)
    image.requires_grad_()

    output = model(image)
    loss = nn.CrossEntropyLoss()(output, torch.tensor([label]).to(device))
    model.zero_grad()
    loss.backward()

    saliency, _ = torch.max(image.grad.data.abs(), dim=1)
    saliency = saliency.squeeze().cpu().numpy()

    return saliency


sample_batch = next(iter(test_loader))
sample_image = sample_batch['image'][0]
sample_label = sample_batch['emotion_label_idx'][0]
saliency = generate_saliency_map_vit(trained_vit_model, sample_image, sample_label, device)
print(idx2label,sample_label)
# Plotting
plt.figure(figsize=(8, 8))
plt.imshow(sample_image.permute(1, 2, 0).cpu().numpy())
plt.imshow(saliency, cmap='hot', alpha=0.5)
plt.title(f"Saliency Map - True Label: {idx2label[sample_label.item()]}")
plt.axis('off')
plt.show()

In [None]:
def generate_saliency_map_clip(model, image, label, device):
    """
    Generates a saliency map for the given image using CLIP.

    Args:
        model (nn.Module): Trained CLIP model.
        image (Tensor): Input image tensor.
        label (int): True label.
        device: Device.

    Returns:
        saliency: Saliency map.
    """
    model.eval()
    image = image.unsqueeze(0).to(device)
    image.requires_grad_()

    output = model(image)
    loss = nn.CrossEntropyLoss()(output, torch.tensor([label]).to(device))
    model.zero_grad()
    loss.backward()

    saliency, _ = torch.max(image.grad.data.abs(), dim=1)
    saliency = saliency.squeeze().cpu().numpy()

    return saliency

saliency_clip = generate_saliency_map_clip(trained_clip_model, sample_image, sample_label, device)

# Plotting
plt.figure(figsize=(8, 8))
plt.imshow(sample_image.permute(1, 2, 0).cpu().numpy())
plt.imshow(saliency_clip, cmap='hot', alpha=0.5)
plt.title(f"Saliency Map (CLIP) - True Label: {idx2label[sample_label.item()]}")
plt.axis('off')
plt.show()