In [2]:
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
from tqdm import tqdm
from SEMPIDataLoader import ListenerSpeakerFeatureDataset
import torch
import torch.nn as nn
import sys, os
import numpy as np
from types import SimpleNamespace


In [3]:
sys.path.append(os.path.abspath("code"))
from metr import compute_ccc_batched , compute_pearson_correlation_batched,compute_r2_score_batched

In [4]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = ListenerSpeakerFeatureDataset(
    csv_path="AudioVideo_Feature_Path_v2.csv",
    frame_length=64,
    root_dir="./",
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)

sample = dataset[3]
speaker_feat, listener_feat,_  = sample["features"]
listener_dim, speaker_dim = listener_feat.shape[0], speaker_feat.shape[0]
print(f"Listener feature shape: {listener_feat.shape}")
print(f"Speaker feature shape: {speaker_feat.shape}")



Listener feature shape: torch.Size([1024, 64])
Speaker feature shape: torch.Size([768, 64])


In [None]:
config = { 'activation_fn': 'Tanh',
           'extra_dropout':0 ,
             'hidden_size':128,
             'dropout': 0.1,
             'num_labels': 1}

config = SimpleNamespace(**config)

# MLP and Self-attention on MLP hidden features

In [5]:

class MLPUpToHidden(nn.Module):
    def __init__(self, in_size, hidden_size, dropout, config):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(in_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)

        self.activation_fn = {
            "relu": nn.ReLU(),
            "leaky_relu": nn.LeakyReLU(),
            "Tanh": nn.Tanh()
        }[config.activation_fn]

    def forward(self, x):
        x = self.dropout(self.fc1(x))
        x = self.activation_fn(x)
        x = self.dropout(self.fc2(x))
        x = self.activation_fn(x)
        return x  


class CrossAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=1, batch_first=True)

    def forward(self, q, k):
        # q, k: [B, D]
        q = q.unsqueeze(1)  # [B, 1, D]
        k = k.unsqueeze(1)  # [B, 1, D]
        v = k  
        # print(f"q shape: {q.shape}")
        # print(f"k shape: {k.shape}")
        # print(f"v shape: {v.shape}")
        attn_out, _ = self.attn(q, k, v)  # [B, 1, D]
        return attn_out.squeeze(1)  # [B, D]
    
        # q shape: torch.Size([32, 1, 128])
        # k shape: torch.Size([32, 1, 128])
        # v shape: torch.Size([32, 1, 128])

class ListenerSpeakerFusion(nn.Module):
    def __init__(self, config, listener_input_dim=329, listener_seq_len=50,
                 speaker_input_dim=768,
                 reduced_speaker_dim=32):
        super().__init__()
        self.config = config

        self.listener_pool = nn.Sequential(
            nn.Linear(listener_input_dim, 1),
            nn.Flatten(start_dim=1)
        )

        self.speaker_cnn = nn.Sequential(
            nn.Conv1d(speaker_input_dim, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(64, reduced_speaker_dim)
        )

        self.mlp_input_size = listener_seq_len + reduced_speaker_dim
        self.mlp = MLPUpToHidden(
            in_size=self.mlp_input_size,
            hidden_size=config.hidden_size,
            dropout=config.dropout,
            config=config
        )

        self.cross_attention = CrossAttention(dim=config.hidden_size)
        self.final_fc = nn.Linear(config.hidden_size, config.num_labels)



    def forward(self, listener_feats, speaker_feats):

        listener_x = self.listener_pool(listener_feats) 
        # print(f"Listener x shape: {listener_x.shape}") # [B, 1024]
        speaker_y = self.speaker_cnn(speaker_feats)     
        # print("speaker_y shape: ", speaker_y.shape) #[B, 32]
        fused = torch.cat([listener_x, speaker_y], dim=1)  # [B, 1056]
        # print(f"Fused shape: {fused.shape}")
        mlp_features = self.mlp(fused)                     # [B, hidden_size] , hidden_size=128
        # print(f"MLP features shape: {mlp_features.shape}")
        attn_out = self.cross_attention(mlp_features, mlp_features)  # [B, hidden_size]
        # print(f"Attention output shape: {attn_out.shape}") # [B, 128]
        return self.final_fc(attn_out)  # [B, num_labels]



## Experiment 1: No cross validation 

In [None]:

sample = dataset[0]
speaker_feat, listener_feat ,_= sample["features"]

model = ListenerSpeakerFusion(config=config , listener_input_dim=listener_feat.shape[1], listener_seq_len=listener_feat.shape[0], speaker_input_dim=speaker_feat.shape[0],reduced_speaker_dim=32).to(device)
model.train()


ListenerSpeakerFusion(
  (listener_pool): Sequential(
    (0): Linear(in_features=64, out_features=1, bias=True)
    (1): Flatten(start_dim=1, end_dim=-1)
  )
  (speaker_cnn): Sequential(
    (0): Conv1d(768, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): ReLU()
    (2): AdaptiveAvgPool1d(output_size=1)
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=64, out_features=32, bias=True)
  )
  (mlp): MLPUpToHidden(
    (dropout): Dropout(p=0.1, inplace=False)
    (fc1): Linear(in_features=1056, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
    (activation_fn): Tanh()
  )
  (cross_attention): CrossAttention(
    (attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
  )
  (final_fc): Linear(in_features=128, out_features=1, bias=True)
)

In [None]:
epochs = 20

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        speaker_feat, listener_feat,_ = batch["features"]
        engagement = batch["score"]

        speaker_feat = speaker_feat.to(device)
        listener_feat = listener_feat.to(device)
        engagement = engagement.to(device)
        engagement = engagement.view(-1, 1)
 
 
        optimizer.zero_grad()
        output = model(listener_feat, speaker_feat)
        loss = criterion(output, engagement)
        # print(output)
        # print(engagement)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss / len(train_loader):.4f}")
    model.eval()
    val_loss = 0
    val_correct = 0
    val_preds = []
    val_targets = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False):
                speaker, listener,_ = batch["features"]
                target = batch["score"].to(device)
                target = target.view(-1, 1)
                speaker = speaker.to(device)
                listener = listener.to(device)

                output = model(listener, speaker)
                loss = criterion(output, target)

                val_loss += loss.item() * target.size(0)
                val_preds.append(output.cpu())
                val_targets.append(target.cpu())
                # print(compute_ccc_batched(output.cpu(), target.cpu()))

    val_loss /= len(val_loader.dataset)
    val_ccc = compute_ccc_batched(
            np.concatenate(val_preds),
            np.concatenate(val_targets)
            
        )
    val_pcc = compute_pearson_correlation_batched(
             np.concatenate(val_preds),
            np.concatenate(val_targets)
           
        )
    val_r2 = compute_r2_score_batched(
         np.concatenate(val_targets),
             np.concatenate(val_preds)
            
           
        )
    val_preds = torch.cat(val_preds).numpy()
    val_targets = torch.cat(val_targets).numpy()
      # if epoch == 5:
        #     print("Val targets and preds:")
        #     print(val_targets , val_preds)
        
    print(f"| Val Loss: {val_loss:.4f} | Val CCC: {val_ccc:.4f} | Val PCC: {val_pcc:.4f} | Val R2: {val_r2:.4f} |")


    

## Experiment 2: Cross Validation 

In [8]:
from sklearn.model_selection import KFold
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset

epochs = 30
k_folds = 5
batch_size = 32
learning_rate = 1e-3

kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)


fold_results = {}

sample = dataset[0]

for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
    print(f'\n--- Fold {fold+1} / {k_folds} ---')

    # Create data loaders for this fold
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # model = model(config).to(device)
    speaker, listener ,_= sample["features"]
    # print("Listener feature shape: ", listener_feat.shape)
    # print("Speaker feature shape: ", speaker_feat.shape)
    model = ListenerSpeakerFusion(config=config , listener_input_dim=listener.shape[1], listener_seq_len=listener.shape[0], speaker_input_dim=speaker.shape[0], reduced_speaker_dim=32).to(device)


    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            speaker_feat, listener_feat, _ = batch["features"]
            engagement = batch["score"]

            speaker_feat = speaker_feat.to(device)
            listener_feat = listener_feat.to(device)
            engagement = engagement.to(device)
            engagement = engagement.view(-1, 1)

            optimizer.zero_grad()
            # print("before model", listener_feat.shape , speaker_feat.shape)
            output = model(listener_feat, speaker_feat)
            loss = criterion(output, engagement)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        val_preds = []
        val_targets = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False):
                speaker, listener, _ = batch["features"]
                target = batch["score"].to(device)
                target = target.view(-1, 1)
                speaker = speaker.to(device)
                listener = listener.to(device)

                output = model(listener, speaker)
                loss = criterion(output, target)

                val_loss += loss.item() * target.size(0)
                val_preds.append(output.cpu())
                val_targets.append(target.cpu())

        val_loss /= len(val_loader.dataset)

        val_preds_np = torch.cat(val_preds).numpy()
        val_targets_np = torch.cat(val_targets).numpy()

        val_ccc = compute_ccc_batched(val_preds_np, val_targets_np)
        val_pcc = compute_pearson_correlation_batched(val_preds_np, val_targets_np)
        val_r2 = compute_r2_score_batched(val_targets_np, val_preds_np)

        print(f"| Val Loss: {val_loss:.4f} | Val CCC: {val_ccc:.4f} | Val PCC: {val_pcc:.4f} | Val R2: {val_r2:.4f} |")

    # Save the result for this fold
    fold_results[fold] = {
        'val_loss': val_loss,
        'val_ccc': val_ccc,
        'val_pcc': val_pcc,
        'val_r2': val_r2
    }

# After K-Folds
print("\n--- K-Fold Cross Validation Results ---")
for fold in fold_results:
    print(f"Fold {fold+1} | Val Loss: {fold_results[fold]['val_loss']:.4f} | CCC: {fold_results[fold]['val_ccc']:.4f} | PCC: {fold_results[fold]['val_pcc']:.4f} | R2: {fold_results[fold]['val_r2']:.4f}")



--- Fold 1 / 5 ---
Epoch 1/30 - Train Loss: 0.0186


                                                              

| Val Loss: 0.0192 | Val CCC: 0.0389 | Val PCC: 0.2100 | Val R2: 0.0090 |
Epoch 2/30 - Train Loss: 0.0166


                                                              

| Val Loss: 0.0189 | Val CCC: 0.1331 | Val PCC: 0.2958 | Val R2: 0.0263 |
Epoch 3/30 - Train Loss: 0.0153


                                                              

| Val Loss: 0.0206 | Val CCC: 0.2639 | Val PCC: 0.3082 | Val R2: -0.0654 |
Epoch 4/30 - Train Loss: 0.0144


                                                              

| Val Loss: 0.0171 | Val CCC: 0.2489 | Val PCC: 0.3504 | Val R2: 0.1151 |
Epoch 5/30 - Train Loss: 0.0133


                                                              

| Val Loss: 0.0169 | Val CCC: 0.1920 | Val PCC: 0.3781 | Val R2: 0.1267 |
Epoch 6/30 - Train Loss: 0.0127


                                                              

| Val Loss: 0.0187 | Val CCC: 0.3519 | Val PCC: 0.3749 | Val R2: 0.0340 |
Epoch 7/30 - Train Loss: 0.0122


                                                              

| Val Loss: 0.0174 | Val CCC: 0.2605 | Val PCC: 0.3448 | Val R2: 0.1023 |
Epoch 8/30 - Train Loss: 0.0118


                                                              

| Val Loss: 0.0197 | Val CCC: 0.3115 | Val PCC: 0.3425 | Val R2: -0.0183 |
Epoch 9/30 - Train Loss: 0.0117


                                                              

| Val Loss: 0.0184 | Val CCC: 0.2990 | Val PCC: 0.3429 | Val R2: 0.0508 |
Epoch 10/30 - Train Loss: 0.0109


                                                               

| Val Loss: 0.0175 | Val CCC: 0.2856 | Val PCC: 0.3517 | Val R2: 0.0974 |
Epoch 11/30 - Train Loss: 0.0107


                                                               

| Val Loss: 0.0188 | Val CCC: 0.2962 | Val PCC: 0.3496 | Val R2: 0.0300 |
Epoch 12/30 - Train Loss: 0.0106


                                                               

| Val Loss: 0.0186 | Val CCC: 0.2913 | Val PCC: 0.3336 | Val R2: 0.0371 |
Epoch 13/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0192 | Val CCC: 0.3126 | Val PCC: 0.3402 | Val R2: 0.0099 |
Epoch 14/30 - Train Loss: 0.0102


                                                               

| Val Loss: 0.0191 | Val CCC: 0.3144 | Val PCC: 0.3431 | Val R2: 0.0143 |
Epoch 15/30 - Train Loss: 0.0098


                                                               

| Val Loss: 0.0202 | Val CCC: 0.2862 | Val PCC: 0.3206 | Val R2: -0.0414 |
Epoch 16/30 - Train Loss: 0.0098


                                                               

| Val Loss: 0.0196 | Val CCC: 0.2848 | Val PCC: 0.3198 | Val R2: -0.0124 |
Epoch 17/30 - Train Loss: 0.0098


                                                               

| Val Loss: 0.0183 | Val CCC: 0.2834 | Val PCC: 0.3358 | Val R2: 0.0559 |
Epoch 18/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0202 | Val CCC: 0.2965 | Val PCC: 0.3255 | Val R2: -0.0450 |
Epoch 19/30 - Train Loss: 0.0098


                                                               

| Val Loss: 0.0181 | Val CCC: 0.2698 | Val PCC: 0.3262 | Val R2: 0.0652 |
Epoch 20/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0218 | Val CCC: 0.3030 | Val PCC: 0.3120 | Val R2: -0.1271 |
Epoch 21/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0209 | Val CCC: 0.3036 | Val PCC: 0.3240 | Val R2: -0.0784 |
Epoch 22/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0192 | Val CCC: 0.2785 | Val PCC: 0.3136 | Val R2: 0.0102 |
Epoch 23/30 - Train Loss: 0.0092


                                                               

| Val Loss: 0.0191 | Val CCC: 0.2763 | Val PCC: 0.3130 | Val R2: 0.0152 |
Epoch 24/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0186 | Val CCC: 0.2738 | Val PCC: 0.3189 | Val R2: 0.0388 |
Epoch 25/30 - Train Loss: 0.0091


                                                               

| Val Loss: 0.0205 | Val CCC: 0.2977 | Val PCC: 0.3154 | Val R2: -0.0592 |
Epoch 26/30 - Train Loss: 0.0091


                                                               

| Val Loss: 0.0196 | Val CCC: 0.2923 | Val PCC: 0.3230 | Val R2: -0.0099 |
Epoch 27/30 - Train Loss: 0.0091


                                                               

| Val Loss: 0.0198 | Val CCC: 0.2949 | Val PCC: 0.3197 | Val R2: -0.0222 |
Epoch 28/30 - Train Loss: 0.0091


                                                               

| Val Loss: 0.0210 | Val CCC: 0.3020 | Val PCC: 0.3154 | Val R2: -0.0838 |
Epoch 29/30 - Train Loss: 0.0090


                                                               

| Val Loss: 0.0204 | Val CCC: 0.2839 | Val PCC: 0.3068 | Val R2: -0.0557 |
Epoch 30/30 - Train Loss: 0.0092


                                                               

| Val Loss: 0.0200 | Val CCC: 0.2879 | Val PCC: 0.3191 | Val R2: -0.0324 |

--- Fold 2 / 5 ---
Epoch 1/30 - Train Loss: 0.0187


                                                              

| Val Loss: 0.0161 | Val CCC: 0.0815 | Val PCC: 0.2598 | Val R2: 0.0473 |
Epoch 2/30 - Train Loss: 0.0173


                                                              

| Val Loss: 0.0168 | Val CCC: 0.1326 | Val PCC: 0.3485 | Val R2: 0.0086 |
Epoch 3/30 - Train Loss: 0.0162


                                                              

| Val Loss: 0.0151 | Val CCC: 0.2888 | Val PCC: 0.3671 | Val R2: 0.1089 |
Epoch 4/30 - Train Loss: 0.0144


                                                              

| Val Loss: 0.0154 | Val CCC: 0.2481 | Val PCC: 0.3644 | Val R2: 0.0887 |
Epoch 5/30 - Train Loss: 0.0139


                                                              

| Val Loss: 0.0148 | Val CCC: 0.2586 | Val PCC: 0.3634 | Val R2: 0.1284 |
Epoch 6/30 - Train Loss: 0.0134


                                                              

| Val Loss: 0.0148 | Val CCC: 0.2709 | Val PCC: 0.3661 | Val R2: 0.1238 |
Epoch 7/30 - Train Loss: 0.0123


                                                              

| Val Loss: 0.0155 | Val CCC: 0.2816 | Val PCC: 0.3491 | Val R2: 0.0837 |
Epoch 8/30 - Train Loss: 0.0116


                                                              

| Val Loss: 0.0158 | Val CCC: 0.3309 | Val PCC: 0.3721 | Val R2: 0.0700 |
Epoch 9/30 - Train Loss: 0.0117


                                                              

| Val Loss: 0.0148 | Val CCC: 0.3036 | Val PCC: 0.3803 | Val R2: 0.1244 |
Epoch 10/30 - Train Loss: 0.0111


                                                               

| Val Loss: 0.0179 | Val CCC: 0.3504 | Val PCC: 0.3651 | Val R2: -0.0565 |
Epoch 11/30 - Train Loss: 0.0111


                                                               

| Val Loss: 0.0165 | Val CCC: 0.3322 | Val PCC: 0.3579 | Val R2: 0.0266 |
Epoch 12/30 - Train Loss: 0.0108


                                                               

| Val Loss: 0.0165 | Val CCC: 0.2793 | Val PCC: 0.3520 | Val R2: 0.0285 |
Epoch 13/30 - Train Loss: 0.0105


                                                               

| Val Loss: 0.0192 | Val CCC: 0.3222 | Val PCC: 0.3341 | Val R2: -0.1364 |
Epoch 14/30 - Train Loss: 0.0104


                                                               

| Val Loss: 0.0174 | Val CCC: 0.3259 | Val PCC: 0.3482 | Val R2: -0.0259 |
Epoch 15/30 - Train Loss: 0.0102


                                                               

| Val Loss: 0.0183 | Val CCC: 0.3061 | Val PCC: 0.3240 | Val R2: -0.0776 |
Epoch 16/30 - Train Loss: 0.0103


                                                               

| Val Loss: 0.0157 | Val CCC: 0.2780 | Val PCC: 0.3377 | Val R2: 0.0755 |
Epoch 17/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0168 | Val CCC: 0.3117 | Val PCC: 0.3385 | Val R2: 0.0070 |
Epoch 18/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0159 | Val CCC: 0.2865 | Val PCC: 0.3492 | Val R2: 0.0599 |
Epoch 19/30 - Train Loss: 0.0102


                                                               

| Val Loss: 0.0162 | Val CCC: 0.2913 | Val PCC: 0.3330 | Val R2: 0.0451 |
Epoch 20/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0168 | Val CCC: 0.3186 | Val PCC: 0.3432 | Val R2: 0.0055 |
Epoch 21/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0181 | Val CCC: 0.3243 | Val PCC: 0.3384 | Val R2: -0.0678 |
Epoch 22/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0164 | Val CCC: 0.2872 | Val PCC: 0.3273 | Val R2: 0.0340 |
Epoch 23/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0164 | Val CCC: 0.2805 | Val PCC: 0.3204 | Val R2: 0.0300 |
Epoch 24/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0191 | Val CCC: 0.3023 | Val PCC: 0.3117 | Val R2: -0.1264 |
Epoch 25/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0181 | Val CCC: 0.3145 | Val PCC: 0.3316 | Val R2: -0.0712 |
Epoch 26/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0192 | Val CCC: 0.3125 | Val PCC: 0.3289 | Val R2: -0.1334 |
Epoch 27/30 - Train Loss: 0.0095


                                                               

| Val Loss: 0.0167 | Val CCC: 0.2882 | Val PCC: 0.3211 | Val R2: 0.0119 |
Epoch 28/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0198 | Val CCC: 0.2913 | Val PCC: 0.3167 | Val R2: -0.1698 |
Epoch 29/30 - Train Loss: 0.0094


                                                               

| Val Loss: 0.0172 | Val CCC: 0.2880 | Val PCC: 0.3169 | Val R2: -0.0156 |
Epoch 30/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0183 | Val CCC: 0.3025 | Val PCC: 0.3275 | Val R2: -0.0796 |

--- Fold 3 / 5 ---
Epoch 1/30 - Train Loss: 0.0184


                                                              

| Val Loss: 0.0186 | Val CCC: 0.1164 | Val PCC: 0.2395 | Val R2: 0.0544 |
Epoch 2/30 - Train Loss: 0.0166


                                                              

| Val Loss: 0.0178 | Val CCC: 0.2020 | Val PCC: 0.3106 | Val R2: 0.0930 |
Epoch 3/30 - Train Loss: 0.0160


                                                              

| Val Loss: 0.0179 | Val CCC: 0.1353 | Val PCC: 0.3170 | Val R2: 0.0913 |
Epoch 4/30 - Train Loss: 0.0144


                                                              

| Val Loss: 0.0179 | Val CCC: 0.1572 | Val PCC: 0.3003 | Val R2: 0.0887 |
Epoch 5/30 - Train Loss: 0.0138


                                                              

| Val Loss: 0.0190 | Val CCC: 0.2892 | Val PCC: 0.3301 | Val R2: 0.0354 |
Epoch 6/30 - Train Loss: 0.0128


                                                              

| Val Loss: 0.0197 | Val CCC: 0.3030 | Val PCC: 0.3340 | Val R2: -0.0003 |
Epoch 7/30 - Train Loss: 0.0119


                                                              

| Val Loss: 0.0176 | Val CCC: 0.1770 | Val PCC: 0.3300 | Val R2: 0.1031 |
Epoch 8/30 - Train Loss: 0.0120


                                                              

| Val Loss: 0.0189 | Val CCC: 0.2599 | Val PCC: 0.3094 | Val R2: 0.0389 |
Epoch 9/30 - Train Loss: 0.0115


                                                              

| Val Loss: 0.0201 | Val CCC: 0.2924 | Val PCC: 0.3179 | Val R2: -0.0215 |
Epoch 10/30 - Train Loss: 0.0110


                                                               

| Val Loss: 0.0204 | Val CCC: 0.2942 | Val PCC: 0.3194 | Val R2: -0.0368 |
Epoch 11/30 - Train Loss: 0.0105


                                                               

| Val Loss: 0.0182 | Val CCC: 0.2226 | Val PCC: 0.3021 | Val R2: 0.0724 |
Epoch 12/30 - Train Loss: 0.0106


                                                               

| Val Loss: 0.0206 | Val CCC: 0.2828 | Val PCC: 0.3160 | Val R2: -0.0486 |
Epoch 13/30 - Train Loss: 0.0102


                                                               

| Val Loss: 0.0198 | Val CCC: 0.2334 | Val PCC: 0.2799 | Val R2: -0.0096 |
Epoch 14/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0189 | Val CCC: 0.2355 | Val PCC: 0.2929 | Val R2: 0.0375 |
Epoch 15/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0216 | Val CCC: 0.2835 | Val PCC: 0.2985 | Val R2: -0.0966 |
Epoch 16/30 - Train Loss: 0.0102


                                                               

| Val Loss: 0.0226 | Val CCC: 0.2697 | Val PCC: 0.2812 | Val R2: -0.1513 |
Epoch 17/30 - Train Loss: 0.0100


                                                               

| Val Loss: 0.0186 | Val CCC: 0.2099 | Val PCC: 0.2841 | Val R2: 0.0550 |
Epoch 18/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0241 | Val CCC: 0.2891 | Val PCC: 0.2953 | Val R2: -0.2254 |
Epoch 19/30 - Train Loss: 0.0095


                                                               

| Val Loss: 0.0210 | Val CCC: 0.2718 | Val PCC: 0.2915 | Val R2: -0.0706 |
Epoch 20/30 - Train Loss: 0.0094


                                                               

| Val Loss: 0.0225 | Val CCC: 0.2831 | Val PCC: 0.2960 | Val R2: -0.1461 |
Epoch 21/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0211 | Val CCC: 0.2708 | Val PCC: 0.2906 | Val R2: -0.0736 |
Epoch 22/30 - Train Loss: 0.0092


                                                               

| Val Loss: 0.0208 | Val CCC: 0.2571 | Val PCC: 0.2818 | Val R2: -0.0558 |
Epoch 23/30 - Train Loss: 0.0092


                                                               

| Val Loss: 0.0206 | Val CCC: 0.2217 | Val PCC: 0.2583 | Val R2: -0.0456 |
Epoch 24/30 - Train Loss: 0.0091


                                                               

| Val Loss: 0.0205 | Val CCC: 0.2724 | Val PCC: 0.2970 | Val R2: -0.0409 |
Epoch 25/30 - Train Loss: 0.0095


                                                               

| Val Loss: 0.0222 | Val CCC: 0.2575 | Val PCC: 0.2728 | Val R2: -0.1277 |
Epoch 26/30 - Train Loss: 0.0092


                                                               

| Val Loss: 0.0239 | Val CCC: 0.2938 | Val PCC: 0.3009 | Val R2: -0.2142 |
Epoch 27/30 - Train Loss: 0.0091


                                                               

| Val Loss: 0.0229 | Val CCC: 0.2734 | Val PCC: 0.2822 | Val R2: -0.1665 |
Epoch 28/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0197 | Val CCC: 0.2525 | Val PCC: 0.2923 | Val R2: -0.0009 |
Epoch 29/30 - Train Loss: 0.0098


                                                               

| Val Loss: 0.0202 | Val CCC: 0.2604 | Val PCC: 0.2900 | Val R2: -0.0286 |
Epoch 30/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0227 | Val CCC: 0.2697 | Val PCC: 0.2823 | Val R2: -0.1548 |

--- Fold 4 / 5 ---
Epoch 1/30 - Train Loss: 0.0188


                                                              

| Val Loss: 0.0182 | Val CCC: 0.1316 | Val PCC: 0.2254 | Val R2: 0.0342 |
Epoch 2/30 - Train Loss: 0.0179


                                                              

| Val Loss: 0.0185 | Val CCC: 0.2446 | Val PCC: 0.3350 | Val R2: 0.0172 |
Epoch 3/30 - Train Loss: 0.0163


                                                              

| Val Loss: 0.0181 | Val CCC: 0.2242 | Val PCC: 0.3657 | Val R2: 0.0396 |
Epoch 4/30 - Train Loss: 0.0153


                                                              

| Val Loss: 0.0164 | Val CCC: 0.2410 | Val PCC: 0.3868 | Val R2: 0.1311 |
Epoch 5/30 - Train Loss: 0.0143


                                                              

| Val Loss: 0.0168 | Val CCC: 0.2989 | Val PCC: 0.4002 | Val R2: 0.1086 |
Epoch 6/30 - Train Loss: 0.0132


                                                              

| Val Loss: 0.0180 | Val CCC: 0.3114 | Val PCC: 0.3885 | Val R2: 0.0465 |
Epoch 7/30 - Train Loss: 0.0124


                                                              

| Val Loss: 0.0164 | Val CCC: 0.3199 | Val PCC: 0.3915 | Val R2: 0.1278 |
Epoch 8/30 - Train Loss: 0.0119


                                                              

| Val Loss: 0.0185 | Val CCC: 0.3669 | Val PCC: 0.3841 | Val R2: 0.0162 |
Epoch 9/30 - Train Loss: 0.0117


                                                              

| Val Loss: 0.0175 | Val CCC: 0.3484 | Val PCC: 0.3843 | Val R2: 0.0731 |
Epoch 10/30 - Train Loss: 0.0113


                                                               

| Val Loss: 0.0184 | Val CCC: 0.3708 | Val PCC: 0.3916 | Val R2: 0.0248 |
Epoch 11/30 - Train Loss: 0.0110


                                                               

| Val Loss: 0.0173 | Val CCC: 0.3390 | Val PCC: 0.3770 | Val R2: 0.0804 |
Epoch 12/30 - Train Loss: 0.0112


                                                               

| Val Loss: 0.0167 | Val CCC: 0.3184 | Val PCC: 0.3817 | Val R2: 0.1145 |
Epoch 13/30 - Train Loss: 0.0105


                                                               

| Val Loss: 0.0180 | Val CCC: 0.3497 | Val PCC: 0.3749 | Val R2: 0.0414 |
Epoch 14/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0179 | Val CCC: 0.3032 | Val PCC: 0.3444 | Val R2: 0.0469 |
Epoch 15/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0178 | Val CCC: 0.3373 | Val PCC: 0.3731 | Val R2: 0.0535 |
Epoch 16/30 - Train Loss: 0.0102


                                                               

| Val Loss: 0.0170 | Val CCC: 0.3011 | Val PCC: 0.3623 | Val R2: 0.0976 |
Epoch 17/30 - Train Loss: 0.0103


                                                               

| Val Loss: 0.0187 | Val CCC: 0.3466 | Val PCC: 0.3658 | Val R2: 0.0088 |
Epoch 18/30 - Train Loss: 0.0100


                                                               

| Val Loss: 0.0176 | Val CCC: 0.3208 | Val PCC: 0.3607 | Val R2: 0.0645 |
Epoch 19/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0186 | Val CCC: 0.3357 | Val PCC: 0.3592 | Val R2: 0.0133 |
Epoch 20/30 - Train Loss: 0.0096


                                                               

| Val Loss: 0.0184 | Val CCC: 0.3347 | Val PCC: 0.3662 | Val R2: 0.0209 |
Epoch 21/30 - Train Loss: 0.0096


                                                               

| Val Loss: 0.0183 | Val CCC: 0.3363 | Val PCC: 0.3650 | Val R2: 0.0265 |
Epoch 22/30 - Train Loss: 0.0096


                                                               

| Val Loss: 0.0218 | Val CCC: 0.3551 | Val PCC: 0.3620 | Val R2: -0.1557 |
Epoch 23/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0190 | Val CCC: 0.3522 | Val PCC: 0.3691 | Val R2: -0.0094 |
Epoch 24/30 - Train Loss: 0.0095


                                                               

| Val Loss: 0.0175 | Val CCC: 0.2961 | Val PCC: 0.3673 | Val R2: 0.0721 |
Epoch 25/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0195 | Val CCC: 0.3313 | Val PCC: 0.3646 | Val R2: -0.0341 |
Epoch 26/30 - Train Loss: 0.0096


                                                               

| Val Loss: 0.0206 | Val CCC: 0.3495 | Val PCC: 0.3565 | Val R2: -0.0944 |
Epoch 27/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0186 | Val CCC: 0.3311 | Val PCC: 0.3549 | Val R2: 0.0125 |
Epoch 28/30 - Train Loss: 0.0094


                                                               

| Val Loss: 0.0211 | Val CCC: 0.3498 | Val PCC: 0.3554 | Val R2: -0.1191 |
Epoch 29/30 - Train Loss: 0.0095


                                                               

| Val Loss: 0.0177 | Val CCC: 0.2925 | Val PCC: 0.3473 | Val R2: 0.0611 |
Epoch 30/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0180 | Val CCC: 0.3170 | Val PCC: 0.3541 | Val R2: 0.0432 |

--- Fold 5 / 5 ---
Epoch 1/30 - Train Loss: 0.0190


                                                              

| Val Loss: 0.0171 | Val CCC: 0.0652 | Val PCC: 0.2620 | Val R2: 0.0319 |
Epoch 2/30 - Train Loss: 0.0172


                                                              

| Val Loss: 0.0167 | Val CCC: 0.2035 | Val PCC: 0.2942 | Val R2: 0.0513 |
Epoch 3/30 - Train Loss: 0.0161


                                                              

| Val Loss: 0.0165 | Val CCC: 0.1000 | Val PCC: 0.2718 | Val R2: 0.0659 |
Epoch 4/30 - Train Loss: 0.0152


                                                              

| Val Loss: 0.0181 | Val CCC: 0.2116 | Val PCC: 0.3013 | Val R2: -0.0258 |
Epoch 5/30 - Train Loss: 0.0137


                                                              

| Val Loss: 0.0161 | Val CCC: 0.2237 | Val PCC: 0.3495 | Val R2: 0.0853 |
Epoch 6/30 - Train Loss: 0.0125


                                                              

| Val Loss: 0.0184 | Val CCC: 0.2992 | Val PCC: 0.3198 | Val R2: -0.0415 |
Epoch 7/30 - Train Loss: 0.0119


                                                              

| Val Loss: 0.0179 | Val CCC: 0.3090 | Val PCC: 0.3367 | Val R2: -0.0153 |
Epoch 8/30 - Train Loss: 0.0115


                                                              

| Val Loss: 0.0168 | Val CCC: 0.2669 | Val PCC: 0.3176 | Val R2: 0.0488 |
Epoch 9/30 - Train Loss: 0.0113


                                                              

| Val Loss: 0.0181 | Val CCC: 0.2614 | Val PCC: 0.2929 | Val R2: -0.0275 |
Epoch 10/30 - Train Loss: 0.0115


                                                               

| Val Loss: 0.0171 | Val CCC: 0.2835 | Val PCC: 0.3316 | Val R2: 0.0287 |
Epoch 11/30 - Train Loss: 0.0114


                                                               

| Val Loss: 0.0184 | Val CCC: 0.2935 | Val PCC: 0.3152 | Val R2: -0.0434 |
Epoch 12/30 - Train Loss: 0.0103


                                                               

| Val Loss: 0.0171 | Val CCC: 0.2566 | Val PCC: 0.3039 | Val R2: 0.0306 |
Epoch 13/30 - Train Loss: 0.0102


                                                               

| Val Loss: 0.0178 | Val CCC: 0.2260 | Val PCC: 0.2928 | Val R2: -0.0093 |
Epoch 14/30 - Train Loss: 0.0105


                                                               

| Val Loss: 0.0189 | Val CCC: 0.2293 | Val PCC: 0.2693 | Val R2: -0.0726 |
Epoch 15/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0173 | Val CCC: 0.2363 | Val PCC: 0.2901 | Val R2: 0.0191 |
Epoch 16/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0186 | Val CCC: 0.2428 | Val PCC: 0.2810 | Val R2: -0.0525 |
Epoch 17/30 - Train Loss: 0.0101


                                                               

| Val Loss: 0.0184 | Val CCC: 0.2662 | Val PCC: 0.2911 | Val R2: -0.0450 |
Epoch 18/30 - Train Loss: 0.0103


                                                               

| Val Loss: 0.0175 | Val CCC: 0.2739 | Val PCC: 0.3103 | Val R2: 0.0112 |
Epoch 19/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0193 | Val CCC: 0.2614 | Val PCC: 0.2819 | Val R2: -0.0953 |
Epoch 20/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0182 | Val CCC: 0.2689 | Val PCC: 0.2966 | Val R2: -0.0284 |
Epoch 21/30 - Train Loss: 0.0099


                                                               

| Val Loss: 0.0189 | Val CCC: 0.2563 | Val PCC: 0.2882 | Val R2: -0.0715 |
Epoch 22/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0200 | Val CCC: 0.2776 | Val PCC: 0.2916 | Val R2: -0.1351 |
Epoch 23/30 - Train Loss: 0.0096


                                                               

| Val Loss: 0.0202 | Val CCC: 0.2785 | Val PCC: 0.2966 | Val R2: -0.1436 |
Epoch 24/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0203 | Val CCC: 0.2758 | Val PCC: 0.2935 | Val R2: -0.1502 |
Epoch 25/30 - Train Loss: 0.0097


                                                               

| Val Loss: 0.0181 | Val CCC: 0.2532 | Val PCC: 0.2850 | Val R2: -0.0239 |
Epoch 26/30 - Train Loss: 0.0093


                                                               

| Val Loss: 0.0203 | Val CCC: 0.2670 | Val PCC: 0.2803 | Val R2: -0.1525 |
Epoch 27/30 - Train Loss: 0.0092


                                                               

| Val Loss: 0.0180 | Val CCC: 0.2669 | Val PCC: 0.2973 | Val R2: -0.0189 |
Epoch 28/30 - Train Loss: 0.0096


                                                               

| Val Loss: 0.0224 | Val CCC: 0.2570 | Val PCC: 0.2779 | Val R2: -0.2713 |
Epoch 29/30 - Train Loss: 0.0096


                                                               

| Val Loss: 0.0210 | Val CCC: 0.2791 | Val PCC: 0.2900 | Val R2: -0.1880 |
Epoch 30/30 - Train Loss: 0.0095


                                                               

| Val Loss: 0.0188 | Val CCC: 0.2521 | Val PCC: 0.2756 | Val R2: -0.0676 |

--- K-Fold Cross Validation Results ---
Fold 1 | Val Loss: 0.0200 | CCC: 0.2879 | PCC: 0.3191 | R2: -0.0324
Fold 2 | Val Loss: 0.0183 | CCC: 0.3025 | PCC: 0.3275 | R2: -0.0796
Fold 3 | Val Loss: 0.0227 | CCC: 0.2697 | PCC: 0.2823 | R2: -0.1548
Fold 4 | Val Loss: 0.0180 | CCC: 0.3170 | PCC: 0.3541 | R2: 0.0432
Fold 5 | Val Loss: 0.0188 | CCC: 0.2521 | PCC: 0.2756 | R2: -0.0676




# MLP on features and Merge with Xattention features on features 

In [None]:
import torch
import torch.nn as nn


class CrossAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=1, batch_first=True)

    def forward(self, query, context):
        query = query.unsqueeze(1)    # [B, 1, D]
        context = context.unsqueeze(1)  # [B, 1, D]
        attn_out, _ = self.attn(query, context, context)  # [B, 1, D]
        return attn_out.squeeze(1)     # [B, D]


class ListenerSpeakerHybridFusion(nn.Module):
    def __init__(self, config, listener_input_dim=329, listener_seq_len=50,
                 speaker_input_dim=768, reduced_speaker_dim=32):
        super().__init__()
        self.config = config

        # Listener: [B, 50, 329] to [B, 50]
        self.listener_pool = nn.Sequential(
            nn.Linear(listener_input_dim, 1),
            nn.Flatten(start_dim=1)
        )

        # Speaker: [B, 768, 424] to [B, 32]
        self.speaker_cnn = nn.Sequential(
        nn.Conv1d(speaker_input_dim, reduced_speaker_dim, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.AdaptiveAvgPool1d(1),
        nn.Flatten(),
        # nn.Linear(64, reduced_speaker_dim)     
        )        

        self.mlp = nn.Sequential(
            nn.Linear(listener_seq_len + reduced_speaker_dim, config.hidden_size),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU()
        )
        self.mlp_norm = nn.LayerNorm(config.hidden_size)

        self.listener_proj = nn.Linear(listener_seq_len, config.hidden_size)
        self.speaker_proj = nn.Linear(reduced_speaker_dim, config.hidden_size)

        self.cross_attention = CrossAttention(dim=config.hidden_size)
        self.attn_norm = nn.LayerNorm(config.hidden_size)

        self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels)
        self.activation_fn = {
            "relu": nn.ReLU(),
            "leaky_relu": nn.LeakyReLU(),
            "Tanh": nn.Tanh()
        }[config.activation_fn]

    def forward(self, listener_feats, speaker_feats):

        listener_x = self.listener_pool(listener_feats)     # [B, 50]
        speaker_y = self.speaker_cnn(speaker_feats)         # [B, 32]


        fused = torch.cat([listener_x, speaker_y], dim=1)   # [B, 82]
        mlp_hidden = self.mlp(fused)                        # [B, hidden_size]
        mlp_hidden = self.mlp_norm(mlp_hidden)

        listener_proj_out = self.listener_proj(listener_x)  # [B, hidden_size]
        speaker_proj_out = self.speaker_proj(speaker_y)     # [B, hidden_size]

        listener_attn = self.cross_attention(listener_proj_out, speaker_proj_out)  # [B, hidden_size]
        listener_attn = self.attn_norm(listener_attn)

        # final_rep = torch.cat([mlp_hidden, listener_attn], dim=1)  # [B, hidden_size * 2]
        final_rep  = mlp_hidden+ listener_attn
        final_rep = self.classifier(final_rep)  # [B, num_labels]
        # print("final_rep : ", final_rep)
        # print("Logits before tanh:", final_rep.mean().item(), final_rep.min().item(), final_rep.max().item())
        final_rep = self.activation_fn(final_rep)
    
        return final_rep


In [None]:


# class ListenerSpeakerHybridFusion(nn.Module):
#     def __init__(self, config, listener_input_dim=329, speaker_input_dim=768):
#         super().__init__()
#         self.config = config

#         self.listener_pool = nn.AdaptiveAvgPool1d(1)
#         self.speaker_pool = nn.AdaptiveAvgPool1d(1)

#         self.mlp_fc1 = nn.Linear(listener_input_dim + speaker_input_dim, config.hidden_size)
#         self.mlp_fc2 = nn.Linear(config.hidden_size, config.hidden_size)

#         self.listener_proj = nn.Linear(listener_input_dim, config.hidden_size)
#         self.speaker_proj = nn.Linear(speaker_input_dim, config.hidden_size)

#         self.cross_attention = CrossAttention(dim=config.hidden_size)

#         self.classifier = nn.Linear(config.hidden_size, 1)
#         self.activation_fn = nn.Tanh()  # <-- Final Tanh activation

#     def forward(self, listener_feats, speaker_feats):
#         listener_vec = self.listener_pool(listener_feats.permute(0, 2, 1)).squeeze(-1)
#         speaker_vec = self.speaker_pool(speaker_feats).squeeze(-1)

#         # MLP path
#         fused = torch.cat([listener_vec, speaker_vec], dim=1)
#         mlp_hidden = torch.relu(self.mlp_fc1(fused))
#         mlp_hidden = torch.relu(self.mlp_fc2(mlp_hidden))

#         # Cross-attention path
#         listener_proj = self.listener_proj(listener_vec)
#         speaker_proj = self.speaker_proj(speaker_vec)
#         attn_out = self.cross_attention(listener_proj, speaker_proj)

#         # Combine and regress
#         combined = mlp_hidden + attn_out
#         output = self.classifier(combined)
#         return self.activation_fn(output) 


In [16]:
config = { 'activation_fn': 'Tanh',
           'extra_dropout':0 ,
             'hidden_size':64,
             'dropout': 0.4,
             'num_labels': 1}

config = SimpleNamespace(**config)
sample = dataset[0]
speaker_feat, listener_feat,_ = sample["features"]
model = ListenerSpeakerHybridFusion(config=config , listener_input_dim=listener_feat.shape[1], speaker_input_dim=speaker_feat.shape[0]).to(device)


In [47]:
model.train()

ListenerSpeakerHybridFusion(
  (listener_pool): AdaptiveAvgPool1d(output_size=1)
  (speaker_pool): AdaptiveAvgPool1d(output_size=1)
  (mlp_fc1): Linear(in_features=393, out_features=64, bias=True)
  (mlp_fc2): Linear(in_features=64, out_features=64, bias=True)
  (listener_proj): Linear(in_features=64, out_features=64, bias=True)
  (speaker_proj): Linear(in_features=329, out_features=64, bias=True)
  (cross_attention): CrossAttention(
    (attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
    )
  )
  (classifier): Linear(in_features=64, out_features=1, bias=True)
  (activation_fn): Tanh()
)

## Experiment 3: No Validation

In [17]:
import numpy as np
from tqdm import tqdm

epochs = 20
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    for batch in train_loader:
        speaker_feat, listener_feat,_ = batch["features"]
        engagement = batch["score"].view(-1, 1).to(device)

        speaker_feat = speaker_feat.to(device)
        listener_feat = listener_feat.to(device)

        optimizer.zero_grad()
        output = model(listener_feat, speaker_feat)
        loss = criterion(output, engagement)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f}")

    model.eval()
    val_loss = 0.0
    val_preds, val_targets = [], []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False):
            speaker_feat, listener_feat,_= batch["features"]
            target = batch["score"].view(-1, 1).to(device)

            speaker_feat = speaker_feat.to(device)
            listener_feat = listener_feat.to(device)

            output = model(listener_feat, speaker_feat)
            loss = criterion(output, target)

            val_loss += loss.item() * target.size(0)
            val_preds.append(output.cpu().numpy())
            val_targets.append(target.cpu().numpy())

            # print("Output: ", output.shape)
            # print("Target: ", target.shape)
            # print(compute_ccc_batched(output.detach().cpu().numpy(), target.detach().cpu().numpy()))

    val_loss /= len(val_loader.dataset)
    val_preds = np.concatenate(val_preds)
    val_targets = np.concatenate(val_targets)

    val_ccc = compute_ccc_batched(val_preds, val_targets)
    val_pcc = compute_pearson_correlation_batched(val_preds, val_targets)

    print(f"| Val Loss: {val_loss:.4f} | Val CCC: {val_ccc:.4f} | Val PCC: {val_pcc:.4f}")


Epoch 1/20 | Train Loss: 0.0182


                                                              

| Val Loss: 0.0176 | Val CCC: 0.1390 | Val PCC: 0.2459
Epoch 2/20 | Train Loss: 0.0157


                                                              

| Val Loss: 0.0184 | Val CCC: 0.1563 | Val PCC: 0.2579
Epoch 3/20 | Train Loss: 0.0151


                                                              

| Val Loss: 0.0181 | Val CCC: 0.1899 | Val PCC: 0.2753
Epoch 4/20 | Train Loss: 0.0141


                                                              

| Val Loss: 0.0193 | Val CCC: 0.1934 | Val PCC: 0.2789
Epoch 5/20 | Train Loss: 0.0135


                                                              

| Val Loss: 0.0184 | Val CCC: 0.2182 | Val PCC: 0.2828
Epoch 6/20 | Train Loss: 0.0132


                                                              

| Val Loss: 0.0182 | Val CCC: 0.2334 | Val PCC: 0.3095
Epoch 7/20 | Train Loss: 0.0122


                                                              

| Val Loss: 0.0176 | Val CCC: 0.2443 | Val PCC: 0.3036
Epoch 8/20 | Train Loss: 0.0115


                                                              

| Val Loss: 0.0188 | Val CCC: 0.2702 | Val PCC: 0.3022
Epoch 9/20 | Train Loss: 0.0117


                                                              

| Val Loss: 0.0187 | Val CCC: 0.2351 | Val PCC: 0.2842
Epoch 10/20 | Train Loss: 0.0108


                                                               

| Val Loss: 0.0192 | Val CCC: 0.2544 | Val PCC: 0.2850
Epoch 11/20 | Train Loss: 0.0107


                                                               

| Val Loss: 0.0208 | Val CCC: 0.1931 | Val PCC: 0.2499
Epoch 12/20 | Train Loss: 0.0105


                                                               

| Val Loss: 0.0186 | Val CCC: 0.2438 | Val PCC: 0.2834
Epoch 13/20 | Train Loss: 0.0102


                                                               

| Val Loss: 0.0196 | Val CCC: 0.2799 | Val PCC: 0.3005
Epoch 14/20 | Train Loss: 0.0098


                                                               

| Val Loss: 0.0205 | Val CCC: 0.2646 | Val PCC: 0.2840
Epoch 15/20 | Train Loss: 0.0100


                                                               

| Val Loss: 0.0208 | Val CCC: 0.2808 | Val PCC: 0.2930
Epoch 16/20 | Train Loss: 0.0099


                                                               

| Val Loss: 0.0200 | Val CCC: 0.2481 | Val PCC: 0.2743
Epoch 17/20 | Train Loss: 0.0096


                                                               

| Val Loss: 0.0204 | Val CCC: 0.2452 | Val PCC: 0.2757
Epoch 18/20 | Train Loss: 0.0097


                                                               

| Val Loss: 0.0214 | Val CCC: 0.2121 | Val PCC: 0.2379
Epoch 19/20 | Train Loss: 0.0095


                                                               

| Val Loss: 0.0205 | Val CCC: 0.2708 | Val PCC: 0.2862
Epoch 20/20 | Train Loss: 0.0093


                                                               

| Val Loss: 0.0200 | Val CCC: 0.2343 | Val PCC: 0.2584




In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model):,}")


Total trainable parameters: 127,553


# Experiment 4: To do - add corss-validation

In [9]:
## todo

# Feature Pooling-- no AveragePooling layer

In [98]:
import torch
print(torch.__version__)
import numpy
print(numpy.__version__)
# !pip install numpy==1.24.4



2.1.0
1.24.4


In [22]:
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
from tqdm import tqdm
from SEMPIDataLoader import ListenerSpeakerFeatureDataset
import torch
import torch.nn as nn
import sys, os
import numpy as np
from types import SimpleNamespace

In [None]:
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=1, batch_first=True)

    def forward(self, query, context):
        query = query.unsqueeze(1)    # [B, 1, D]
        context = context.unsqueeze(1)  # [B, 1, D]
        attn_out, _ = self.attn(query, context, context)  # [B, 1, D]
        return attn_out.squeeze(1)     # [B, D]

class ListenerSpeakerHybridFusion_2(nn.Module):
    def __init__(self, config, listener_input_dim=64, listener_seq_len=1024,
                 speaker_input_dim=64, speaker_seq_len=768, reduced_speaker_dim=32):
        super().__init__()
        self.config = config

        # Pool across sequence dimension
        self.listener_pool = nn.AdaptiveAvgPool1d(1)  # keep seq_len = 1024
        self.speaker_pool = nn.AdaptiveAvgPool1d(1)  # squeeze speaker to [B, feature_dim, 1]

        self.mlp = nn.Sequential(
            nn.Linear(1792, config.hidden_size),  
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU()
        )
        self.mlp_norm = nn.LayerNorm(config.hidden_size)

        self.listener_proj = nn.Linear(listener_seq_len, config.hidden_size)
        self.speaker_proj = nn.Linear(speaker_input_dim, config.hidden_size)

        self.cross_attention = CrossAttention(dim=config.hidden_size)
        self.attn_norm = nn.LayerNorm(config.hidden_size)

        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.activation_fn = {
            "relu": nn.ReLU(),
            "leaky_relu": nn.LeakyReLU(),
            "tanh": nn.Tanh()

        }[config.activation_fn.lower()] 

    def forward(self, listener_feats, speaker_feats):
        # listener_feats: [B, 1024, 64]
        # speaker_feats:  [B, 768, 64]

        # [B, feature_dim, seq_len]
        listener_feats = listener_feats.transpose(1, 2)  # [B, 64, 1024]
        speaker_feats = speaker_feats.transpose(1, 2)    # [B, 64, 768]

        listener_x = listener_feats.mean(dim=1)  # [B, 1024]
        speaker_y = speaker_feats.mean(dim=1) # [B, 768]

        # print(f"listener x shape: {listener_x.shape}") 
        # print(f"speaker y shape: {speaker_y.shape}") 

        # print(f"Listener feats shape: {listener_feats.shape}")
        # print(f"Speaker feats shape: {speaker_feats.shape}")
        
        # listener_x = self.listener_pool(listener_feats).squeeze(2)  # [B, 64]
        # speaker_y = self.speaker_pool(speaker_feats).squeeze(2)     # [B, 64]

    
        fused = torch.cat([listener_x, speaker_y], dim=1)    # [B, 1792] 
        # print(f"Fused shape: {fused.shape}")  
        # print(f"Fused shape: {fused.shape}")

        mlp_hidden = self.mlp(fused)                      
        # mlp_hidden = self.mlp_norm(mlp_hidden)
        # print(f"mlp_hidden shape: {mlp_hidden.shape}")  # [B, 64]
        # print(listener_x.shape)
        listener_proj_out = self.listener_proj(listener_x)  # [B, hidden_size]
        speaker_proj_out = self.speaker_proj(speaker_y)     # [B, hidden_size]

        # print(f"listener_proj_out shape: {listener_proj_out.shape}")
        # print(f"speaker_proj_out shape: {speaker_proj_out.shape}")
        listener_attn = self.cross_attention(listener_proj_out, speaker_proj_out)  # [B, hidden_size]
        # listener_attn = self.attn_norm(listener_attn)

        # print(f"listener_attn shape: {listener_attn.shape}")
        final_rep = mlp_hidden + listener_attn  # [B, hidden_size]
        # print(f"final_rep shape: {final_rep.shape}")
        final_rep = self.classifier(final_rep)  # [B, num_labels]
        # print(f"Logits before tanh: {final_rep.mean().item()}, {final_rep.min().item()}, {final_rep.max().item()}")
        final_rep = self.activation_fn(final_rep)


#torch.Size([32, 768, 64]) torch.Size([32, 1024, 64])
# Listener feats shape: torch.Size([32, 64, 1024])
# Speaker feats shape: torch.Size([32, 64, 768])
# Listener x shape: torch.Size([32, 64])
# Speaker y shape: torch.Size([32, 64])
# Fused shape: torch.Size([32, 128])

        return final_rep


In [None]:
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from SEMPIDataLoader import ListenerSpeakerFeatureDataset


dataset = ListenerSpeakerFeatureDataset(
    csv_path="AudioVideo_Feature_Path_v2.csv",
    frame_length=64,
    root_dir="./",
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)

config = { 'activation_fn': 'Tanh',
           'extra_dropout':0 ,
             'hidden_size':64,
             'dropout': 0.4,
             'num_labels': 1}

config = SimpleNamespace(**config)
sample = dataset[9]
speaker_feat, listener_feat,_ = sample["features"]
model = ListenerSpeakerHybridFusion_2(config=config , listener_input_dim=listener_feat.shape[1], speaker_input_dim=speaker_feat.shape[0]).to(device)
listener_feat.shape,speaker_feat.shape

# Experiment 5: no validation

In [None]:
epochs = 20

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        speaker_feat, listener_feat,_ = batch["features"]
        # print(speaker_feat.shape, listener_feat.shape)
        engagement = batch["score"]

        speaker_feat = speaker_feat.to(device)
        listener_feat = listener_feat.to(device)
        engagement = engagement.to(device)
        engagement = engagement.view(-1, 1)
 
        optimizer.zero_grad()
        output = model(listener_feat,speaker_feat)
        loss = criterion(output, engagement)
        # print(output)
        # print(engagement)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss / len(train_loader):.4f}")
    model.eval()
    val_loss = 0
    val_correct = 0
    val_preds = []
    val_targets = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False):
                speaker, listener,_ = batch["features"]
                # print(speaker_feat.shape, listener_feat.shape)
                target = batch["score"].to(device)
                target = target.view(-1, 1)
                speaker = speaker.to(device)
                listener = listener.to(device)

                output = model(listener,speaker )
                loss = criterion(output, target)

                val_loss += loss.item() * target.size(0)
                val_preds.append(output.cpu())
                val_targets.append(target.cpu())
                # print(compute_ccc_batched(output.cpu(), target.cpu()))

    val_loss /= len(val_loader.dataset)
    val_ccc = compute_ccc_batched(
            np.concatenate(val_preds),
            np.concatenate(val_targets)
            
        )
    val_pcc = compute_pearson_correlation_batched(
             np.concatenate(val_preds),
            np.concatenate(val_targets)
           
        )
    val_preds = torch.cat(val_preds).numpy()
    val_targets = torch.cat(val_targets).numpy()
      # if epoch == 5:
        #     print("Val targets and preds:")
        #     print(val_targets , val_preds)
        
    print(f"| Val Loss: {val_loss:.4f} | Val CCC: {val_ccc:.4f} | Val PCC: {val_pcc:.4f}")


# listener x shape: torch.Size([32, 1024])
# speaker y shape: torch.Size([32, 768])
# Fused shape: torch.Size([32, 1792])
# mlp_hidden shape: torch.Size([32, 64])
# listener_proj_out shape: torch.Size([32, 64])
# speaker_proj_out shape: torch.Size([32, 64])
# listener_attn shape: torch.Size([32, 64])
# final_rep shape: torch.Size([32, 64])
# Logits before tanh: 0.025505496188998222, -0.016170799732208252, 0.11344294995069504


# Experiment 6: cross- validation

In [None]:
from sklearn.model_selection import KFold
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset

epochs = 30
k_folds = 5
batch_size = 32
learning_rate = 1e-3

kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)


criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
fold_results = {}

sample = dataset[0]

for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
    print(f'\n--- Fold {fold+1} / {k_folds} ---')

    # Create data loaders for this fold
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    # model = model(config).to(device)
    speaker, listener ,_= sample["features"]
    # print("Listener feature shape: ", listener_feat.shape)
    # print("Speaker feature shape: ", speaker_feat.shape)
    model = ListenerSpeakerHybridFusion_2(config=config , listener_input_dim=listener.shape[1], speaker_input_dim=speaker.shape[0]).to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            speaker_feat, listener_feat, _ = batch["features"]
            engagement = batch["score"]

            speaker_feat = speaker_feat.to(device)
            listener_feat = listener_feat.to(device)
            engagement = engagement.to(device)
            engagement = engagement.view(-1, 1)

            optimizer.zero_grad()
            # print("before model", listener_feat.shape , speaker_feat.shape)
            output = model(listener_feat, speaker_feat)
            loss = criterion(output, engagement)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        val_preds = []
        val_targets = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False):
                speaker, listener, _ = batch["features"]
                target = batch["score"].to(device)
                target = target.view(-1, 1)
                speaker = speaker.to(device)
                listener = listener.to(device)

                output = model(listener, speaker)
                loss = criterion(output, target)

                val_loss += loss.item() * target.size(0)
                val_preds.append(output.cpu())
                val_targets.append(target.cpu())

        val_loss /= len(val_loader.dataset)

        val_preds_np = torch.cat(val_preds).numpy()
        val_targets_np = torch.cat(val_targets).numpy()

        val_ccc = compute_ccc_batched(val_preds_np, val_targets_np)
        val_pcc = compute_pearson_correlation_batched(val_preds_np, val_targets_np)
        val_r2 = compute_r2_score_batched(val_targets_np, val_preds_np)

        print(f"| Val Loss: {val_loss:.4f} | Val CCC: {val_ccc:.4f} | Val PCC: {val_pcc:.4f} | Val R2: {val_r2:.4f} |")

    # Save the result for this fold
    fold_results[fold] = {
        'val_loss': val_loss,
        'val_ccc': val_ccc,
        'val_pcc': val_pcc,
        'val_r2': val_r2
    }

# After K-Folds
print("\n--- K-Fold Cross Validation Results ---")
for fold in fold_results:
    print(f"Fold {fold+1} | Val Loss: {fold_results[fold]['val_loss']:.4f} | CCC: {fold_results[fold]['val_ccc']:.4f} | PCC: {fold_results[fold]['val_pcc']:.4f} | R2: {fold_results[fold]['val_r2']:.4f}")


In [None]:
# ## adaptive pooling results:

# Epoch 1/20 - Loss: 0.0567
                                                              
# | Val Loss: 0.0195 | Val CCC: 0.0012 | Val PCC: 0.0148
# Epoch 2/20 - Loss: 0.0211
                                                              
# | Val Loss: 0.0221 | Val CCC: 0.0027 | Val PCC: 0.0416
# Epoch 3/20 - Loss: 0.0206
                                                              
# | Val Loss: 0.0207 | Val CCC: 0.0019 | Val PCC: 0.0326
# Epoch 4/20 - Loss: 0.0193
                                                              
# | Val Loss: 0.0196 | Val CCC: 0.0003 | Val PCC: 0.0049
# Epoch 5/20 - Loss: 0.0197
                                                              
# | Val Loss: 0.0234 | Val CCC: 0.0011 | Val PCC: 0.0269
# Epoch 6/20 - Loss: 0.0213
                                                              
# | Val Loss: 0.0204 | Val CCC: 0.0011 | Val PCC: 0.0267
# Epoch 7/20 - Loss: 0.0194
                                                              
# | Val Loss: 0.0208 | Val CCC: 0.0014 | Val PCC: 0.0301
# Epoch 8/20 - Loss: 0.0194
                                                              
# | Val Loss: 0.0199 | Val CCC: 0.0016 | Val PCC: 0.0310
# Epoch 9/20 - Loss: 0.0198
                                                              
# | Val Loss: 0.0198 | Val CCC: 0.0007 | Val PCC: 0.0139
# Epoch 10/20 - Loss: 0.0189
                                                              
# | Val Loss: 0.0199 | Val CCC: 0.0013 | Val PCC: 0.0293
# Epoch 11/20 - Loss: 0.0184
                                                               
# | Val Loss: 0.0196 | Val CCC: 0.0013 | Val PCC: 0.0238
# Epoch 12/20 - Loss: 0.0191
                                                               
# | Val Loss: 0.0200 | Val CCC: 0.0017 | Val PCC: 0.0333
# Epoch 13/20 - Loss: 0.0198
                                                               
# | Val Loss: 0.0205 | Val CCC: 0.0010 | Val PCC: 0.0220
# Epoch 14/20 - Loss: 0.0185
                                                               
# | Val Loss: 0.0196 | Val CCC: 0.0015 | Val PCC: 0.0248
# Epoch 15/20 - Loss: 0.0186
                                                               
# | Val Loss: 0.0222 | Val CCC: 0.0025 | Val PCC: 0.0437
# Epoch 16/20 - Loss: 0.0188
                                                               
# | Val Loss: 0.0201 | Val CCC: 0.0014 | Val PCC: 0.0220
# Epoch 17/20 - Loss: 0.0183
                                                               
# | Val Loss: 0.0196 | Val CCC: -0.0005 | Val PCC: -0.0064
# Epoch 18/20 - Loss: 0.0193
                                                               
# | Val Loss: 0.0202 | Val CCC: 0.0011 | Val PCC: 0.0186
# Epoch 19/20 - Loss: 0.0183
                                                               
# | Val Loss: 0.0208 | Val CCC: 0.0016 | Val PCC: 0.0284
# Epoch 20/20 - Loss: 0.0188
#                                                                | Val Loss: 0.0198 | Val CCC: 0.0024 | Val PCC: 0.0301

Epoch 1/20 - Loss: 0.0567


                                                              

| Val Loss: 0.0195 | Val CCC: 0.0012 | Val PCC: 0.0148
Epoch 2/20 - Loss: 0.0211


                                                              

| Val Loss: 0.0221 | Val CCC: 0.0027 | Val PCC: 0.0416
Epoch 3/20 - Loss: 0.0206


                                                              

| Val Loss: 0.0207 | Val CCC: 0.0019 | Val PCC: 0.0326
Epoch 4/20 - Loss: 0.0193


                                                              

| Val Loss: 0.0196 | Val CCC: 0.0003 | Val PCC: 0.0049
Epoch 5/20 - Loss: 0.0197


                                                              

| Val Loss: 0.0234 | Val CCC: 0.0011 | Val PCC: 0.0269
Epoch 6/20 - Loss: 0.0213


                                                              

| Val Loss: 0.0204 | Val CCC: 0.0011 | Val PCC: 0.0267
Epoch 7/20 - Loss: 0.0194


                                                              

| Val Loss: 0.0208 | Val CCC: 0.0014 | Val PCC: 0.0301
Epoch 8/20 - Loss: 0.0194


                                                              

| Val Loss: 0.0199 | Val CCC: 0.0016 | Val PCC: 0.0310
Epoch 9/20 - Loss: 0.0198


                                                              

| Val Loss: 0.0198 | Val CCC: 0.0007 | Val PCC: 0.0139
Epoch 10/20 - Loss: 0.0189


                                                              

| Val Loss: 0.0199 | Val CCC: 0.0013 | Val PCC: 0.0293
Epoch 11/20 - Loss: 0.0184


                                                               

| Val Loss: 0.0196 | Val CCC: 0.0013 | Val PCC: 0.0238
Epoch 12/20 - Loss: 0.0191


                                                               

| Val Loss: 0.0200 | Val CCC: 0.0017 | Val PCC: 0.0333
Epoch 13/20 - Loss: 0.0198


                                                               

| Val Loss: 0.0205 | Val CCC: 0.0010 | Val PCC: 0.0220
Epoch 14/20 - Loss: 0.0185


                                                               

| Val Loss: 0.0196 | Val CCC: 0.0015 | Val PCC: 0.0248
Epoch 15/20 - Loss: 0.0186


                                                               

| Val Loss: 0.0222 | Val CCC: 0.0025 | Val PCC: 0.0437
Epoch 16/20 - Loss: 0.0188


                                                               

| Val Loss: 0.0201 | Val CCC: 0.0014 | Val PCC: 0.0220
Epoch 17/20 - Loss: 0.0183


                                                               

| Val Loss: 0.0196 | Val CCC: -0.0005 | Val PCC: -0.0064
Epoch 18/20 - Loss: 0.0193


                                                               

| Val Loss: 0.0202 | Val CCC: 0.0011 | Val PCC: 0.0186
Epoch 19/20 - Loss: 0.0183


                                                               

| Val Loss: 0.0208 | Val CCC: 0.0016 | Val PCC: 0.0284
Epoch 20/20 - Loss: 0.0188


                                                               

| Val Loss: 0.0198 | Val CCC: 0.0024 | Val PCC: 0.0301




In [146]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model):,}")


Total trainable parameters: 37,697
