In [14]:
import sys
import os
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
from tqdm import tqdm
from SEMPIDataLoader import ListenerSpeakerFeatureDataset
from multimodal_xattention import EarlyFusion

import numpy as np
from types import SimpleNamespace

import matplotlib.pyplot as plt
from scipy.stats import pearsonr

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = ListenerSpeakerFeatureDataset(
    csv_path="AudioVideo_Feature_Paths.csv",
    frame_length=64,
    root_dir="./",
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)


In [16]:
sample = dataset[0]
speaker_feat, listener_feat = sample["features"]
listener_dim, speaker_dim = listener_feat.shape[0], speaker_feat.shape[0]
print(f"Listener feature shape: {listener_feat.shape}")
print(f"Speaker feature shape: {speaker_feat.shape}")

Listener feature shape: torch.Size([329, 64])
Speaker feature shape: torch.Size([424, 64])


In [66]:
config = {
    'ckpt_root': './pretrained',
    'activation_fn': 'tanh',
    'extra_dropout': 0,
    'hidden_size': 128,
    'dropout': 0.1,
    'weight_decay': 0.01,
    'expnum': 8,
    'openfacefeat': 1,
    'openfacefeat_extramlp': 1,
    'openfacefeat_extramlp_dim': 128,
    'ablation': 8,
    'num_labels': 1
}

config = SimpleNamespace(**config)

model = EarlyFusion(config=config).to(device).double()
for name, param in model.named_parameters():
    print(f"Parameter {name}: dtype = {param.dtype}")
print(model)

Parameter extra_mlp.0.weight: dtype = torch.float64
Parameter extra_mlp.0.bias: dtype = torch.float64
Parameter out.fc1.weight: dtype = torch.float64
Parameter out.fc1.bias: dtype = torch.float64
Parameter out.fc2.weight: dtype = torch.float64
Parameter out.fc2.bias: dtype = torch.float64
Parameter out.fc3.weight: dtype = torch.float64
Parameter out.fc3.bias: dtype = torch.float64
Parameter cross_attention.in_proj_weight: dtype = torch.float64
Parameter cross_attention.in_proj_bias: dtype = torch.float64
Parameter cross_attention.out_proj.weight: dtype = torch.float64
Parameter cross_attention.out_proj.bias: dtype = torch.float64
Parameter audio_mlp.0.weight: dtype = torch.float64
Parameter audio_mlp.0.bias: dtype = torch.float64
Parameter fusion_mlp.0.weight: dtype = torch.float64
Parameter fusion_mlp.0.bias: dtype = torch.float64
EarlyFusion(
  (extra_mlp): Sequential(
    (0): Linear(in_features=329, out_features=128, bias=True)
    (1): Tanh()
  )
  (out): Classifier(
    (dropout)

In [67]:
# Count total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

Total number of parameters: 289153
Number of trainable parameters: 289153


In [68]:
def compute_ccc_batched(y_pred, y_true):
    # Flatten 
    y_true_np = y_true.flatten()
    y_pred_np = y_pred.flatten()
    mean_true = np.mean(y_true_np)
    mean_pred = np.mean(y_pred_np)
    std_true = np.std(y_true_np)
    std_pred = np.std(y_pred_np)

    # Pearson
    rho, _ = pearsonr(y_true_np, y_pred_np)

    # CCC from Pearson
    ccc = (2 * rho * std_true * std_pred) / (std_true**2 + std_pred**2 + (mean_true - mean_pred)**2)
    
    return ccc

import numpy as np
from scipy.stats import pearsonr

def compute_pearson_correlation_batched(y_pred, y_true):
    # Flatten 
    y_true_np = y_true.flatten()
    y_pred_np = y_pred.flatten()

    # Calculate Pearson 
    rho, _ = pearsonr(y_true_np, y_pred_np)

    return rho

In [69]:
# Training settings
num_epochs = 30
learning_rate = 1e-3
early_stop_patience = 10
best_model_path = 'best_early_fusion_model.pth'

# Define the loss function
criterion = torch.nn.MSELoss()

# Define the optimizer with your settings
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=config.weight_decay)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

# Training tracking
best_val_loss = float('inf')
no_improve_count = 0
train_losses = []
val_losses = []
val_cccs = []
val_pccs = []

# Create figure for real-time plotting
plt.figure(figsize=(15, 5))



<Figure size 1500x500 with 0 Axes>

<Figure size 1500x500 with 0 Axes>

In [70]:
try:
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Training phase
        model.train()
        total_train_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False):
            # Get features and target
            speaker_feat, listener_feat = batch["features"]
            engagement = batch["score"]
            print(f"speaker_feat dtype before conversion: {speaker_feat.dtype}")
            print(f"listener_feat dtype before conversion: {listener_feat.dtype}")
            
            # Move to device and reshape
            speaker_feat = speaker_feat.to(device).double()
            listener_feat = listener_feat.to(device).double()
            engagement = engagement.to(device)
            engagement = engagement.view(-1, 1)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Create a list of listener features with transposed dimensions
            listener_feat_list = []
            for i in range(listener_feat.size(0)):
                listener_feat_list.append(listener_feat[i].transpose(0, 1))
            
            # Forward pass with properly formatted features
            output = model(audio_paths=speaker_feat, openfacefeat_=listener_feat_list)
            
            # Calculate loss
            loss = criterion(output, engagement)
            
            # Backward pass
            loss.backward()
            
            # Update weights
            optimizer.step()
            
            # Track loss
            total_train_loss += loss.item()
        
        # Calculate average training loss
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Training Loss: {avg_train_loss:.4f}")
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False):
                # Get features and target
                speaker_feat, listener_feat = batch["features"]
                engagement = batch["score"]
                
                # Move to device and reshape
                speaker_feat = speaker_feat.to(device)
                listener_feat = listener_feat.to(device)
                engagement = engagement.to(device)
                engagement = engagement.view(-1, 1)

                # Create a list of listener features with transposed dimensions
                listener_feat_list = []
                for i in range(listener_feat.size(0)):
                    listener_feat_list.append(listener_feat[i].transpose(0, 1))
                
                # Forward pass
                output = model(audio_paths=speaker_feat, openfacefeat_=listener_feat_list)
                
                # Calculate loss
                loss = criterion(output, engagement)
                
                # Track loss
                val_loss += loss.item() * engagement.size(0)
                
                # Store predictions and targets for metrics
                val_preds.append(output.cpu())
                val_targets.append(engagement.cpu())
        
        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        
        # Calculate metrics
        val_preds_combined = torch.cat(val_preds).numpy()
        val_targets_combined = torch.cat(val_targets).numpy()
        
        val_ccc = compute_ccc_batched(val_preds_combined, val_targets_combined)
        val_pcc = compute_pearson_correlation_batched(val_preds_combined, val_targets_combined)
        
        val_cccs.append(val_ccc)
        val_pccs.append(val_pcc)
        
        print(f"Validation Loss: {avg_val_loss:.4f}, CCC: {val_ccc:.4f}, PCC: {val_pcc:.4f}")
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        
        # Check if this is the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # Save model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_val_loss,
                'val_ccc': val_ccc,
                'val_pcc': val_pcc,
                'config': config,
            }, best_model_path)
            print(f"✓ Saved new best model with validation loss: {avg_val_loss:.4f}")
            no_improve_count = 0
        else:
            no_improve_count += 1
            print(f"✗ No improvement for {no_improve_count} epochs")
        
        # Early stopping
        if no_improve_count >= early_stop_patience:
            print(f"Early stopping after {epoch+1} epochs without improvement")
            break
        
        # Update plots
        plt.clf()
        
        # Plot 1: Loss curves
        plt.subplot(1, 3, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Loss Curves')
        
        # Plot 2: CCC curve
        plt.subplot(1, 3, 2)
        plt.plot(val_cccs, label='Validation CCC', color='green')
        plt.xlabel('Epoch')
        plt.ylabel('CCC')
        plt.legend()
        plt.title('Concordance Correlation Coefficient')
        
        # Plot 3: PCC curve
        plt.subplot(1, 3, 3)
        plt.plot(val_pccs, label='Validation PCC', color='purple')
        plt.xlabel('Epoch')
        plt.ylabel('PCC')
        plt.legend()
        plt.title('Pearson Correlation Coefficient')
        
        plt.tight_layout()
        plt.savefig('training_metrics.png')
        plt.pause(0.1)
        
except KeyboardInterrupt:
    print("Training interrupted by user")

Epoch 1/30


                                                       

speaker_feat dtype before conversion: torch.float32
listener_feat dtype before conversion: torch.float32




RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x21056 and 768x128)

In [None]:
# Final visualization
plt.figure(figsize=(15, 5))

# Plot 1: Loss curves
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curves')

# Plot 2: CCC curve
plt.subplot(1, 3, 2)
plt.plot(val_cccs, label='Validation CCC', color='green')
plt.xlabel('Epoch')
plt.ylabel('CCC')
plt.legend()
plt.title('Concordance Correlation Coefficient')

# Plot 3: PCC curve
plt.subplot(1, 3, 3)
plt.plot(val_pccs, label='Validation PCC', color='purple')
plt.xlabel('Epoch')
plt.ylabel('PCC')
plt.legend()
plt.title('Pearson Correlation Coefficient')

plt.tight_layout()
plt.savefig('final_training_metrics.png')
plt.show()

In [None]:
# Load the best model for evaluation
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1} with:")
    print(f"- Validation Loss: {checkpoint['val_loss']:.4f}")
    print(f"- Validation CCC: {checkpoint['val_ccc']:.4f}")
    print(f"- Validation PCC: {checkpoint['val_pcc']:.4f}")

# Comprehensive evaluation on validation set
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            speaker_feat, listener_feat = batch["features"]
            engagement = batch["score"]
            
            speaker_feat = speaker_feat.to(device)
            listener_feat = listener_feat.to(device)
            engagement = engagement.to(device)
            engagement = engagement.view(-1, 1)
            
            output = model(listener_feat, speaker_feat)
            
            all_preds.append(output.cpu().numpy())
            all_targets.append(engagement.cpu().numpy())
    
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    
    # Calculate metrics
    mse = np.mean((all_targets - all_preds) ** 2)
    rmse = np.sqrt(mse)
    ccc = compute_ccc_batched(all_preds, all_targets)
    pcc = compute_pearson_correlation_batched(all_preds, all_targets)
    
    return {
        'mse': mse,
        'rmse': rmse,
        'ccc': ccc,
        'pcc': pcc,
        'predictions': all_preds,
        'targets': all_targets
    }

# Run final evaluation
print("Running final evaluation on validation set...")
eval_results = evaluate_model(model, val_loader, device)

print("\nFinal Evaluation Metrics:")
print(f"- MSE: {eval_results['mse']:.4f}")
print(f"- RMSE: {eval_results['rmse']:.4f}")
print(f"- CCC: {eval_results['ccc']:.4f}")
print(f"- PCC: {eval_results['pcc']:.4f}")

# Plot predictions vs targets for a sample
plt.figure(figsize=(10, 6))
sample_size = min(100, len(eval_results['predictions']))
indices = np.random.choice(len(eval_results['predictions']), sample_size, replace=False)

plt.scatter(eval_results['targets'][indices], eval_results['predictions'][indices], alpha=0.5)
plt.plot([min(eval_results['targets']), max(eval_results['targets'])], 
         [min(eval_results['targets']), max(eval_results['targets'])], 'r--')
plt.xlabel('True Values')
plt.ylabel('Predicted Values')
plt.title('Predicted vs True Values')
plt.grid(True)
plt.savefig('prediction_scatter.png')
plt.show()



NameError: name 'plt' is not defined