In [7]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt

In [8]:
class EmotionDataset(Dataset):
    def __init__(self, json_path, split="train"):
        with open(json_path, "r") as f:
            data = json.load(f)
        self.samples = []
        for item in data.get(split, []):
            face = item.get("face_embedding")
            pose = item.get("pose_embedding")
            label = item.get("multi_hot")
            if face is None or pose is None or label is None:
                continue
            self.samples.append((face, pose, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        face, pose, label = self.samples[idx]
        return (
            torch.tensor(face, dtype=torch.float32),
            torch.tensor(pose, dtype=torch.float32),
            torch.tensor(label, dtype=torch.float32),
        )

In [9]:
def compute_metrics(y_true, y_pred, threshold=0.5):
    y_pred_bin = (y_pred > threshold).astype(int)
    return {
        "accuracy":  accuracy_score(y_true, y_pred_bin),
        "precision": precision_score(y_true, y_pred_bin, average='micro', zero_division=0),
        "recall":    recall_score(y_true, y_pred_bin, average='micro', zero_division=0),
        "f1":        f1_score(y_true, y_pred_bin, average='micro', zero_division=0),
    }

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    all_labels, all_preds = [], []
    for face, pose, label in tqdm(dataloader, desc="Training", leave=False):
        face, pose, label = face.to(device), pose.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(face, pose)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        all_labels.append(label.cpu().numpy())
        all_preds.append(output.detach().cpu().numpy())
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    return total_loss / len(dataloader), compute_metrics(y_true, y_pred)

def evaluate(model, dataloader, criterion, device, mode="Validation"):
    model.eval()
    total_loss = 0.0
    all_labels, all_preds = [], []
    with torch.no_grad():
        for face, pose, label in tqdm(dataloader, desc=mode, leave=False):
            face, pose, label = face.to(device), pose.to(device), label.to(device)
            output = model(face, pose)
            total_loss += criterion(output, label).item()
            all_labels.append(label.cpu().numpy())
            all_preds.append(output.cpu().numpy())
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    return total_loss / len(dataloader), compute_metrics(y_true, y_pred)

def train_model(model, train_loader, val_loader, optimizer, criterion, device,
                epochs=20, save_path="best_vit.pt"):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device, mode="Validation")
        print(f"Train Loss: {train_loss:.4f} | "
              f"Acc: {train_metrics['accuracy']:.4f} | "
              f"Prec: {train_metrics['precision']:.4f} | "
              f"Rec: {train_metrics['recall']:.4f} | "
              f"F1: {train_metrics['f1']:.4f}")
        print(f" Val  Loss: {val_loss:.4f} | "
              f"Acc: {val_metrics['accuracy']:.4f} | "
              f"Prec: {val_metrics['precision']:.4f} | "
              f"Rec: {val_metrics['recall']:.4f} | "
              f"F1: {val_metrics['f1']:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("→ Best model saved.")

def test_model(model, test_loader, criterion, device, save_path="best_vit.pt"):
    model.load_state_dict(torch.load(save_path))
    model.eval()

    test_loss = 0.0
    y_true, y_pred = [], []

    with torch.no_grad():
        for face, pose, label in test_loader:
            face, pose, label = face.to(device), pose.to(device), label.to(device)
            output = model(face, pose)
            test_loss += criterion(output, label).item() * face.size(0)

            preds = torch.argmax(output, dim=1).cpu().numpy()
            truths = torch.argmax(label, dim=1).cpu().numpy()

            y_pred.extend(preds)
            y_true.extend(truths)

    test_loss /= len(test_loader.dataset)

    acc   = accuracy_score(y_true, y_pred)
    prec  = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    rec   = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1    = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    print(f"\nTest Loss: {test_loss:.4f} | "
          f"Acc: {acc:.4f} | "
          f"Prec: {prec:.4f} | "
          f"Rec: {rec:.4f} | "
          f"F1: {f1:.4f}")

In [10]:
class ChunkedMultiStageViT(nn.Module):
    def __init__(
        self,
        face_dim: int = 512,
        face_chunks: int = 8,
        pose_dim: int = 34,
        pose_chunks: int = 2,
        hidden_dim: int = 256,
        num_classes: int = 7,
        n_heads: int = 4,
        face_layers: int = 2,
        pose_layers: int = 2,
        fusion_layers: int = 4,
    ):
        super().__init__()
        assert face_dim % face_chunks == 0, "face_dim must divide evenly by face_chunks"
        assert pose_dim % pose_chunks == 0, "pose_dim must divide evenly by pose_chunks"

        self.face_chunks = face_chunks
        self.pose_chunks = pose_chunks
        self.f_chunk_size = face_dim // face_chunks
        self.p_chunk_size = pose_dim // pose_chunks

        self.face_proj = nn.Linear(self.f_chunk_size, hidden_dim)
        self.pose_proj = nn.Linear(self.p_chunk_size, hidden_dim)

        face_enc_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=n_heads, batch_first=True
        )
        self.face_enc = nn.TransformerEncoder(face_enc_layer, num_layers=face_layers)

        pose_enc_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=n_heads, batch_first=True
        )
        self.pose_enc = nn.TransformerEncoder(pose_enc_layer, num_layers=pose_layers)

        total_tokens = 1 + face_chunks + pose_chunks
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.fusion_pos = nn.Parameter(torch.randn(1, total_tokens, hidden_dim))

        fusion_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=n_heads, batch_first=True
        )
        self.fusion_enc = nn.TransformerEncoder(fusion_layer, num_layers=fusion_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, num_classes),
            nn.Sigmoid(),
        )

    def forward(self, face, pose):
        B = face.size(0)

        f = face.view(B, self.face_chunks, self.f_chunk_size)        
        f = self.face_proj(f)                                        
        f = self.face_enc(f)                                         

        p = pose.view(B, self.pose_chunks, self.p_chunk_size)        
        p = self.pose_proj(p)                                        
        p = self.pose_enc(p)                                        

        cls = self.cls_token.expand(B, -1, -1)                       
        seq = torch.cat([cls, f, p], dim=1)                          
        seq = seq + self.fusion_pos                                 

        fused = self.fusion_enc(seq)                                 
        cls_out = fused[:, 0]                                        
        return self.mlp_head(cls_out)

In [11]:
json_path = "final_annotations.json"
batch_size = 32
epochs     = 20
save_path  = "best_vit_complex.pt"
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_ds = EmotionDataset(json_path, split="train")
val_ds   = EmotionDataset(json_path, split="val")
test_ds  = EmotionDataset(json_path, split="test")

train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_ld   = DataLoader(val_ds,   batch_size=batch_size)
test_ld  = DataLoader(test_ds,  batch_size=batch_size)

model     = ChunkedMultiStageViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()

In [12]:
train_model(model, train_ld, val_ld, optimizer, criterion, device, epochs, save_path)
test_model(model, test_ld, criterion, device, save_path)


Epoch 1/20


                                                             

Train Loss: 0.2790 | Acc: 0.5649 | Prec: 0.8094 | Rec: 0.6198 | F1: 0.7020
 Val  Loss: 0.2632 | Acc: 0.5638 | Prec: 0.8200 | Rec: 0.6131 | F1: 0.7016
→ Best model saved.

Epoch 2/20


                                                             

Train Loss: 0.2622 | Acc: 0.5625 | Prec: 0.8095 | Rec: 0.6268 | F1: 0.7065
 Val  Loss: 0.2577 | Acc: 0.5656 | Prec: 0.8059 | Rec: 0.6385 | F1: 0.7125
→ Best model saved.

Epoch 3/20


                                                              

Train Loss: 0.2574 | Acc: 0.5623 | Prec: 0.8101 | Rec: 0.6288 | F1: 0.7081
 Val  Loss: 0.2576 | Acc: 0.5624 | Prec: 0.7985 | Rec: 0.6460 | F1: 0.7142
→ Best model saved.

Epoch 4/20


                                                              

Train Loss: 0.2536 | Acc: 0.5607 | Prec: 0.8076 | Rec: 0.6334 | F1: 0.7099
 Val  Loss: 0.2565 | Acc: 0.5636 | Prec: 0.8245 | Rec: 0.6112 | F1: 0.7020
→ Best model saved.

Epoch 5/20


                                                              

Train Loss: 0.2496 | Acc: 0.5619 | Prec: 0.8075 | Rec: 0.6386 | F1: 0.7132
 Val  Loss: 0.2568 | Acc: 0.5624 | Prec: 0.7943 | Rec: 0.6477 | F1: 0.7136

Epoch 6/20


                                                              

Train Loss: 0.2454 | Acc: 0.5649 | Prec: 0.8079 | Rec: 0.6458 | F1: 0.7178
 Val  Loss: 0.2573 | Acc: 0.5567 | Prec: 0.7899 | Rec: 0.6428 | F1: 0.7088

Epoch 7/20


                                                             

Train Loss: 0.2401 | Acc: 0.5661 | Prec: 0.8087 | Rec: 0.6520 | F1: 0.7220
 Val  Loss: 0.2579 | Acc: 0.5558 | Prec: 0.7853 | Rec: 0.6523 | F1: 0.7126

Epoch 8/20


                                                             

Train Loss: 0.2342 | Acc: 0.5641 | Prec: 0.8061 | Rec: 0.6597 | F1: 0.7256
 Val  Loss: 0.2668 | Acc: 0.5498 | Prec: 0.7661 | Rec: 0.6465 | F1: 0.7012

Epoch 9/20


                                                              

Train Loss: 0.2267 | Acc: 0.5700 | Prec: 0.8101 | Rec: 0.6734 | F1: 0.7354
 Val  Loss: 0.2691 | Acc: 0.5438 | Prec: 0.7740 | Rec: 0.6460 | F1: 0.7043

Epoch 10/20


                                                             

Train Loss: 0.2194 | Acc: 0.5739 | Prec: 0.8133 | Rec: 0.6865 | F1: 0.7446
 Val  Loss: 0.2697 | Acc: 0.5444 | Prec: 0.7773 | Rec: 0.6419 | F1: 0.7032

Epoch 11/20


                                                              

Train Loss: 0.2100 | Acc: 0.5825 | Prec: 0.8195 | Rec: 0.7009 | F1: 0.7556
 Val  Loss: 0.2907 | Acc: 0.5213 | Prec: 0.7486 | Rec: 0.6551 | F1: 0.6987

Epoch 12/20


                                                             

Train Loss: 0.2005 | Acc: 0.5948 | Prec: 0.8277 | Rec: 0.7174 | F1: 0.7686
 Val  Loss: 0.2887 | Acc: 0.5104 | Prec: 0.7489 | Rec: 0.6432 | F1: 0.6920

Epoch 13/20


                                                             

Train Loss: 0.1911 | Acc: 0.6056 | Prec: 0.8340 | Rec: 0.7329 | F1: 0.7802
 Val  Loss: 0.3187 | Acc: 0.5181 | Prec: 0.7532 | Rec: 0.6262 | F1: 0.6839

Epoch 14/20


                                                             

Train Loss: 0.1823 | Acc: 0.6185 | Prec: 0.8428 | Rec: 0.7474 | F1: 0.7922
 Val  Loss: 0.2959 | Acc: 0.5104 | Prec: 0.7432 | Rec: 0.6503 | F1: 0.6937

Epoch 15/20


                                                             

Train Loss: 0.1715 | Acc: 0.6357 | Prec: 0.8488 | Rec: 0.7674 | F1: 0.8061
 Val  Loss: 0.3135 | Acc: 0.4999 | Prec: 0.7388 | Rec: 0.6157 | F1: 0.6717

Epoch 16/20


                                                             

Train Loss: 0.1622 | Acc: 0.6544 | Prec: 0.8597 | Rec: 0.7834 | F1: 0.8198
 Val  Loss: 0.3362 | Acc: 0.4590 | Prec: 0.6907 | Rec: 0.6290 | F1: 0.6584

Epoch 17/20


                                                             

Train Loss: 0.1526 | Acc: 0.6699 | Prec: 0.8689 | Rec: 0.7984 | F1: 0.8322
 Val  Loss: 0.3483 | Acc: 0.4910 | Prec: 0.7148 | Rec: 0.6480 | F1: 0.6798

Epoch 18/20


                                                             

Train Loss: 0.1446 | Acc: 0.6846 | Prec: 0.8747 | Rec: 0.8072 | F1: 0.8396
 Val  Loss: 0.3672 | Acc: 0.4496 | Prec: 0.6843 | Rec: 0.6611 | F1: 0.6725

Epoch 19/20


                                                             

Train Loss: 0.1346 | Acc: 0.7046 | Prec: 0.8843 | Rec: 0.8258 | F1: 0.8541
 Val  Loss: 0.3820 | Acc: 0.4353 | Prec: 0.6638 | Rec: 0.6432 | F1: 0.6533

Epoch 20/20


                                                             

Train Loss: 0.1259 | Acc: 0.7230 | Prec: 0.8931 | Rec: 0.8376 | F1: 0.8645
 Val  Loss: 0.3871 | Acc: 0.4819 | Prec: 0.7124 | Rec: 0.6419 | F1: 0.6753

Test Loss: 0.2500 | Acc: 0.8182 | Prec: 0.7178 | Rec: 0.8182 | F1: 0.7427


# Fine tuned model

In [13]:
class ChunkedMultiStageViT_fine(nn.Module):
    def __init__(
        self,
        face_dim: int = 512,
        face_chunks: int = 8,
        pose_dim: int = 34,
        pose_chunks: int = 2,
        hidden_dim: int = 256,
        num_classes: int = 7,
        n_heads: int = 4,
        face_layers: int = 2,
        pose_layers: int = 2,
        fusion_layers: int = 4,
        dropout: float = 0.2,           
    ):
        super().__init__()
        assert face_dim % face_chunks == 0, "face_dim must divide evenly by face_chunks"
        assert pose_dim % pose_chunks == 0, "pose_dim must divide evenly by pose_chunks"

        self.face_chunks = face_chunks
        self.pose_chunks = pose_chunks
        self.f_chunk_size = face_dim // face_chunks
        self.p_chunk_size = pose_dim // pose_chunks

        # projections + per-modality encoders
        self.face_proj = nn.Linear(self.f_chunk_size, hidden_dim)
        self.pose_proj = nn.Linear(self.p_chunk_size, hidden_dim)
        fe = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout, batch_first=True)
        pe = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout, batch_first=True)
        self.face_enc = nn.TransformerEncoder(fe, num_layers=face_layers)
        self.pose_enc = nn.TransformerEncoder(pe, num_layers=pose_layers)

        # fusion
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.fusion_pos = nn.Parameter(torch.randn(1, 1 + face_chunks + pose_chunks, hidden_dim))
        fusion_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout, batch_first=True)
        self.fusion_enc = nn.TransformerEncoder(fusion_layer, num_layers=fusion_layers)

        # a bit more dropout before head
        self.dropout = nn.Dropout(dropout)

        # MLP head (already had dropout, but we’ve harmonized rates)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes),
            nn.Sigmoid(),
        )

    def forward(self, face, pose):
        B = face.size(0)

        # chunk + encode face
        f = face.view(B, self.face_chunks, self.f_chunk_size)
        f = self.face_proj(f)
        f = self.dropout(f)
        f = self.face_enc(f)

        # chunk + encode pose
        p = pose.view(B, self.pose_chunks, self.p_chunk_size)
        p = self.pose_proj(p)
        p = self.dropout(p)
        p = self.pose_enc(p)

        # prep cls token + fuse
        cls = self.cls_token.expand(B, -1, -1)
        seq = torch.cat([cls, f, p], dim=1) + self.fusion_pos
        fused = self.fusion_enc(seq)

        # take cls, regularize, then head
        cls_out = fused[:, 0]
        cls_out = self.dropout(cls_out)
        return self.mlp_head(cls_out)


In [14]:
class EarlyStopping:
    """
    Stops training if validation loss doesn’t improve after a given patience.
    """
    def __init__(self, patience: int = 3, min_delta: float = 0.0, verbose: bool = False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss: float):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [15]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    epochs: int = 20,
    save_path: str = "best_vit_chunked_fine_tuned.pt",
    patience: int = 3
):
    best_val_loss = float('inf')
    early_stopper = EarlyStopping(patience=patience, verbose=True)

    for epoch in range(1, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        train_loss, train_metrics = train_one_epoch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_metrics = evaluate(
            model, val_loader, criterion, device, mode="Validation"
        )

        print(
            f"Train Loss: {train_loss:.4f} | "
            f"Acc: {train_metrics['accuracy']:.4f} | "
            f"Prec: {train_metrics['precision']:.4f} | "
            f"Rec: {train_metrics['recall']:.4f} | "
            f"F1: {train_metrics['f1']:.4f}"
        )
        print(
            f" Val  Loss: {val_loss:.4f} | "
            f"Acc: {val_metrics['accuracy']:.4f} | "
            f"Prec: {val_metrics['precision']:.4f} | "
            f"Rec: {val_metrics['recall']:.4f} | "
            f"F1: {val_metrics['f1']:.4f}"
        )

        # save best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("→ Best model saved.")

        # check early stopping
        early_stopper(val_loss)
        if early_stopper.early_stop:
            print("!! Early stopping triggered -- exiting training loop.")
            break

In [18]:
train_model(model, train_ld, val_ld, optimizer, criterion, device,
            epochs=20, save_path="best_vit_chunked_fine_tuned.pt", patience=4)
test_model(model, test_ld, criterion, device, save_path="best_vit_chunked_fine_tuned.pt")


Epoch 1/20


                                                              

Train Loss: 0.2222 | Acc: 0.5747 | Prec: 0.8137 | Rec: 0.6786 | F1: 0.7400
 Val  Loss: 0.2706 | Acc: 0.5501 | Prec: 0.7774 | Rec: 0.6594 | F1: 0.7135
→ Best model saved.

Epoch 2/20


                                                              

Train Loss: 0.2137 | Acc: 0.5817 | Prec: 0.8196 | Rec: 0.6941 | F1: 0.7517
 Val  Loss: 0.2779 | Acc: 0.5219 | Prec: 0.7488 | Rec: 0.6346 | F1: 0.6870
EarlyStopping counter: 1 out of 4

Epoch 3/20


                                                             

Train Loss: 0.2062 | Acc: 0.5900 | Prec: 0.8252 | Rec: 0.7042 | F1: 0.7599
 Val  Loss: 0.2923 | Acc: 0.5021 | Prec: 0.7320 | Rec: 0.6351 | F1: 0.6801
EarlyStopping counter: 2 out of 4

Epoch 4/20


                                                             

Train Loss: 0.1977 | Acc: 0.5989 | Prec: 0.8313 | Rec: 0.7161 | F1: 0.7694
 Val  Loss: 0.2896 | Acc: 0.5281 | Prec: 0.7623 | Rec: 0.6254 | F1: 0.6871
EarlyStopping counter: 3 out of 4

Epoch 5/20


                                                             

Train Loss: 0.1884 | Acc: 0.6095 | Prec: 0.8376 | Rec: 0.7336 | F1: 0.7822
 Val  Loss: 0.2885 | Acc: 0.5184 | Prec: 0.7398 | Rec: 0.6617 | F1: 0.6986
EarlyStopping counter: 4 out of 4
!! Early stopping triggered -- exiting training loop.

Test Loss: 0.2603 | Acc: 0.8111 | Prec: 0.7212 | Rec: 0.8111 | F1: 0.7489
