In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [2]:
class EmotionDataset(Dataset):
    def __init__(self, json_path, split='train'):
        with open(json_path, 'r') as f:
            data = json.load(f)

        self.samples = []
        for item in data[split]:
            face = item.get('face_embedding')
            pose = item.get('pose_embedding')
            label = item.get('multi_hot')

            if face is None or pose is None or label is None:
                continue
            if any(x is None for x in face) or any(x is None for x in pose) or any(x is None for x in label):
                continue

            self.samples.append((face, pose, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        face, pose, label = self.samples[idx]
        return (
            torch.tensor(face, dtype=torch.float32),
            torch.tensor(pose, dtype=torch.float32),
            torch.tensor(label, dtype=torch.float32)
        )

In [3]:
class ChunkedViT(nn.Module):
    def __init__(
        self,
        face_dim=512,
        face_chunks=8,
        pose_dim=34,
        pose_chunks=2,
        hidden_dim=256,
        num_classes=7,
        n_heads=4,
        n_layers=4
    ):
        super().__init__()
        assert face_dim % face_chunks == 0
        assert pose_dim % pose_chunks == 0

        self.f_chunk_size = face_dim // face_chunks
        self.p_chunk_size = pose_dim // pose_chunks
        self.face_chunks = face_chunks
        self.pose_chunks = pose_chunks

        self.face_proj = nn.Linear(self.f_chunk_size, hidden_dim)
        self.pose_proj = nn.Linear(self.p_chunk_size, hidden_dim)

        total_tokens = 1 + face_chunks + pose_chunks
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, total_tokens, hidden_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_heads,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, num_classes),
            nn.Sigmoid()
        )

    def forward(self, face, pose):
        B = face.size(0)

        face_tokens = face.view(B, self.face_chunks, self.f_chunk_size)
        face_tokens = self.face_proj(face_tokens)

        pose_tokens = pose.view(B, self.pose_chunks, self.p_chunk_size)
        pose_tokens = self.pose_proj(pose_tokens)

        cls_tokens = self.cls_token.expand(B, -1, -1)

        x = torch.cat([cls_tokens, face_tokens, pose_tokens], dim=1)
        x = x + self.pos_embedding

        x = self.transformer(x)
        cls_out = x[:, 0]
        return self.mlp_head(cls_out)

In [4]:
def compute_metrics(y_true, y_pred, threshold=0.5):
    y_pred_bin = (y_pred > threshold).astype(int)

    accuracy = accuracy_score(y_true, y_pred_bin)
    precision = precision_score(y_true, y_pred_bin, average='micro', zero_division=0)
    recall = recall_score(y_true, y_pred_bin, average='micro', zero_division=0)
    f1 = f1_score(y_true, y_pred_bin, average='micro', zero_division=0)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    all_labels, all_preds = [], []

    for face, pose, label in tqdm(dataloader, desc="Training", leave=False):
        face, pose, label = face.to(device), pose.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(face, pose)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        all_labels.append(label.cpu().numpy())
        all_preds.append(output.detach().cpu().numpy())

    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    metrics = compute_metrics(y_true, y_pred)
    return total_loss / len(dataloader), metrics

def evaluate(model, dataloader, criterion, device, mode="Validation"):
    model.eval()
    total_loss = 0.0
    all_labels, all_preds = [], []

    with torch.no_grad():
        for face, pose, label in tqdm(dataloader, desc=mode, leave=False):
            face, pose, label = face.to(device), pose.to(device), label.to(device)
            output = model(face, pose)
            loss = criterion(output, label)

            total_loss += loss.item()
            all_labels.append(label.cpu().numpy())
            all_preds.append(output.cpu().numpy())

    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    metrics = compute_metrics(y_true, y_pred)
    return total_loss / len(dataloader), metrics

def train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs=20, save_path="best_vit.pt"):
    best_val_loss = float('inf')
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f} | Acc: {train_metrics['accuracy']:.4f} | Prec: {train_metrics['precision']:.4f} | Rec: {train_metrics['recall']:.4f} | F1: {train_metrics['f1']:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Acc: {val_metrics['accuracy']:.4f} | Prec: {val_metrics['precision']:.4f} | Rec: {val_metrics['recall']:.4f} | F1: {val_metrics['f1']:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print("Best model saved.")

def test_model(model, test_loader, criterion, device, save_path="best_vit.pt"):
    model.load_state_dict(torch.load(save_path))
    test_loss, test_metrics = evaluate(model, test_loader, criterion, device, mode="Test")
    print(f"\nTest Loss: {test_loss:.4f} | Acc: {test_metrics['accuracy']:.4f} | Prec: {test_metrics['precision']:.4f} | Rec: {test_metrics['recall']:.4f} | F1: {test_metrics['f1']:.4f}")

In [5]:
json_path = "final_annotations.json"
batch_size = 32
epochs = 20
save_path = "best_chunked_vit.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_dataset = EmotionDataset(json_path, split="train")
val_dataset = EmotionDataset(json_path, split="val")
test_dataset = EmotionDataset(json_path, split="test")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


model = ChunkedViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()


train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs=epochs, save_path=save_path)
test_model(model, test_loader, criterion, device, save_path=save_path)


Epoch 1/20


                                                             

Train Loss: 0.2741 | Acc: 0.5670 | Prec: 0.8105 | Rec: 0.6168 | F1: 0.7005
Val   Loss: 0.2691 | Acc: 0.5638 | Prec: 0.8147 | Rec: 0.6138 | F1: 0.7001
Best model saved.

Epoch 2/20


                                                             

Train Loss: 0.2698 | Acc: 0.5682 | Prec: 0.8153 | Rec: 0.6171 | F1: 0.7025
Val   Loss: 0.2697 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000

Epoch 3/20


                                                             

Train Loss: 0.2686 | Acc: 0.5668 | Prec: 0.8136 | Rec: 0.6181 | F1: 0.7025
Val   Loss: 0.2683 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000
Best model saved.

Epoch 4/20


                                                             

Train Loss: 0.2685 | Acc: 0.5679 | Prec: 0.8148 | Rec: 0.6174 | F1: 0.7025
Val   Loss: 0.2706 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000

Epoch 5/20


                                                             

Train Loss: 0.2678 | Acc: 0.5673 | Prec: 0.8142 | Rec: 0.6178 | F1: 0.7025
Val   Loss: 0.2708 | Acc: 0.5638 | Prec: 0.8149 | Rec: 0.6135 | F1: 0.7000

Epoch 6/20


                                                             

Train Loss: 0.2634 | Acc: 0.5647 | Prec: 0.8119 | Rec: 0.6217 | F1: 0.7042
Val   Loss: 0.2609 | Acc: 0.5707 | Prec: 0.8103 | Rec: 0.6340 | F1: 0.7114
Best model saved.

Epoch 7/20


                                                             

Train Loss: 0.2604 | Acc: 0.5646 | Prec: 0.8087 | Rec: 0.6300 | F1: 0.7082
Val   Loss: 0.2583 | Acc: 0.5616 | Prec: 0.8034 | Rec: 0.6389 | F1: 0.7118
Best model saved.

Epoch 8/20


                                                             

Train Loss: 0.2591 | Acc: 0.5652 | Prec: 0.8116 | Rec: 0.6303 | F1: 0.7096
Val   Loss: 0.2567 | Acc: 0.5633 | Prec: 0.8229 | Rec: 0.6127 | F1: 0.7024
Best model saved.

Epoch 9/20


                                                             

Train Loss: 0.2581 | Acc: 0.5647 | Prec: 0.8109 | Rec: 0.6285 | F1: 0.7082
Val   Loss: 0.2571 | Acc: 0.5490 | Prec: 0.7821 | Rec: 0.6576 | F1: 0.7145

Epoch 10/20


                                                             

Train Loss: 0.2575 | Acc: 0.5629 | Prec: 0.8086 | Rec: 0.6304 | F1: 0.7085
Val   Loss: 0.2573 | Acc: 0.5638 | Prec: 0.8219 | Rec: 0.6144 | F1: 0.7032

Epoch 11/20


                                                             

Train Loss: 0.2558 | Acc: 0.5641 | Prec: 0.8089 | Rec: 0.6329 | F1: 0.7102
Val   Loss: 0.2562 | Acc: 0.5618 | Prec: 0.8262 | Rec: 0.6095 | F1: 0.7015
Best model saved.

Epoch 12/20


                                                             

Train Loss: 0.2545 | Acc: 0.5668 | Prec: 0.8105 | Rec: 0.6337 | F1: 0.7113
Val   Loss: 0.2545 | Acc: 0.5670 | Prec: 0.8194 | Rec: 0.6237 | F1: 0.7083
Best model saved.

Epoch 13/20


                                                             

Train Loss: 0.2539 | Acc: 0.5667 | Prec: 0.8073 | Rec: 0.6393 | F1: 0.7135
Val   Loss: 0.2532 | Acc: 0.5681 | Prec: 0.8208 | Rec: 0.6206 | F1: 0.7068
Best model saved.

Epoch 14/20


                                                             

Train Loss: 0.2530 | Acc: 0.5664 | Prec: 0.8095 | Rec: 0.6357 | F1: 0.7122
Val   Loss: 0.2551 | Acc: 0.5684 | Prec: 0.8037 | Rec: 0.6437 | F1: 0.7148

Epoch 15/20


                                                             

Train Loss: 0.2526 | Acc: 0.5667 | Prec: 0.8059 | Rec: 0.6403 | F1: 0.7136
Val   Loss: 0.2520 | Acc: 0.5733 | Prec: 0.8143 | Rec: 0.6404 | F1: 0.7170
Best model saved.

Epoch 16/20


                                                             

Train Loss: 0.2513 | Acc: 0.5691 | Prec: 0.8091 | Rec: 0.6416 | F1: 0.7157
Val   Loss: 0.2553 | Acc: 0.5547 | Prec: 0.7870 | Rec: 0.6581 | F1: 0.7168

Epoch 17/20


                                                             

Train Loss: 0.2503 | Acc: 0.5672 | Prec: 0.8068 | Rec: 0.6432 | F1: 0.7157
Val   Loss: 0.2521 | Acc: 0.5707 | Prec: 0.8095 | Rec: 0.6434 | F1: 0.7170

Epoch 18/20


                                                             

Train Loss: 0.2488 | Acc: 0.5680 | Prec: 0.8079 | Rec: 0.6439 | F1: 0.7166
Val   Loss: 0.2532 | Acc: 0.5738 | Prec: 0.8098 | Rec: 0.6402 | F1: 0.7151

Epoch 19/20


                                                             

Train Loss: 0.2520 | Acc: 0.5670 | Prec: 0.8079 | Rec: 0.6380 | F1: 0.7130
Val   Loss: 0.2549 | Acc: 0.5621 | Prec: 0.8234 | Rec: 0.6155 | F1: 0.7044

Epoch 20/20


                                                             

Train Loss: 0.2478 | Acc: 0.5673 | Prec: 0.8091 | Rec: 0.6417 | F1: 0.7157
Val   Loss: 0.2560 | Acc: 0.5667 | Prec: 0.8051 | Rec: 0.6512 | F1: 0.7200


                                                       


Test Loss: 0.2474 | Acc: 0.5663 | Prec: 0.8198 | Rec: 0.6430 | F1: 0.7207
