In [None]:
import pickle

# Path
pkl_path = '/project/msoleyma_1026/personality_detection/first_impressions_v2_dataset/reconstructed_outputs.pkl'

# Load
with open(pkl_path, 'rb') as f:
    reconstructed_data = pickle.load(f)

# Check type
print("Top-level type:", type(reconstructed_data))
print("Number of entries:", len(reconstructed_data))

# Print first 3 entries safely
for idx, entry in enumerate(reconstructed_data[:3]):
    print(f"\nEntry {idx + 1}:")
    print("Type:", type(entry))
    if isinstance(entry, dict):
        print("Keys:", entry.keys())
        for k, v in entry.items():
            if hasattr(v, 'shape'):
                print(f" - {k}: shape {v.shape}")
            else:
                print(f" - {k}: {v}")
    else:
        print(entry)


Top-level type: <class 'list'>
Number of entries: 1986

Entry 1:
Type: <class 'dict'>
Keys: dict_keys(['original_text', 'original_vision', 'original_audio', 'ground_truth_text', 'ground_truth_vision', 'ground_truth_audio', 'generated_vision', 'generated_audio', 'agreeableness', 'openness', 'neuroticism', 'extraversion', 'conscientiousness'])
 - original_text: shape torch.Size([768, 5])
 - original_vision: shape torch.Size([512, 5])
 - original_audio: shape torch.Size([1024, 5])
 - ground_truth_text: shape torch.Size([32, 3])
 - ground_truth_vision: shape torch.Size([32, 3])
 - ground_truth_audio: shape torch.Size([32, 3])
 - generated_vision: shape torch.Size([32, 3])
 - generated_audio: shape torch.Size([32, 3])
 - agreeableness: 0.36263737082481384
 - openness: 0.47777777910232544
 - neuroticism: 0.3645833432674408
 - extraversion: 0.20560747385025024
 - conscientiousness: 0.33980581164360046

Entry 2:
Type: <class 'dict'>
Keys: dict_keys(['original_text', 'original_vision', 'origina

In [None]:
with open('annotation_test.pkl', 'rb') as f:
    annotations = pickle.load(f, encoding='latin1')

In [None]:
annotations.keys()

dict_keys(['extraversion', 'neuroticism', 'agreeableness', 'conscientiousness', 'interview', 'openness'])

In [None]:
# annotations['openness']

In [None]:
import os
import pickle
import random

# Base paths
base_text_dir = '/project/msoleyma_1026/personality_detection/first_impressions_v2_dataset/text_feature_vectors/'
reconstructed_visual_dir = './reconstructed_visual_feature_vectors/'
reconstructed_audio_dir = './reconstructed_audio_feature_vectors/'

# Annotation files
annotation_train_file = './annotation_training.pkl'
annotation_val_file = './annotation_validation.pkl'
annotation_test_file = './annotation_test.pkl'

# Dirs
train_dirs = [f"train-{i}" for i in range(1, 7)]
val_dir = 'val'
test_dir = 'test'

def list_npy_files(directory):
    if os.path.exists(directory):
        return [f for f in os.listdir(directory) if f.endswith('.npy')]
    else:
        return []

def print_sample_files(base_dir, subdirs, label):
    all_files = []
    for sub in subdirs:
        full_path = os.path.join(base_dir, sub)
        all_files.extend(list_npy_files(full_path))
    print(f"\n--- {label} ({len(all_files)} files) ---")
    for f in random.sample(all_files, min(5, len(all_files))):
        print(f)

def print_annotation_keys(annotation_file, label):
    if os.path.exists(annotation_file):
        with open(annotation_file, 'rb') as f:
            annotations = pickle.load(f, encoding='latin1')
        # Pick one trait to sample keys from
        sample_trait = 'openness'
        video_ids = list(annotations[sample_trait].keys())
        print(f"\n--- {label} ({len(video_ids)} video ids) ---")
        for vid in random.sample(video_ids, min(5, len(video_ids))):
            print(f"Video ID: {vid} → Openness: {annotations[sample_trait][vid]}")
    else:
        print(f"Annotation file {annotation_file} not found!")

def test_lookup(annotation_file, feature_file_sample):
    with open(annotation_file, 'rb') as f:
        annotations = pickle.load(f, encoding='latin1')
    feature_id = feature_file_sample.replace('.npy', '')
    annotation_id = feature_id + '.mp4'
    sample_trait = 'openness'
    value = annotations[sample_trait].get(annotation_id, None)
    print(f"\nTrying to match Feature '{feature_id}' with Annotation '{annotation_id}':")
    print(f"  Openness: {value}")

# Print random samples
print_sample_files(base_text_dir, train_dirs, "Train Text Files")
print_sample_files(base_text_dir, [val_dir], "Validation Text Files")
print_sample_files(base_text_dir, [test_dir], "Test Text Files")

print("\n--- Reconstructed Visual Files ---")
reconstructed_visual_files = list_npy_files(reconstructed_visual_dir)
for f in random.sample(reconstructed_visual_files, min(5, len(reconstructed_visual_files))):
    print(f)

print("\n--- Reconstructed Audio Files ---")
reconstructed_audio_files = list_npy_files(reconstructed_audio_dir)
for f in random.sample(reconstructed_audio_files, min(5, len(reconstructed_audio_files))):
    print(f)

print_annotation_keys(annotation_train_file, "Training Annotations")
print_annotation_keys(annotation_val_file, "Validation Annotations")
print_annotation_keys(annotation_test_file, "Test Annotations")

# Test matching a random test file
if reconstructed_visual_files:
    test_lookup(annotation_test_file, random.choice(reconstructed_visual_files))



--- Train Text Files (6000 files) ---
cpzY1b6wJqs.002.npy
Wx_oe0SxD9w.002.npy
m04e9ylCoK0.003.npy
e07IozLUeKc.005.npy
z-CV743owek.005.npy

--- Validation Text Files (2000 files) ---
sPMNhG1Sehc.001.npy
9rF3BEXetOo.003.npy
lxnV9X8T2Zc.005.npy
R-qB2FX7ZbE.002.npy
mlXZQ8dO0nQ.001.npy

--- Test Text Files (2000 files) ---
XJj34u5IzU0.005.npy
2GHz8LYflE0.003.npy
uYvDMWWq2Jw.003.npy
QXFRE_pjrAE.002.npy
kFak4VnRnRM.000.npy

--- Reconstructed Visual Files ---
F5kL7RWS_f0.005.npy
L4uFD6434Pc.000.npy
XJj34u5IzU0.005.npy
LP5N5uPfTdA.003.npy
_tq-VgoXGMo.004.npy

--- Reconstructed Audio Files ---
RbX4q4KceVk.001.npy
jvHDFrgu9PA.002.npy
VPFrKx72gvo.003.npy
xgwiqo2AsCA.004.npy
p5v75vAZ7F0.000.npy

--- Training Annotations (6000 video ids) ---
Video ID: BZ3FEf_KKso.001.mp4 → Openness: 0.5333333333333333
Video ID: PmJw3uI1qBE.000.mp4 → Openness: 0.3
Video ID: 0mym1CooiTE.003.mp4 → Openness: 0.7888888888888888
Video ID: OMHlfDF99Mw.000.mp4 → Openness: 0.6333333333333333
Video ID: jub-AHFTH_g.000.mp4 → 

In [None]:
import os
import numpy as np
import torch
import pickle
from tqdm import tqdm

base_dir = '/project/msoleyma_1026/personality_detection/first_impressions_v2_dataset/'

text_dir = os.path.join(base_dir, 'text_feature_vectors')
vision_dir = os.path.join(base_dir, 'video_feature_vectors')
audio_dir = os.path.join(base_dir, 'audio_feature_vectors')

reconstructed_visual_dir = './reconstructed_visual_feature_vectors/'
reconstructed_audio_dir = './reconstructed_audio_feature_vectors/'

annotation_train_file = './annotation_training.pkl'
annotation_val_file = './annotation_validation.pkl'
annotation_test_file = './annotation_test.pkl'

train_subdirs = [f"train-{i}" for i in range(1, 7)]
val_subdirs = ['val']
test_subdirs = ['test']


def load_annotations(annotation_path):
    with open(annotation_path, 'rb') as f:
        annotations = pickle.load(f, encoding='latin1')
    return annotations

def collect_feature_files(base_path, subdirs):
    npy_files = []
    for subdir in subdirs:
        full_path = os.path.join(base_path, subdir)
        if os.path.exists(full_path):
            files = [os.path.join(full_path, f) for f in os.listdir(full_path) if f.endswith('.npy')]
            npy_files.extend(files)
    return npy_files

def load_feature(path):
    if os.path.exists(path):
        arr = np.load(path)
        return torch.tensor(arr).T  # Shape: (feature_dim, 5)
    else:
        return None

def lookup_traits(annotations, video_id_no_ext):
    video_id_with_ext = video_id_no_ext + '.mp4'
    traits = {}
    try:
        traits['agreeableness'] = annotations['agreeableness'][video_id_with_ext]
        traits['openness'] = annotations['openness'][video_id_with_ext]
        traits['neuroticism'] = annotations['neuroticism'][video_id_with_ext]
        traits['extraversion'] = annotations['extraversion'][video_id_with_ext]
        traits['conscientiousness'] = annotations['conscientiousness'][video_id_with_ext]
    except KeyError:
        return None
    return traits

def process_split(feature_dir_base, subdirs, annotations, split_type, recon_visual_dir=None, recon_audio_dir=None):
    output = {}
    feature_files = collect_feature_files(feature_dir_base, subdirs)
    print(f"\nProcessing {split_type} set: {len(feature_files)} total samples...")

    skipped = 0
    for file_path in tqdm(feature_files):
        video_filename = os.path.basename(file_path)  # e.g., 'abc123.003.npy'
        video_id = video_filename.replace('.npy', '')  # e.g., 'abc123.003'

        subdir = file_path.split('/')[-2]
        filename = file_path.split('/')[-1]

        # Load original features
        text_feature = load_feature(os.path.join(text_dir, subdir, filename))
        vision_feature = load_feature(os.path.join(vision_dir, subdir, filename))
        audio_feature = load_feature(os.path.join(audio_dir, subdir, filename))

        if text_feature is None or vision_feature is None or audio_feature is None:
            skipped += 1
            continue

        # Lookup traits
        traits = lookup_traits(annotations, video_id)
        if traits is None:
            skipped += 1
            continue

        entry = {
            'original_text': text_feature,
            'original_vision': vision_feature,
            'original_audio': audio_feature,
            **traits
        }

        # Only for test set, load reconstructions if requested
        if split_type == 'test' and recon_visual_dir and recon_audio_dir:
            visual_recon_path = os.path.join(reconstructed_visual_dir, f"{video_id}.npy")
            audio_recon_path = os.path.join(reconstructed_audio_dir, f"{video_id}.npy")

            if os.path.exists(visual_recon_path) and os.path.exists(audio_recon_path):
                entry['generated_vision'] = torch.tensor(np.load(visual_recon_path))  # (32,5)
                entry['generated_audio'] = torch.tensor(np.load(audio_recon_path))    # (32,5)
            else:
                # Skip samples without both reconstructions
                skipped += 1
                continue

        output[video_id] = entry

    print(f"{split_type.capitalize()} samples processed: {len(output)} (Skipped {skipped})")
    return output

def save_pickle(data, save_path):
    with open(save_path, 'wb') as f:
        pickle.dump(data, f)
    print(f"Saved pickle: {save_path} ({len(data)} entries)")


# Load annotation files
annotations_train = load_annotations(annotation_train_file)
annotations_val = load_annotations(annotation_val_file)
annotations_test = load_annotations(annotation_test_file)

# Process train
train_data = process_split(
    feature_dir_base=text_dir,
    subdirs=train_subdirs,
    annotations=annotations_train,
    split_type='train'
)
save_pickle(train_data, 'cvae_train.pkl')

# Process validation
val_data = process_split(
    feature_dir_base=text_dir,
    subdirs=val_subdirs,
    annotations=annotations_val,
    split_type='val'
)
save_pickle(val_data, 'cvae_val.pkl')

# Process test (with reconstructions)
test_data = process_split(
    feature_dir_base=text_dir,
    subdirs=test_subdirs,
    annotations=annotations_test,
    split_type='test',
    recon_visual_dir=reconstructed_visual_dir,
    recon_audio_dir=reconstructed_audio_dir
)
save_pickle(test_data, 'cvae_test.pkl')

print("\n✅ All pickle files created successfully!")



Processing train set: 6000 total samples...


  return torch.tensor(arr).T  # Shape: (feature_dim, 5)
100%|██████████| 6000/6000 [16:35<00:00,  6.02it/s]  


Train samples processed: 6000 (Skipped 0)
Saved pickle: cvae_train.pkl (6000 entries)

Processing val set: 2000 total samples...


100%|██████████| 2000/2000 [06:02<00:00,  5.51it/s]


Val samples processed: 2000 (Skipped 0)
Saved pickle: cvae_val.pkl (2000 entries)

Processing test set: 2000 total samples...


100%|██████████| 2000/2000 [06:27<00:00,  5.16it/s]


Test samples processed: 1999 (Skipped 1)
Saved pickle: cvae_test.pkl (1999 entries)

✅ All pickle files created successfully!


In [None]:
import pickle

# Path to the files
train_pkl = 'reconstructed_outputs_train.pkl'
val_pkl = 'reconstructed_outputs_valid.pkl'
test_pkl = 'reconstructed_outputs_test.pkl'

# Load
def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f, encoding='latin1')

train_data = load_pickle(train_pkl)
val_data = load_pickle(val_pkl)
test_data = load_pickle(test_pkl)

# Print top-level type
print(f"Train Data Type: {type(train_data)}")
print(f"Number of Train Samples: {len(train_data)}")

# # Peek inside one sample
# first_key = list(train_data.keys())[0]
# print(f"Sample Key: {first_key}")
# print(f"Sample Entry Type: {type(train_data[first_key])}")

# # Print shapes if available
# sample = train_data[first_key]
# if isinstance(sample, dict):
#     for k, v in sample.items():
#         if hasattr(v, 'shape'):
#             print(f" - {k}: shape {v.shape}")
#         else:
#             print(f" - {k}: {type(v)}")


Train Data Type: <class 'list'>
Number of Train Samples: 5952


AttributeError: 'list' object has no attribute 'keys'

In [None]:
# Pick a sample from the list
sample = train_data[0]

print(f"Sample Type: {type(sample)}")
print(f"Sample Keys: {list(sample.keys())}")

# Print types and shapes
for k, v in sample.items():
    if hasattr(v, 'shape'):
        print(f" - {k}: shape {v.shape}")
    else:
        print(f" - {k}: {type(v)}, value={v}")


Sample Type: <class 'dict'>
Sample Keys: ['original_text', 'original_vision', 'original_audio', 'ground_truth_text', 'ground_truth_vision', 'ground_truth_audio', 'generated_vision', 'generated_audio', 'agreeableness', 'openness', 'neuroticism', 'extraversion', 'conscientiousness']
 - original_text: shape torch.Size([768, 5])
 - original_vision: shape torch.Size([512, 5])
 - original_audio: shape torch.Size([1024, 5])
 - ground_truth_text: shape torch.Size([32, 3])
 - ground_truth_vision: shape torch.Size([32, 3])
 - ground_truth_audio: shape torch.Size([32, 3])
 - generated_vision: shape torch.Size([32, 3])
 - generated_audio: shape torch.Size([32, 3])
 - agreeableness: <class 'float'>, value=0.5054945349693298
 - openness: <class 'float'>, value=0.5777778029441833
 - neuroticism: <class 'float'>, value=0.4479166567325592
 - extraversion: <class 'float'>, value=0.4392523467540741
 - conscientiousness: <class 'float'>, value=0.5631067752838135


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

class DiffusionDownstreamDataset(Dataset):
    def __init__(self, data_list, split, task_type):
        """
        Args:
            data_list: loaded list from pickle (train/val/test)
            split: 'train', 'val', 'test'
            task_type: 'upper', 'middle_audio', 'middle_vision', 'lower_audio', 'lower_vision'
        """
        self.data = data_list
        self.split = split
        self.task_type = task_type

    def __len__(self):
        return len(self.data)

    def mean_pool(self, feature):
        if isinstance(feature, torch.Tensor):
            return feature.view(-1)
        else:
            raise ValueError(f"Unsupported feature type: {type(feature)}")


    def __getitem__(self, idx):
        sample = self.data[idx]

        input_feats = {}

        # Text (always ground truth)
        text_feat = self.mean_pool(sample['ground_truth_text'])

        # Vision
        if self.split == 'test' and self.task_type in ['middle_audio', 'lower_audio']:
            # Vision missing ➔ use generated_vision
            vision_feat = self.mean_pool(sample['generated_vision'])
        else:
            # Vision available
            vision_feat = self.mean_pool(sample['ground_truth_vision'])

        # Audio
        if self.split == 'test' and self.task_type in ['middle_vision', 'lower_vision']:
            # Audio missing ➔ use generated_audio
            audio_feat = self.mean_pool(sample['generated_audio'])
        else:
            # Audio available
            audio_feat = self.mean_pool(sample['ground_truth_audio'])

        # Assign based on task type
        if self.task_type == 'upper':
            input_feats = {'audio': audio_feat, 'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'middle_audio':
            input_feats = {'audio': audio_feat, 'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'middle_vision':
            input_feats = {'audio': audio_feat, 'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'lower_audio':
            input_feats = {'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'lower_vision':
            input_feats = {'audio': audio_feat, 'text': text_feat}
        else:
            raise ValueError(f"Unknown task type {self.task_type}")

        # Traits target
        traits = torch.tensor([
            sample['agreeableness'],
            sample['openness'],
            sample['neuroticism'],
            sample['extraversion'],
            sample['conscientiousness']
        ], dtype=torch.float32)

        return input_feats, traits


In [None]:
class ModalityLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        """
        x: (batch, feature_dim)  or (batch, feature_dim, time_steps)
        """
        if x.dim() == 3:
            # (batch, feature_dim, time_steps) ➔ (batch, time_steps, feature_dim)
            x = x.permute(0, 2, 1)
        elif x.dim() == 2:
            # (batch, feature_dim) ➔ unsqueeze to (batch, 1, feature_dim)
            x = x.unsqueeze(1)

        out, (h_n, c_n) = self.lstm(x)
        return h_n[-1]  # Return last hidden state


class TextEncoder(nn.Module):
    def __init__(self, input_dim=768, output_dim=256):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)


In [None]:
class EarlyFusionRegressor(nn.Module):
    def __init__(self, input_dims, hidden_dim=256, output_dim=5):
        super().__init__()
        self.input_dims = input_dims
        total_input_dim = sum(input_dims.values())
        self.fc1 = nn.Linear(total_input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, inputs):
        # inputs: dict of {'audio': tensor, 'vision': tensor, 'text': tensor}
        x = torch.cat(list(inputs.values()), dim=-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class LateFusionRegressor(nn.Module):
    def __init__(self, input_dims, hidden_dim=256, output_dim=5):
        super().__init__()
        self.modalities = nn.ModuleDict({
            k: nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
            for k, dim in input_dims.items()
        })

    def forward(self, inputs):
        preds = []
        for k, v in inputs.items():
            preds.append(self.modalities[k](v))
        # Mean the outputs
        preds = torch.stack(preds, dim=0).mean(dim=0)
        return preds

In [None]:
import sys
!{sys.executable} -m pip install scikit-learn

Collecting scikit-learn
  Using cached scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)
Collecting scipy>=1.6.0 (from scikit-learn)
  Using cached scipy-1.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Using cached scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.2 MB)
Using cached scipy-1.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.3 MB)
Using cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, scikit-learn
Successfully installed scikit-learn-1.6.1 scipy-1.15.2 threadpoolctl-3.6.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np
import os
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import pearsonr

In [None]:
class Trainer:
    def __init__(self, model, optimizer, loss_fn=nn.MSELoss(), device='cuda'):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device

    def train(self, train_loader, val_loader, epochs=50, patience=5):
        best_val_loss = float('inf')
        best_model = None
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            total_train_loss = 0
            for batch_x, batch_y in train_loader:
                batch_x = {k: v.to(self.device) for k, v in batch_x.items()}
                batch_y = batch_y.to(self.device)

                preds = self.model(batch_x)
                loss = self.loss_fn(preds, batch_y)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_train_loss += loss.item()

            avg_train_loss = total_train_loss / len(train_loader)

            val_loss = self.evaluate_loss(val_loader)
            print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={val_loss:.4f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = self.model.state_dict()
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print("Early stopping triggered!")
                break

        self.model.load_state_dict(best_model)

    def evaluate_loss(self, loader):
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for batch_x, batch_y in loader:
                batch_x = {k: v.to(self.device) for k, v in batch_x.items()}
                batch_y = batch_y.to(self.device)

                preds = self.model(batch_x)
                loss = self.loss_fn(preds, batch_y)
                total_loss += loss.item()

        return total_loss / len(loader)

In [None]:
class Evaluator:
    @staticmethod
    def evaluate(model, loader, device='cuda'):
        model.eval()
        preds_list, labels_list = [], []

        with torch.no_grad():
            for batch_x, batch_y in loader:
                batch_x = {k: v.to(device) for k, v in batch_x.items()}
                batch_y = batch_y.to(device)

                preds = model(batch_x)
                preds_list.append(preds.cpu())
                labels_list.append(batch_y.cpu())

        preds = torch.cat(preds_list, dim=0).numpy()
        labels = torch.cat(labels_list, dim=0).numpy()

        return Evaluator.compute_metrics(preds, labels)

    @staticmethod
    def compute_metrics(preds, labels):
        results = {}
        num_traits = preds.shape[1]

        for i in range(num_traits):
            p = preds[:, i]
            l = labels[:, i]

            mae = np.mean(np.abs(p - l))
            acc = 1 - mae
            mse = mean_squared_error(l, p)
            r2 = r2_score(l, p)
            pcc, _ = pearsonr(l, p)
            mean_p = np.mean(p)
            mean_l = np.mean(l)
            var_p = np.var(p)
            var_l = np.var(l)
            ccc = (2 * pcc * np.sqrt(var_p) * np.sqrt(var_l)) / (var_p + var_l + (mean_p - mean_l) ** 2)

            results[f'trait_{i+1}'] = {
                'MAE': mae,
                'ACC': acc,
                'MSE': mse,
                'R2': r2,
                'PCC': pcc,
                'CCC': ccc
            }

        # Aggregate metrics
        avg_metrics = {metric: np.mean([results[f'trait_{i+1}'][metric] for i in range(num_traits)]) for metric in ['MAE', 'ACC', 'MSE', 'R2', 'PCC', 'CCC']}
        results['average'] = avg_metrics

        return results


In [None]:
class ExperimentRunner:
    def __init__(self, train_data, val_data, test_data, batch_size=64, device='cuda', dataset_cls=None):
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.batch_size = batch_size
        self.device = device
        self.dataset_cls = dataset_cls if dataset_cls is not None else PersonalityDataset  # default is PersonalityDataset

    def run(self, task_type, fusion_type):
        print(f"\n=== Running Task: {task_type.upper()}, Fusion: {fusion_type.upper()} ===")

        # Use the custom dataset class
        train_dataset = self.dataset_cls(self.train_data, split='train', task_type=task_type)
        val_dataset = self.dataset_cls(self.val_data, split='val', task_type=task_type)
        test_dataset = self.dataset_cls(self.test_data, split='test', task_type=task_type)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

        input_dims = self.get_input_dims(train_dataset)

        audio_lstm = ModalityLSTM(96, hidden_dim=256) if 'audio' in input_dims else None
        vision_lstm = ModalityLSTM(96, hidden_dim=256) if 'vision' in input_dims else None
        text_encoder = TextEncoder(96, output_dim=256)


        if fusion_type == 'early':
            model = EarlyFusionRegressor(input_dims={k: 256 for k in input_dims})
        else:
            model = LateFusionRegressor(input_dims={k: 256 for k in input_dims})

        full_model = FullModel(audio_lstm, vision_lstm, text_encoder, model, device=self.device)
        optimizer = torch.optim.Adam(full_model.parameters(), lr=1e-3)

        trainer = Trainer(full_model, optimizer, loss_fn=nn.MSELoss(), device=self.device)
        trainer.train(train_loader, val_loader, epochs=50, patience=5)

        evaluator = Evaluator()
        results = evaluator.evaluate(full_model, test_loader, device=self.device)

        print("Results:")
        for trait, metrics in results.items():
            print(f"{trait}: {metrics}")

        return results  # very important: return results!

    def get_input_dims(self, dataset):
        sample, _ = dataset[0]
        return {k: v.shape[-1] for k, v in sample.items()}


In [None]:
class FullModel(nn.Module):
    def __init__(self, audio_lstm, vision_lstm, text_encoder, fusion_model, device='cuda'):
        super().__init__()
        self.audio_lstm = audio_lstm.to(device) if audio_lstm else None
        self.vision_lstm = vision_lstm.to(device) if vision_lstm else None
        self.text_encoder = text_encoder.to(device)
        self.fusion_model = fusion_model.to(device)

    def forward(self, inputs):
        feats = {}
        if 'audio' in inputs and self.audio_lstm:
            feats['audio'] = self.audio_lstm(inputs['audio'])
        if 'vision' in inputs and self.vision_lstm:
            feats['vision'] = self.vision_lstm(inputs['vision'])
        if 'text' in inputs:
            feats['text'] = self.text_encoder(inputs['text'])

        return self.fusion_model(feats)

In [None]:
runner = ExperimentRunner(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    batch_size=64,
    device='cuda',
    dataset_cls=DiffusionDownstreamDataset  # <-- new dataset class for latent
)


In [1]:
task = 'middle_audio'
fusion = 'early'
results = runner.run(task, fusion)

In [2]:
# Middle Audio - Late Fusion
task = 'middle_audio'
fusion = 'late'
results = runner.run(task, fusion)

In [3]:
# Middle Vision - Early Fusion
task = 'middle_vision'
fusion = 'early'
results = runner.run(task, fusion)

In [4]:
# Middle Vision - Late Fusion
task = 'middle_vision'
fusion = 'late'
results = runner.run(task, fusion)