In [None]:

# For the reason for training and getting better accuracy, this file has been modified from the original version.

import os
import copy
import random
import time
from typing import List, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, models
from PIL import Image

from transformers import BertTokenizer, BertModel
from sklearn.preprocessing import LabelEncoder, label_binarize
from sklearn.metrics import confusion_matrix, roc_curve, auc, accuracy_score
from sklearn.metrics import classification_report, precision_recall_fscore_support
from tqdm import tqdm
import soundfile as sf
import librosa
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 - required for 3D projection

# -----------------------------
# USER CONFIG - adjust paths
# -----------------------------
CSV_PATH = r"K:\Code\Project\Research Paper\Emotion Detection\meld dataset\MELD.Raw\self\fusion_dataset.csv"
IMAGE_MODEL_PATH = r"K:\Code\Project\Research Paper\Emotion Detection\inceptionresnetv3_face_emotion.pth"
TEXT_MODEL_PATH  = r"K:\Code\Project\Research Paper\Emotion Detection\Code\bert_emotion_text_final.pth"
AUDIO_MODEL_PATH = r"K:\Code\Project\Research Paper\Emotion Detection\best_transformer_speech_model.pth"

# Unimodal model settings (must match your trained models)
IMG_SIZE = 299
BERT_MODEL_NAME = "bert-base-multilingual-cased"
MAX_LEN = 64
AUDIO_MAX_PAD = 174  # same as audio model

# Fusion training hyperparams
BATCH_SIZE = 128
EPOCHS = 50
LR = 1e-4
WEIGHT_DECAY = 1e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FREEZE_BASE_MODELS = True  # True: only fusion head trained. Set False to fine-tune bases.

BEST_FUSION_PATH = "./best_early_fusion_model.pth"

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

# -----------------------------
# UTIL: audio feature extraction (soundfile + librosa)
# -----------------------------
def load_audio(path, sr=22050):
    audio, native_sr = sf.read(path)
    
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    
    if native_sr != sr:
        audio = librosa.resample(audio, orig_sr=native_sr, target_sr=sr)
    return audio, sr

def extract_mfcc(file_path, sr=22050, n_mfcc=40, max_pad_len=AUDIO_MAX_PAD):
    signal, sr = load_audio(file_path, sr)
    if len(signal) < 2048:
        signal = np.pad(signal, (0, 2048 - len(signal)))
    mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=n_mfcc, n_fft=1024, hop_length=512)
    if mfcc.shape[1] < max_pad_len:
        mfcc = np.pad(mfcc, ((0, 0), (0, max_pad_len - mfcc.shape[1])), mode='constant')
    else:
        mfcc = mfcc[:, :max_pad_len]
    return mfcc

# -----------------------------
# UNIMODAL MODEL WRAPPERS (EMBEDDING OUTPUTS)
# -----------------------------
class InceptionResNetV3_Embed(nn.Module):
    def __init__(self, embed_dim: int = 512, pretrained=True):
        super().__init__()
        if pretrained:
            base = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, aux_logits=False)
        else:
            base = models.inception_v3(aux_logits=False)
        in_features = base.fc.in_features
        base.fc = nn.Identity()
        self.base = base
        self.proj = nn.Sequential(
            nn.Linear(in_features, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )
    def forward(self, x):
        feat = self.base(x)
        emb = self.proj(feat)
        return emb

class BERTEncoder(nn.Module):
    def __init__(self, model_name: str, embed_dim: int = 512, dropout: float = 0.2):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.proj = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.bert.config.hidden_size, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = out.pooler_output
        emb = self.proj(pooled)
        return emb

class TransformerLSTM_Embed(nn.Module):
    def __init__(self, embed_dim=512, d_model=128, nhead=4, num_layers=2):
        super().__init__()
        self.feature_proj = nn.Linear(40, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=256, dropout=0.3, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.lstm = nn.LSTM(d_model, 128, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        self.proj = nn.Sequential(
            nn.Linear(128, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
    def forward(self, x):
        x = x.squeeze(1).permute(0, 2, 1)
        x = self.feature_proj(x)
        x = self.transformer(x)
        x, _ = self.lstm(x)
        h = self.dropout(x[:, -1, :])
        emb = self.proj(h)
        return emb

# -----------------------------
# MULTIMODAL DATASET
# -----------------------------
class MultimodalDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: BertTokenizer, label_encoder: LabelEncoder,
                 img_transform, audio_pad=AUDIO_MAX_PAD):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.le = label_encoder
        self.img_transform = img_transform
        self.audio_pad = audio_pad
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img_path = row['image_path']
        txt = str(row['text'])
        audio_path = row['audio_path']
        label = int(row['label'])
        try:
            img = Image.open(img_path).convert("RGB")
            img_t = self.img_transform(img)
        except Exception:
            img_t = torch.zeros(3, IMG_SIZE, IMG_SIZE)
        enc = self.tokenizer(txt, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors='pt')
        input_ids = enc['input_ids'].squeeze(0)
        attention_mask = enc['attention_mask'].squeeze(0)
        try:
            mfcc = extract_mfcc(audio_path, max_pad_len=self.audio_pad)
            mfcc_t = torch.tensor(mfcc, dtype=torch.float32).unsqueeze(0)
        except Exception:
            mfcc_t = torch.zeros(1, 40, self.audio_pad)
        return {
            "image": img_t,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "audio": mfcc_t,
            "label": torch.tensor(label, dtype=torch.long)
        }

# -----------------------------
# PLOTTING HELPERS
# -----------------------------
def plot_accuracy(train_accs, val_accs):
    plt.figure(figsize=(7,4))
    plt.plot(range(1,len(train_accs)+1), train_accs, marker='o', label='Train Acc')
    plt.plot(range(1,len(val_accs)+1), val_accs, marker='o', label='Val Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title("Train vs Val Acc")
    plt.grid(True); plt.legend(); plt.tight_layout(); plt.show()

def plot_confusion(cm, labels):
    fig, ax = plt.subplots(figsize=(7,7))
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(len(labels)), yticks=np.arange(len(labels)),
           xticklabels=labels, yticklabels=labels, xlabel='Predicted', ylabel='True', title='Confusion Matrix')
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], 'd'), ha='center', va='center',
                    color='white' if cm[i, j] > thresh else 'black')
    plt.tight_layout(); plt.show()

def plot_multiclass_roc(y_true, y_score, class_names):
    try:
        y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
        plt.figure(figsize=(8,6))
        for i in range(y_true_bin.shape[1]):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f"{class_names[i]} (AUC={roc_auc:.2f})")
        fpr, tpr, _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
        plt.plot(fpr, tpr, label=f"micro (AUC={auc(fpr,tpr):.2f})", linestyle='--')
        plt.plot([0,1],[0,1],'k--', linewidth=0.6)
        plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("Multi-class ROC"); plt.legend(fontsize='small'); plt.grid(True); plt.tight_layout(); plt.show()
    except Exception as e:
        print("ROC plotting failed:", e)

# -----------------------------
# EARLY FUSION HEAD
# -----------------------------
class EarlyFusion(nn.Module):
    def __init__(self, num_classes, embed_dim=512, d_model=512, use_transformer=True, nhead=8, n_layers=1, dropout=0.2):
        super().__init__()
        self.use_transformer = use_transformer
        if embed_dim != d_model:
            self.proj_img = nn.Linear(embed_dim, d_model)
            self.proj_txt = nn.Linear(embed_dim, d_model)
            self.proj_aud = nn.Linear(embed_dim, d_model)
        else:
            self.proj_img = nn.Identity()
            self.proj_txt = nn.Identity()
            self.proj_aud = nn.Identity()
        if use_transformer:
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                       dim_feedforward=d_model*2, dropout=dropout, batch_first=True)
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
            self.classifier = nn.Sequential(
                nn.Linear(d_model, d_model//1),
                nn.BatchNorm1d(d_model//1),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(d_model//1, num_classes)
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(d_model*3, d_model),
                nn.ReLU(),
                nn.BatchNorm1d(d_model),
                nn.Dropout(dropout),
                nn.Linear(d_model, num_classes)
            )
        self.modality_scale = nn.Parameter(torch.ones(3))
        self._init_weights()
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    def forward(self, e_img, e_txt, e_aud):
        t_img = self.proj_img(e_img)
        t_txt = self.proj_txt(e_txt)
        t_aud = self.proj_aud(e_aud)
        scales = torch.softmax(self.modality_scale, dim=0)
        t_img = t_img * scales[0]
        t_txt = t_txt * scales[1]
        t_aud = t_aud * scales[2]
        if self.use_transformer:
            tokens = torch.stack([t_img, t_txt, t_aud], dim=1)
            tokens = self.transformer(tokens)
            pooled = tokens.mean(dim=1)
            logits = self.classifier(pooled)
        else:
            concat = torch.cat([t_img, t_txt, t_aud], dim=1)
            logits = self.classifier(concat)
        return logits

# -----------------------------
# MODEL LOADING UTIL
# -----------------------------
def load_unimodal_models(num_classes, image_model_path, text_model_path, audio_model_path, device, embed_dim=512):
    img_model = InceptionResNetV3_Embed(embed_dim=embed_dim, pretrained=False)
    if os.path.exists(image_model_path):
        st = torch.load(image_model_path, map_location=device)
        try:
            if isinstance(st, dict):
                candidate = None
                for k in ('model','state_dict','model_state','model_state_dict'):
                    if k in st and isinstance(st[k], dict):
                        candidate = st[k]; break
                if candidate is not None:
                    st = candidate
                st_clean = {k:v for k,v in st.items() if not (k.startswith("fc.") or k.startswith("classifier.") or 'fc' in k and k.endswith('weight'))}
                img_model.load_state_dict(st_clean, strict=False)
            else:
                img_model.load_state_dict(st, strict=False)
            print("[INFO] Loaded image checkpoint (partial load allowed).")
        except Exception as e:
            print("[WARN] Failed to load image checkpoint cleanly:", e)
    img_model.to(device).eval()
    txt_model = BERTEncoder(BERT_MODEL_NAME, embed_dim=embed_dim)
    if os.path.exists(text_model_path):
        st = torch.load(text_model_path, map_location=device)
        try:
            if isinstance(st, dict):
                candidate = None
                for k in ('model','state_dict','model_state','model_state_dict'):
                    if k in st and isinstance(st[k], dict):
                        candidate = st[k]; break
                if candidate is not None:
                    st = candidate
                st_no_classifier = {k:v for k,v in st.items() if not k.startswith("classifier.")}
                txt_model.load_state_dict(st_no_classifier, strict=False)
            else:
                txt_model.load_state_dict(st, strict=False)
            print("[INFO] Loaded text checkpoint (partial load).")
        except Exception as e:
            print("[WARN] Failed to load text checkpoint:", e)
    txt_model.to(device).eval()
    aud_model = TransformerLSTM_Embed(embed_dim=embed_dim)
    if os.path.exists(audio_model_path):
        st = torch.load(audio_model_path, map_location=device)
        try:
            if isinstance(st, dict):
                candidate = None
                for k in ('model','state_dict','model_state','model_state_dict'):
                    if k in st and isinstance(st[k], dict):
                        candidate = st[k]; break
                if candidate is not None:
                    st = candidate
                aud_model.load_state_dict(st, strict=False)
            else:
                aud_model.load_state_dict(st, strict=False)
            print("[INFO] Loaded audio checkpoint (partial load).")
        except Exception as e:
            print("[WARN] Failed to load audio checkpoint:", e)
    aud_model.to(device).eval()
    return img_model, txt_model, aud_model

def set_requires_grad(model, requires_grad: bool):
    for p in model.parameters():
        p.requires_grad = requires_grad

# -----------------------------
# EVAL / TRAIN LOOPS (early fusion)
# -----------------------------
def evaluate_full_early(model_components, fusion_head, loader, device, class_names):
    img_model, txt_model, aud_model = model_components
    fusion_head.eval()
    img_model.eval(); txt_model.eval(); aud_model.eval()
    all_true = []
    all_pred = []
    all_probs = []
    with torch.no_grad():
        for batch in loader:
            img = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)
            e_img = img_model(img)
            e_txt = txt_model(input_ids, attention_mask)
            e_aud = aud_model(audio)
            out = fusion_head(e_img, e_txt, e_aud)
            probs = F.softmax(out, dim=1)
            preds = probs.argmax(dim=1)
            all_true.append(labels.cpu().numpy())
            all_pred.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    y_true = np.concatenate(all_true)
    y_pred = np.concatenate(all_pred)
    y_score = np.vstack(all_probs)
    acc = accuracy_score(y_true, y_pred) * 100.0
    print(f"[EVAL] Accuracy: {acc:.2f}%")
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
    return acc, cm, y_true, y_score

def validate_epoch_early(model_components, fusion_head, val_loader, device, num_classes):
    img_model, txt_model, aud_model = model_components
    fusion_head.eval()
    img_model.eval(); txt_model.eval(); aud_model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in val_loader:
            img = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)
            e_img = img_model(img)
            e_txt = txt_model(input_ids, attention_mask)
            e_aud = aud_model(audio)
            out = fusion_head(e_img, e_txt, e_aud)
            preds = out.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_labels)
    acc = (y_pred == y_true).mean() * 100.0
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    return acc, cm, y_true, None

def print_classification_metrics(y_true, y_pred, class_names):
    print("\n========== Classification Report ==========\n")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    print("Weighted Precision: {:.4f}".format(precision))
    print("Weighted Recall:    {:.4f}".format(recall))
    print("Weighted F1-score:  {:.4f}".format(f1))
    print("===========================================\n")

# -----------------------------
# 3D PSO VISUALIZATION
# -----------------------------
def visualize_pso_3d(particles, best_particle=None, metric_names=("Accuracy", "Precision", "Recall")):
    
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np

    particles = np.array(particles)

    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    # Normal particle points
    ax.scatter(
        particles[:, 0], particles[:, 1], particles[:, 2],
        s=65, alpha=0.7
    )

    # Highlight best performer
    if best_particle is not None:
        best_particle = np.array(best_particle)
        ax.scatter(
            best_particle[0], best_particle[1], best_particle[2],
            s=250, color='red', edgecolor='black', marker='o', label="Best Particle"
        )
        ax.legend()

    ax.set_xlabel(metric_names[0])
    ax.set_ylabel(metric_names[1])
    ax.set_zlabel(metric_names[2])
    ax.set_title("3D PSO Per-Iteration Optimization")

    plt.show()

# -----------------------------
# SIMPLE PSO-LIKE SEARCH (illustrative)
# -----------------------------
def run_simple_pso_evaluate(model_components, train_loader, val_loader, class_names,
                            swarm_size=6, iters=3, device=DEVICE, quick_batches=8):

    print(f"[PSO] Starting simple PSO-like search: swarm={swarm_size}, iters={iters}")

    img_model, txt_model, aud_model = model_components
    all_iterations_metrics = []  

    particles = []
    velocities = []
    for _ in range(swarm_size):
        lr = 10 ** np.random.uniform(-5, -2)
        dropout = np.random.uniform(0.0, 0.5)
        d_ratio = np.random.uniform(0.5, 1.0)
        particles.append([lr, dropout, d_ratio])
        velocities.append([0.0, 0.0, 0.0])

    pbest = particles.copy()
    pbest_scores = [-1.0] * swarm_size
    gbest = None
    gbest_score = -1.0

    for it in range(iters):
        print(f"[PSO] Iteration {it+1}/{iters}")

        iter_metrics = []

        # NEW — track best in this iteration
        best_particle_metric = None
        best_particle_acc = -1  

        for i, p in enumerate(particles):
            lr, dropout, d_ratio = p
            d_model = int(512 * float(d_ratio))
            d_model = max(4, int(d_model // 4) * 4)

            fusion_head = EarlyFusion(num_classes=len(class_names), embed_dim=512, d_model=d_model,
                                      use_transformer=True, nhead=4, n_layers=1, dropout=dropout).to(device)

            opt = torch.optim.AdamW(fusion_head.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
            criterion = nn.CrossEntropyLoss()

            fusion_head.train()
            img_model.eval(); txt_model.eval(); aud_model.eval()

            batch_iter = iter(train_loader)
            for b_idx in range(quick_batches):
                try:
                    batch = next(batch_iter)
                except StopIteration:
                    batch_iter = iter(train_loader)
                    batch = next(batch_iter)

                img = batch['image'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                audio = batch['audio'].to(device)
                labels = batch['label'].to(device)

                opt.zero_grad()
                with torch.no_grad():
                    e_img = img_model(img)
                    e_txt = txt_model(input_ids, attention_mask)
                    e_aud = aud_model(audio)

                out = fusion_head(e_img, e_txt, e_aud)
                loss = criterion(out, labels)
                loss.backward()
                opt.step()

            val_acc, _, y_true, y_score = evaluate_full_early(
                (img_model, txt_model, aud_model), 
                fusion_head, val_loader, device, class_names
            )

            y_pred = np.argmax(y_score, axis=1)
            precision, recall, f1, _ = precision_recall_fscore_support(
                y_true, y_pred, average='macro', zero_division=0
            )

            print(f"[PSO] Particle {i} -> Acc: {val_acc:.2f} Prec: {precision:.4f} Rec: {recall:.4f}")

            metric_vec = [val_acc, precision*100, recall*100]
            iter_metrics.append(metric_vec)

            # update best for THIS iteration
            if val_acc > best_particle_acc:
                best_particle_acc = val_acc
                best_particle_metric = metric_vec

            # update personal/global bests
            if val_acc > pbest_scores[i]:
                pbest_scores[i] = val_acc
                pbest[i] = p

            if val_acc > gbest_score:
                gbest_score = val_acc
                gbest = p

        # save iteration metrics
        all_iterations_metrics.append(iter_metrics)

        # visualize this iteration (with red best dot)
        print(f"[PSO] Visualizing iteration {it+1}")
        visualize_pso_3d(
            particles=iter_metrics,
            best_particle=best_particle_metric,
            metric_names=("Val Acc (%)", "Precision (%)", "Recall (%)")
        )

        # update particles
        for i in range(swarm_size):
            inertia = 0.5
            cognitive = 0.8
            social = 0.9
            r1 = np.random.rand(3)
            r2 = np.random.rand(3)
            v = np.array(velocities[i])
            pb = np.array(pbest[i])
            gb = np.array(gbest)
            pos = np.array(particles[i])
            v = inertia * v + cognitive * r1 * (pb - pos) + social * r2 * (gb - pos)

            new_pos = pos + v
            particles[i] = [
                float(np.clip(new_pos[0], 1e-6, 1e-1)),
                float(np.clip(new_pos[1], 0.0, 0.8)),
                float(np.clip(new_pos[2], 0.3, 1.2))
            ]
            velocities[i] = v.tolist()

    print(f"[PSO] Done. Best acc found: {gbest_score:.2f} (particle hyperparams {gbest})")

    return all_iterations_metrics

# -----------------------------
# MAIN TRAINING WITH OPTIONAL PSO
# -----------------------------
def train_early_fusion(csv_path=CSV_PATH, image_model_path=IMAGE_MODEL_PATH,
                 text_model_path=TEXT_MODEL_PATH, audio_model_path=AUDIO_MODEL_PATH,
                 freeze_bases=FREEZE_BASE_MODELS, device=DEVICE,
                 run_pso=False):
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"CSV mapping not found: {csv_path}")
    df = pd.read_csv(csv_path)
    required_cols = {'image_path', 'text', 'audio_path', 'label'}
    if not required_cols.issubset(set(df.columns)):
        raise RuntimeError(f"CSV must have columns: {required_cols}")
    if df['label'].dtype == object:
        le = LabelEncoder()
        df['label'] = le.fit_transform(df['label'].astype(str))
        class_names = list(le.classes_)
    else:
        le = None
        class_names = sorted(df['label'].unique().tolist())
        class_names = [str(int(x)) for x in class_names]
    num_classes = int(df['label'].nunique())
    print(f"Found {len(df)} paired samples. Classes: {num_classes} -> {class_names}")
    idxs = list(range(len(df)))
    from sklearn.model_selection import train_test_split
    train_idx, val_idx = train_test_split(idxs, test_size=0.2, random_state=SEED, stratify=df['label'])
    train_df = df.loc[train_idx].reset_index(drop=True)
    val_df = df.loc[val_idx].reset_index(drop=True)
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME, use_fast=True)
    img_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    train_ds = MultimodalDataset(train_df, tokenizer, LabelEncoder() if le is None else le, img_transform)
    val_ds   = MultimodalDataset(val_df, tokenizer, LabelEncoder() if le is None else le, img_transform)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
    EMBED_DIM = 512
    img_model, txt_model, aud_model = load_unimodal_models(num_classes, image_model_path, text_model_path, audio_model_path, device, embed_dim=EMBED_DIM)
    set_requires_grad(img_model, not (freeze_bases))
    set_requires_grad(txt_model, not (freeze_bases))
    set_requires_grad(aud_model, not (freeze_bases))
    fusion_head = EarlyFusion(num_classes, embed_dim=EMBED_DIM, d_model=512, use_transformer=True, nhead=8, n_layers=1, dropout=0.2).to(device)
    params = list(fusion_head.parameters())
    if not freeze_bases:
        params += list(img_model.parameters()) + list(txt_model.parameters()) + list(aud_model.parameters())
    optimizer = torch.optim.Adam(params, lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = -1.0
    best_state = None
    train_accs, val_accs = [], []
    for epoch in range(1, EPOCHS+1):
        t0 = time.time()
        fusion_head.train()
        if not freeze_bases:
            img_model.train(); txt_model.train(); aud_model.train()
        running_correct = 0
        running_total = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
        for batch in pbar:
            img = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)
            optimizer.zero_grad()
            
            e_img = img_model(img)
            e_txt = txt_model(input_ids, attention_mask)
            e_aud = aud_model(audio)
            out = fusion_head(e_img, e_txt, e_aud)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            preds = out.argmax(dim=1)
            running_total += labels.size(0)
            running_correct += (preds == labels).sum().item()
            train_acc = 100.0 * running_correct / running_total
            pbar.set_postfix({"TrainAcc": f"{train_acc:.2f}%", "Loss": f"{loss.item():.4f}"})
        train_accs.append(train_acc)
        val_acc, cm, y_true_val, y_score_val = validate_epoch_early((img_model, txt_model, aud_model), fusion_head, val_loader, device, num_classes)
        val_accs.append(val_acc)
        print(f"Epoch {epoch}/{EPOCHS} - Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}% | Time: {(time.time()-t0):.1f}s")
        scheduler.step()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {
                "epoch": epoch,
                "fusion_state": copy.deepcopy(fusion_head.state_dict()),
                "img_state": copy.deepcopy(img_model.state_dict()),
                "txt_state": copy.deepcopy(txt_model.state_dict()),
                "aud_state": copy.deepcopy(aud_model.state_dict()),
                "val_acc": val_acc,
                "class_names": class_names
            }
            torch.save(best_state, BEST_FUSION_PATH)
            print(f"  Best early-fusion model saved -> {BEST_FUSION_PATH} (Val Acc {val_acc:.2f}%)")
    # final plots & eval
    plot_accuracy(train_accs, val_accs)
    if best_state is not None:
        print(f"Loaded best early fusion (epoch {best_state['epoch']}, val acc {best_state['val_acc']:.2f}%)")
        fusion_head.load_state_dict(best_state['fusion_state'])
        img_model.load_state_dict(best_state['img_state'], strict=False)
        txt_model.load_state_dict(best_state['txt_state'], strict=False)
        aud_model.load_state_dict(best_state['aud_state'], strict=False)
    final_acc, final_cm, y_true, y_score = evaluate_full_early((img_model, txt_model, aud_model), fusion_head, val_loader, device, class_names)
    print_classification_metrics(y_true, np.argmax(y_score, axis=1), class_names)
    plot_confusion(final_cm, class_names)
    try:
        plot_multiclass_roc(y_true, y_score, class_names)
    except Exception as e:
        print("ROC plot error:", e)
    # Optional PSO run (disabled by default)
    if run_pso:
        # Lightweight PSO: small swarm & iterations by default
        PSO_SWARM = 6
        PSO_ITERS = 3
        QUICK_BATCHES = 6   # number of quick mini-batch steps per particle
        particle_metrics = run_simple_pso_evaluate((img_model, txt_model, aud_model),
                                                   train_loader, val_loader, class_names,
                                                   swarm_size=PSO_SWARM, iters=PSO_ITERS, device=device,
                                                   quick_batches=QUICK_BATCHES)
        visualize_pso_3d(particle_metrics, metric_names=("Val Acc (%)", "Precision (pct)", "Recall (pct)"))
    print("[DONE]")

# -----------------------------
# Run
# -----------------------------
if __name__ == "__main__":
    print("Starting Early Fusion Training")
    # Set run_pso=True to run the simple PSO search after training (can be slow)
    train_early_fusion(run_pso=True)


In [None]:

# For the reason for training and getting better accuracy, this file has been modified from the original version.
import os
import copy
import random
import time
from typing import List, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, models
from PIL import Image

from transformers import BertTokenizer, BertModel
from sklearn.preprocessing import LabelEncoder, label_binarize
from sklearn.metrics import confusion_matrix, roc_curve, auc, accuracy_score
from sklearn.metrics import classification_report, precision_recall_fscore_support
from tqdm import tqdm
import soundfile as sf
import librosa
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 - required for 3D projection


CSV_PATH = r"K:\Code\Project\Research Paper\Emotion Detection\meld dataset\MELD.Raw\self\fusion_dataset.csv"
IMAGE_MODEL_PATH = r"K:\Code\Project\Research Paper\Emotion Detection\efficientnetv3_face_emotion_final.pth"
TEXT_MODEL_PATH  = r"K:\Code\Project\Research Paper\Emotion Detection\Code\bert_emotion_text_final.pth"
AUDIO_MODEL_PATH = r"K:\Code\Project\Research Paper\Emotion Detection\Code\best_transformer_speech_model.pth"

# Unimodal model settings (must match your trained models)
IMG_SIZE = 224
BERT_MODEL_NAME = "bert-base-multilingual-cased"
MAX_LEN = 64

# Audio MFCC params (tweak to match your audio model)
AUDIO_MAX_PAD = 174     # same as audio model (kept)
AUDIO_N_MFCC = 40
AUDIO_N_FFT = 1024
AUDIO_HOP = 512
AUDIO_SR = 22050

# Fusion training hyperparams
BATCH_SIZE = 64           
EPOCHS = 100
LR_FUSION = 1e-3          # fusion head lr
LR_ENCODER = 2e-5        # small lr for encoders if unfrozen
WEIGHT_DECAY = 1e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FREEZE_BASE_MODELS = True  
UNFREEZE_BERT_LAST_N = 4   # when freeze_bases=True, unfreeze last N BERT encoder layers + LayerNorms/pooler
GRAD_CLIP_NORM = 1.0
USE_AMP = True            

BEST_FUSION_PATH = "./best_early_fusion_model_fixed_preload.pth"

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True


def load_audio(path, sr=AUDIO_SR):
    audio, native_sr = sf.read(path)
    if audio is None:
        raise RuntimeError(f"Failed reading audio: {path}")
    # convert stereo -> mono
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    # resample if needed
    if native_sr != sr:
        audio = librosa.resample(audio, orig_sr=native_sr, target_sr=sr)
    return audio, sr


def extract_mfcc(file_path, sr=AUDIO_SR, n_mfcc=AUDIO_N_MFCC, max_pad_len=AUDIO_MAX_PAD,
                 n_fft=AUDIO_N_FFT, hop_length=AUDIO_HOP):

    try:
        signal, sr = load_audio(file_path, sr)
    except Exception:
        # return zeros if audio load fails
        return np.zeros((n_mfcc, max_pad_len), dtype=np.float32)

    if len(signal) < 2048:
        signal = np.pad(signal, (0, 2048 - len(signal)))

    mfcc = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)

    # pad/crop to fixed length
    if mfcc.shape[1] < max_pad_len:
        mfcc = np.pad(mfcc, ((0, 0), (0, max_pad_len - mfcc.shape[1])), mode='constant')
    else:
        mfcc = mfcc[:, :max_pad_len]

    # per-sample normalization (avoid dividing by zero)
    mean = mfcc.mean(axis=1, keepdims=True)
    std = mfcc.std(axis=1, keepdims=True)
    std[std < 1e-6] = 1.0
    mfcc = (mfcc - mean) / std

    return mfcc.astype(np.float32)


class EfficientNetV2_Embed(nn.Module):
    def __init__(self, embed_dim=512, pretrained=True, version="s"):
        super().__init__()

        # Select EfficientNet-V2 model
        model_fn = {
            "s": models.efficientnet_v2_s,
            "m": models.efficientnet_v2_m,
            "l": models.efficientnet_v2_l,
        }[version]

        if pretrained:
            base = model_fn(weights="IMAGENET1K_V1")
        else:
            base = model_fn(weights=None)

     
        in_features = base.classifier[1].in_features

   
        base.classifier = nn.Identity()
        self.base = base

   
        self.proj = nn.Sequential(
            nn.Linear(in_features, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

    def forward(self, x):
        feat = self.base(x)
        return self.proj(feat)


class BERTEncoder(nn.Module):
    def __init__(self, model_name: str, embed_dim: int = 512, dropout: float = 0.2):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.proj = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.bert.config.hidden_size, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = out.pooler_output
        emb = self.proj(pooled)
        return emb

class TransformerLSTM_Embed(nn.Module):
    def __init__(self, embed_dim=512, d_model=128, nhead=4, num_layers=2):
        super().__init__()
        self.feature_proj = nn.Linear(AUDIO_N_MFCC, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=256, dropout=0.3, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.lstm = nn.LSTM(d_model, 128, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        self.proj = nn.Sequential(
            nn.Linear(128, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
    def forward(self, x):
       
        x = x.squeeze(1).permute(0, 2, 1)   # (B, T, n_mfcc)
        x = self.feature_proj(x)            # (B, T, d_model)
        x = self.transformer(x)             # (B, T, d_model)
        x, _ = self.lstm(x)                 # (B, T, hidden)
        h = self.dropout(x[:, -1, :])       # (B, 128)
        emb = self.proj(h)                  # (B, embed_dim)
        return emb


class MultimodalDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: BertTokenizer, label_encoder: LabelEncoder,
                 img_transform, audio_pad=AUDIO_MAX_PAD, preload: bool = False, preload_verbose: bool = True):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.le = label_encoder
        self.img_transform = img_transform
        self.audio_pad = audio_pad
        self.preload = preload
        self.cache = None
        if self.preload:
            self._preload_to_ram(verbose=preload_verbose)

    def __len__(self): return len(self.df)

    def _preload_to_ram(self, verbose: bool = True):
        """Load images (transformed), tokenized text tensors and MFCCs into memory lists."""
        self.cache = [None] * len(self.df)
        iterator = range(len(self.df))
        if verbose:
            iterator = tqdm(iterator, desc="Preloading dataset into RAM", ncols=80)
        for idx in iterator:
            row = self.df.loc[idx]
            img_path = row['image_path']
            txt = str(row['text'])
            audio_path = row['audio_path']
            label = int(row['label'])
          
            try:
                img = Image.open(img_path).convert("RGB")
                img_t = self.img_transform(img)
            except Exception:
              
                img_t = torch.zeros(3, IMG_SIZE, IMG_SIZE, dtype=torch.float32)
           
            try:
                enc = self.tokenizer(txt, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors='pt')
                input_ids = enc['input_ids'].squeeze(0)
                attention_mask = enc['attention_mask'].squeeze(0)
            except Exception:
                input_ids = torch.zeros(MAX_LEN, dtype=torch.long)
                attention_mask = torch.zeros(MAX_LEN, dtype=torch.long)
           
            try:
                mfcc = extract_mfcc(audio_path, sr=AUDIO_SR, n_mfcc=AUDIO_N_MFCC, max_pad_len=self.audio_pad,
                                    n_fft=AUDIO_N_FFT, hop_length=AUDIO_HOP)
                mfcc_t = torch.tensor(mfcc, dtype=torch.float32).unsqueeze(0)  # (1, n_mfcc, T)
            except Exception:
                mfcc_t = torch.zeros(1, AUDIO_N_MFCC, self.audio_pad, dtype=torch.float32)
            self.cache[idx] = {
                "image": img_t,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "audio": mfcc_t,
                "label": torch.tensor(label, dtype=torch.long)
            }
        if verbose:
            print(f"[PRELOAD] Finished preloading {len(self.df)} samples into RAM.")

    def __getitem__(self, idx):
        if self.preload and (self.cache is not None):
            item = self.cache[idx]
            
            return {
                "image": item['image'].clone(),
                "input_ids": item['input_ids'].clone(),
                "attention_mask": item['attention_mask'].clone(),
                "audio": item['audio'].clone(),
                "label": item['label'].clone()
            }

        row = self.df.loc[idx]
        img_path = row['image_path']
        txt = str(row['text'])
        audio_path = row['audio_path']
        label = int(row['label'])
        # Image
        try:
            img = Image.open(img_path).convert("RGB")
            img_t = self.img_transform(img)
        except Exception:
            img_t = torch.zeros(3, IMG_SIZE, IMG_SIZE)
        # Text
        enc = self.tokenizer(txt, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors='pt')
        input_ids = enc['input_ids'].squeeze(0)
        attention_mask = enc['attention_mask'].squeeze(0)
        # Audio -> mfcc
        try:
            mfcc = extract_mfcc(audio_path, sr=AUDIO_SR, n_mfcc=AUDIO_N_MFCC, max_pad_len=self.audio_pad,
                                n_fft=AUDIO_N_FFT, hop_length=AUDIO_HOP)
            mfcc_t = torch.tensor(mfcc, dtype=torch.float32).unsqueeze(0)  # (1, n_mfcc, T)
        except Exception:
            mfcc_t = torch.zeros(1, AUDIO_N_MFCC, self.audio_pad, dtype=torch.float32)
        return {
            "image": img_t,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "audio": mfcc_t,
            "label": torch.tensor(label, dtype=torch.long)
        }


def plot_accuracy(train_accs, val_accs):
    plt.figure(figsize=(7,4))
    plt.plot(range(1,len(train_accs)+1), train_accs, marker='o', label='Train Acc')
    plt.plot(range(1,len(val_accs)+1), val_accs, marker='o', label='Val Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title("Train vs Val Acc")
    plt.grid(True); plt.legend(); plt.tight_layout(); plt.show()


def plot_confusion(cm, labels):
    fig, ax = plt.subplots(figsize=(7,7))
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(len(labels)), yticks=np.arange(len(labels)),
           xticklabels=labels, yticklabels=labels, xlabel='Predicted', ylabel='True', title='Confusion Matrix')
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], 'd'), ha='center', va='center',
                    color='white' if cm[i, j] > thresh else 'black')
    plt.tight_layout(); plt.show()


def plot_multiclass_roc(y_true, y_score, class_names):
    try:
        y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
        plt.figure(figsize=(8,6))
        for i in range(y_true_bin.shape[1]):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_score[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f"{class_names[i]} (AUC={roc_auc:.2f})")
        fpr, tpr, _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
        plt.plot(fpr, tpr, label=f"micro (AUC={auc(fpr,tpr):.2f})", linestyle='--')
        plt.plot([0,1],[0,1],'k--', linewidth=0.6)
        plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("Multi-class ROC"); plt.legend(fontsize='small'); plt.grid(True); plt.tight_layout(); plt.show()
    except Exception as e:
        print("ROC plotting failed:", e)


class EarlyFusion(nn.Module):
    def __init__(self, num_classes, embed_dim=512, d_model=512,
                 use_transformer=True, nhead=8, n_layers=1, dropout=0.2):
        super().__init__()

        self.use_transformer = use_transformer


        if embed_dim != d_model:
            self.proj_img = nn.Linear(embed_dim, d_model)
            self.proj_txt = nn.Linear(embed_dim, d_model)
            self.proj_aud = nn.Linear(embed_dim, d_model)
        else:
            self.proj_img = nn.Identity()
            self.proj_txt = nn.Identity()
            self.proj_aud = nn.Identity()


        self.modality_type_embed = nn.Embedding(3, d_model)


        self.pre_ln = nn.LayerNorm(d_model)

        # Transformer fusion
        if use_transformer:
            if d_model % nhead != 0:
                raise ValueError(f"d_model ({d_model}) must be divisible by nhead ({nhead})")

            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=d_model * 2,
                dropout=dropout,
                batch_first=True
            )
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

            self.classifier = nn.Sequential(
                nn.Linear(d_model, d_model),
                nn.LayerNorm(d_model),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(d_model, num_classes),
            )

        else:
            self.classifier = nn.Sequential(
                nn.Linear(d_model * 3, d_model),
                nn.ReLU(),
                nn.LayerNorm(d_model),
                nn.Dropout(dropout),
                nn.Linear(d_model, num_classes),
            )


        self.modality_scale = nn.Parameter(torch.ones(3))

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, e_img, e_txt, e_aud):

        # Project embeddings
        t_img = self.proj_img(e_img)
        t_txt = self.proj_txt(e_txt)
        t_aud = self.proj_aud(e_aud)

        scales = torch.relu(self.modality_scale)

        t_img = t_img * scales[0]
        t_txt = t_txt * scales[1]
        t_aud = t_aud * scales[2]

        if self.use_transformer:
            tokens = torch.stack([t_img, t_txt, t_aud], dim=1)  # (B, 3, d_model)


            B = tokens.size(0)
            mod_ids = torch.tensor([0, 1, 2], device=tokens.device)\
                        .unsqueeze(0).repeat(B, 1)
            tokens = tokens + self.modality_type_embed(mod_ids)


            tokens = self.pre_ln(tokens)

            tokens = self.transformer(tokens)
            pooled = tokens.mean(dim=1)  # mean-pool over 3 modalities
            logits = self.classifier(pooled)

        else:
            concat = torch.cat([t_img, t_txt, t_aud], dim=1)
            logits = self.classifier(concat)

        return logits


def load_unimodal_models(num_classes, image_model_path, text_model_path, audio_model_path, device, embed_dim=512):
    img_model = EfficientNetV2_Embed(embed_dim=embed_dim, pretrained=True,version="s")
    if os.path.exists(image_model_path):
        st = torch.load(image_model_path, map_location=device)
        try:
            if isinstance(st, dict):
                candidate = None
                for k in ('model','state_dict','model_state','model_state_dict'):
                    if k in st and isinstance(st[k], dict):
                        candidate = st[k]; break
                if candidate is not None:
                    st = candidate
                st_clean = {k:v for k,v in st.items() if not (k.startswith("fc.") or k.startswith("classifier.") or ('fc' in k and k.endswith('weight')))}
                img_model.load_state_dict(st_clean, strict=False)
            else:
                img_model.load_state_dict(st, strict=False)
            print("[INFO] Loaded image checkpoint (partial load allowed).")
        except Exception as e:
            print("[WARN] Failed to load image checkpoint cleanly:", e)
    img_model.to(device)

    txt_model = BERTEncoder(BERT_MODEL_NAME, embed_dim=embed_dim)
    if os.path.exists(text_model_path):
        st = torch.load(text_model_path, map_location=device)
        try:
            if isinstance(st, dict):
                candidate = None
                for k in ('model','state_dict','model_state','model_state_dict'):
                    if k in st and isinstance(st[k], dict):
                        candidate = st[k]; break
                if candidate is not None:
                    st = candidate

                st_no_classifier = {k:v for k,v in st.items() if not k.startswith("classifier.")}
                txt_model.load_state_dict(st_no_classifier, strict=False)
            else:
                txt_model.load_state_dict(st, strict=False)
            print("[INFO] Loaded text checkpoint (partial load).")
        except Exception as e:
            print("[WARN] Failed to load text checkpoint:", e)
    txt_model.to(device)

    aud_model = TransformerLSTM_Embed(embed_dim=embed_dim)
    if os.path.exists(audio_model_path):
        st = torch.load(audio_model_path, map_location=device)
        try:
            if isinstance(st, dict):
                candidate = None
                for k in ('model','state_dict','model_state','model_state_dict'):
                    if k in st and isinstance(st[k], dict):
                        candidate = st[k]; break
                if candidate is not None:
                    st = candidate
                aud_model.load_state_dict(st, strict=False)
            else:
                aud_model.load_state_dict(st, strict=False)
            print("[INFO] Loaded audio checkpoint (partial load).")
        except Exception as e:
            print("[WARN] Failed to load audio checkpoint:", e)
    aud_model.to(device)

    return img_model, txt_model, aud_model


def set_requires_grad(model, requires_grad: bool):
    for p in model.parameters():
        p.requires_grad = requires_grad

def unfreeze_efficientnet(img_model, depth=2):
  

    for p in img_model.base.parameters():
        p.requires_grad = False

    stages = [
        img_model.base.features[2],  # Stage 3
        img_model.base.features[3],  # Stage 4
        img_model.base.features[4],  # Stage 5
        img_model.base.features[5],  # Stage 6
    ]


    for s in stages[-depth:]:
        for p in s.parameters():
            p.requires_grad = True


    for p in img_model.proj.parameters():
        p.requires_grad = True

    print(f"[INFO] Unfroze last {depth} EfficientNetV2 stages + projection layer.")

def freeze_bert_layers(bert_model: BertModel, unfreeze_last_n: int = 4):


    for name, p in bert_model.named_parameters():
        p.requires_grad = False

    for name, p in bert_model.named_parameters():
        if name.startswith("embeddings.") or name.startswith("pooler.") or "LayerNorm" in name or "layer_norm" in name:
            p.requires_grad = True


    try:
        total = bert_model.config.num_hidden_layers
        for i in range(total - unfreeze_last_n, total):
            prefix = f"encoder.layer.{i}."
            for name, p in bert_model.named_parameters():
                if name.startswith(prefix):
                    p.requires_grad = True
    except Exception:

        for name, p in bert_model.named_parameters():
            if "encoder.layer" in name and any(f"encoder.layer.{j}." in name for j in range(max(0, total-unfreeze_last_n), total)):
                p.requires_grad = True


def evaluate_full_early(model_components, fusion_head, loader, device, class_names):
    img_model, txt_model, aud_model = model_components
    fusion_head.eval()
    img_model.eval(); txt_model.eval(); aud_model.eval()
    all_true = []
    all_pred = []
    all_probs = []
    with torch.no_grad():
        for batch in loader:
            img = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)
            e_img = img_model(img)
            e_txt = txt_model(input_ids, attention_mask)
            e_aud = aud_model(audio)
            out = fusion_head(e_img, e_txt, e_aud)
            probs = F.softmax(out, dim=1)
            preds = probs.argmax(dim=1)
            all_true.append(labels.cpu().numpy())
            all_pred.append(preds.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    y_true = np.concatenate(all_true)
    y_pred = np.concatenate(all_pred)
    y_score = np.vstack(all_probs)
    acc = accuracy_score(y_true, y_pred) * 100.0
    print(f"[EVAL] Accuracy: {acc:.2f}%")
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
    return acc, cm, y_true, y_score


def validate_epoch_early(model_components, fusion_head, val_loader, device, num_classes):
    img_model, txt_model, aud_model = model_components
    fusion_head.eval()
    img_model.eval(); txt_model.eval(); aud_model.eval()

    preds_list = []
    labels_list = []


    use_amp = True

    with torch.no_grad():
        for batch in val_loader:
            img = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)


            with torch.amp.autocast('cuda',enabled=USE_AMP):
                e_img = img_model(img)
                e_txt = txt_model(input_ids, attention_mask)
                e_aud = aud_model(audio)
                out = fusion_head(e_img, e_txt, e_aud)

            preds = out.argmax(dim=1)


            preds_list.append(preds)
            labels_list.append(labels)


    y_pred = torch.cat(preds_list).cpu().numpy()
    y_true = torch.cat(labels_list).cpu().numpy()


    acc = (y_pred == y_true).mean() * 100.0


    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))

    return acc, cm, y_true, None



def print_classification_metrics(y_true, y_pred, class_names):
    print("\n========== Classification Report ==========" )
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    print("Weighted Precision: {:.4f}".format(precision))
    print("Weighted Recall:    {:.4f}".format(recall))
    print("Weighted F1-score:  {:.4f}".format(f1))
    print("===========================================\n")

def visualize_pso_3d(particles, best_particle=None, metric_names=("Accuracy", "Precision", "Recall")):
    import matplotlib.pyplot as plt
    import numpy as np
    particles = np.array(particles)
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')
    if particles.size == 0:
        print("[PSO VIS] no particles to display")
        return
    ax.scatter(particles[:, 0], particles[:, 1], particles[:, 2], s=65, alpha=0.7)
    if best_particle is not None:
        best_particle = np.array(best_particle)
        ax.scatter(best_particle[0], best_particle[1], best_particle[2],
                   s=250, color='red', edgecolor='black', marker='o', label="Best Particle")
        ax.legend()
    ax.set_xlabel(metric_names[0])
    ax.set_ylabel(metric_names[1])
    ax.set_zlabel(metric_names[2])
    ax.set_title("3D PSO Per-Iteration Optimization")
    plt.show()


def run_simple_pso_evaluate(model_components, train_loader, val_loader, class_names,
                            swarm_size=6, iters=3, device=DEVICE, quick_batches=8):
    print(f"[PSO] Starting simple PSO-like search: swarm={swarm_size}, iters={iters}")
    img_model, txt_model, aud_model = model_components
    all_iterations_metrics = []

    particles = []
    velocities = []
    for _ in range(swarm_size):
        lr = 10 ** np.random.uniform(-5, -2)
        dropout = np.random.uniform(0.0, 0.5)
        d_ratio = np.random.uniform(0.5, 1.0)
        particles.append([lr, dropout, d_ratio])
        velocities.append([0.0, 0.0, 0.0])
    pbest = particles.copy()
    pbest_scores = [-1.0] * swarm_size
    gbest = None
    gbest_score = -1.0
    for it in range(iters):
        print(f"[PSO] Iteration {it+1}/{iters}")
        iter_metrics = []
        best_particle_metric = None
        best_particle_acc = -1
        for i, p in enumerate(particles):
            lr, dropout, d_ratio = p
            d_model = int(512 * float(d_ratio))
            d_model = max(4, int(d_model // 4) * 4)

            nhead = 4
            if d_model % 8 == 0:
                nhead = 8
            elif d_model % 4 == 0:
                nhead = 4
            else:
                nhead = 1
            fusion_head = EarlyFusion(num_classes=len(class_names), embed_dim=512, d_model=d_model,
                                      use_transformer=True, nhead=nhead, n_layers=1, dropout=dropout).to(device)
            opt = torch.optim.AdamW(fusion_head.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
            criterion = nn.CrossEntropyLoss()
            fusion_head.train()
            img_model.eval(); txt_model.eval(); aud_model.eval()
            batch_iter = iter(train_loader)

            for b_idx in range(quick_batches):
                try:
                    batch = next(batch_iter)
                except StopIteration:
                    batch_iter = iter(train_loader)
                    batch = next(batch_iter)
                img = batch['image'].to(device)
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                audio = batch['audio'].to(device)
                labels = batch['label'].to(device)
                opt.zero_grad()

                e_img = img_model(img)
                e_txt = txt_model(input_ids, attention_mask)
                e_aud = aud_model(audio)
                out = fusion_head(e_img, e_txt, e_aud)
                loss = criterion(out, labels)
                loss.backward()
                opt.step()
            val_acc, _, y_true, y_score = evaluate_full_early(
                (img_model, txt_model, aud_model),
                fusion_head, val_loader, device, class_names
            )
            y_pred = np.argmax(y_score, axis=1)
            precision, recall, f1, _ = precision_recall_fscore_support(
                y_true, y_pred, average='macro', zero_division=0
            )
            print(f"[PSO] Particle {i} -> Acc: {val_acc:.2f} Prec: {precision:.4f} Rec: {recall:.4f}")
            metric_vec = [val_acc, precision*100, recall*100]
            iter_metrics.append(metric_vec)
            if val_acc > best_particle_acc:
                best_particle_acc = val_acc
                best_particle_metric = metric_vec
            if val_acc > pbest_scores[i]:
                pbest_scores[i] = val_acc
                pbest[i] = p
            if val_acc > gbest_score:
                gbest_score = val_acc
                gbest = p
        all_iterations_metrics.append(iter_metrics)
        print(f"[PSO] Visualizing iteration {it+1}")
        visualize_pso_3d(particles=iter_metrics, best_particle=best_particle_metric,
                         metric_names=("Val Acc (%)", "Precision (%)", "Recall (%)"))

        for i in range(swarm_size):
            inertia = 0.5
            cognitive = 0.8
            social = 0.9
            r1 = np.random.rand(3)
            r2 = np.random.rand(3)
            v = np.array(velocities[i])
            pb = np.array(pbest[i])
            gb = np.array(gbest) if gbest is not None else np.array(particles[i])
            pos = np.array(particles[i])
            v = inertia * v + cognitive * r1 * (pb - pos) + social * r2 * (gb - pos)
            new_pos = pos + v
            particles[i] = [
                float(np.clip(new_pos[0], 1e-6, 1e-1)),
                float(np.clip(new_pos[1], 0.0, 0.8)),
                float(np.clip(new_pos[2], 0.3, 1.2))
            ]
            velocities[i] = v.tolist()
    print(f"[PSO] Done. Best acc found: {gbest_score:.2f} (particle hyperparams {gbest})")
    return all_iterations_metrics

def train_early_fusion(csv_path=CSV_PATH, image_model_path=IMAGE_MODEL_PATH,
                 text_model_path=TEXT_MODEL_PATH, audio_model_path=AUDIO_MODEL_PATH,
                 freeze_bases=FREEZE_BASE_MODELS, device=DEVICE,
                 run_pso=False, preload_dataset: bool = True):
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"CSV mapping not found: {csv_path}")
    df = pd.read_csv(csv_path)
    required_cols = {'image_path', 'text', 'audio_path', 'label'}
    if not required_cols.issubset(set(df.columns)):
        raise RuntimeError(f"CSV must have columns: {required_cols}")
    # label encoding
    if df['label'].dtype == object:
        le = LabelEncoder()
        df['label'] = le.fit_transform(df['label'].astype(str))
        class_names = list(le.classes_)
    else:
        le = None
        class_names = sorted(df['label'].unique().tolist())
        class_names = [str(int(x)) for x in class_names]
    num_classes = int(df['label'].nunique())
    print(f"Found {len(df)} paired samples. Classes: {num_classes} -> {class_names}")
    idxs = list(range(len(df)))
    from sklearn.model_selection import train_test_split
    train_idx, val_idx = train_test_split(idxs, test_size=0.2, random_state=SEED, stratify=df['label'])
    train_df = df.loc[train_idx].reset_index(drop=True)
    val_df = df.loc[val_idx].reset_index(drop=True)
    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME, use_fast=True)
    img_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])

    train_ds = MultimodalDataset(train_df, tokenizer, le, img_transform, preload=preload_dataset)
    val_ds   = MultimodalDataset(val_df, tokenizer, le, img_transform, preload=preload_dataset)



    pin_mem = True if device.type == 'cuda' else False
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=pin_mem)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=pin_mem)

    EMBED_DIM = 512
    img_model, txt_model, aud_model = load_unimodal_models(num_classes, image_model_path, text_model_path, audio_model_path, device, embed_dim=EMBED_DIM)

    if FREEZE_BASE_MODELS:


        set_requires_grad(img_model, False)
        set_requires_grad(txt_model, False)
        set_requires_grad(aud_model, False)


        for p in img_model.base.parameters():
            p.requires_grad = False

        # EfficientNetV2 stages (features[2]–[5])
        stages = [
            img_model.base.features[2],  # Stage 3
            img_model.base.features[3],  # Stage 4
            img_model.base.features[4],  # Stage 5
            img_model.base.features[5],  # Stage 6
        ]

        # unfreeze last 2 stages: Stage 5 & 6
        for s in stages[-2:]:
            for p in s.parameters():
                p.requires_grad = True

        # projection head always trainable
        set_requires_grad(img_model.proj, True)

        print("[INFO] Unfroze EfficientNetV2 last 2 stages + projection.")

        for name, p in aud_model.named_parameters():
            if "transformer.layers.1" in name or "lstm" in name:
                p.requires_grad = True

        # audio projection always trainable
        set_requires_grad(aud_model.proj, True)

        print("[INFO] Unfroze audio last transformer layer + LSTM + projection.")


        freeze_bert_layers(txt_model.bert, UNFREEZE_BERT_LAST_N)
        print(f"[INFO] Unfroze last {UNFREEZE_BERT_LAST_N} BERT layers.")

    else:
        # No freezing at all
        set_requires_grad(img_model, True)
        set_requires_grad(txt_model, True)
        set_requires_grad(aud_model, True)
        print("[INFO] All encoders fully trainable.")

    fusion_head = EarlyFusion(num_classes, embed_dim=EMBED_DIM, d_model=512, use_transformer=True, nhead=8, n_layers=1, dropout=0.2).to(device)   
    fusion_params = list(fusion_head.parameters())
    encoder_params = []
    for m in (img_model, txt_model, aud_model):
        for p in m.parameters():
            if p.requires_grad:
                encoder_params.append(p)
    param_groups = [
        {"params": fusion_params, "lr": LR_FUSION, "weight_decay": WEIGHT_DECAY},
    ]
    if len(encoder_params) > 0:
        param_groups.append({"params": encoder_params, "lr": LR_ENCODER, "weight_decay": WEIGHT_DECAY})

    optimizer = torch.optim.AdamW([
    {"params": img_model.parameters(),   "lr": 1e-4, "weight_decay": 0.01},   # EfficientNetV2 (partial unfreeze)
    {"params": txt_model.parameters(),  "lr": 6e-5, "weight_decay": 0.01},   # BERT (last few layers unfrozen)
    {"params": aud_model.parameters(), "lr": 1e-4, "weight_decay": 0.0},    # Audio model
    {"params": fusion_head.parameters(), "lr": 5e-4, "weight_decay": 0.0},    # Fusion layers
])
    total_steps = int(EPOCHS * len(train_loader))
    warmup_steps = max(1, int(0.1 * total_steps))
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    criterion = nn.CrossEntropyLoss()
    best_val_acc = -1.0
    best_state = None
    train_accs, val_accs = [], []

    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

    for epoch in range(1, EPOCHS+1):
        t0 = time.time()
        fusion_head.train()
        img_model.train(); txt_model.train(); aud_model.train()
        running_correct = 0
        running_total = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
        for batch in pbar:
            img = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            with torch.amp.autocast('cuda',enabled=USE_AMP):
                e_img = img_model(img)
                e_txt = txt_model(input_ids, attention_mask)
                e_aud = aud_model(audio)
                out = fusion_head(e_img, e_txt, e_aud)
                loss = criterion(out, labels)
            
            if USE_AMP:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_( [p for g in optimizer.param_groups for p in g['params'] if p.requires_grad], GRAD_CLIP_NORM)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_( [p for g in optimizer.param_groups for p in g['params'] if p.requires_grad], GRAD_CLIP_NORM)
                optimizer.step()

        

            preds = out.argmax(dim=1)
            running_total += labels.size(0)
            running_correct += (preds == labels).sum().item()
            train_acc = 100.0 * running_correct / running_total
            pbar.set_postfix({"TrainAcc": f"{train_acc:.2f}%", "Loss": f"{loss.item():.4f}"})
        scheduler.step()
        train_accs.append(train_acc)
        val_acc, cm, y_true_val, y_score_val = validate_epoch_early((img_model, txt_model, aud_model), fusion_head, val_loader, device, num_classes)
        val_accs.append(val_acc)
        print(f"Epoch {epoch}/{EPOCHS} - Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}% | Time: {(time.time()-t0):.1f}s")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {
                "epoch": epoch,
                "fusion_state": copy.deepcopy(fusion_head.state_dict()),
                "img_state": copy.deepcopy(img_model.state_dict()),
                "txt_state": copy.deepcopy(txt_model.state_dict()),
                "aud_state": copy.deepcopy(aud_model.state_dict()),
                "val_acc": val_acc,
                "class_names": class_names
            }
            torch.save(best_state, BEST_FUSION_PATH)
            print(f" ✅ Best early-fusion model saved -> {BEST_FUSION_PATH} (Val Acc {val_acc:.2f}%)")
 
    plot_accuracy(train_accs, val_accs)
    if best_state is not None:
        print(f"Loaded best early fusion (epoch {best_state['epoch']}, val acc {best_state['val_acc']:.2f}%)")
        fusion_head.load_state_dict(best_state['fusion_state'])
        img_model.load_state_dict(best_state['img_state'], strict=False)
        txt_model.load_state_dict(best_state['txt_state'], strict=False)
        aud_model.load_state_dict(best_state['aud_state'], strict=False)

    final_acc, final_cm, y_true, y_score = evaluate_full_early((img_model, txt_model, aud_model), fusion_head, val_loader, device, class_names)
    print_classification_metrics(y_true, np.argmax(y_score, axis=1), class_names)
    plot_confusion(final_cm, class_names)
    try:
        plot_multiclass_roc(y_true, y_score, class_names)
    except Exception as e:
        print("ROC plot error:", e)

    if run_pso:
        PSO_SWARM = 6
        PSO_ITERS = 3
        QUICK_BATCHES = 6
        particle_metrics = run_simple_pso_evaluate((img_model, txt_model, aud_model),
                                                   train_loader, val_loader, class_names,
                                                   swarm_size=PSO_SWARM, iters=PSO_ITERS, device=device,
                                                   quick_batches=QUICK_BATCHES)
        
        flattened = [m for iter_batch in particle_metrics for m in iter_batch]
        visualize_pso_3d(flattened, metric_names=("Val Acc (%)", "Precision (pct)", "Recall (pct)"))

    print("[DONE]")


if __name__ == "__main__":
    print("Starting Early Fusion Training (fixed + preload)")
    
    train_early_fusion(run_pso=True, preload_dataset=True)




Starting Early Fusion Training (fixed + preload)
Found 13674 paired samples. Classes: 7 -> ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise']


Preloading dataset into RAM: 100%|████████| 10939/10939 [01:50<00:00, 98.85it/s]


[PRELOAD] Finished preloading 10939 samples into RAM.


Preloading dataset into RAM: 100%|██████████| 2735/2735 [00:29<00:00, 94.13it/s]


[PRELOAD] Finished preloading 2735 samples into RAM.
[INFO] Loaded image checkpoint (partial load allowed).


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


[INFO] Loaded text checkpoint (partial load).
[INFO] Loaded audio checkpoint (partial load).
[INFO] Unfroze EfficientNetV2 last 2 stages + projection.
[INFO] Unfroze audio last transformer layer + LSTM + projection.
[INFO] Unfroze last 4 BERT layers.


Epoch 1/100:   1%|          | 2/171 [00:06<07:18,  2.60s/it, TrainAcc=11.72%, Loss=2.4840]