In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
class EmotionDataset(Dataset):
    def __init__(self, json_path, split='train'):
        with open(json_path, 'r') as f:
            data = json.load(f)

        self.samples = []
        for item in data[split]:
            face = item.get('face_embedding')
            pose = item.get('pose_embedding')
            label = item.get('multi_hot')

            if face is None or pose is None or label is None:
                continue
            if any(x is None for x in face) or any(x is None for x in pose) or any(x is None for x in label):
                continue

            self.samples.append((face, pose, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        face, pose, label = self.samples[idx]
        return (
            torch.tensor(face, dtype=torch.float32),
            torch.tensor(pose, dtype=torch.float32),
            torch.tensor(label, dtype=torch.float32)
        )

In [3]:
class SimpleViT(nn.Module):
    def __init__(self, face_dim=512, pose_dim=34, hidden_dim=256, num_classes=7, n_heads=4, n_layers=4):
        super(SimpleViT, self).__init__()
        self.face_proj = nn.Linear(face_dim, hidden_dim)
        self.pose_proj = nn.Linear(pose_dim, hidden_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 3, hidden_dim))  # cls, face, pose

        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, num_classes),
            nn.Sigmoid()
        )

    def forward(self, face, pose):
        face_token = self.face_proj(face).unsqueeze(1)
        pose_token = self.pose_proj(pose).unsqueeze(1)
        cls_token = self.cls_token.expand(face.size(0), -1, -1)
        x = torch.cat([cls_token, face_token, pose_token], dim=1) + self.pos_embedding
        return self.mlp_head(self.transformer(x)[:, 0])

In [4]:
def compute_metrics(y_true, y_pred, threshold=0.5):
    y_pred_bin = (y_pred > threshold).astype(int)

    accuracy = accuracy_score(y_true, y_pred_bin)
    precision = precision_score(y_true, y_pred_bin, average='micro', zero_division=0)
    recall = recall_score(y_true, y_pred_bin, average='micro', zero_division=0)
    f1 = f1_score(y_true, y_pred_bin, average='micro', zero_division=0)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    all_labels, all_preds = [], []

    for face, pose, label in tqdm(dataloader, desc="Training", leave=False):
        face, pose, label = face.to(device), pose.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(face, pose)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_labels.append(label.cpu().numpy())
        all_preds.append(output.detach().cpu().numpy())

    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    metrics = compute_metrics(y_true, y_pred)
    return total_loss / len(dataloader), metrics

def evaluate(model, dataloader, criterion, device, mode="Validation"):
    model.eval()
    total_loss = 0.0
    all_labels, all_preds = [], []

    with torch.no_grad():
        for face, pose, label in tqdm(dataloader, desc=mode, leave=False):
            face, pose, label = face.to(device), pose.to(device), label.to(device)
            output = model(face, pose)
            loss = criterion(output, label)

            total_loss += loss.item()
            all_labels.append(label.cpu().numpy())
            all_preds.append(output.cpu().numpy())

    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    metrics = compute_metrics(y_true, y_pred)
    return total_loss / len(dataloader), metrics

def train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs=20, save_path="best_vit.pt"):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f} | Acc: {train_metrics['accuracy']:.4f} | Prec: {train_metrics['precision']:.4f} | Rec: {train_metrics['recall']:.4f} | F1: {train_metrics['f1']:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Acc: {val_metrics['accuracy']:.4f} | Prec: {val_metrics['precision']:.4f} | Rec: {val_metrics['recall']:.4f} | F1: {val_metrics['f1']:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("Best model saved.")

def test_model(model, test_loader, criterion, device, save_path="best_vit.pt"):
    model.load_state_dict(torch.load(save_path))
    test_loss, test_metrics = evaluate(model, test_loader, criterion, device, mode="Test")
    print(f"\nTest Loss: {test_loss:.4f} | Acc: {test_metrics['accuracy']:.4f} | Prec: {test_metrics['precision']:.4f} | Rec: {test_metrics['recall']:.4f} | F1: {test_metrics['f1']:.4f}")

In [5]:
json_path = "final_annotations.json"
batch_size = 32
epochs = 20
save_path = "best_simple_vit.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = EmotionDataset(json_path, split="train")
val_dataset = EmotionDataset(json_path, split="val")
test_dataset = EmotionDataset(json_path, split="test")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

model = SimpleViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs=epochs, save_path=save_path)
test_model(model, test_loader, criterion, device, save_path=save_path)


Epoch 1/20


                                                             

Train Loss: 0.2726 | Acc: 0.5664 | Prec: 0.8110 | Rec: 0.6169 | F1: 0.7007
Val   Loss: 0.2710 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000
Best model saved.

Epoch 2/20


                                                             

Train Loss: 0.2687 | Acc: 0.5679 | Prec: 0.8146 | Rec: 0.6180 | F1: 0.7028
Val   Loss: 0.2697 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000
Best model saved.

Epoch 3/20


                                                             

Train Loss: 0.2678 | Acc: 0.5681 | Prec: 0.8152 | Rec: 0.6171 | F1: 0.7025
Val   Loss: 0.2675 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000
Best model saved.

Epoch 4/20


                                                             

Train Loss: 0.2669 | Acc: 0.5672 | Prec: 0.8144 | Rec: 0.6174 | F1: 0.7023
Val   Loss: 0.2670 | Acc: 0.5630 | Prec: 0.8163 | Rec: 0.6127 | F1: 0.7000
Best model saved.

Epoch 5/20


                                                             

Train Loss: 0.2658 | Acc: 0.5674 | Prec: 0.8146 | Rec: 0.6174 | F1: 0.7024
Val   Loss: 0.2666 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000
Best model saved.

Epoch 6/20


                                                             

Train Loss: 0.2649 | Acc: 0.5667 | Prec: 0.8148 | Rec: 0.6180 | F1: 0.7029
Val   Loss: 0.2642 | Acc: 0.5636 | Prec: 0.8176 | Rec: 0.6129 | F1: 0.7006
Best model saved.

Epoch 7/20


                                                             

Train Loss: 0.2631 | Acc: 0.5657 | Prec: 0.8149 | Rec: 0.6184 | F1: 0.7032
Val   Loss: 0.2636 | Acc: 0.5607 | Prec: 0.8204 | Rec: 0.6082 | F1: 0.6985
Best model saved.

Epoch 8/20


                                                             

Train Loss: 0.2606 | Acc: 0.5651 | Prec: 0.8152 | Rec: 0.6202 | F1: 0.7045
Val   Loss: 0.2610 | Acc: 0.5636 | Prec: 0.8165 | Rec: 0.6155 | F1: 0.7019
Best model saved.

Epoch 9/20


                                                             

Train Loss: 0.2584 | Acc: 0.5641 | Prec: 0.8121 | Rec: 0.6243 | F1: 0.7059
Val   Loss: 0.2577 | Acc: 0.5570 | Prec: 0.8028 | Rec: 0.6357 | F1: 0.7096
Best model saved.

Epoch 10/20


                                                             

Train Loss: 0.2565 | Acc: 0.5622 | Prec: 0.8120 | Rec: 0.6258 | F1: 0.7069
Val   Loss: 0.2586 | Acc: 0.5621 | Prec: 0.8255 | Rec: 0.6114 | F1: 0.7025

Epoch 11/20


                                                             

Train Loss: 0.2558 | Acc: 0.5642 | Prec: 0.8121 | Rec: 0.6274 | F1: 0.7079
Val   Loss: 0.2567 | Acc: 0.5667 | Prec: 0.8133 | Rec: 0.6297 | F1: 0.7098
Best model saved.

Epoch 12/20


                                                             

Train Loss: 0.2537 | Acc: 0.5631 | Prec: 0.8117 | Rec: 0.6282 | F1: 0.7083
Val   Loss: 0.2570 | Acc: 0.5658 | Prec: 0.8098 | Rec: 0.6353 | F1: 0.7120

Epoch 13/20


                                                             

Train Loss: 0.2528 | Acc: 0.5641 | Prec: 0.8112 | Rec: 0.6307 | F1: 0.7097
Val   Loss: 0.2554 | Acc: 0.5653 | Prec: 0.8060 | Rec: 0.6396 | F1: 0.7132
Best model saved.

Epoch 14/20


                                                             

Train Loss: 0.2517 | Acc: 0.5644 | Prec: 0.8113 | Rec: 0.6322 | F1: 0.7107
Val   Loss: 0.2575 | Acc: 0.5641 | Prec: 0.8185 | Rec: 0.6148 | F1: 0.7022

Epoch 15/20


                                                             

Train Loss: 0.2511 | Acc: 0.5652 | Prec: 0.8107 | Rec: 0.6334 | F1: 0.7111
Val   Loss: 0.2545 | Acc: 0.5613 | Prec: 0.8032 | Rec: 0.6398 | F1: 0.7122
Best model saved.

Epoch 16/20


                                                             

Train Loss: 0.2512 | Acc: 0.5646 | Prec: 0.8116 | Rec: 0.6302 | F1: 0.7094
Val   Loss: 0.2578 | Acc: 0.5453 | Prec: 0.7767 | Rec: 0.6622 | F1: 0.7149

Epoch 17/20


                                                             

Train Loss: 0.2489 | Acc: 0.5659 | Prec: 0.8128 | Rec: 0.6330 | F1: 0.7117
Val   Loss: 0.2576 | Acc: 0.5673 | Prec: 0.8167 | Rec: 0.6265 | F1: 0.7090

Epoch 18/20


                                                             

Train Loss: 0.2484 | Acc: 0.5646 | Prec: 0.8106 | Rec: 0.6367 | F1: 0.7132
Val   Loss: 0.2569 | Acc: 0.5541 | Prec: 0.7897 | Rec: 0.6495 | F1: 0.7128

Epoch 19/20


                                                             

Train Loss: 0.2471 | Acc: 0.5669 | Prec: 0.8097 | Rec: 0.6404 | F1: 0.7152
Val   Loss: 0.2550 | Acc: 0.5678 | Prec: 0.8154 | Rec: 0.6325 | F1: 0.7124

Epoch 20/20


                                                             

Train Loss: 0.2461 | Acc: 0.5659 | Prec: 0.8096 | Rec: 0.6413 | F1: 0.7157
Val   Loss: 0.2560 | Acc: 0.5658 | Prec: 0.8077 | Rec: 0.6385 | F1: 0.7132


                                                       


Test Loss: 0.2483 | Acc: 0.5635 | Prec: 0.8166 | Rec: 0.6455 | F1: 0.7210


