In [85]:
import torch
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
from tqdm import tqdm
from SEMPIDataLoader import ListenerSpeakerFeatureDataset
import torch
import torch.nn as nn
import sys, os
import numpy as np


In [86]:
sys.path.append(os.path.abspath("code"))
from metr import compute_ccc_batched , compute_pearson_correlation_batched

In [5]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = ListenerSpeakerFeatureDataset(
    csv_path="AudioVideo_Feature_Paths.csv",
    frame_length=64,
    root_dir="./",
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)



In [6]:
sample = dataset[0]
speaker_feat, listener_feat = sample["features"]
listener_dim, speaker_dim = listener_feat.shape[0], speaker_feat.shape[0]
print(f"Listener feature shape: {listener_feat.shape}")
print(f"Speaker feature shape: {speaker_feat.shape}")

Listener feature shape: torch.Size([329, 64])
Speaker feature shape: torch.Size([329, 64])


In [7]:

class MLPClassifier(nn.Module):
    def __init__(self, in_size, hidden_size, dropout, num_classes, config):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(in_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        
        self.activation_fn = {
            "relu": nn.ReLU(),
            "leaky_relu": nn.LeakyReLU(),
            "tanh": nn.Tanh()
        }[config.activation_fn]

    def forward(self, x):
        x = self.dropout(self.fc1(x))
        x = self.activation_fn(x)
        # if not (self.config.expnum in [1, 3, 4, 10]):
        #     x = self.dropout(self.fc2(x))
        x = self.activation_fn(x)
        x = self.fc3(x)
        if self.config.extra_dropout:
            x = self.dropout(x)
        return x




In [18]:
# class ListenerSpeakerFusion(nn.Module):
#     def __init__(self, config, listener_input_dim=329, listener_seq_len=50,
#                  speaker_input_dim=768, speaker_seq_len=424,
#                  reduced_speaker_dim=32):
#         super().__init__()
#         self.config = config
#         self.reduced_speaker_dim = reduced_speaker_dim
#         self.listener_dim = listener_seq_len
#         # self.ablation = config.ablation

#         self.listener_pool = nn.Sequential(
#             nn.Linear(listener_input_dim, 1),
#             nn.Flatten(start_dim=1)
#         )

#         self.speaker_cnn = nn.Sequential(
#             nn.Conv1d(speaker_input_dim, 64, kernel_size=5, padding=2),
#             nn.ReLU(),
#             nn.AdaptiveAvgPool1d(1),
#             nn.Flatten(),
#             nn.Linear(64, reduced_speaker_dim)
#         )

#         clf_input_size = 0
#         # if self.ablation != 1: 
#         #     clf_input_size += self.listener_dim
#         # if self.ablation != 2: 
#         #     clf_input_size += reduced_speaker_dim

#         self.out = MLPClassifier(
#             in_size=clf_input_size,
#             hidden_size=config.hidden_size,
#             dropout=config.dropout,
#             num_classes=config.num_labels,
#             config=config
#         )

#     def forward(self, listener_feats, speaker_feats):
#         # listener_feats: [B, 50, 329]
#         # speaker_feats: [B, 768, 424]
#         fusion_vec = []
        
#         listener_x = self.listener_pool(listener_feats)  # [B, 50]
#         fusion_vec.append(listener_x)
#         speaker_y = self.speaker_cnn(speaker_feats)  # [B, 32]
#         fusion_vec.append(speaker_y)

#         # if self.ablation != 1:
#         #     listener_x = self.listener_pool(listener_feats)  # [B, 50]
#         #     fusion_vec.append(listener_x)

#         # if self.ablation != 2:
#         #     speaker_y = self.speaker_cnn(speaker_feats)  # [B, 32]
#         #     fusion_vec.append(speaker_y)

#         x = torch.cat(fusion_vec, dim=1)
#         return self.out(x)


import torch
import torch.nn as nn


class MLPUpToHidden(nn.Module):
    def __init__(self, in_size, hidden_size, dropout, config):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(in_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)

        self.activation_fn = {
            "relu": nn.ReLU(),
            "leaky_relu": nn.LeakyReLU(),
            "Tanh": nn.Tanh()
        }[config.activation_fn]

    def forward(self, x):
        x = self.dropout(self.fc1(x))
        x = self.activation_fn(x)
        x = self.dropout(self.fc2(x))
        x = self.activation_fn(x)
        return x  # no final output layer here


class CrossAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=1, batch_first=True)

    def forward(self, q, k):
        # q, k: [B, D]
        q = q.unsqueeze(1)  # [B, 1, D]
        k = k.unsqueeze(1)  # [B, 1, D]
        v = k  # value = key
        attn_out, _ = self.attn(q, k, v)  # [B, 1, D]
        return attn_out.squeeze(1)  # [B, D]


class ListenerSpeakerFusion(nn.Module):
    def __init__(self, config, listener_input_dim=329, listener_seq_len=50,
                 speaker_input_dim=768, speaker_seq_len=424,
                 reduced_speaker_dim=32):
        super().__init__()
        self.config = config

        self.listener_pool = nn.Sequential(
            nn.Linear(listener_input_dim, 1),
            nn.Flatten(start_dim=1)
        )

        self.speaker_cnn = nn.Sequential(
            nn.Conv1d(speaker_input_dim, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(64, reduced_speaker_dim)
        )

        self.mlp_input_size = listener_seq_len + reduced_speaker_dim
        self.mlp = MLPUpToHidden(
            in_size=self.mlp_input_size,
            hidden_size=config.hidden_size,
            dropout=config.dropout,
            config=config
        )

        self.cross_attention = CrossAttention(dim=config.hidden_size)
        self.final_fc = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, listener_feats, speaker_feats):
        # listener_feats: [B, 50, 329]
        # speaker_feats: [B, 768, 424]

        listener_x = self.listener_pool(listener_feats)  # [B, 50]
        speaker_y = self.speaker_cnn(speaker_feats)      # [B, 32]

        # Concatenate and pass through MLP (excluding final layer)
        fused = torch.cat([listener_x, speaker_y], dim=1)  # [B, 82]
        mlp_features = self.mlp(fused)                     # [B, hidden_size]

        # Cross-attention between listener and speaker representations
        attn_out = self.cross_attention(mlp_features, mlp_features)  # [B, hidden_size]

        # Final prediction
        return self.final_fc(attn_out)  # [B, num_labels]


In [19]:
config = { 'activation_fn': 'Tanh',
           'extra_dropout':0 ,
             'hidden_size':128,
             'dropout': 0.1,
             'num_labels': 1}

from types import SimpleNamespace



config = SimpleNamespace(**config)
sample = dataset[0]
speaker_feat, listener_feat = sample["features"]
# Listener feature shape: torch.Size([50, 329])
# Speaker feature shape: torch.Size([768, 424])

model = ListenerSpeakerFusion(config=config , listener_input_dim=listener_feat.shape[1], listener_seq_len=listener_feat.shape[0], speaker_input_dim=speaker_feat.shape[0], speaker_seq_len=speaker_feat.shape[1], reduced_speaker_dim=32).to(device)


In [20]:
model.train()

ListenerSpeakerFusion(
  (listener_pool): Sequential(
    (0): Linear(in_features=64, out_features=1, bias=True)
    (1): Flatten(start_dim=1, end_dim=-1)
  )
  (speaker_cnn): Sequential(
    (0): Conv1d(329, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): ReLU()
    (2): AdaptiveAvgPool1d(output_size=1)
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=64, out_features=32, bias=True)
  )
  (mlp): MLPUpToHidden(
    (dropout): Dropout(p=0.1, inplace=False)
    (fc1): Linear(in_features=361, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
    (activation_fn): Tanh()
  )
  (cross_attention): CrossAttention(
    (attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
  )
  (final_fc): Linear(in_features=128, out_features=1, bias=True)
)

In [80]:
import torch
import torch.nn as nn


class CrossAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=1, batch_first=True)

    def forward(self, query, context):
        query = query.unsqueeze(1)    # [B, 1, D]
        context = context.unsqueeze(1)  # [B, 1, D]
        attn_out, _ = self.attn(query, context, context)  # [B, 1, D]
        return attn_out.squeeze(1)     # [B, D]


class ListenerSpeakerHybridFusion(nn.Module):
    def __init__(self, config, listener_input_dim=329, listener_seq_len=50,
                 speaker_input_dim=768, speaker_seq_len=424, reduced_speaker_dim=32):
        super().__init__()
        self.config = config

        # Listener: [B, 50, 329] → [B, 50]
        self.listener_pool = nn.Sequential(
            nn.Linear(listener_input_dim, 1),
            nn.Flatten(start_dim=1)
        )

        # Speaker: [B, 768, 424] → [B, 32]
        self.speaker_cnn = nn.Sequential(
        nn.Conv1d(speaker_input_dim, reduced_speaker_dim, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.AdaptiveAvgPool1d(1),
        nn.Flatten(),
        # nn.Linear(64, reduced_speaker_dim)     
        )        

        # MLP for fused early features
        self.mlp = nn.Sequential(
            nn.Linear(listener_seq_len + reduced_speaker_dim, config.hidden_size),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU()
        )
        self.mlp_norm = nn.LayerNorm(config.hidden_size)

        # Project to hidden size before attention
        self.listener_proj = nn.Linear(listener_seq_len, config.hidden_size)
        self.speaker_proj = nn.Linear(reduced_speaker_dim, config.hidden_size)

        # Cross-attention: listener attends to speaker
        self.cross_attention = CrossAttention(dim=config.hidden_size)
        self.attn_norm = nn.LayerNorm(config.hidden_size)

        # Final classifier
        self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels)
        self.activation_fn = {
            "relu": nn.ReLU(),
            "leaky_relu": nn.LeakyReLU(),
            "Tanh": nn.Tanh()
        }[config.activation_fn]

    def forward(self, listener_feats, speaker_feats):
        # Feature extraction
        listener_x = self.listener_pool(listener_feats)     # [B, 50]
        speaker_y = self.speaker_cnn(speaker_feats)         # [B, 32]

        # MLP hidden fusion
        fused = torch.cat([listener_x, speaker_y], dim=1)   # [B, 82]
        mlp_hidden = self.mlp(fused)                        # [B, hidden_size]
        mlp_hidden = self.mlp_norm(mlp_hidden)

        
        # Project listener/speaker before attention
        listener_proj_out = self.listener_proj(listener_x)  # [B, hidden_size]
        speaker_proj_out = self.speaker_proj(speaker_y)     # [B, hidden_size]

        # Cross-attention: listener attends to speaker
        listener_attn = self.cross_attention(listener_proj_out, speaker_proj_out)  # [B, hidden_size]
        listener_attn = self.attn_norm(listener_attn)
        
        # Final fusion and classification
        final_rep = torch.cat([mlp_hidden, listener_attn], dim=1)  # [B, hidden_size * 2]
        final_rep = self.classifier(final_rep)  # [B, num_labels]
        # print("final_rep shape: ", final_rep)
        # print("Logits before tanh:", final_rep.mean().item(), final_rep.min().item(), final_rep.max().item())

        final_rep = self.activation_fn(final_rep)
        return final_rep


In [81]:
config = { 'activation_fn': 'Tanh',
           'extra_dropout':0 ,
             'hidden_size':128,
             'dropout': 0.4,
             'num_labels': 1}

from types import SimpleNamespace



config = SimpleNamespace(**config)
sample = dataset[0]
speaker_feat, listener_feat = sample["features"]
# Listener feature shape: torch.Size([50, 329])
# Speaker feature shape: torch.Size([768, 424])

model = ListenerSpeakerHybridFusion(config=config , listener_input_dim=listener_feat.shape[1], listener_seq_len=listener_feat.shape[0], speaker_input_dim=speaker_feat.shape[0], speaker_seq_len=speaker_feat.shape[1], reduced_speaker_dim=32).to(device)


In [None]:
epochs = 20

criterion = torch.nn.MSELoss()  # or BCEWithLogitsLoss for binary
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for batch in train_loader:
        speaker_feat, listener_feat = batch["features"]
        engagement = batch["score"]

        speaker_feat = speaker_feat.to(device)
        listener_feat = listener_feat.to(device)
        engagement = engagement.to(device)
        engagement = engagement.view(-1, 1)
 
        optimizer.zero_grad()
        output = model(listener_feat, speaker_feat)
        loss = criterion(output, engagement)
        # print(output)

        # print(engagement)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss / len(train_loader):.4f}")
    model.eval()
    val_loss = 0
    val_correct = 0
    val_preds = []
    val_targets = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False):
                speaker, listener = batch["features"]
                target = batch["score"].to(device)
                target = target.view(-1, 1)
                speaker = speaker.to(device)
                listener = listener.to(device)

                output = model(listener, speaker)
                loss = criterion(output, target)

                val_loss += loss.item() * target.size(0)
                val_preds.append(output.cpu())
                val_targets.append(target.cpu())
                # print(compute_ccc_batched(output.cpu(), target.cpu()))

    val_loss /= len(val_loader.dataset)
    val_ccc = compute_ccc_batched(
            np.concatenate(val_preds),
            np.concatenate(val_targets)
            
        )
    val_pcc = compute_pearson_correlation_batched(
             np.concatenate(val_preds),
            np.concatenate(val_targets)
           
        )
    val_preds = torch.cat(val_preds).numpy()
    val_targets = torch.cat(val_targets).numpy()
      # if epoch == 5:
        #     print("Val targets and preds:")
        #     print(val_targets , val_preds)
        
    print(f"| Val Loss: {val_loss:.4f} | Val CCC: {val_ccc:.4f} | Val PCC: {val_pcc:.4f}")


    

Epoch 1/20 - Loss: 0.0197


                                                              

| Val Loss: 0.0157 | Val CCC: 0.0216 | Val PCC: 0.1383
Epoch 2/20 - Loss: 0.0200


                                                              

| Val Loss: 0.0156 | Val CCC: 0.0163 | Val PCC: 0.1358
Epoch 3/20 - Loss: 0.0194


                                                              

| Val Loss: 0.0156 | Val CCC: 0.0298 | Val PCC: 0.1328
Epoch 4/20 - Loss: 0.0196


                                                              

| Val Loss: 0.0157 | Val CCC: 0.0214 | Val PCC: 0.1340
Epoch 5/20 - Loss: 0.0194


                                                              

| Val Loss: 0.0156 | Val CCC: 0.0177 | Val PCC: 0.1248
Epoch 6/20 - Loss: 0.0192


                                                              

| Val Loss: 0.0160 | Val CCC: 0.0383 | Val PCC: 0.1324
Epoch 7/20 - Loss: 0.0190


                                                              

| Val Loss: 0.0156 | Val CCC: 0.0266 | Val PCC: 0.1195
Epoch 8/20 - Loss: 0.0194


                                                              

| Val Loss: 0.0167 | Val CCC: 0.0141 | Val PCC: 0.0940
Epoch 9/20 - Loss: 0.0193


                                                              

| Val Loss: 0.0165 | Val CCC: 0.0183 | Val PCC: 0.1083
Epoch 10/20 - Loss: 0.0191


                                                              

| Val Loss: 0.0157 | Val CCC: 0.0222 | Val PCC: 0.1148
Epoch 11/20 - Loss: 0.0193


                                                               

| Val Loss: 0.0161 | Val CCC: 0.0414 | Val PCC: 0.1250
Epoch 12/20 - Loss: 0.0192


                                                               

| Val Loss: 0.0159 | Val CCC: 0.0199 | Val PCC: 0.0864
Epoch 13/20 - Loss: 0.0192


                                                               

| Val Loss: 0.0160 | Val CCC: 0.0230 | Val PCC: 0.0903
Epoch 14/20 - Loss: 0.0188


                                                               

| Val Loss: 0.0162 | Val CCC: 0.0333 | Val PCC: 0.1295
Epoch 15/20 - Loss: 0.0192


                                                               

| Val Loss: 0.0167 | Val CCC: 0.0335 | Val PCC: 0.1189
Epoch 16/20 - Loss: 0.0192


                                                               

| Val Loss: 0.0155 | Val CCC: 0.0410 | Val PCC: 0.1337
Epoch 17/20 - Loss: 0.0188


                                                               

| Val Loss: 0.0159 | Val CCC: 0.0444 | Val PCC: 0.1431
Epoch 18/20 - Loss: 0.0187


                                                               

| Val Loss: 0.0162 | Val CCC: 0.0519 | Val PCC: 0.1623
Epoch 19/20 - Loss: 0.0189


                                                               

| Val Loss: 0.0162 | Val CCC: 0.0236 | Val PCC: 0.0882
Epoch 20/20 - Loss: 0.0187


                                                               

| Val Loss: 0.0156 | Val CCC: 0.0720 | Val PCC: 0.1549




In [83]:
output

tensor([[-0.0287],
        [-0.0269],
        [-0.0198],
        [-0.0165],
        [-0.0175],
        [-0.0072],
        [-0.0110],
        [-0.0126],
        [-0.0389],
        [-0.0078],
        [-0.0236],
        [-0.0324],
        [-0.0485],
        [ 0.0146],
        [ 0.0105],
        [ 0.0082],
        [-0.0203],
        [-0.0431],
        [-0.0268],
        [-0.0392],
        [ 0.0051],
        [-0.0114],
        [-0.0369],
        [-0.0213],
        [ 0.0049],
        [-0.0012],
        [-0.0248],
        [ 0.0065],
        [ 0.0475],
        [-0.0084],
        [ 0.0036],
        [-0.0120]], grad_fn=<TanhBackward0>)

In [84]:
engagement

tensor([[ 1.0000e-02],
        [ 6.3333e-02],
        [ 1.6333e-01],
        [ 2.1000e-01],
        [-2.7000e-01],
        [-2.6667e-02],
        [-9.0000e-02],
        [ 9.6667e-02],
        [-2.4333e-01],
        [ 1.1333e-01],
        [-1.5667e-01],
        [ 8.0000e-02],
        [-1.6667e-01],
        [-1.1000e-01],
        [ 1.9333e-01],
        [ 1.9333e-01],
        [ 9.0000e-02],
        [ 1.9333e-01],
        [ 1.7667e-01],
        [-2.3667e-01],
        [ 5.6667e-02],
        [-3.4694e-18],
        [-1.1000e-01],
        [ 1.3333e-01],
        [-2.0000e-02],
        [ 7.3333e-02],
        [ 1.9667e-01],
        [ 1.3333e-02],
        [ 1.9000e-01],
        [ 1.0000e-02],
        [ 1.3333e-02],
        [-2.5333e-01]])