In [1]:
import sys
sys.path.append("/media/thatblueboy/Seagate/LOP")
from datasets.SAMM.in_mem_graph_dataset import SAMMAUGraphDataset
from datasets.SAMMLong.in_mem_graph_dataset import SAMMLongAUGraphDataset

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import WeightedRandomSampler
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
import numpy as np
from torch_geometric.nn import GCNConv
from train.utils import create_4_subject_independent_folds, KFold, train_one_epoch, evaluate
import json

In [2]:
aus = [12]

with open("/media/thatblueboy/Seagate/LOP/configs/sammlong.json", "r") as file:
   config = json.load(file)
pretrain_config = config["lips"]

pretrain_dataset = SAMMLongAUGraphDataset(frames_path = pretrain_config["frames_path"],
                    labels_csv_path = "/media/thatblueboy/Seagate/LOP/data/SAMMLong/macro_samm_long_cleaned.csv",
                    aus=aus,
                    central_landmark = pretrain_config["central_landmark"],
                    landmarks = pretrain_config["ROI_index"],
                    num_timesteps = pretrain_config["num_timesteps"],
                    edges = pretrain_config["edges"],
                    include_flipped = True,
                    self_loop_in_edge_index = True, #has to be true to allow learnable self loop weights
                    normalizing_factor = 200,
                    processed_data_path="/media/thatblueboy/Seagate/LOP/pickle/samm_long_au12_with_flips.pkl")


Data successfully loaded from /media/thatblueboy/Seagate/LOP/pickle/samm_long_au12_with_flips.pkl


In [3]:
aus = [12]

with open("/media/thatblueboy/Seagate/LOP/configs/sammlong.json", "r") as file:
   config = json.load(file)
pretrain_config = config["lips"]

pretrain_dataset_no_flips = SAMMLongAUGraphDataset(frames_path = pretrain_config["frames_path"],
                    labels_csv_path = "/media/thatblueboy/Seagate/LOP/data/SAMMLong/macro_samm_long_cleaned.csv",
                    aus=aus,
                    central_landmark = pretrain_config["central_landmark"],
                    landmarks = pretrain_config["ROI_index"],
                    num_timesteps = pretrain_config["num_timesteps"],
                    edges = pretrain_config["edges"],
                    include_flipped = False,
                    self_loop_in_edge_index = True, #has to be true to allow learnable self loop weights
                    normalizing_factor = 200,
                    processed_data_path="/media/thatblueboy/Seagate/LOP/pickle/samm_long_au12.pkl")

Data successfully loaded from /media/thatblueboy/Seagate/LOP/pickle/samm_long_au12.pkl


In [4]:
with open("/media/thatblueboy/Seagate/LOP/configs/samm.json", "r") as file:
    config = json.load(file)
train_config = config["lips"]

trainset = SAMMAUGraphDataset(frames_path = train_config["frames_path"],
                    labels_csv_path = train_config["labels_path"],
                    aus=aus,
                    central_landmark = train_config["central_landmark"],
                    landmarks = train_config["ROI_index"],
                    num_timesteps = train_config["num_timesteps"],
                    edges = train_config["edges"],
                    noise=False,
                    include_flipped = True, 
                    self_loop_in_edge_index = True, #has to be true to allow learnable self loop weights
                    normalizing_factor = 200,
                    processed_data_path="/media/thatblueboy/Seagate/LOP/pickle/samm_au12_with_flips.pkl")

Data successfully loaded from /media/thatblueboy/Seagate/LOP/pickle/samm_au12_with_flips.pkl


In [5]:
with open("/media/thatblueboy/Seagate/LOP/configs/samm.json", "r") as file:
    config = json.load(file)
train_config = config["lips"]

testset = SAMMAUGraphDataset(frames_path = train_config["frames_path"],
                    labels_csv_path = train_config["labels_path"],
                    aus=aus,
                    central_landmark = train_config["central_landmark"],
                    landmarks = train_config["ROI_index"],
                    num_timesteps = train_config["num_timesteps"],
                    edges = train_config["edges"],
                    include_flipped = False,
                    self_loop_in_edge_index = True, #shape of reqd. edge weight should match
                    normalizing_factor = 200,
                    processed_data_path="/media/thatblueboy/Seagate/LOP/pickle/samm_au12.pkl")

Data successfully loaded from /media/thatblueboy/Seagate/LOP/pickle/samm_au12.pkl


Reset label to just AU12. (Required if we would have made dataset with multiple labels)

In [6]:
pretrain_dataset.reset_label([12])
testset.reset_label([12])
trainset.reset_label([12])

In [7]:
folds = [[6, 12, 11, 23, 28, 34, 31, 36],
[7, 20, 24, 15, 30, 35, 21],
[10, 18, 13, 14, 9, 17, 25, 37],
[22, 16, 26, 32, 19, 33]]

In [8]:
def get_subject_stats(subjects, labels):
    """
    Returns num of positive and negative samples per subject
    """
    unique_subjects = np.unique(subjects)
    subject_stats = {}

    for subject in unique_subjects:
        mask = (subjects == subject)
        num_zeros = np.sum(labels[mask] == 0)
        num_ones = np.sum(labels[mask] == 1)
        total_samples = num_zeros + num_ones
        subject_stats[subject] = {'0s': num_zeros, '1s': num_ones, 'total': total_samples}
        
    return subject_stats

In [9]:
subject_stats = get_subject_stats(testset.subjects, testset.all_au_labels[:,0].numpy())

In [10]:
for i, fold in enumerate(folds):
        fold_stats = {'0s': 0, '1s': 0, 'total': 0}
        for subject in fold:
                fold_stats['0s'] += subject_stats[subject]['0s']
                fold_stats['1s'] += subject_stats[subject]['1s']
                fold_stats['total'] += subject_stats[subject]['total']
        
        print(f"Fold {i+1}: Subjects {fold}, Samples: {fold_stats['total']}, Label Counts: {fold_stats}")

Fold 1: Subjects [6, 12, 11, 23, 28, 34, 31, 36], Samples: 42, Label Counts: {'0s': 33, '1s': 9, 'total': 42}
Fold 2: Subjects [7, 20, 24, 15, 30, 35, 21], Samples: 36, Label Counts: {'0s': 30, '1s': 6, 'total': 36}
Fold 3: Subjects [10, 18, 13, 14, 9, 17, 25, 37], Samples: 40, Label Counts: {'0s': 32, '1s': 8, 'total': 40}
Fold 4: Subjects [22, 16, 26, 32, 19, 33], Samples: 36, Label Counts: {'0s': 30, '1s': 6, 'total': 36}


In [11]:
subjects = testset.subjects
labels = testset.all_au_labels.numpy()


unique_subjects = np.unique(subjects)
for subject in unique_subjects:
    mask = (subjects == subject)
    num_zeros = np.sum(labels[mask] == 0)
    num_ones = np.sum(labels[mask] == 1)
    total_samples = num_zeros + num_ones
    print(subject, '0s:', num_zeros, '     1s:', num_ones, '    total:', total_samples)
    

6 0s: 11      1s: 0     total: 11
7 0s: 5      1s: 5     total: 10
9 0s: 4      1s: 0     total: 4
10 0s: 2      1s: 0     total: 2
11 0s: 11      1s: 9     total: 20
12 0s: 3      1s: 0     total: 3
13 0s: 6      1s: 0     total: 6
14 0s: 3      1s: 8     total: 11
15 0s: 3      1s: 0     total: 3
16 0s: 4      1s: 1     total: 5
17 0s: 6      1s: 0     total: 6
18 0s: 4      1s: 0     total: 4
19 0s: 2      1s: 1     total: 3
20 0s: 7      1s: 1     total: 8
21 0s: 2      1s: 0     total: 2
22 0s: 3      1s: 2     total: 5
23 0s: 1      1s: 0     total: 1
24 0s: 1      1s: 0     total: 1
25 0s: 4      1s: 0     total: 4
26 0s: 13      1s: 0     total: 13
28 0s: 2      1s: 0     total: 2
30 0s: 3      1s: 0     total: 3
31 0s: 1      1s: 0     total: 1
32 0s: 5      1s: 0     total: 5
33 0s: 3      1s: 2     total: 5
34 0s: 3      1s: 0     total: 3
35 0s: 9      1s: 0     total: 9
36 0s: 1      1s: 0     total: 1
37 0s: 3      1s: 0     total: 3


Train: Fold 1, 2, 3
Test: Fold 4

Define Model

In [26]:
class STConv(nn.Module):
    '''
    Spatio Temporal GCN
    '''
    def __init__(self, in_channels, out_channels, num_edges, kernel_size=3, stride=1, residual=True):
        super(STConv, self).__init__()
        self.gcn = GCNConv(in_channels, out_channels)
        self.gcn_bn = nn.BatchNorm2d(out_channels)
        self.tcn = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
                              padding=((kernel_size - 1) // 2, 0), stride=(stride, 1))
        self.tcn_bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        # self.padding = (kernel_size - 1) //2
        self.edge_weight = nn.Parameter(torch.ones(num_edges))
        self.dropout = nn.Dropout(0.2)

        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(stride, 1)),
                nn.BatchNorm2d(out_channels) # per feature
            )

    def forward(self, x, edge_index):
        batch_size, num_timesteps, num_nodes, features = x.shape
        #residual
        res = self.residual(x.permute(0, 3, 2, 1)).permute(0, 3, 2, 1)

        x_out = torch.zeros(batch_size, num_timesteps, num_nodes, self.gcn.out_channels, device=x.device)
        #spatial conv per batch per timestep
        for b in range(batch_size):        
            for t in range(num_timesteps):
                xt = x[b][t]
                xt = self.gcn(xt, edge_index, self.edge_weight)
                x_out[b][t] = xt

        x_out = x_out.permute(0, 3, 2, 1).contiguous()  # [batch, features, num_nodes, num_timesteps]
        x_out = self.gcn_bn(x_out) #per feature/channel
        x_out = self.relu(x_out)

        #temporal conv per batch
        x_out = self.tcn(x_out)
        x_out = self.tcn_bn(x_out)
        x_out = x_out.permute(0, 3, 2, 1).contiguous()

        # x_out = self.dropout(x_out)
        
        return self.relu(x_out +res)

class STGCN(nn.Module):
    def __init__(self, in_channels, num_nodes, num_edges, num_timesteps):
        super(STGCN, self).__init__()

        # Spatio-temporal convolution layers using STConv2
        self.st_conv1 = STConv(in_channels, out_channels=4, num_edges=num_edges, kernel_size=3)
        self.st_conv2 = STConv(4, out_channels=8, num_edges=num_edges, kernel_size=3)
        self.st_conv3 = STConv(8, out_channels=16, num_edges=num_edges, kernel_size=3)
        self.st_conv4 = STConv(16, out_channels=32, num_edges=num_edges, kernel_size=5)
        self.st_conv5 = STConv(32, out_channels=64, num_edges=num_edges, kernel_size=5)
        # self.st_conv6 = STConv(64, out_channels=128, num_edges=num_edges, kernel_size=5)
        # self.st_conv7 = STConv(128, out_channels=256, num_edges=num_edges, kernel_size=5)

        self.conv = nn.Conv2d(in_channels=num_timesteps, out_channels=1, kernel_size=(num_nodes, 1))

        # Fully connected classifier
        self.classifier = nn.Sequential(
            nn.Linear(64, 128),  
            nn.ReLU(),  
            nn.Linear(128, 64),  
            nn.ReLU(),
            nn.Linear(64, 16),  
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()  # Binary classification
        )

    def forward(self, x, edge_index):
        batch_size, num_timesteps, num_nodes, _ = x.shape
        pad_size = 1  # Since kernel_size=3, we need to pad equally on both sides
        x = self.st_conv1(x, edge_index)
        x = self.st_conv2(x, edge_index)
        x = self.st_conv3(x, edge_index)
        x = self.st_conv4(x, edge_index)
        x = self.st_conv5(x, edge_index)
        # x = self.st_conv6(x, edge_index)
        # x = self.st_conv7(x, edge_index)
        # x = self.st_conv6(x, edge_index)
        
        # Global Average Pooling over nodes (keep batch & features)
        # x = torch.mean(x, dim=(1, 2))  # Reduce nodes dimension
        # [batch_size, num_timesteps, num_nodes, features]
        # [batch, channels, height, width]
        x = self.conv(x)

        # Flatten to [batch, time * channels]
        x_flat = x.view(batch_size, -1)
        
        # Pass through classifier
        y = self.classifier(x_flat)
        
        return y.squeeze(-1)  # (batch_size, 1)

Pretrain

In [13]:
import torch
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
import os

def train(train_dataset, test_dataset, train_subjects, test_subjects, model_class, model_params, 
          num_epochs=10, optimizer_class=None, optimizer_params=None, criterion=None, device=None, 
          batch_size=24, balanced_sampling=False, log_dir="runs/pretrain", 
          save_every_n_epochs=5, model_path=None):  # Accept both datasets and subjects
    
    # Set device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    # Set loss function
    if criterion is None:
        criterion = torch.nn.BCELoss()
    
    # Set optimizer
    if optimizer_class is None:
        optimizer_class = torch.optim.Adam
    if optimizer_params is None:
        optimizer_params = {"lr": 1e-3}
    
    # Setup for TensorBoard logging
    writer = SummaryWriter(log_dir)
    model_save_dir = os.path.join(log_dir, "models")
    os.makedirs(model_save_dir, exist_ok=True)
    
    # Filter train dataset by train subjects
    train_indices = [i for i, subj in enumerate(train_dataset.subjects) if subj in train_subjects]
    
    # Filter test dataset by test subjects
    test_indices = [i for i, subj in enumerate(test_dataset.subjects) if subj in test_subjects]
    
    # Prepare DataLoader for train and test datasets
    if balanced_sampling:
        # Get labels from filtered train dataset
        train_labels = torch.tensor([train_dataset[i][2].item() for i in train_indices], dtype=torch.long)
        
        # Compute class weights for balancing
        class_counts = torch.bincount(train_labels)
        class_counts = class_counts.float() + 1e-6  # Add small constant to avoid division by zero
        class_weights = 1.0 / class_counts
        weights = class_weights[train_labels]
        
        # Create a weighted random sampler
        sampler = WeightedRandomSampler(weights, num_samples=len(train_indices), replacement=True)
        
        train_loader = DataLoader(Subset(train_dataset, train_indices), batch_size=batch_size, sampler=sampler)
    else:
        train_loader = DataLoader(Subset(train_dataset, train_indices), batch_size=batch_size, shuffle=True)

    test_loader = DataLoader(Subset(test_dataset, test_indices), batch_size=batch_size, shuffle=False)
    
    # Initialize model
    model = model_class(**model_params).to(device)
    # if model_path is not None:
    #     model.load_state_dict(torch.load(model_path))
    if model_path is not None and os.path.exists(model_path):
        state_dict = torch.load(model_path)

        # Handle full checkpoint vs just state dict
        if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
            model.load_state_dict(state_dict['model_state_dict'])  # full checkpoint
        else:
            model.load_state_dict(state_dict)  # just weights

        print(f"Loaded model weights from {model_path}")

    # Load model from checkpoint if provided
    start_epoch = 0
    

    # Initialize optimizer
    optimizer = optimizer_class(model.parameters(), **optimizer_params)
    
    print("Starting pretraining...")
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):  
        # Train one epoch
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        
        # Evaluate on the validation set (test set in your case)
        val_loss, val_acc, val_f1, val_recall, val_precision = evaluate(model, test_loader, criterion, device)
        
        # Log metrics to TensorBoard
        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Accuracy/Train", train_acc, epoch)
        writer.add_scalar("Loss/Validation", val_loss, epoch)
        writer.add_scalar("Accuracy/Validation", val_acc, epoch)
        writer.add_scalar("F1/Validation", val_f1, epoch)
        writer.add_scalar("Recall/Validation", val_recall, epoch)
        writer.add_scalar("Precision/Validation", val_precision, epoch)
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f} | "
              f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}, F1={val_f1:.4f}, Recall={val_recall:.4f}")
        
        # Save model every 'n' epochs
        if (epoch + 1) % save_every_n_epochs == 0:
            model_path = os.path.join(model_save_dir, f"model_epoch_{epoch+1}.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, model_path)
            print(f"Model saved at epoch {epoch+1}.")

    writer.close()
    print("Pretraining completed!")

pretrain

In [21]:
model = train(
        train_dataset=pretrain_dataset_no_flips,
        test_dataset=pretrain_dataset_no_flips,
        model_class=STGCN, 
        model_params={"in_channels": 2, 
                  "num_nodes": trainset.num_nodes,
                  "num_edges":trainset.edge_index.size(1), 
                  "num_timesteps": trainset.num_timesteps},
        train_subjects=folds[0]+folds[2]+folds[1],
        test_subjects=folds[3],
        optimizer_class=optim.Adam,
        optimizer_params={"lr": 1e-5},
        criterion=nn.BCELoss(),
        batch_size=32,
        log_dir="/media/thatblueboy/Seagate/LOP/logs/pretrain_5_layers_no_flips",
        num_epochs=2000,
        device="cuda",
        balanced_sampling=True,
        save_every_n_epochs=50
        )

cuda
Starting pretraining...


  train_labels = torch.tensor([train_dataset[i][2].item() for i in train_indices], dtype=torch.long)


Epoch 1: Train Loss=0.6930, Train Acc=0.50 | Val Loss=0.7147, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 2: Train Loss=0.6918, Train Acc=0.51 | Val Loss=0.7106, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 3: Train Loss=0.6928, Train Acc=0.50 | Val Loss=0.7108, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 4: Train Loss=0.6891, Train Acc=0.55 | Val Loss=0.7112, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 5: Train Loss=0.6928, Train Acc=0.50 | Val Loss=0.7098, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 6: Train Loss=0.6931, Train Acc=0.49 | Val Loss=0.7055, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 7: Train Loss=0.6953, Train Acc=0.46 | Val Loss=0.7015, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 8: Train Loss=0.6916, Train Acc=0.51 | Val Loss=0.7003, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 9: Train Loss=0.6932, Train Acc=0.49 | Val Loss=0.6998, Val Acc=0.18, F1=0.3009, Recall=1.0000
Epoch 10: Train Loss=0.6919, Train Acc=0.50 | Val Loss=0.6988, Val Acc=0.19, F1=0.3036, Rec

Train

In [28]:
model = train(
        train_dataset=trainset,
        test_dataset=testset,
        model_class=STGCN, 
        model_params={"in_channels": 2, 
                  "num_nodes": trainset.num_nodes,
                  "num_edges":trainset.edge_index.size(1), 
                  "num_timesteps": trainset.num_timesteps},
        train_subjects=folds[0]+folds[2]+folds[1],
        test_subjects=folds[3],
        optimizer_class=optim.AdamW,
        optimizer_params={"lr": 1e-6, "weight_decay":1e-5},
        # optimizer_params={"lr": 1e-6},
        criterion=nn.BCELoss(),
        batch_size=32,
        num_epochs=2000,
        device="cuda",
        log_dir="/media/thatblueboy/Seagate/LOP/logs/au12_5_layers_lr_6_weight_decay_pretrain_had_no_flips_cntd",
        balanced_sampling=True,
        save_every_n_epochs=50,
        model_path="/media/thatblueboy/Seagate/LOP/logs/au12_5_layers_lr_6_weight_decay_pretrain_had_no_flips/models/model_epoch_1850.pth"
        )

cuda
Loaded model weights from /media/thatblueboy/Seagate/LOP/logs/au12_5_layers_lr_6_weight_decay_pretrain_had_no_flips/models/model_epoch_1850.pth
Starting pretraining...


  train_labels = torch.tensor([train_dataset[i][2].item() for i in train_indices], dtype=torch.long)
  state_dict = torch.load(model_path)


Epoch 1: Train Loss=0.2827, Train Acc=0.90 | Val Loss=0.5366, Val Acc=0.81, F1=0.3636, Recall=0.3333
Epoch 2: Train Loss=0.2774, Train Acc=0.91 | Val Loss=0.5730, Val Acc=0.81, F1=0.3636, Recall=0.3333
Epoch 3: Train Loss=0.3239, Train Acc=0.87 | Val Loss=0.5682, Val Acc=0.83, F1=0.5000, Recall=0.5000
Epoch 4: Train Loss=0.2315, Train Acc=0.91 | Val Loss=0.5492, Val Acc=0.81, F1=0.3636, Recall=0.3333
Epoch 5: Train Loss=0.2321, Train Acc=0.91 | Val Loss=0.5203, Val Acc=0.83, F1=0.5000, Recall=0.5000
Epoch 6: Train Loss=0.2240, Train Acc=0.93 | Val Loss=0.5287, Val Acc=0.83, F1=0.5000, Recall=0.5000
Epoch 7: Train Loss=0.2945, Train Acc=0.91 | Val Loss=0.5221, Val Acc=0.83, F1=0.5000, Recall=0.5000
Epoch 8: Train Loss=0.3001, Train Acc=0.90 | Val Loss=0.5501, Val Acc=0.83, F1=0.5000, Recall=0.5000
Epoch 9: Train Loss=0.2559, Train Acc=0.90 | Val Loss=0.5380, Val Acc=0.81, F1=0.3636, Recall=0.3333
Epoch 10: Train Loss=0.3122, Train Acc=0.93 | Val Loss=0.5297, Val Acc=0.83, F1=0.5000, Rec

Eval on Testset

In [6]:
def evaluate_subjects(dataset, subjects, model_class, model_params, 
                      model_path, criterion=None, device=None, batch_size=24):
    """
    Evaluates a model on a specific subset of subjects in a dataset.

    Args:
        dataset: PyTorch dataset with a .subjects attribute.
        subjects: List of subject identifiers to evaluate on.
        model_class: Class of the model to instantiate.
        model_params: Dictionary of parameters to pass to the model constructor.
        model_path: Path to the saved model checkpoint (.pth).
        criterion: Loss function (default: BCELoss).
        device: PyTorch device (default: auto-detect CUDA if available).
        batch_size: Batch size for evaluation.

    Returns:
        Tuple of (loss, accuracy, F1, recall, precision).
    """
    import torch
    from torch.utils.data import DataLoader, Subset
    import os

    # Set device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set loss function
    if criterion is None:
        criterion = torch.nn.BCELoss()

    # Filter dataset by given subjects
    subject_indices = [i for i, subj in enumerate(dataset.subjects) if subj in subjects]
    eval_loader = DataLoader(Subset(dataset, subject_indices), batch_size=batch_size, shuffle=False)

    # Initialize and load model
    model = model_class(**model_params).to(device)

    if model_path is not None:
        checkpoint = torch.load(model_path, map_location=device)
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)

    # Evaluate the model
    print(f"Evaluating on {len(subject_indices)} samples from specified subjects...")
    val_loss, val_acc, val_f1, val_recall, val_precision = evaluate(model, eval_loader, criterion, device)

    print(f"Evaluation Results — Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, "
          f"F1: {val_f1:.4f}, Recall: {val_recall:.4f}, Precision: {val_precision:.4f}")

    return val_loss, val_acc, val_f1, val_recall, val_precision

