In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np
from sklearn.metrics import r2_score
from tqdm import tqdm, trange
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

In [None]:


class CrossAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)

    def forward(self, query, key_value):
        attn_output, _ = self.cross_attn(query, key_value, key_value)
        return attn_output

class SelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)

    def forward(self, x):
        attn_output, _ = self.self_attn(x, x, x)
        return attn_output

class TransformerFusionModel(nn.Module):
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()

        self.audio_proj = nn.Linear(1024, d_model)
        self.visual_proj = nn.Linear(512, d_model)
        self.text_proj = nn.Linear(768, d_model)

        self.va_cross = CrossAttention(d_model, n_heads)
        self.vt_cross = CrossAttention(d_model, n_heads)
        self.av_cross = CrossAttention(d_model, n_heads)
        self.at_cross = CrossAttention(d_model, n_heads)
        self.tv_cross = CrossAttention(d_model, n_heads)
        self.ta_cross = CrossAttention(d_model, n_heads)

        self.visual_self = SelfAttention(d_model*2, n_heads)
        self.audio_self = SelfAttention(d_model*2, n_heads)
        self.text_self = SelfAttention(d_model*2, n_heads)

        self.fc = nn.Sequential(
            nn.Linear(6*d_model, 512),
            nn.ReLU(),
            nn.Linear(512, 5)
        )

    def forward(self, audio, vision, text):
        audio = self.audio_proj(audio) 
        vision = self.visual_proj(vision)
        text = self.text_proj(text)

        v_a = self.va_cross(vision, audio)
        v_t = self.vt_cross(vision, text)
        a_v = self.av_cross(audio, vision)
        a_t = self.at_cross(audio, text)
        t_v = self.tv_cross(text, vision)
        t_a = self.ta_cross(text, audio)

        vision_cat = torch.cat([v_a, v_t], dim=-1)
        audio_cat = torch.cat([a_v, a_t], dim=-1)
        text_cat = torch.cat([t_v, t_a], dim=-1)

        vision_final = self.visual_self(vision_cat)
        audio_final = self.audio_self(audio_cat)
        text_final = self.text_self(text_cat)


        fused = torch.cat([vision_final, audio_final, text_final], dim=-1) 
        fused = fused.mean(dim=1) 


        out = self.fc(fused)
        return out



class PersonalityDatasetGenTest(Dataset):
    def __init__(self, pkl_file):
        with open(pkl_file, 'rb') as f:
            self.data = pickle.load(f)

        with open(pkl_file, 'rb') as f:
            self.data = pickle.load(f)
        self.keys = list(self.data.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]
        sample = self.data[key]

        audio = sample.get('original_audio', None)
        if audio == None: 
            print("audio for", idx, "is None")
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)

        
        if audio.shape[-1]!=5: 

            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)
        if audio.ndim == 3:   # (5, 149, 1024)
            audio = audio.mean(dim = 1) 
        audio = audio.float()
        audio = audio.transpose(0, 1)
 
        vision = sample.get('generated_vision', None)
        if vision == None: 
            print("vision for", idx, "is None")
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)
        # vision = sample['reconstructed_visual']
        # print(vision.shape)
        if vision.shape[0]!=5: 
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)
        if vision.ndim == 3:  # (5, 5, 512) ?
            vision = vision.squeeze(1)  # (5, 512)
        vision = vision.float()

        text = sample['original_text'].unsqueeze(0).repeat(5, 1).float()


        labels = torch.tensor([
            sample['openness'],
            sample['conscientiousness'],
            sample['extraversion'],
            sample['agreeableness'],
            sample['neuroticism']
        ]).float()

        return audio, vision, text, labels

class PersonalityDatasetGen(Dataset):
    def __init__(self, pkl_file):
        with open(pkl_file, 'rb') as f:
            self.data = pickle.load(f)

        with open(pkl_file, 'rb') as f:
            self.data = pickle.load(f)
        self.keys = list(self.data.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]
        sample = self.data[key]

        audio = sample.get('original_audio', None)
        if audio == None: 
            print("audio for", idx, "is None")
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)

        if audio.shape[-1]!=5: 
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)

            
        if audio.ndim == 3: 
            audio = audio.mean(dim = 1)   
        audio = audio.float()
        audio = audio.transpose(0, 1)


        vision = sample.get('reconstructed_visual', None)
        if vision == None: 
            print("vision for", idx, "is None")
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)

        if vision.shape[0]!=5: 
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)
        if vision.ndim == 3: 
            vision = vision.squeeze(1) 
        vision = vision.float()

        
        text = sample['original_text'].unsqueeze(0).repeat(5, 1).float()


        labels = torch.tensor([
            sample['openness'],
            sample['conscientiousness'],
            sample['extraversion'],
            sample['agreeableness'],
            sample['neuroticism']
        ]).float()

        return audio, vision, text, labels

class PersonalityDataset(Dataset):
    def __init__(self, pkl_file):
        with open(pkl_file, 'rb') as f:
            self.data = pickle.load(f)

        with open(pkl_file, 'rb') as f:
            self.data = pickle.load(f)
        self.keys = list(self.data.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]
        sample = self.data[key]
        print(sample)

        audio = sample['original_audio']
        if audio.shape[-1]!=5: 
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)

        if audio.ndim == 3: 
            audio = audio.mean(dim=1) 
        audio = audio.float()
        audio = audio.transpose(0, 1)


        
        vision = sample['original_vision']
        if vision.shape[-1]!=5: 
            new_idx = (idx + 1) % len(self.keys)
            return self.__getitem__(new_idx)
        if vision.ndim == 3:  # (5, 5, 512) ?
            vision = vision.mean(dim=1)  # (5, 512)
        vision = vision.float()
        vision = vision.transpose(0, 1)
        text = sample['original_text'].unsqueeze(0).repeat(5, 1).float()

        labels = torch.tensor([
            sample['openness'],
            sample['conscientiousness'],
            sample['extraversion'],
            sample['agreeableness'],
            sample['neuroticism']
        ]).float()

        return audio, vision, text, labels

def concordance_correlation_coefficient(y_true, y_pred):
    mean_true = np.mean(y_true)
    mean_pred = np.mean(y_pred)
    var_true = np.var(y_true)
    var_pred = np.var(y_pred)
    cov = np.mean((y_true - mean_true) * (y_pred - mean_pred))

    ccc = (2 * cov) / (var_true + var_pred + (mean_true - mean_pred) ** 2 + 1e-8)
    return ccc


def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for audio, vision, text, labels in loader:
        # print(audio.shape)  # should be (batch_size, 5, 1024)
        # print(vision.shape) # should be (batch_size, 5, 512)
        # print(text.shape)   # should be (batch_size, 5, 768)
        
        audio, vision, text, labels = audio.to(device), vision.to(device), text.to(device), labels.to(device)

        optimizer.zero_grad()
        preds = model(audio, vision, text)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    y_true = []
    y_pred = []
    for audio, vision, text, labels in loader:
        audio, vision, text, labels = audio.to(device), vision.to(device), text.to(device), labels.to(device)
        preds = model(audio, vision, text)
        y_true.append(labels.cpu().numpy())
        y_pred.append(preds.cpu().numpy())

    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)

    trait_names = ["Openness", "Conscientiousness", "Extraversion", "Agreeableness", "Neuroticism"]
    for i in range(5):
        mse = mean_squared_error(y_true[:, i], y_pred[:, i])
        mae = mean_absolute_error(y_true[:, i], y_pred[:, i])
        acc = 1 - mae
        ccc = concordance_correlation_coefficient(y_true[:, i], y_pred[:, i])
        r2 = r2_score(y_true[:, i], y_pred[:, i])
        pcc, _ = pearsonr(y_true[:, i], y_pred[:, i])

        print(f">> {trait_names[i]}:  MAE: {mae:.4f}  ACC: {acc:.4f}  MSE: {mse:.4f}  R2: {r2:.4f}  PCC: {pcc:.4f}  CCC: {ccc:.4f}")
    avg_ccc = np.mean([concordance_correlation_coefficient(y_true[:, i], y_pred[:, i]) for i in range(5)])
    avg_r2 = np.mean([r2_score(y_true[:, i], y_pred[:, i]) for i in range(5)])
    return avg_ccc, avg_r2




In [None]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    train_pickle = '/project/msoleyma_1026/personality_detection/first_impressions_v2_dataset/reconstructed_cvae_train.pkl'
    val_pickle = '/project/msoleyma_1026/personality_detection/first_impressions_v2_dataset/reconstructed_cvae_val.pkl'
    test_pickle = '/project/msoleyma_1026/personality_detection/first_impressions_v2_dataset/cvae_test.pkl'


    train_dataset = PersonalityDatasetGen(train_pickle)
    val_dataset = PersonalityDatasetGen(val_pickle)
    test_dataset = PersonalityDatasetGenTest(test_pickle)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    test_loader = DataLoader(test_dataset, batch_size=32)

    model = TransformerFusionModel().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()

    best_val_ccc = -np.inf
    patience = 5
    wait = 0

    for epoch in trange(30, desc="Epochs"):
        train_loss = train(model, train_loader, optimizer, criterion, device)
        val_ccc, val_r2 = evaluate(model, val_loader, device)
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val CCC = {val_ccc:.4f}, Val R2 = {val_r2:.4f}")

        if val_ccc > best_val_ccc:
            best_val_ccc = val_ccc
            wait = 0
            torch.save(model.state_dict(), 'best_transformer_model_gen_vid_original_aud.pth')
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}. Best Val CCC = {best_val_ccc:.4f}")
                break

    model.load_state_dict(torch.load('best_transformer_model_gen_vid_original_aud.pth'))
    test_ccc, test_r2 = evaluate(model, test_loader, device)
    print(f"Test CCC = {test_ccc:.4f}, Test R2 = {test_r2:.4f}")

 

if __name__ == "__main__":
    main()
