In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
json_path    = "final_annotations.json"
batch_size   = 32
epochs       = 20
lr           = 1e-4
weight_decay = 1e-2
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class EmotionDataset(Dataset):
    def __init__(self, json_path, split='train'):
        with open(json_path, 'r') as f:
            data = json.load(f)
        self.samples = []
        for item in data[split]:
            face = item.get('face_embedding')
            pose = item.get('pose_embedding')
            label = item.get('multi_hot')
            if face is None or pose is None or label is None:
                continue
            if any(x is None for x in face) or any(x is None for x in pose) or any(x is None for x in label):
                continue
            self.samples.append((face, pose, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        face, pose, label = self.samples[idx]
        return (
            torch.tensor(face, dtype=torch.float32),
            torch.tensor(pose, dtype=torch.float32),
            torch.tensor(label, dtype=torch.float32)
        )

In [4]:
class FaceViT(nn.Module):
    def __init__(self, face_dim=512, hidden_dim=256, num_classes=7, n_heads=4, n_layers=4):
        super().__init__()
        self.face_proj = nn.Linear(face_dim, hidden_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 2, hidden_dim))  # [CLS] + face
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, num_classes),
            nn.Sigmoid()
        )

    def forward(self, face, pose):
        face_tok = self.face_proj(face).unsqueeze(1)
        cls_tok  = self.cls_token.expand(face.size(0), -1, -1)
        x = torch.cat([cls_tok, face_tok], dim=1) + self.pos_embedding
        x = self.transformer(x)
        return self.mlp_head(x[:, 0])

In [5]:
class PoseViT(nn.Module):
    def __init__(self, pose_dim=34, hidden_dim=256, num_classes=7, n_heads=4, n_layers=4):
        super().__init__()
        self.pose_proj = nn.Linear(pose_dim, hidden_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 2, hidden_dim))  # [CLS] + pose
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, num_classes),
            nn.Sigmoid()
        )

    def forward(self, face, pose):
        # ignore `face`
        pose_tok = self.pose_proj(pose).unsqueeze(1)
        cls_tok  = self.cls_token.expand(pose.size(0), -1, -1)
        x = torch.cat([cls_tok, pose_tok], dim=1) + self.pos_embedding
        x = self.transformer(x)
        return self.mlp_head(x[:, 0])

In [6]:
def compute_metrics(y_true, y_pred, threshold=0.5):
    y_pred_bin = (y_pred > threshold).astype(int)
    return {
        "accuracy":  accuracy_score(y_true, y_pred_bin),
        "precision": precision_score(y_true, y_pred_bin, average='micro', zero_division=0),
        "recall":    recall_score(y_true, y_pred_bin, average='micro', zero_division=0),
        "f1":        f1_score(y_true, y_pred_bin, average='micro', zero_division=0)
    }

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, all_labels, all_preds = 0.0, [], []
    for face, pose, label in tqdm(loader, desc="Training", leave=False):
        face, pose, label = face.to(device), pose.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(face, pose)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        all_labels.append(label.cpu().numpy())
        all_preds.append(output.detach().cpu().numpy())
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    return total_loss / len(loader), compute_metrics(y_true, y_pred)

def evaluate(model, loader, criterion, device, mode="Validation"):
    model.eval()
    total_loss, all_labels, all_preds = 0.0, [], []
    with torch.no_grad():
        for face, pose, label in tqdm(loader, desc=mode, leave=False):
            face, pose, label = face.to(device), pose.to(device), label.to(device)
            output = model(face, pose)
            loss = criterion(output, label)
            total_loss += loss.item()
            all_labels.append(label.cpu().numpy())
            all_preds.append(output.cpu().numpy())
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    return total_loss / len(loader), compute_metrics(y_true, y_pred)

def train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs, save_path):
    best_val_loss = float('inf')
    for epoch in range(1, epochs+1):
        print(f"\nEpoch {epoch}/{epochs}")
        tr_loss, tr_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device)
        print(f"Train Loss: {tr_loss:.4f} | Acc: {tr_metrics['accuracy']:.4f} | "
              f"Prec: {tr_metrics['precision']:.4f} | Rec: {tr_metrics['recall']:.4f} | F1: {tr_metrics['f1']:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Acc: {val_metrics['accuracy']:.4f} | "
              f"Prec: {val_metrics['precision']:.4f} | Rec: {val_metrics['recall']:.4f} | F1: {val_metrics['f1']:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("→ Best model saved.")

def test_model(model, test_loader, criterion, device, save_path):
    model.load_state_dict(torch.load(save_path))
    test_loss, test_metrics = evaluate(model, test_loader, criterion, device, mode="Test")
    print(f"\nTest Loss: {test_loss:.4f} | Acc: {test_metrics['accuracy']:.4f} | "
          f"Prec: {test_metrics['precision']:.4f} | Rec: {test_metrics['recall']:.4f} | F1: {test_metrics['f1']:.4f}")

In [7]:
train_dataset = EmotionDataset(json_path, split="train")
val_dataset   = EmotionDataset(json_path, split="val")
test_dataset  = EmotionDataset(json_path, split="test")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

In [8]:
model_face = FaceViT().to(device)
optimizer   = optim.AdamW(model_face.parameters(), lr=lr, weight_decay=weight_decay)
criterion   = nn.BCELoss()
save_path_f = "best_face_vit.pt"

train_model(model_face, train_loader, val_loader, optimizer, criterion, device, epochs, save_path_f)
test_model(model_face, test_loader, criterion, device, save_path_f)


Epoch 1/20


                                                             

Train Loss: 0.2695 | Acc: 0.5670 | Prec: 0.8080 | Rec: 0.6234 | F1: 0.7038
Val   Loss: 0.2628 | Acc: 0.5616 | Prec: 0.7965 | Rec: 0.6441 | F1: 0.7122
→ Best model saved.

Epoch 2/20


                                                             

Train Loss: 0.2570 | Acc: 0.5637 | Prec: 0.8081 | Rec: 0.6321 | F1: 0.7093
Val   Loss: 0.2532 | Acc: 0.5644 | Prec: 0.8067 | Rec: 0.6346 | F1: 0.7104
→ Best model saved.

Epoch 3/20


                                                             

Train Loss: 0.2538 | Acc: 0.5650 | Prec: 0.8088 | Rec: 0.6328 | F1: 0.7101
Val   Loss: 0.2556 | Acc: 0.5667 | Prec: 0.8091 | Rec: 0.6391 | F1: 0.7142

Epoch 4/20


                                                             

Train Loss: 0.2512 | Acc: 0.5638 | Prec: 0.8066 | Rec: 0.6361 | F1: 0.7113
Val   Loss: 0.2527 | Acc: 0.5693 | Prec: 0.8200 | Rec: 0.6241 | F1: 0.7088
→ Best model saved.

Epoch 5/20


                                                             

Train Loss: 0.2495 | Acc: 0.5679 | Prec: 0.8085 | Rec: 0.6399 | F1: 0.7144
Val   Loss: 0.2547 | Acc: 0.5687 | Prec: 0.8154 | Rec: 0.6280 | F1: 0.7095

Epoch 6/20


                                                             

Train Loss: 0.2484 | Acc: 0.5665 | Prec: 0.8081 | Rec: 0.6412 | F1: 0.7150
Val   Loss: 0.2537 | Acc: 0.5513 | Prec: 0.7828 | Rec: 0.6667 | F1: 0.7201

Epoch 7/20


                                                             

Train Loss: 0.2468 | Acc: 0.5631 | Prec: 0.8047 | Rec: 0.6429 | F1: 0.7148
Val   Loss: 0.2520 | Acc: 0.5641 | Prec: 0.8111 | Rec: 0.6297 | F1: 0.7090
→ Best model saved.

Epoch 8/20


                                                             

Train Loss: 0.2458 | Acc: 0.5678 | Prec: 0.8082 | Rec: 0.6442 | F1: 0.7169
Val   Loss: 0.2523 | Acc: 0.5564 | Prec: 0.7859 | Rec: 0.6624 | F1: 0.7189

Epoch 9/20


                                                             

Train Loss: 0.2444 | Acc: 0.5663 | Prec: 0.8056 | Rec: 0.6465 | F1: 0.7173
Val   Loss: 0.2542 | Acc: 0.5644 | Prec: 0.8202 | Rec: 0.6200 | F1: 0.7062

Epoch 10/20


                                                             

Train Loss: 0.2433 | Acc: 0.5662 | Prec: 0.8065 | Rec: 0.6486 | F1: 0.7190
Val   Loss: 0.2520 | Acc: 0.5656 | Prec: 0.8254 | Rec: 0.6202 | F1: 0.7083

Epoch 11/20


                                                             

Train Loss: 0.2419 | Acc: 0.5682 | Prec: 0.8068 | Rec: 0.6510 | F1: 0.7206
Val   Loss: 0.2593 | Acc: 0.5510 | Prec: 0.8012 | Rec: 0.6155 | F1: 0.6962

Epoch 12/20


                                                             

Train Loss: 0.2409 | Acc: 0.5676 | Prec: 0.8056 | Rec: 0.6527 | F1: 0.7211
Val   Loss: 0.2531 | Acc: 0.5653 | Prec: 0.8249 | Rec: 0.6189 | F1: 0.7072

Epoch 13/20


                                                             

Train Loss: 0.2392 | Acc: 0.5672 | Prec: 0.8060 | Rec: 0.6535 | F1: 0.7218
Val   Loss: 0.2564 | Acc: 0.5544 | Prec: 0.7949 | Rec: 0.6417 | F1: 0.7101

Epoch 14/20


                                                             

Train Loss: 0.2370 | Acc: 0.5687 | Prec: 0.8076 | Rec: 0.6563 | F1: 0.7241
Val   Loss: 0.2595 | Acc: 0.5564 | Prec: 0.7923 | Rec: 0.6505 | F1: 0.7145

Epoch 15/20


                                                             

Train Loss: 0.2349 | Acc: 0.5704 | Prec: 0.8068 | Rec: 0.6609 | F1: 0.7266
Val   Loss: 0.2590 | Acc: 0.5578 | Prec: 0.7912 | Rec: 0.6460 | F1: 0.7113

Epoch 16/20


                                                             

Train Loss: 0.2326 | Acc: 0.5686 | Prec: 0.8073 | Rec: 0.6635 | F1: 0.7284
Val   Loss: 0.2610 | Acc: 0.5538 | Prec: 0.7858 | Rec: 0.6540 | F1: 0.7138

Epoch 17/20


                                                             

Train Loss: 0.2292 | Acc: 0.5677 | Prec: 0.8045 | Rec: 0.6706 | F1: 0.7315
Val   Loss: 0.2611 | Acc: 0.5581 | Prec: 0.7956 | Rec: 0.6372 | F1: 0.7077

Epoch 18/20


                                                             

Train Loss: 0.2256 | Acc: 0.5725 | Prec: 0.8082 | Rec: 0.6742 | F1: 0.7351
Val   Loss: 0.2674 | Acc: 0.5556 | Prec: 0.8028 | Rec: 0.6241 | F1: 0.7022

Epoch 19/20


                                                             

Train Loss: 0.2217 | Acc: 0.5727 | Prec: 0.8073 | Rec: 0.6812 | F1: 0.7389
Val   Loss: 0.2694 | Acc: 0.5338 | Prec: 0.7667 | Rec: 0.6630 | F1: 0.7111

Epoch 20/20


                                                             

Train Loss: 0.2157 | Acc: 0.5745 | Prec: 0.8097 | Rec: 0.6918 | F1: 0.7461
Val   Loss: 0.2707 | Acc: 0.5470 | Prec: 0.7858 | Rec: 0.6452 | F1: 0.7085


                                                       


Test Loss: 0.2470 | Acc: 0.5584 | Prec: 0.8134 | Rec: 0.6375 | F1: 0.7148




In [None]:
model_pose = PoseViT().to(device)
optimizer   = optim.AdamW(model_pose.parameters(), lr=lr, weight_decay=weight_decay)
criterion   = nn.BCELoss()
save_path_p = "best_pose_vit.pt"

train_model(model_pose, train_loader, val_loader, optimizer, criterion, device, epochs, save_path_p)
test_model(model_pose, test_loader, criterion, device, save_path_p)


Epoch 1/20


                                                             

Train Loss: 0.2733 | Acc: 0.5672 | Prec: 0.8116 | Rec: 0.6166 | F1: 0.7008
Val   Loss: 0.2677 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000
→ Best model saved.

Epoch 2/20


Training:  56%|█████▋    | 294/521 [00:02<00:01, 126.15it/s]