In [38]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split


In [39]:
# ----------------------------
# 1. Load Speech Dataset
# ----------------------------
speech_dir = 'dataset/speech'
speech_data = []

for file in os.listdir(speech_dir):
    if file.endswith('.wav'):
        parts = file.split('_')
        if len(parts) == 3:
            word = parts[1]
            emotion = parts[2].replace('.wav', '')
            speech_data.append({
                'word': word,
                'emotion': emotion,
                'speech_path': os.path.join(speech_dir, file)
            })

speech_df = pd.DataFrame(speech_data)
speech_df.to_csv('speech_word_dataset.csv', index=False)

# ----------------------------
# 2. Load Text Dataset
# ----------------------------
def load_csvs_from_dir(directory):
    combined_df = pd.DataFrame()
    for file in os.listdir(directory):
        if file.endswith(".csv"):
            df = pd.read_csv(os.path.join(directory, file))
            combined_df = pd.concat([combined_df, df], ignore_index=True)
    return combined_df

text_train_df = load_csvs_from_dir("dataset/text/train")
text_val_df = load_csvs_from_dir("dataset/text/validation")
text_test_df = load_csvs_from_dir("dataset/text/test")
text_df = pd.concat([text_train_df, text_val_df, text_test_df], ignore_index=True)


In [40]:
# ----------------------------
# 3. Encode Labels (Shared)
# ----------------------------
label_encoder = LabelEncoder()
all_labels = pd.concat([speech_df['emotion'], text_df['label']], ignore_index=True)
label_encoder.fit(all_labels)

speech_df['label'] = label_encoder.transform(speech_df['emotion'])
text_df['label'] = label_encoder.transform(text_df['label'])


In [41]:
# ----------------------------
# 4. Tokenizer and BERT Model
# ----------------------------
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased")

In [42]:
# ----------------------------
# 5. Feature Extraction Utils
# ----------------------------
def extract_mfcc(wav_path, max_len=100):
    y, sr = librosa.load(wav_path, sr=16000)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
    if mfcc.shape[1] < max_len:
        pad_width = max_len - mfcc.shape[1]
        mfcc = np.pad(mfcc, ((0,0), (0, pad_width)), mode='constant')
    else:
        mfcc = mfcc[:, :max_len]
    return mfcc.T
def extract_bert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=32)
    with torch.no_grad():
        outputs = bert_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()


In [43]:
# ----------------------------
# 6. Early Fusion Dataset
# ----------------------------
class EarlyFusionDataset(Dataset):
    def __init__(self, speech_df, text_df):
        self.speech_df = speech_df.reset_index(drop=True)
        self.text_df = text_df.reset_index(drop=True)

    def __len__(self):
        return min(len(self.speech_df), len(self.text_df))

    def __getitem__(self, idx):
        speech_row = self.speech_df.iloc[idx]
        text_row = self.text_df.iloc[idx]

        mfcc = extract_mfcc(speech_row['speech_path'])  # [time, 40]
        label = torch.tensor(speech_row['label'], dtype=torch.long)

        # Tokenize text for BERT
        text_input = tokenizer(text_row['text'], return_tensors="pt", padding="max_length", truncation=True, max_length=32)
        text_input = {k: v.squeeze(0) for k, v in text_input.items()}  # remove batch dim

        return torch.tensor(mfcc, dtype=torch.float32), text_input, label


In [44]:
# ----------------------------
# 7. Collate Function for Padding
# ----------------------------
def collate_fn(batch):
    mfccs, text_inputs, labels = zip(*batch)

    # Pad MFCC sequences
    mfccs = nn.utils.rnn.pad_sequence(mfccs, batch_first=True)

    # Stack BERT inputs
    input_ids = torch.stack([ti['input_ids'] for ti in text_inputs])
    attention_mask = torch.stack([ti['attention_mask'] for ti in text_inputs])
    text_input = {"input_ids": input_ids, "attention_mask": attention_mask}

    labels = torch.tensor(labels)
    return mfccs, text_input, labels


In [45]:
# ----------------------------
# 8. BERT Model
# ----------------------------
class BERTFusionModel(nn.Module):
    def __init__(self, bert_model, mfcc_dim=40, audio_pooling='avg', num_classes=6):
        super().__init__()
        self.bert = bert_model
        self.audio_pooling = audio_pooling
        
        # If using average pooling for MFCC
        self.audio_proj = nn.Linear(mfcc_dim, 128)  # compress MFCC features
        
        # Fusion Layer
        self.fusion = nn.Sequential(
            nn.Linear(768 + 128, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, mfcc, text_input):
        # Get BERT features
        with torch.no_grad():
            bert_output = self.bert(**text_input)
            text_feat = bert_output.last_hidden_state[:, 0, :]  # CLS token: [batch, 768]

        # Aggregate MFCC (avg over time dimension)
        audio_feat = mfcc.mean(dim=1)  # [batch, 40]
        audio_feat = self.audio_proj(audio_feat)  # [batch, 128]

        # Combine
        fused = torch.cat((text_feat, audio_feat), dim=1)  # [batch, 896]
        out = self.fusion(fused)
        return out


In [46]:
# ----------------------------
# 9. Training and Evaluation
# ----------------------------

# Split dataset into train and validation
full_dataset = EarlyFusionDataset(speech_df, text_df)
train_indices, val_indices = train_test_split(list(range(len(full_dataset))), test_size=0.2, random_state=42)

train_subset = torch.utils.data.Subset(full_dataset, train_indices)
val_subset = torch.utils.data.Subset(full_dataset, val_indices)

train_loader = DataLoader(train_subset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False, collate_fn=collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTFusionModel(bert_model=bert_model, num_classes=len(label_encoder.classes_)).to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

def evaluate(model, loader):
    model.eval()
    total_correct = 0
    total_samples = 0
    total_loss = 0
    with torch.no_grad():
        for mfccs, text_inputs, labels in loader:
            mfccs, labels = mfccs.to(device), labels.to(device)
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}

            outputs = model(mfccs, text_inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = total_correct / total_samples
    avg_loss = total_loss / len(loader)
    return accuracy, avg_loss

# Training loop with metrics
for epoch in range(25):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for mfccs, text_inputs, labels in train_loader:
        mfccs, labels = mfccs.to(device), labels.to(device)
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}

        optimizer.zero_grad()
        outputs = model(mfccs, text_inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    train_acc = total_correct / total_samples
    train_loss = total_loss / len(train_loader)

    val_acc, val_loss = evaluate(model, val_loader)

    print(f"Epoch {epoch+1}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_acc*100:.2f}%, "
          f"Val Loss: {val_loss:.4f}, "
          f"Val Acc: {val_acc*100:.2f}%")


Epoch 1, Train Loss: 2.3470, Train Acc: 22.92%, Val Loss: 1.3296, Val Acc: 63.75%
Epoch 2, Train Loss: 1.3269, Train Acc: 49.27%, Val Loss: 0.9661, Val Acc: 95.42%
Epoch 3, Train Loss: 0.9584, Train Acc: 70.62%, Val Loss: 0.7116, Val Acc: 96.67%
Epoch 4, Train Loss: 0.6926, Train Acc: 82.40%, Val Loss: 0.5441, Val Acc: 95.83%
Epoch 5, Train Loss: 0.5322, Train Acc: 90.42%, Val Loss: 0.3934, Val Acc: 98.75%
Epoch 6, Train Loss: 0.3828, Train Acc: 94.06%, Val Loss: 0.3084, Val Acc: 98.75%
Epoch 7, Train Loss: 0.3125, Train Acc: 95.52%, Val Loss: 0.2289, Val Acc: 98.75%
Epoch 8, Train Loss: 0.2334, Train Acc: 97.81%, Val Loss: 0.1894, Val Acc: 97.50%
Epoch 9, Train Loss: 0.1973, Train Acc: 97.40%, Val Loss: 0.1546, Val Acc: 99.58%
Epoch 10, Train Loss: 0.1661, Train Acc: 98.33%, Val Loss: 0.1350, Val Acc: 99.17%
Epoch 11, Train Loss: 0.1400, Train Acc: 98.33%, Val Loss: 0.1130, Val Acc: 99.58%
Epoch 12, Train Loss: 0.1064, Train Acc: 99.38%, Val Loss: 0.0994, Val Acc: 99.58%
Epoch 13, Tra