In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
class EmotionDataset(Dataset):
    def __init__(self, json_path, split="train"):
        with open(json_path, "r") as f:
            data = json.load(f)
        self.samples = []
        for item in data.get(split, []):
            face = item.get("face_embedding")
            pose = item.get("pose_embedding")
            label = item.get("multi_hot")
            if face is None or pose is None or label is None:
                continue
            self.samples.append((face, pose, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        face, pose, label = self.samples[idx]
        return (
            torch.tensor(face, dtype=torch.float32),
            torch.tensor(pose, dtype=torch.float32),
            torch.tensor(label, dtype=torch.float32),
        )

In [3]:
def compute_metrics(y_true, y_pred, threshold=0.5):
    y_pred_bin = (y_pred > threshold).astype(int)
    accuracy  = accuracy_score(y_true, y_pred_bin)
    precision = precision_score(y_true, y_pred_bin, average='micro', zero_division=0)
    recall    = recall_score(y_true, y_pred_bin, average='micro', zero_division=0)
    f1        = f1_score(y_true, y_pred_bin, average='micro', zero_division=0)
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    all_labels, all_preds = [], []
    for face, pose, label in tqdm(dataloader, desc="Training", leave=False):
        face, pose, label = face.to(device), pose.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(face, pose)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        all_labels.append(label.cpu().numpy())
        all_preds.append(output.detach().cpu().numpy())
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    return total_loss / len(dataloader), compute_metrics(y_true, y_pred)

def evaluate(model, dataloader, criterion, device, mode="Validation"):
    model.eval()
    total_loss = 0.0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for face, pose, label in tqdm(dataloader, desc=mode, leave=False):
            face, pose, label = face.to(device), pose.to(device), label.to(device)
            output = model(face, pose)
            loss = criterion(output, label)
            total_loss += loss.item()
            all_labels.append(label.cpu().numpy())
            all_preds.append(output.cpu().numpy())
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    return total_loss / len(dataloader), compute_metrics(y_true, y_pred)

In [4]:
def train_model(model, train_loader, val_loader, optimizer, criterion, device,
                epochs=20, save_path="best_vit_crossattn.pt"):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss,   val_metrics   = evaluate(model, val_loader,   criterion, device, mode="Validation")

        print(f"Train Loss: {train_loss:.4f} | "
              f"Acc: {train_metrics['accuracy']:.4f} | "
              f"Prec: {train_metrics['precision']:.4f} | "
              f"Rec: {train_metrics['recall']:.4f} | "
              f"F1: {train_metrics['f1']:.4f}")
        print(f" Val  Loss: {val_loss:.4f} | "
              f"Acc: {val_metrics['accuracy']:.4f} | "
              f"Prec: {val_metrics['precision']:.4f} | "
              f"Rec: {val_metrics['recall']:.4f} | "
              f"F1: {val_metrics['f1']:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("→ Best model saved.")

def test_model(model, test_loader, criterion, device, save_path="best_vit_crossattn.pt"):
    model.load_state_dict(torch.load(save_path))
    test_loss, test_metrics = evaluate(model, test_loader, criterion, device, mode="Test")
    print(f"\nTest Loss: {test_loss:.4f} | "
          f"Acc: {test_metrics['accuracy']:.4f} | "
          f"Prec: {test_metrics['precision']:.4f} | "
          f"Rec: {test_metrics['recall']:.4f} | "
          f"F1: {test_metrics['f1']:.4f}")


In [5]:
class ChunkedCrossAttnViT(nn.Module):
    def __init__(
        self,
        face_dim: int = 512,
        face_chunks: int = 8,
        pose_dim: int = 34,
        pose_chunks: int = 2,
        hidden_dim: int = 256,
        num_classes: int = 7,
        n_heads: int = 4,
        face_layers: int = 2,
        pose_layers: int = 2,
        fusion_layers: int = 4,
    ):
        super().__init__()
        assert face_dim % face_chunks == 0, "face_dim must be divisible by face_chunks"
        assert pose_dim % pose_chunks == 0, "pose_dim must be divisible by pose_chunks"
        self.face_chunks = face_chunks
        self.pose_chunks = pose_chunks
        self.f_chunk_size = face_dim // face_chunks
        self.p_chunk_size = pose_dim // pose_chunks

        self.face_proj = nn.Linear(self.f_chunk_size, hidden_dim)
        self.pose_proj = nn.Linear(self.p_chunk_size, hidden_dim)

        fe = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, batch_first=True)
        self.face_enc = nn.TransformerEncoder(fe, num_layers=face_layers)
        pe = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, batch_first=True)
        self.pose_enc = nn.TransformerEncoder(pe, num_layers=pose_layers)

        self.f2p_attn = nn.MultiheadAttention(hidden_dim, n_heads, batch_first=True)
        self.p2f_attn = nn.MultiheadAttention(hidden_dim, n_heads, batch_first=True)

        total_tokens = 1 + face_chunks + pose_chunks
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_emb = nn.Parameter(torch.randn(1, total_tokens, hidden_dim))

        fu = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, batch_first=True)
        self.fusion_enc = nn.TransformerEncoder(fu, num_layers=fusion_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Linear(hidden_dim // 2, num_classes),
            nn.Sigmoid(),
        )

    def forward(self, face: torch.Tensor, pose: torch.Tensor) -> torch.Tensor:
        B = face.size(0)

        f = face.view(B, self.face_chunks, self.f_chunk_size)
        f = self.face_proj(f)
        f = self.face_enc(f)

        p = pose.view(B, self.pose_chunks, self.p_chunk_size)
        p = self.pose_proj(p)
        p = self.pose_enc(p)

        f2p, _ = self.f2p_attn(query=f, key=p, value=p)
        f = f + f2p
        p2f, _ = self.p2f_attn(query=p, key=f, value=f)
        p = p + p2f

        cls = self.cls_token.expand(B, -1, -1)
        seq = torch.cat([cls, f, p], dim=1) + self.pos_emb

        fused = self.fusion_enc(seq)
        cls_out = fused[:, 0]
        return self.mlp_head(cls_out)

In [7]:
json_path = "final_annotations.json"
batch_size = 32
epochs     = 20
save_path  = "best_vit_chunked_crossattn.pt"
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_ds = EmotionDataset(json_path, split="train")
val_ds   = EmotionDataset(json_path, split="val")
test_ds  = EmotionDataset(json_path, split="test")

train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_ld   = DataLoader(val_ds,   batch_size=batch_size)
test_ld  = DataLoader(test_ds,  batch_size=batch_size)

model     = ChunkedCrossAttnViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

In [8]:
train_model(model, train_ld, val_ld, optimizer, criterion, device, epochs, save_path)
test_model(model, test_ld, criterion, device, save_path)


Epoch 1/20


                                                             

Train Loss: 0.2723 | Acc: 0.5674 | Prec: 0.8155 | Rec: 0.6174 | F1: 0.7028
 Val  Loss: 0.2601 | Acc: 0.5636 | Prec: 0.8207 | Rec: 0.6125 | F1: 0.7015
→ Best model saved.

Epoch 2/20


                                                             

Train Loss: 0.2583 | Acc: 0.5639 | Prec: 0.8135 | Rec: 0.6245 | F1: 0.7066
 Val  Loss: 0.2654 | Acc: 0.5638 | Prec: 0.8184 | Rec: 0.6135 | F1: 0.7013

Epoch 3/20


                                                             

Train Loss: 0.2552 | Acc: 0.5640 | Prec: 0.8098 | Rec: 0.6301 | F1: 0.7087
 Val  Loss: 0.2565 | Acc: 0.5658 | Prec: 0.8187 | Rec: 0.6168 | F1: 0.7035
→ Best model saved.

Epoch 4/20


                                                             

Train Loss: 0.2510 | Acc: 0.5638 | Prec: 0.8103 | Rec: 0.6347 | F1: 0.7118
 Val  Loss: 0.2565 | Acc: 0.5430 | Prec: 0.7678 | Rec: 0.6735 | F1: 0.7176
→ Best model saved.

Epoch 5/20


                                                             

Train Loss: 0.2479 | Acc: 0.5668 | Prec: 0.8096 | Rec: 0.6425 | F1: 0.7164
 Val  Loss: 0.2547 | Acc: 0.5547 | Prec: 0.7871 | Rec: 0.6576 | F1: 0.7166
→ Best model saved.

Epoch 6/20


                                                             

Train Loss: 0.2434 | Acc: 0.5688 | Prec: 0.8094 | Rec: 0.6485 | F1: 0.7200
 Val  Loss: 0.2638 | Acc: 0.5653 | Prec: 0.8102 | Rec: 0.6299 | F1: 0.7088

Epoch 7/20


                                                             

Train Loss: 0.2388 | Acc: 0.5671 | Prec: 0.8068 | Rec: 0.6564 | F1: 0.7239
 Val  Loss: 0.2606 | Acc: 0.5258 | Prec: 0.7524 | Rec: 0.6796 | F1: 0.7141

Epoch 8/20


                                                             

Train Loss: 0.2325 | Acc: 0.5683 | Prec: 0.8081 | Rec: 0.6630 | F1: 0.7284
 Val  Loss: 0.2608 | Acc: 0.5547 | Prec: 0.7872 | Rec: 0.6548 | F1: 0.7150

Epoch 9/20


                                                             

Train Loss: 0.2257 | Acc: 0.5695 | Prec: 0.8104 | Rec: 0.6714 | F1: 0.7343
 Val  Loss: 0.2666 | Acc: 0.5458 | Prec: 0.7813 | Rec: 0.6523 | F1: 0.7110

Epoch 10/20


                                                              

Train Loss: 0.2181 | Acc: 0.5765 | Prec: 0.8151 | Rec: 0.6861 | F1: 0.7451
 Val  Loss: 0.2721 | Acc: 0.5227 | Prec: 0.7742 | Rec: 0.6230 | F1: 0.6904

Epoch 11/20


                                                             

Train Loss: 0.2093 | Acc: 0.5792 | Prec: 0.8209 | Rec: 0.6955 | F1: 0.7530
 Val  Loss: 0.2856 | Acc: 0.5473 | Prec: 0.7919 | Rec: 0.6228 | F1: 0.6972

Epoch 12/20


                                                             

Train Loss: 0.1985 | Acc: 0.5938 | Prec: 0.8311 | Rec: 0.7107 | F1: 0.7662
 Val  Loss: 0.2875 | Acc: 0.5190 | Prec: 0.7435 | Rec: 0.6632 | F1: 0.7011

Epoch 13/20


                                                             

Train Loss: 0.1882 | Acc: 0.6086 | Prec: 0.8393 | Rec: 0.7326 | F1: 0.7823
 Val  Loss: 0.2977 | Acc: 0.5104 | Prec: 0.7414 | Rec: 0.6512 | F1: 0.6934

Epoch 14/20


                                                             

Train Loss: 0.1787 | Acc: 0.6226 | Prec: 0.8437 | Rec: 0.7513 | F1: 0.7949
 Val  Loss: 0.3104 | Acc: 0.4879 | Prec: 0.7196 | Rec: 0.6398 | F1: 0.6774

Epoch 15/20


                                                             

Train Loss: 0.1682 | Acc: 0.6397 | Prec: 0.8551 | Rec: 0.7684 | F1: 0.8094
 Val  Loss: 0.3129 | Acc: 0.4981 | Prec: 0.7365 | Rec: 0.6292 | F1: 0.6787

Epoch 16/20


                                                             

Train Loss: 0.1579 | Acc: 0.6615 | Prec: 0.8674 | Rec: 0.7852 | F1: 0.8243
 Val  Loss: 0.3341 | Acc: 0.4773 | Prec: 0.7066 | Rec: 0.6262 | F1: 0.6640

Epoch 17/20


                                                             

Train Loss: 0.1457 | Acc: 0.6831 | Prec: 0.8765 | Rec: 0.8048 | F1: 0.8391
 Val  Loss: 0.3595 | Acc: 0.4879 | Prec: 0.7129 | Rec: 0.6542 | F1: 0.6823

Epoch 18/20


                                                             

Train Loss: 0.1353 | Acc: 0.7014 | Prec: 0.8852 | Rec: 0.8221 | F1: 0.8525
 Val  Loss: 0.3675 | Acc: 0.4616 | Prec: 0.7011 | Rec: 0.6153 | F1: 0.6554

Epoch 19/20


                                                             

Train Loss: 0.1254 | Acc: 0.7239 | Prec: 0.8928 | Rec: 0.8360 | F1: 0.8635
 Val  Loss: 0.3803 | Acc: 0.4442 | Prec: 0.6802 | Rec: 0.6508 | F1: 0.6651

Epoch 20/20


                                                             

Train Loss: 0.1162 | Acc: 0.7402 | Prec: 0.8983 | Rec: 0.8503 | F1: 0.8736
 Val  Loss: 0.3981 | Acc: 0.4624 | Prec: 0.6924 | Rec: 0.6396 | F1: 0.6650


                                                       


Test Loss: 0.2477 | Acc: 0.5599 | Prec: 0.8008 | Rec: 0.6671 | F1: 0.7279




# Hyperparameter tuning

In [23]:
class ChunkedCrossAttnViT_tuned(nn.Module):
    def __init__(
        self,
        face_dim: int = 512,
        face_chunks: int = 8,
        pose_dim: int = 34,
        pose_chunks: int = 2,
        hidden_dim: int = 256,
        num_classes: int = 7,
        n_heads: int = 4,
        face_layers: int = 2,
        pose_layers: int = 2,
        fusion_layers: int = 4,
        dropout_rate: float = 0.3,
    ):
        super().__init__()
        assert face_dim % face_chunks == 0
        assert pose_dim % pose_chunks == 0

        self.face_chunks = face_chunks
        self.pose_chunks = pose_chunks
        self.f_chunk_size = face_dim // face_chunks
        self.p_chunk_size = pose_dim // pose_chunks

        self.face_proj = nn.Linear(self.f_chunk_size, hidden_dim)
        self.pose_proj = nn.Linear(self.p_chunk_size, hidden_dim)

        fe = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout_rate, batch_first=True)
        self.face_enc = nn.TransformerEncoder(fe, num_layers=face_layers)
        pe = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout_rate, batch_first=True)
        self.pose_enc = nn.TransformerEncoder(pe, num_layers=pose_layers)

        self.f2p_attn = nn.MultiheadAttention(hidden_dim, n_heads, batch_first=True)
        self.p2f_attn = nn.MultiheadAttention(hidden_dim, n_heads, batch_first=True)

        total_tokens = 1 + face_chunks + pose_chunks
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_emb = nn.Parameter(torch.randn(1, total_tokens, hidden_dim))

        fu = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout_rate, batch_first=True)
        self.fusion_enc = nn.TransformerEncoder(fu, num_layers=fusion_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout_rate),  # <-- Apply dropout here
            nn.Linear(hidden_dim // 2, num_classes),
            nn.Sigmoid(),
        )

    def forward(self, face: torch.Tensor, pose: torch.Tensor) -> torch.Tensor:
        B = face.size(0)

        f = face.view(B, self.face_chunks, self.f_chunk_size)
        f = self.face_proj(f)
        f = self.face_enc(f)

        p = pose.view(B, self.pose_chunks, self.p_chunk_size)
        p = self.pose_proj(p)
        p = self.pose_enc(p)

        f2p, _ = self.f2p_attn(query=f, key=p, value=p)
        f = f + f2p
        p2f, _ = self.p2f_attn(query=p, key=f, value=f)
        p = p + p2f

        cls = self.cls_token.expand(B, -1, -1)
        seq = torch.cat([cls, f, p], dim=1) + self.pos_emb

        fused = self.fusion_enc(seq)
        cls_out = fused[:, 0]
        return self.mlp_head(cls_out)

In [24]:
class EarlyStopping:
    def __init__(self, patience: int = 3, min_delta: float = 0.0, verbose: bool = False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss: float):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [25]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    epochs: int = 20,
    save_path: str = "best_vit_crossattn_tuned.pt",
    patience: int = 5
):
    best_val_loss = float('inf')
    early_stopper = EarlyStopping(patience=patience, verbose=True)

    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")

        train_loss, train_metrics = train_one_epoch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_metrics = evaluate(
            model, val_loader, criterion, device, mode="Validation"
        )
        print(
            f"Train Loss: {train_loss:.4f} | "
            f"Acc: {train_metrics['accuracy']:.4f} | "
            f"Prec: {train_metrics['precision']:.4f} | "
            f"Rec: {train_metrics['recall']:.4f} | "
            f"F1: {train_metrics['f1']:.4f}"
        )
        print(
            f" Val  Loss: {val_loss:.4f} | "
            f"Acc: {val_metrics['accuracy']:.4f} | "
            f"Prec: {val_metrics['precision']:.4f} | "
            f"Rec: {val_metrics['recall']:.4f} | "
            f"F1: {val_metrics['f1']:.4f}"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("→ Best model saved.")

        early_stopper(val_loss)
        if early_stopper.early_stop:
            print("→ Early stopping triggered.")
            break

In [26]:
model = ChunkedCrossAttnViT_tuned(dropout_rate=0.3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

train_model(model, train_ld, val_ld, optimizer, criterion, device, epochs=20, save_path="best_vit_crossattn_tuned.pt", patience=5)
test_model(model, test_ld, criterion, device, save_path="best_vit_crossattn_tuned.pt")



Epoch 1/20


                                                              

Train Loss: 0.2844 | Acc: 0.5583 | Prec: 0.7984 | Rec: 0.6231 | F1: 0.7000
 Val  Loss: 0.2665 | Acc: 0.5624 | Prec: 0.8129 | Rec: 0.6148 | F1: 0.7001
→ Best model saved.

Epoch 2/20


                                                             

Train Loss: 0.2671 | Acc: 0.5616 | Prec: 0.8075 | Rec: 0.6258 | F1: 0.7052
 Val  Loss: 0.2631 | Acc: 0.5641 | Prec: 0.7980 | Rec: 0.6439 | F1: 0.7127
→ Best model saved.

Epoch 3/20


                                                             

Train Loss: 0.2631 | Acc: 0.5598 | Prec: 0.8087 | Rec: 0.6278 | F1: 0.7068
 Val  Loss: 0.2612 | Acc: 0.5498 | Prec: 0.7847 | Rec: 0.6553 | F1: 0.7142
→ Best model saved.

Epoch 4/20


                                                             

Train Loss: 0.2595 | Acc: 0.5636 | Prec: 0.8082 | Rec: 0.6340 | F1: 0.7106
 Val  Loss: 0.2571 | Acc: 0.5627 | Prec: 0.8023 | Rec: 0.6381 | F1: 0.7108
→ Best model saved.

Epoch 5/20


                                                             

Train Loss: 0.2586 | Acc: 0.5637 | Prec: 0.8092 | Rec: 0.6333 | F1: 0.7105
 Val  Loss: 0.2571 | Acc: 0.5636 | Prec: 0.8186 | Rec: 0.6161 | F1: 0.7031
EarlyStopping counter: 1 out of 5

Epoch 6/20


                                                             

Train Loss: 0.2558 | Acc: 0.5618 | Prec: 0.8085 | Rec: 0.6318 | F1: 0.7093
 Val  Loss: 0.2568 | Acc: 0.5621 | Prec: 0.8019 | Rec: 0.6415 | F1: 0.7128
→ Best model saved.

Epoch 7/20


                                                             

Train Loss: 0.2527 | Acc: 0.5646 | Prec: 0.8087 | Rec: 0.6360 | F1: 0.7120
 Val  Loss: 0.2599 | Acc: 0.5387 | Prec: 0.7724 | Rec: 0.6649 | F1: 0.7147
EarlyStopping counter: 1 out of 5

Epoch 8/20


                                                             

Train Loss: 0.2521 | Acc: 0.5626 | Prec: 0.8081 | Rec: 0.6358 | F1: 0.7117
 Val  Loss: 0.2596 | Acc: 0.5493 | Prec: 0.7807 | Rec: 0.6570 | F1: 0.7135
EarlyStopping counter: 2 out of 5

Epoch 9/20


                                                             

Train Loss: 0.2496 | Acc: 0.5631 | Prec: 0.8091 | Rec: 0.6365 | F1: 0.7125
 Val  Loss: 0.2604 | Acc: 0.5461 | Prec: 0.7767 | Rec: 0.6598 | F1: 0.7135
EarlyStopping counter: 3 out of 5

Epoch 10/20


                                                             

Train Loss: 0.2477 | Acc: 0.5625 | Prec: 0.8075 | Rec: 0.6399 | F1: 0.7140
 Val  Loss: 0.2589 | Acc: 0.5593 | Prec: 0.7976 | Rec: 0.6458 | F1: 0.7137
EarlyStopping counter: 4 out of 5

Epoch 11/20


                                                             

Train Loss: 0.2454 | Acc: 0.5615 | Prec: 0.8072 | Rec: 0.6416 | F1: 0.7150
 Val  Loss: 0.2614 | Acc: 0.5416 | Prec: 0.7659 | Rec: 0.6675 | F1: 0.7133
EarlyStopping counter: 5 out of 5
→ Early stopping triggered.


                                                       


Test Loss: 0.2483 | Acc: 0.5630 | Prec: 0.8124 | Rec: 0.6480 | F1: 0.7210


