In [None]:
import pickle
import random
import torch

# Load function
def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

# Paths
train_pickle = 'cvae_train.pkl'
val_pickle = 'cvae_val.pkl'
test_pickle = 'cvae_test.pkl'

# Load data
train_data = load_pickle(train_pickle)
val_data = load_pickle(val_pickle)
test_data = load_pickle(test_pickle)

# Print random few samples from each split
def print_sample_shapes(data_dict, split_name, check_generated=False):
    keys = list(data_dict.keys())
    print(f"\n--- {split_name.upper()} ({len(keys)} samples) ---")
    for vid in random.sample(keys, min(3, len(keys))):
        sample = data_dict[vid]
        print(f"Video ID: {vid}")
        print(f"  original_text shape: {sample['original_text'].shape}")
        print(f"  original_vision shape: {sample['original_vision'].shape}")
        print(f"  original_audio shape: {sample['original_audio'].shape}")
        if check_generated:
            if 'generated_vision' in sample:
                print(f"  generated_vision shape: {sample['generated_vision'].shape}")
            if 'generated_audio' in sample:
                print(f"  generated_audio shape: {sample['generated_audio'].shape}")
        print(f"  Traits: Agreeableness {sample['agreeableness']}, Openness {sample['openness']}")

# Print examples
print_sample_shapes(train_data, "Train")
print_sample_shapes(val_data, "Validation")
print_sample_shapes(test_data, "Test", check_generated=True)



--- TRAIN (6000 samples) ---
Video ID: FxVUG2R1y0Q.004
  original_text shape: torch.Size([768])
  original_vision shape: torch.Size([512, 1, 5])
  original_audio shape: torch.Size([1024, 149, 5])
  Traits: Agreeableness 0.835164835164835, Openness 0.7999999999999999
Video ID: TD3H2DOSi1Y.001
  original_text shape: torch.Size([768])
  original_vision shape: torch.Size([512, 1, 5])
  original_audio shape: torch.Size([1024, 149, 5])
  Traits: Agreeableness 0.6703296703296703, Openness 0.6777777777777777
Video ID: wK_ExIjn5Q8.002
  original_text shape: torch.Size([768])
  original_vision shape: torch.Size([512, 1, 5])
  original_audio shape: torch.Size([1024, 149, 5])
  Traits: Agreeableness 0.4395604395604395, Openness 0.5444444444444444

--- VALIDATION (2000 samples) ---
Video ID: tvKUJujTUEo.002
  original_text shape: torch.Size([768])
  original_vision shape: torch.Size([512, 1, 5])
  original_audio shape: torch.Size([1024, 149, 5])
  Traits: Agreeableness 0.7032967032967032, Openness

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import pickle
import numpy as np
import os

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

def pad_tensor_to_shape(x, target_shape):
    """Pads a tensor with zeros to match the target shape."""
    current_shape = x.shape
    pad = []
    for c, t in zip(reversed(current_shape), reversed(target_shape)):
        pad.extend([0, max(t - c, 0)])
    return F.pad(x, pad)

class PersonalityDataset(Dataset):
    def __init__(self, data_dict, split, task_type):
        """
        Args:
            data_dict: dictionary from loaded pickle
            split: 'train', 'val', 'test'
            task_type: 'upper', 'middle_audio', 'middle_vision', 'lower_audio', 'lower_vision'
        """
        self.data = data_dict
        self.split = split
        self.task_type = task_type
        self.keys = list(data_dict.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        vid = self.keys[idx]
        sample = self.data[vid]

        text_feat = sample['original_text'].float()

        # -------- Handle audio feature --------
        if self.split == 'test' and self.task_type in ['middle_vision', 'lower_vision']:
            audio = sample['generated_audio'].squeeze(1).permute(1, 0).float()
        else:
            audio = sample['original_audio'].float()

        # Fix audio shapes
        if audio.dim() == 3:
            audio_feat = audio.mean(dim=1)  # (1024, frames, 5) → (1024,5)
        elif audio.dim() == 2:
            audio_feat = audio  # (1024,5)
        elif audio.dim() == 1:
            audio_feat = audio.unsqueeze(1).repeat(1, 5)  # (1024,) → (1024,5)
        else:
            raise ValueError(f"Unexpected audio feature shape {audio.shape} for video {vid}")

        audio_feat = pad_tensor_to_shape(audio_feat, (1024,5))

        # -------- Handle vision feature --------
        if self.split == 'test' and self.task_type in ['middle_audio', 'lower_audio']:
            vision_feat = sample['generated_vision'].squeeze(1).permute(1, 0).float()
        else:
            vision_feat = sample['original_vision'].squeeze(1).float()

        vision_feat = pad_tensor_to_shape(vision_feat, (512,5))

        # -------- Assemble inputs --------
        input_feats = {}
        if self.task_type in ['upper', 'middle_audio', 'middle_vision']:
            input_feats = {'audio': audio_feat, 'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'lower_audio':
            input_feats = {'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'lower_vision':
            input_feats = {'audio': audio_feat, 'text': text_feat}
        else:
            raise ValueError(f"Unknown task type {self.task_type}")

        # -------- Assemble target --------
        traits = torch.tensor([
            sample['agreeableness'],
            sample['openness'],
            sample['neuroticism'],
            sample['extraversion'],
            sample['conscientiousness']
        ], dtype=torch.float32)

        return input_feats, traits


In [None]:
class ModalityLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=True)

    def forward(self, x):
        """
        x: (batch, feature_dim, time_steps)
        """
        x = x.permute(0, 2, 1)  # (batch, time_steps, feature_dim)
        out, (h_n, c_n) = self.lstm(x)
        return h_n[-1]  # Take last hidden state

class TextEncoder(nn.Module):
    def __init__(self, input_dim=768, output_dim=256):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)


In [None]:
class EarlyFusionRegressor(nn.Module):
    def __init__(self, input_dims, hidden_dim=256, output_dim=5):
        super().__init__()
        self.input_dims = input_dims
        total_input_dim = sum(input_dims.values())
        self.fc1 = nn.Linear(total_input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, inputs):
        # inputs: dict of {'audio': tensor, 'vision': tensor, 'text': tensor}
        x = torch.cat(list(inputs.values()), dim=-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class LateFusionRegressor(nn.Module):
    def __init__(self, input_dims, hidden_dim=256, output_dim=5):
        super().__init__()
        self.modalities = nn.ModuleDict({
            k: nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
            for k, dim in input_dims.items()
        })

    def forward(self, inputs):
        preds = []
        for k, v in inputs.items():
            preds.append(self.modalities[k](v))
        # Mean the outputs
        preds = torch.stack(preds, dim=0).mean(dim=0)
        return preds

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np
import os
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import pearsonr

In [None]:
class Trainer:
    def __init__(self, model, optimizer, loss_fn=nn.MSELoss(), device='cuda'):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.device = device

    def train(self, train_loader, val_loader, epochs=50, patience=5):
        best_val_loss = float('inf')
        best_model = None
        patience_counter = 0

        for epoch in range(epochs):
            self.model.train()
            total_train_loss = 0
            for batch_x, batch_y in train_loader:
                batch_x = {k: v.to(self.device) for k, v in batch_x.items()}
                batch_y = batch_y.to(self.device)

                preds = self.model(batch_x)
                loss = self.loss_fn(preds, batch_y)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_train_loss += loss.item()

            avg_train_loss = total_train_loss / len(train_loader)

            val_loss = self.evaluate_loss(val_loader)
            print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={val_loss:.4f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = self.model.state_dict()
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print("Early stopping triggered!")
                break

        self.model.load_state_dict(best_model)

    def evaluate_loss(self, loader):
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for batch_x, batch_y in loader:
                batch_x = {k: v.to(self.device) for k, v in batch_x.items()}
                batch_y = batch_y.to(self.device)

                preds = self.model(batch_x)
                loss = self.loss_fn(preds, batch_y)
                total_loss += loss.item()

        return total_loss / len(loader)

In [None]:
class Evaluator:
    @staticmethod
    def evaluate(model, loader, device='cuda'):
        model.eval()
        preds_list, labels_list = [], []

        with torch.no_grad():
            for batch_x, batch_y in loader:
                batch_x = {k: v.to(device) for k, v in batch_x.items()}
                batch_y = batch_y.to(device)

                preds = model(batch_x)
                preds_list.append(preds.cpu())
                labels_list.append(batch_y.cpu())

        preds = torch.cat(preds_list, dim=0).numpy()
        labels = torch.cat(labels_list, dim=0).numpy()

        return Evaluator.compute_metrics(preds, labels)

    @staticmethod
    def compute_metrics(preds, labels):
        results = {}
        num_traits = preds.shape[1]

        for i in range(num_traits):
            p = preds[:, i]
            l = labels[:, i]

            mae = np.mean(np.abs(p - l))
            acc = 1 - mae
            mse = mean_squared_error(l, p)
            r2 = r2_score(l, p)
            pcc, _ = pearsonr(l, p)
            mean_p = np.mean(p)
            mean_l = np.mean(l)
            var_p = np.var(p)
            var_l = np.var(l)
            ccc = (2 * pcc * np.sqrt(var_p) * np.sqrt(var_l)) / (var_p + var_l + (mean_p - mean_l) ** 2)

            results[f'trait_{i+1}'] = {
                'MAE': mae,
                'ACC': acc,
                'MSE': mse,
                'R2': r2,
                'PCC': pcc,
                'CCC': ccc
            }

        # Aggregate metrics
        avg_metrics = {metric: np.mean([results[f'trait_{i+1}'][metric] for i in range(num_traits)]) for metric in ['MAE', 'ACC', 'MSE', 'R2', 'PCC', 'CCC']}
        results['average'] = avg_metrics

        return results


In [None]:
class ExperimentRunner:
    def __init__(self, train_data, val_data, test_data, batch_size=64, device='cuda'):
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.batch_size = batch_size
        self.device = device

    def run(self, task_type, fusion_type):
        print(f"\n=== Running Task: {task_type.upper()}, Fusion: {fusion_type.upper()} ===")

        train_dataset = PersonalityDataset(self.train_data, split='train', task_type=task_type)
        val_dataset = PersonalityDataset(self.val_data, split='val', task_type=task_type)
        test_dataset = PersonalityDataset(self.test_data, split='test', task_type=task_type)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

        input_dims = self.get_input_dims(train_dataset)

        audio_lstm = ModalityLSTM(1024, hidden_dim=256) if 'audio' in input_dims else None
        vision_lstm = ModalityLSTM(512, hidden_dim=256) if 'vision' in input_dims else None
        text_encoder = TextEncoder(768, output_dim=256)

        if fusion_type == 'early':
            model = EarlyFusionRegressor(input_dims={k: 256 for k in input_dims})
        else:
            model = LateFusionRegressor(input_dims={k: 256 for k in input_dims})

        full_model = FullModel(audio_lstm, vision_lstm, text_encoder, model, device=self.device)
        optimizer = torch.optim.Adam(full_model.parameters(), lr=1e-3)

        trainer = Trainer(full_model, optimizer, loss_fn=nn.MSELoss(), device=self.device)
        trainer.train(train_loader, val_loader, epochs=50, patience=5)

        evaluator = Evaluator()
        results = evaluator.evaluate(full_model, test_loader, device=self.device)

        print("Results:")
        for trait, metrics in results.items():
            print(f"{trait}: {metrics}")

    def get_input_dims(self, dataset):
        sample, _ = dataset[0]
        return {k: v.shape[-1] for k, v in sample.items()}

In [None]:
class FullModel(nn.Module):
    def __init__(self, audio_lstm, vision_lstm, text_encoder, fusion_model, device='cuda'):
        super().__init__()
        self.audio_lstm = audio_lstm.to(device) if audio_lstm else None
        self.vision_lstm = vision_lstm.to(device) if vision_lstm else None
        self.text_encoder = text_encoder.to(device)
        self.fusion_model = fusion_model.to(device)

    def forward(self, inputs):
        feats = {}
        if 'audio' in inputs and self.audio_lstm:
            feats['audio'] = self.audio_lstm(inputs['audio'])
        if 'vision' in inputs and self.vision_lstm:
            feats['vision'] = self.vision_lstm(inputs['vision'])
        if 'text' in inputs:
            feats['text'] = self.text_encoder(inputs['text'])

        return self.fusion_model(feats)

In [None]:
TASKS = ['upper', 'middle_audio', 'middle_vision', 'lower_audio', 'lower_vision']
FUSIONS = ['early', 'late']

In [None]:
runner = ExperimentRunner(train_data, val_data, test_data, batch_size=16)


In [1]:
runner.run('upper', 'early')

In [None]:
# OCEAN Trait Names
OCEAN_TRAITS = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Neuroticism']

In [9]:
# Upper Bound - Late Fusion
task = 'upper'
fusion = 'late'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_upper_late.json ✅


In [10]:
# Upper Bound - Early Fusion
task = 'upper'
fusion = 'early'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_upper_early.json ✅


In [11]:
# Middle Audio - Early Fusion
task = 'middle_audio'
fusion = 'early'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_middle_audio_early.json ✅


In [12]:
# Middle Audio - Late Fusion
task = 'middle_audio'
fusion = 'late'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_middle_audio_late.json ✅


In [13]:
# Middle Vision - Early Fusion
task = 'middle_vision'
fusion = 'early'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_middle_vision_early.json ✅


In [14]:
# Middle Vision - Late Fusion
task = 'middle_vision'
fusion = 'late'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_middle_vision_late.json ✅


In [15]:
# Lower Audio - Early Fusion
task = 'lower_audio'
fusion = 'early'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_lower_audio_early.json ✅


In [16]:
# Lower Audio - Late Fusion
task = 'lower_audio'
fusion = 'late'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_lower_audio_late.json ✅


In [17]:
# Lower Vision - Early Fusion
task = 'lower_vision'
fusion = 'early'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_lower_vision_early.json ✅


In [18]:
# Lower Vision - Late Fusion
task = 'lower_vision'
fusion = 'late'
results = runner.run(task, fusion)

import json, os
os.makedirs('results', exist_ok=True)
with open(f'results/results_{task}_{fusion}.json', 'w') as f:
    json.dump(results, f, indent=4)
print(f"Saved results to results/results_{task}_{fusion}.json ✅")


Saved results to results/results_lower_vision_late.json ✅


In [None]:
class ConditionalDataset(Dataset):
    def __init__(self, data, target='visual'):
        self.data = data
        self.keys = list(data.keys())
        self.target = target

    def __len__(self):
        return len(self.keys)

    def pad_segments(self, tensor, target_segments=5):
        current_segments = tensor.size(0)
        if current_segments < target_segments:
            pad_size = target_segments - current_segments
            padding = torch.zeros(pad_size, tensor.size(1))
            tensor = torch.cat([tensor, padding], dim=0)
        return tensor

    def __getitem__(self, idx):
        key = self.keys[idx]
        sample = self.data[key]

        text_feat = sample['original_text'].float()  # (768,)

        if self.target == 'visual':
            audio_feat = sample['original_audio']

            if audio_feat.dim() == 3:
                audio_feat = audio_feat.mean(dim=1)  # (1024, 5)
            elif audio_feat.dim() == 2:
                pass
            elif audio_feat.dim() == 1:
                audio_feat = audio_feat.unsqueeze(1)  # (1024, 1)
            else:
                return None  # bad data

            if audio_feat.shape[0] != 1024 or audio_feat.shape[-1] == 0:
                return None

            cond_feat = audio_feat.permute(1, 0)  # (segments, 1024)
            cond_feat = self.pad_segments(cond_feat, target_segments=5)

        elif self.target == 'audio':
            vision_feat = sample['original_vision']
            if vision_feat.dim() == 3:
                vision_feat = vision_feat.squeeze(1).float()  # (512, T)
            elif vision_feat.dim() == 2:
                vision_feat = vision_feat.float()
            else:
                return None

            if vision_feat.shape[0] != 512 or vision_feat.shape[-1] == 0:
                return None

            cond_feat = vision_feat.permute(1, 0)  # (segments, 512)
            cond_feat = self.pad_segments(cond_feat, target_segments=5)

        else:
            raise ValueError("target must be 'visual' or 'audio'")

        return key, cond_feat, text_feat


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SegmentConditionalVAE(nn.Module):
    def __init__(self, text_dim=768, audio_dim=1024, visual_dim=512, latent_dim=128, target='visual'):
        super().__init__()

        # Decide which modality is missing (target), and which is available (conditioning)
        self.target = target
        self.latent_dim = latent_dim

        if target == 'visual':
            self.input_modality_dim = audio_dim
            self.output_dim = visual_dim
        elif target == 'audio':
            self.input_modality_dim = visual_dim
            self.output_dim = audio_dim
        else:
            raise ValueError("target must be 'visual' or 'audio'")

        self.segment_dim = 5  # number of segments (fixed for now)

        # ------ ENCODER -------
        self.encoder = nn.Sequential(
            nn.Linear(self.input_modality_dim + text_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # ------ DECODER -------
        self.decoder_input = nn.Linear(latent_dim + self.input_modality_dim + text_dim, 256)
        self.decoder = nn.Sequential(
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, self.output_dim)
    )

    def encode(self, condition_segments, text):
        """
        Encode the conditioning modality + text into μ and logσ²
        """
        B, T, _ = condition_segments.shape
        text = text.unsqueeze(1).expand(-1, T, -1)  # (B, 5, 768)
        concat = torch.cat([condition_segments, text], dim=-1)  # (B, 5, cond+768)
        hidden = self.encoder(concat)                # (B, 5, 256)
        mu = self.fc_mu(hidden)                      # (B, 5, latent_dim)
        logvar = self.fc_logvar(hidden)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Sample z from N(μ, σ²) using the reparameterization trick
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, condition_segments, text):
        """
        Decode the latent z with conditioning inputs into reconstructed output
        """
        B, T, _ = z.shape
        text = text.unsqueeze(1).expand(-1, T, -1)  # (B, 5, 768)
        dec_input = torch.cat([z, condition_segments, text], dim=-1)  # (B, 5, latent+cond+text)
        dec_input = self.decoder_input(dec_input)  # (B, 5, 256)
        out = self.decoder(dec_input)              # (B, 5, output_dim)
        return out

    def forward(self, condition_segments, text, use_mean=False):
        """
        Forward pass for training or inference.

        Args:
            condition_segments: (B, T, input_modality_dim)
            text: (B, 768)
            use_mean (bool): If True, use z = μ. Otherwise, sample using z = μ + σ * ε

        Returns:
            out: reconstructed modality
            mu, logvar: latent distribution parameters
        """
        mu, logvar = self.encode(condition_segments, text)

        # Safety: warn if use_mean=True during training (not recommended)
        if self.training and use_mean:
            print("[Warning] use_mean=True during training — model will behave deterministically. "
                  "You probably want to set use_mean=False for stochastic training.")

        if use_mean:
            z = mu
        else:
            z = self.reparameterize(mu, logvar)

        out = self.decode(z, condition_segments, text)
        return out, mu, logvar


    def sample_from_prior(self, condition_segments, text):
        """
        Pure generative inference: sample z ~ N(0, I), no encoder needed.
        """
        B, T, _ = condition_segments.shape
        z = torch.randn((B, T, self.latent_dim), device=condition_segments.device)
        out = self.decode(z, condition_segments, text)
        return out


In [None]:
import torch

# Assuming you have your CVAE model class defined as SegmentConditionalVAE
model = SegmentConditionalVAE(target='visual')  # or 'audio', depending on what you're reconstructing
checkpoint = torch.load('cvae_visual_beta_anneal.pt', map_location='cuda')
model.load_state_dict(checkpoint)
model.eval()
model.cuda()


SegmentConditionalVAE(
  (encoder): Sequential(
    (0): Linear(in_features=1792, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
  )
  (fc_mu): Linear(in_features=256, out_features=128, bias=True)
  (fc_logvar): Linear(in_features=256, out_features=128, bias=True)
  (decoder_input): Linear(in_features=1920, out_features=256, bias=True)
  (decoder): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=512, bias=True)
    (6): ReLU()
    (7): Linear(in_features=512, out_features=512, bias=True)
  )
)

In [None]:
from tqdm import tqdm

def reconstruct(model, loader, store_dict, feature_name):
    model.eval()
    with torch.no_grad():
        for keys, cond_feats, text_feats in tqdm(loader):
            cond_feats = cond_feats.cuda()
            text_feats = text_feats.cuda()

            recon_x, mu, logvar = model(cond_feats, text_feats, use_mean=True)
            recon_x = recon_x.cpu()

            for i, key in enumerate(keys):
                if key in store_dict:
                    store_dict[key][feature_name] = recon_x[i]



In [None]:
# import sys
# !{sys.executable} -m pip install tqdm

Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Using cached tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1


In [None]:
def safe_collate(batch):
    batch = [b for b in batch if b is not None]  # Remove bad samples

    keys, cond_feats, text_feats = zip(*batch)

    keys = list(keys)
    cond_feats = torch.stack(cond_feats, dim=0)
    text_feats = torch.stack(text_feats, dim=0)

    return keys, cond_feats, text_feats


In [None]:
# # ---- Choose your target ----
# TARGET = 'visual'   # 'visual' or 'audio'

# # ---- Load the right CVAE model ----
# model = SegmentConditionalVAE(target=TARGET)
# model.load_state_dict(torch.load('cvae_visual_beta_anneal.pt', map_location='cuda'))
# model.cuda()

# # ---- Prepare datasets/loaders ----
# train_dataset = ConditionalDataset(train_data, target=TARGET)
# val_dataset = ConditionalDataset(val_data, target=TARGET)

# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, collate_fn=safe_collate)
# val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=safe_collate)

# # ---- Create dicts to store reconstructed data ----
# reconstructed_train = {}
# reconstructed_val = {}

# # ---- Run Reconstruction ----
# feature_name = 'reconstructed_' + TARGET  # 'reconstructed_visual' or 'reconstructed_audio'

# reconstruct(model, train_loader, train_data, reconstructed_train, feature_name)
# reconstruct(model, val_loader, val_data, reconstructed_val, feature_name)

# # ---- Save the new pickles ----
# with open(f'reconstructed_cvae_train_{TARGET}.pkl', 'wb') as f:
#     pickle.dump(reconstructed_train, f)

# with open(f'reconstructed_cvae_val_{TARGET}.pkl', 'wb') as f:
#     pickle.dump(reconstructed_val, f)

# print(f"✅ Done reconstructing and saving for {TARGET}!")

In [None]:
# Make deep copies so we can safely add new fields
reconstructed_train = {k: v.copy() for k, v in train_data.items()}
reconstructed_val = {k: v.copy() for k, v in val_data.items()}

# ---- Reconstruct for both targets ----
for TARGET, MODEL_PATH in [('visual', 'cvae_visual_beta_anneal.pt'), ('audio', 'cvae_audio_beta_anneal.pt')]:

    print(f"🔵 Now reconstructing {TARGET}...")

    # Load the right model
    model = SegmentConditionalVAE(target=TARGET)
    model.load_state_dict(torch.load(MODEL_PATH, map_location='cuda'))
    model.cuda()

    # Prepare datasets/loaders
    train_dataset = ConditionalDataset(train_data, target=TARGET)
    val_dataset = ConditionalDataset(val_data, target=TARGET)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, collate_fn=safe_collate)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=safe_collate)

    # Run reconstruction
    feature_name = 'reconstructed_' + TARGET  # 'reconstructed_visual' or 'reconstructed_audio'

    reconstruct(model, train_loader, reconstructed_train, feature_name)
    reconstruct(model, val_loader, reconstructed_val, feature_name)

print("✅ All modalities reconstructed!")

# ---- Save the merged pickles ----
with open('reconstructed_cvae_train.pkl', 'wb') as f:
    pickle.dump(reconstructed_train, f)

with open('reconstructed_cvae_val.pkl', 'wb') as f:
    pickle.dump(reconstructed_val, f)

print("✅ Final reconstructed pickles saved!")

🔵 Now reconstructing visual...


100%|██████████| 94/94 [00:18<00:00,  5.04it/s]
100%|██████████| 32/32 [00:06<00:00,  4.93it/s]


🔵 Now reconstructing audio...


100%|██████████| 94/94 [00:00<00:00, 506.55it/s]
100%|██████████| 32/32 [00:00<00:00, 499.42it/s]


✅ All modalities reconstructed!
✅ Final reconstructed pickles saved!


In [None]:
reconstructed_val[list(reconstructed_val.keys())[0]].keys()

NameError: name 'reconstructed_val' is not defined

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

def pad_tensor_to_shape(x, target_shape):
    """Pads a tensor with zeros to match the target shape."""
    current_shape = x.shape
    pad = []
    for c, t in zip(reversed(current_shape), reversed(target_shape)):
        pad.extend([0, max(t - c, 0)])
    return F.pad(x, pad)

class PersonalityDataset(Dataset):
    def __init__(self, data_dict, split, task_type):
        """
        Args:
            data_dict: dictionary from loaded pickle
            split: 'train', 'val', 'test'
            task_type: 'upper', 'middle_audio', 'middle_vision', 'lower_audio', 'lower_vision'
        """
        self.data = data_dict
        self.split = split
        self.task_type = task_type
        self.keys = list(data_dict.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        vid = self.keys[idx]
        sample = self.data[vid]

        text_feat = sample['original_text'].float()

        # === Handle AUDIO ===
        if self.task_type in ['middle_vision', 'lower_vision']:
            if self.split == 'test':
                audio = sample['generated_audio'].squeeze(1).permute(1, 0).float()
            else:
                if 'reconstructed_audio' in sample:
                    audio = sample['reconstructed_audio'].permute(1, 0).float()  # (5,1024) -> (1024,5)
                else:
                    audio = sample['original_audio'].float()

        else:
            audio = sample['original_audio'].float()

        if audio.dim() == 3:
            audio_feat = audio.mean(dim=1)
        elif audio.dim() == 2:
            audio_feat = audio
        elif audio.dim() == 1:
            audio_feat = audio.unsqueeze(1).repeat(1, 5)
        else:
            raise ValueError(f"Unexpected audio feature shape {audio.shape} for video {vid}")

        audio_feat = pad_tensor_to_shape(audio_feat, (1024, 5))


        # === Handle VISION ===
        if self.task_type in ['middle_audio', 'lower_audio']:
            if self.split == 'test':
                vision_feat = sample['generated_vision'].squeeze(1).permute(1, 0).float()
            else:
                if 'reconstructed_visual' in sample:
                    vision_feat = sample['reconstructed_visual'].permute(1, 0).float()  # (5,512) -> (512,5)
                else:
                    vision_feat = sample['original_vision'].squeeze(1).float()

        else:
            vision_feat = sample['original_vision'].squeeze(1).float()

        vision_feat = pad_tensor_to_shape(vision_feat, (512, 5))


        # === Assemble Input Features ===
        input_feats = {}
        if self.task_type in ['upper', 'middle_audio', 'middle_vision']:
            input_feats = {'audio': audio_feat, 'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'lower_audio':
            input_feats = {'vision': vision_feat, 'text': text_feat}
        elif self.task_type == 'lower_vision':
            input_feats = {'audio': audio_feat, 'text': text_feat}
        else:
            raise ValueError(f"Unknown task type {self.task_type}")

        # === Target Traits ===
        traits = torch.tensor([
            sample['agreeableness'],
            sample['openness'],
            sample['neuroticism'],
            sample['extraversion'],
            sample['conscientiousness']
        ], dtype=torch.float32)

        return input_feats, traits


In [None]:
import json
import os

def make_json_serializable(obj):
    if isinstance(obj, dict):
        return {k: make_json_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_json_serializable(v) for v in obj]
    elif hasattr(obj, "item"):  # for np.float32, np.float64, etc.
        return obj.item()
    else:
        return obj

class ExperimentRunner:
    def __init__(self, train_data, val_data, test_data, batch_size=64, device='cuda'):
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.batch_size = batch_size
        self.device = device

    def run(self, task_type, fusion_type, result_prefix=''):
        print(f"\n=== Running Task: {task_type.upper()}, Fusion: {fusion_type.upper()} ===")

        train_dataset = PersonalityDataset(self.train_data, split='train', task_type=task_type)
        val_dataset = PersonalityDataset(self.val_data, split='val', task_type=task_type)
        test_dataset = PersonalityDataset(self.test_data, split='test', task_type=task_type)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

        input_dims = self.get_input_dims(train_dataset)

        audio_lstm = ModalityLSTM(1024, hidden_dim=256) if 'audio' in input_dims else None
        vision_lstm = ModalityLSTM(512, hidden_dim=256) if 'vision' in input_dims else None
        text_encoder = TextEncoder(768, output_dim=256)

        if fusion_type == 'early':
            model = EarlyFusionRegressor(input_dims={k: 256 for k in input_dims})
        else:
            model = LateFusionRegressor(input_dims={k: 256 for k in input_dims})

        full_model = FullModel(audio_lstm, vision_lstm, text_encoder, model, device=self.device)
        optimizer = torch.optim.Adam(full_model.parameters(), lr=1e-3)

        trainer = Trainer(full_model, optimizer, loss_fn=nn.MSELoss(), device=self.device)
        trainer.train(train_loader, val_loader, epochs=50, patience=5)

        evaluator = Evaluator()
        results = evaluator.evaluate(full_model, test_loader, device=self.device)

        print("Results:")
        for trait, metrics in results.items():
            print(f"{trait}: {metrics}")

        # Save results
        import os, json
        os.makedirs('results', exist_ok=True)
        serializable_results = make_json_serializable(results)

        filename = f"results/{result_prefix}results_{task_type}_{fusion_type}.json"
        with open(filename, 'w') as f:
            json.dump(serializable_results, f, indent=4)

        print(f"✅ Results saved to {filename}")

        return results

    def get_input_dims(self, dataset):
        sample, _ = dataset[0]
        return {k: v.shape[-1] for k, v in sample.items()}


In [None]:
# Load the updated CVAE pickles
with open('reconstructed_cvae_train.pkl', 'rb') as f:
    train_data = pickle.load(f)
with open('reconstructed_cvae_val.pkl', 'rb') as f:
    val_data = pickle.load(f)
with open('cvae_test.pkl', 'rb') as f:  # Assuming this is already set up
    test_data = pickle.load(f)


NameError: name 'reconstructed_train' is not defined

In [None]:
train_data[list(train_data.keys())[0]].keys()

dict_keys(['original_text', 'original_vision', 'original_audio', 'agreeableness', 'openness', 'neuroticism', 'extraversion', 'conscientiousness', 'reconstructed_visual', 'reconstructed_audio'])

In [None]:
val_data[list(val_data.keys())[0]].keys()

dict_keys(['original_text', 'original_vision', 'original_audio', 'agreeableness', 'openness', 'neuroticism', 'extraversion', 'conscientiousness', 'reconstructed_visual', 'reconstructed_audio'])

In [None]:
test_data[list(test_data.keys())[0]].keys()

dict_keys(['original_text', 'original_vision', 'original_audio', 'agreeableness', 'openness', 'neuroticism', 'extraversion', 'conscientiousness', 'generated_vision', 'generated_audio'])

In [19]:
# Create runner
runner = ExperimentRunner(train_data, val_data, test_data)

# Run CVAE-based middle tasks
for task in ['middle_audio', 'middle_vision']:
    for fusion in ['early', 'late']:
        runner.run(task, fusion, result_prefix='cvae_')


In [None]:
!ls ./results/

comparison_table.csv		       results_lower_audio_late.json
cvae_results_middle_audio_early.json   results_lower_vision_early.json
cvae_results_middle_audio_late.json    results_lower_vision_late.json
cvae_results_middle_vision_early.json  results_middle_audio_early.json
cvae_results_middle_vision_late.json   results_middle_audio_late.json
diff_results_middle_audio_early.json   results_middle_vision_early.json
diff_results_middle_audio_late.json    results_middle_vision_late.json
diff_results_middle_vision_early.json  results_upper_early.json
diff_results_middle_vision_late.json   results_upper_late.json
pivot_ccc.csv			       unified_comparison_table.csv
pivot_mae.csv			       unified_comparison_with_diffusion.csv
results_lower_audio_early.json
