In [None]:
!pip install pydub
from google.colab import drive
drive.mount("/content/drive")
!pip install torch torchaudio transformers librosa scikit-learn
PHQ_CSV_PATH = "/content/combined_PHQ_sorted.csv"
AUDIO_DIR    = "/content/drive/MyDrive/DAIC-WOZ/processed_patient_segments"
VISUAL_DIR   = "/content/drive/MyDrive/DAIC-WOZ/processed_visual_segments"
OUTPUT_CSV   = "/content/drive/MyDrive/DAIC-WOZ/dataset_info_all.csv"

#1 Firstly, the participant interview segment dataset was randomly split for training and testing. This method allows data from the same participant to appear in both sets, causing data leakage and resulting in an oddly high accuracy of 99% as epochs progress.

In [None]:
!pip install pydub

from google.colab import drive
drive.mount("/content/drive")

!pip install torch torchaudio transformers librosa scikit-learn

import os
import gc
import random
import pandas as pd
import numpy as np
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import Wav2Vec2Config, Wav2Vec2Model, Wav2Vec2Processor
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import itertools

# -----------------------------------------------------
# Set random seed for reproducibility
# -----------------------------------------------------
SEED = 103
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# -----------------------------------------------------
# 1) AudioDataset (Returns participant_id as well)
# -----------------------------------------------------
class AudioDataset(Dataset):
    def __init__(self, df: pd.DataFrame, target_sr=16000, verbose=False):
        self.df = df.reset_index(drop=True)
        self.target_sr = target_sr
        self.verbose = verbose

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        audio_path = row["audio_path"]
        label = float(row["label"])  # 0 or 1
        participant_id = row["participant_id"]  # <-- added

        if self.verbose:
            print(f"[Dataset] idx={idx}, audio_path={audio_path}, label={label}, participant_id={participant_id}")

        try:
            # Load audio
            waveform, sr = torchaudio.load(audio_path)

            # If the number of channels is 2 or more, convert to mono (use only the first channel)
            if waveform.shape[0] > 1:
                waveform = waveform[0, :].unsqueeze(0)

            if sr != self.target_sr:
                waveform = torchaudio.functional.resample(waveform, sr, self.target_sr)
                sr = self.target_sr

            return (idx, waveform, sr, label, participant_id)

        except Exception as e:
            # Return None if loading fails
            print(f"[WARN] Skipped sample idx={idx}, audio='{audio_path}' due to error: {e}")
            return None

# -----------------------------------------------------
# 2) collate_fn: skip None
# -----------------------------------------------------
def collate_fn_audio(batch):
    # Exclude None
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  # Skip in subsequent layers

    sample_indices = []
    waveforms = []
    srs = []
    labels = []
    participant_ids = []

    for (idx, wf, sr, lb, pid) in batch:
        sample_indices.append(idx)
        waveforms.append(wf)
        srs.append(torch.tensor(sr))
        labels.append(torch.tensor(lb))
        participant_ids.append(pid)

    # waveforms: List of [1, T]
    wave_1d_list = [item.squeeze(0) for item in waveforms]  # => [T]
    padded_wav_2d = pad_sequence(wave_1d_list, batch_first=True, padding_value=0.0)  # => (B,T)
    padded_wav_3d = padded_wav_2d.unsqueeze(1)  # => (B,1,T)

    labels_tensor = torch.stack(labels, dim=0)
    sr_tensor = torch.stack(srs, dim=0)

    return sample_indices, padded_wav_3d, sr_tensor, labels_tensor, participant_ids

# -----------------------------------------------------
# 3) Model Definition (Wav2Vec2 + mean pooling + classifier)
# -----------------------------------------------------
class Wav2Vec2AudioEncoder(nn.Module):
    def __init__(self,
                 model_name="facebook/wav2vec2-large-960h",
                 output_hidden_states=False,
                 freeze_feature_extractor=True):
        super().__init__()
        self.config = Wav2Vec2Config.from_pretrained(model_name)
        self.model = Wav2Vec2Model.from_pretrained(model_name)
        if freeze_feature_extractor:
            for param in self.model.feature_extractor.parameters():
                param.requires_grad = False
        self.model.config.output_hidden_states = output_hidden_states
        self.hidden_size = self.config.hidden_size

    def forward(self, input_values):
        outputs = self.model(input_values=input_values)
        return outputs.last_hidden_state  # (B, T', hidden_size)

class AudioOnlyModel(nn.Module):
    """
    1) Feature extraction with Wav2Vec2
    2) Linear transformation + mean pooling
    3) Classification
    """
    def __init__(
        self,
        audio_model_name="facebook/wav2vec2-large-960h",
        freeze_feature_extractor=True,
        cross_hidden_dim=384,
        num_classes=2
    ):
        super().__init__()
        # Audio encoder
        self.audio_encoder = Wav2Vec2AudioEncoder(
            model_name=audio_model_name,
            freeze_feature_extractor=freeze_feature_extractor
        )
        audio_out_dim = self.audio_encoder.hidden_size  # 1024 for wav2vec2-large-960h

        # Project to cross_hidden_dim
        self.proj_audio = nn.Linear(audio_out_dim, cross_hidden_dim)
        self.norm_audio = nn.LayerNorm(cross_hidden_dim)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(cross_hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_values):
        # Audio
        audio_embed = self.audio_encoder(input_values)  # (B, T_a, hidden_size)

        # Project dimensions
        A = self.proj_audio(audio_embed)  # (B, T_a, cross_hidden_dim)

        # mean pooling
        A_pool = A.mean(dim=1)  # (B, cross_hidden_dim)
        A_pool = self.norm_audio(A_pool)

        # Classification
        logits = self.classifier(A_pool)
        return logits

# -----------------------------------------------------
# 4) wave_to_wav2vec2_input
# -----------------------------------------------------
processor = Wav2Vec2Processor.from_pretrained(
    "facebook/wav2vec2-large-960h",
    return_attention_mask=False
)

def wave_to_wav2vec2_input(wave_batch, sr_batch, target_sr=16000):
    """
    wave_batch : (B,1,T)
    sr_batch : (B,)
    """
    B = wave_batch.size(0)
    input_values_list = []
    for i in range(B):
        wave_i = wave_batch[i, 0, :].cpu().numpy()
        sr_i = sr_batch[i].item()
        if sr_i != target_sr:
            wave_i = torchaudio.functional.resample(
                torch.from_numpy(wave_i), sr_i, target_sr
            ).numpy()
        out = processor(
            wave_i,
            sampling_rate=target_sr,
            return_tensors="pt",
            padding=True,
            truncation=False,
            return_attention_mask=False
        )
        input_values_list.append(out.input_values[0])

    # Pad if lengths are different
    padded_input_values = pad_sequence(
        input_values_list,
        batch_first=True,
        padding_value=0.0
    )
    return padded_input_values  # (B, T_wav2vec)

# -----------------------------------------------------
# 5) Segment-level Evaluation
# -----------------------------------------------------
def evaluate_one_epoch(
    model, data_loader, criterion, device, print_confusion=True
):
    """
    Evaluate at the *segment level*.
    Returns average loss, segment-level accuracy, recall of label=1,
    and predictions/labels in lists (for optional further use).
    """
    model.eval()
    total_loss = 0.0
    total_samples = 0
    correct = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            if batch is None:  # Skip if everything is None
                continue

            sample_indices, wave_batch, sr_batch, lbl_batch, pid_list = batch

            wave_batch = wave_batch.to(device)
            sr_batch = sr_batch.to(device)
            lbl_batch = lbl_batch.to(device)

            inputs = wave_to_wav2vec2_input(wave_batch, sr_batch, target_sr=16000).to(device)
            logits = model(inputs)

            loss_val = criterion(logits, lbl_batch.long())

            bs = wave_batch.size(0)
            total_loss += loss_val.item() * bs
            total_samples += bs

            preds = torch.argmax(logits, dim=1)
            correct += (preds == lbl_batch.long()).sum().item()

            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(lbl_batch.cpu().numpy().tolist())

    if total_samples > 0:
        avg_loss = total_loss / total_samples
        accuracy = correct / total_samples
    else:
        avg_loss = 0.0
        accuracy = 0.0

    # confusion matrix etc.
    recall_1 = 0.0
    if print_confusion and len(all_labels) > 0:
        target_names = ["not depressed", "depressed"]
        cr_str = classification_report(all_labels, all_preds, target_names=target_names, digits=6)
        print("=== Classification Report (segment-level) ===")
        print(cr_str)

        cr_dict = classification_report(all_labels, all_preds, target_names=target_names, digits=6, output_dict=True)
        if "depressed" in cr_dict:
            recall_1 = cr_dict["depressed"]["recall"]

        cm = confusion_matrix(all_labels, all_preds, labels=[0,1])
        print("Confusion Matrix ( [0,1] ) (segment-level):\n", cm)

    return avg_loss, accuracy, recall_1, all_preds, all_labels

# -----------------------------------------------------
# 6) Participant-level Evaluation
# -----------------------------------------------------
def evaluate_one_epoch_per_participant(
    model, data_loader, criterion, device, print_confusion=True
):
    """
    Evaluate at the *participant level* by aggregating all segment logits
    belonging to the same participant.

    Steps:
      1) For each segment in the dataloader, compute logits (like normal).
      2) Store [logits, label, participant_id].
      3) After the loop, group by participant_id, average the logits,
         derive final predictions, and compare to the label.

    Note: This expects that *all segments from one participant* have the
    same ground-truth label. The code uses the label from the first segment
    we see for that participant, but you can adapt if your data differ.
    """
    model.eval()
    total_loss = 0.0
    total_segments = 0
    # We'll store the raw segment-level outputs here:
    storage = []  # list of (participant_id, label, logits)

    with torch.no_grad():
        for batch in data_loader:
            if batch is None:
                continue

            sample_indices, wave_batch, sr_batch, lbl_batch, pid_list = batch

            wave_batch = wave_batch.to(device)
            sr_batch = sr_batch.to(device)
            lbl_batch = lbl_batch.to(device)

            inputs = wave_to_wav2vec2_input(wave_batch, sr_batch, target_sr=16000).to(device)
            logits = model(inputs)  # (B, num_classes)

            # Just accumulate the raw data
            bs = wave_batch.size(0)
            total_segments += bs
            # We'll sum up the cross-entropy loss for info only
            # (not exactly meaningful at participant-level directly)
            loss_val = criterion(logits, lbl_batch.long())
            total_loss += loss_val.item() * bs

            for i in range(bs):
                pid = pid_list[i]
                label_i = lbl_batch[i].item()
                logit_i = logits[i].cpu().numpy()
                storage.append((pid, label_i, logit_i))

    # Now do the participant-level aggregation
    participant_dict = {}
    for (pid, label_i, logit_i) in storage:
        if pid not in participant_dict:
            participant_dict[pid] = {
                "logits_list": [],
                "label": label_i
            }
        participant_dict[pid]["logits_list"].append(logit_i)

    all_participant_preds = []
    all_participant_labels = []

    for pid, val in participant_dict.items():
        label = val["label"]
        logits_list = val["logits_list"]  # list of np arrays (num_classes,)

        # Average
        mean_logits = np.mean(logits_list, axis=0)  # shape (num_classes,)
        pred_class = np.argmax(mean_logits)

        all_participant_preds.append(pred_class)
        all_participant_labels.append(label)

    # Now compute participant-level metrics
    n_participants = len(all_participant_labels)
    if n_participants > 0:
        avg_loss_seg = total_loss / total_segments  # segment-level average
        # If we want "average loss per participant," we could do
        # total_loss / n_participants, but that doesn't always
        # match your pipeline. It's up to you.
        # For now, let's keep segment-level average for reference.

        correct_p = sum(
            1 for i in range(n_participants)
            if all_participant_preds[i] == all_participant_labels[i]
        )
        participant_acc = correct_p / n_participants
    else:
        avg_loss_seg = 0.0
        participant_acc = 0.0

    recall_1 = 0.0
    if print_confusion and len(all_participant_labels) > 0:
        target_names = ["not depressed", "depressed"]
        cr_str = classification_report(all_participant_labels, all_participant_preds,
                                       target_names=target_names, digits=6)
        print("=== Classification Report (participant-level) ===")
        print(cr_str)

        cr_dict = classification_report(
            all_participant_labels,
            all_participant_preds,
            target_names=target_names,
            digits=6,
            output_dict=True
        )
        if "depressed" in cr_dict:
            recall_1 = cr_dict["depressed"]["recall"]

        cm = confusion_matrix(all_participant_labels, all_participant_preds, labels=[0,1])
        print("Confusion Matrix ( [0,1] ) (participant-level):\n", cm)

    return avg_loss_seg, participant_acc, recall_1

# -----------------------------------------------------
# 7) Main
# -----------------------------------------------------
if __name__ == "__main__":
    # === Settings ===
    CSV_PATH = "/content/drive/MyDrive/DAIC-WOZ/dataset_info_all_text.csv"
    #
    # IMPORTANT:
    # Make sure your CSV has columns:
    #  - "audio_path"
    #  - "label" (0 or 1)
    #  - "participant_id" (string or int).
    #
    # Example minimal CSV structure:
    # participant_id, audio_path, label
    # P001, /path/to/P001_seg1.wav, 0
    # P001, /path/to/P001_seg2.wav, 0
    # P002, /path/to/P002_seg1.wav, 1
    # ... etc.

    BATCH_SIZE = 2
    accumulation_steps = 16

    # Number of epochs
    EPOCHS = 20

    LEARNING_RATE = 1e-5

    # Class weights [1.0, 2.5] (example)
    class_weights = torch.FloatTensor([1.0, 2.5])

    # Load CSV
    df_all = pd.read_csv(CSV_PATH)
    print(f"Loaded CSV total rows: {len(df_all)}")

    # -------------------------------
    # train/dev/test split (6:2:2)
    # -------------------------------
    df_train, df_temp = train_test_split(df_all, test_size=0.4, random_state=SEED, shuffle=True)
    df_dev, df_test = train_test_split(df_temp, test_size=0.5, random_state=SEED, shuffle=True)
    print(f"Train: {len(df_train)}, Dev: {len(df_dev)}, Test: {len(df_test)}")

    # Create Dataset
    train_dataset = AudioDataset(df_train, target_sr=16000, verbose=False)
    dev_dataset   = AudioDataset(df_dev,   target_sr=16000, verbose=False)
    test_dataset  = AudioDataset(df_test,  target_sr=16000, verbose=False)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn_audio)
    dev_loader   = DataLoader(dev_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_audio)
    test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_audio)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device={device}")

    # Model
    model = AudioOnlyModel(
        audio_model_name="facebook/wav2vec2-large-960h",
        freeze_feature_extractor=True,
        cross_hidden_dim=384,
        num_classes=2
    ).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    global_step = 0

    results = []

    for epoch in range(EPOCHS):
        model.train()
        total_loss_train = 0.0
        total_samples_train = 0

        optimizer.zero_grad()

        # -------------------------
        # Training loop (segment-level)
        # -------------------------
        for step, batch in enumerate(train_loader):
            if batch is None:
                continue

            sample_indices, wave_batch, sr_batch, lbl_batch, pid_list = batch

            wave_batch = wave_batch.to(device)
            sr_batch = sr_batch.to(device)
            lbl_batch = lbl_batch.to(device)

            inputs = wave_to_wav2vec2_input(wave_batch, sr_batch, 16000).to(device)
            logits = model(inputs)
            loss = criterion(logits, lbl_batch.long())

            bs = wave_batch.size(0)
            total_loss_train += loss.item() * bs
            total_samples_train += bs

            # Gradient accumulation
            loss = loss / accumulation_steps
            loss.backward()

            if (step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            global_step += 1
            # Simple log every 50 steps
            if (global_step % 50) == 0:
                current_lr = optimizer.param_groups[0]["lr"]
                print(f"[Train] epoch={epoch+1}, global_step={global_step}, loss={loss.item():.4f}, lr={current_lr}")

        # Average training loss per epoch
        if total_samples_train > 0:
            epoch_train_loss = total_loss_train / total_samples_train
        else:
            epoch_train_loss = 0.0
        print(f"[Epoch {epoch+1}] train loss = {epoch_train_loss:.4f}")

        # -------------------------
        # Dev Evaluation (segment-level)
        # -------------------------
        print(f"=== [Epoch {epoch+1}] Dev Evaluation (Segment-Level) ===")
        dev_loss_eval, dev_acc_eval, dev_recall_1, _, _ = evaluate_one_epoch(
            model, dev_loader, criterion, device, print_confusion=False
        )
        print(f"Dev (segment-level): loss={dev_loss_eval:.4f}, acc={dev_acc_eval:.4f}, recall(depressed)={dev_recall_1:.4f}")

        # Dev Evaluation (participant-level)
        print(f"=== [Epoch {epoch+1}] Dev Evaluation (Participant-Level) ===")
        dev_loss_part, dev_acc_part, dev_recall_part = evaluate_one_epoch_per_participant(
            model, dev_loader, criterion, device, print_confusion=False
        )
        print(f"Dev (participant-level): [segment-based loss={dev_loss_part:.4f}], acc={dev_acc_part:.4f}, recall(depressed)={dev_recall_part:.4f}")

        # -------------------------
        # Test Evaluation (segment-level)
        # -------------------------
        print(f"=== [Epoch {epoch+1}] Test Evaluation (Segment-Level) ===")
        test_loss_eval, test_acc_eval, test_recall_1, _, _ = evaluate_one_epoch(
            model, test_loader, criterion, device, print_confusion=True
        )
        print(f"Test (segment-level): loss={test_loss_eval:.4f}, acc={test_acc_eval:.4f}, recall(depressed)={test_recall_1:.4f}")

        # Test Evaluation (participant-level)
        print(f"=== [Epoch {epoch+1}] Test Evaluation (Participant-Level) ===")
        test_loss_part, test_acc_part, test_recall_part = evaluate_one_epoch_per_participant(
            model, test_loader, criterion, device, print_confusion=True
        )
        print(f"Test (participant-level): [segment-based loss={test_loss_part:.4f}], acc={test_acc_part:.4f}, recall(depressed)={test_recall_part:.4f}")

        # Save the results (save any metrics as needed)
        results.append({
            'epoch': epoch+1,
            'train_loss': epoch_train_loss,

            'dev_loss_seg': dev_loss_eval,
            'dev_acc_seg': dev_acc_eval,
            'dev_recall_seg': dev_recall_1,

            'dev_loss_part': dev_loss_part,  # note: this is segment-based average loss
            'dev_acc_part': dev_acc_part,
            'dev_recall_part': dev_recall_part,

            'test_loss_seg': test_loss_eval,
            'test_acc_seg': test_acc_eval,
            'test_recall_seg': test_recall_1,

            'test_loss_part': test_loss_part,
            'test_acc_part': test_acc_part,
            'test_recall_part': test_recall_part
        })

        # Output report every 5 epochs
        if (epoch+1) % 5 == 0:
            df_report = pd.DataFrame(results)
            save_path = f"/content/drive/MyDrive/training_report_epoch_{epoch+1}.csv"
            df_report.to_csv(save_path, index=False)
            print(f"[Saved] {save_path}")

        print(f"--- End of epoch {epoch+1} ---\n")
        gc.collect()
        torch.cuda.empty_cache()

    print("Training completed.")

    # Save the final report
    df_report_final = pd.DataFrame(results)
    save_path_final = f"/content/drive/MyDrive/training_report_final.csv"
    df_report_final.to_csv(save_path_final, index=False)
    print(f"[Saved Final Report] {save_path_final}")

# 2. However, separating participants for training and testing, with no other changes, causes the model's accuracy to drop significantly to 60%.

In [None]:
!pip install pydub
from google.colab import drive
drive.mount("/content/drive")

!pip install torch torchaudio transformers librosa scikit-learn

import os
import gc
import random
import pandas as pd
import numpy as np
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import Wav2Vec2Config, Wav2Vec2Model, Wav2Vec2Processor
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import itertools
from collections import defaultdict

# -----------------------------------------------------
# 0) Set seed for reproducibility
# -----------------------------------------------------
SEED = 103
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# -----------------------------------------------------
# 1) AudioDataset
#    (Return None and skip if reading audio fails)
# -----------------------------------------------------
class AudioDataset(Dataset):
    def __init__(self, df: pd.DataFrame, target_sr=16000, verbose=False):
        self.df = df.reset_index(drop=True)
        self.target_sr = target_sr
        self.verbose = verbose

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        audio_path = row["audio_path"]
        label = float(row["label"])  # 0 or 1
        participant_id = row["participant_id"]

        if self.verbose:
            print(f"[Dataset] idx={idx}, audio_path={audio_path}, label={label}, pid={participant_id}")

        try:
            # Load audio
            waveform, sr = torchaudio.load(audio_path)

            # If the number of channels is 2 or more, convert to mono (use only the first channel)
            if waveform.shape[0] > 1:
                waveform = waveform[0, :].unsqueeze(0)

            if sr != self.target_sr:
                waveform = torchaudio.functional.resample(waveform, sr, self.target_sr)
                sr = self.target_sr

            return (idx, waveform, sr, label, participant_id)

        except Exception as e:
            # Return None if loading fails
            print(f"[WARN] Skipped sample idx={idx}, audio='{audio_path}' due to error: {e}")
            return None


# -----------------------------------------------------
# 2) collate_fn: skip None
# -----------------------------------------------------
def collate_fn_audio(batch):
    # Filter out None values
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  # Return None to be skipped in the training loop

    sample_indices = []
    waveforms = []
    srs = []
    labels = []
    participant_ids = []

    for (idx, wf, sr, lb, pid) in batch:
        sample_indices.append(idx)
        waveforms.append(wf)
        srs.append(torch.tensor(sr))
        labels.append(torch.tensor(lb))
        participant_ids.append(pid)

    # waveforms: List of tensors of shape [1, T]
    wave_1d_list = [item.squeeze(0) for item in waveforms]  # => Convert to a list of 1D tensors [T]
    padded_wav_2d = pad_sequence(wave_1d_list, batch_first=True, padding_value=0.0)  # => Pad to shape (B, T)
    padded_wav_3d = padded_wav_2d.unsqueeze(1)  # => Add channel dimension to get (B, 1, T)

    labels_tensor = torch.stack(labels, dim=0)
    sr_tensor = torch.stack(srs, dim=0)

    return sample_indices, padded_wav_3d, sr_tensor, labels_tensor, participant_ids


# -----------------------------------------------------
# 3) Model Definition (Wav2Vec2 + mean pooling + classifier)
# -----------------------------------------------------
class Wav2Vec2AudioEncoder(nn.Module):
    def __init__(self,
                 model_name="facebook/wav2vec2-large-960h",
                 output_hidden_states=False,
                 freeze_feature_extractor=True):
        super().__init__()
        self.config = Wav2Vec2Config.from_pretrained(model_name)
        self.model = Wav2Vec2Model.from_pretrained(model_name)
        if freeze_feature_extractor:
            for param in self.model.feature_extractor.parameters():
                param.requires_grad = False
        self.model.config.output_hidden_states = output_hidden_states
        self.hidden_size = self.config.hidden_size

    def forward(self, input_values):
        outputs = self.model(input_values=input_values)
        return outputs.last_hidden_state  # (B, T', hidden_size)


class AudioOnlyModel(nn.Module):
    """
      1) Wav2Vec2 for feature extraction
      2) Project to an intermediate dimension + mean pooling
      3) Classification
    """
    def __init__(
        self,
        audio_model_name="facebook/wav2vec2-large-960h",
        freeze_feature_extractor=True,
        cross_hidden_dim=384,  # arbitrary
        num_classes=2
    ):
        super().__init__()
        # Audio encoder
        self.audio_encoder = Wav2Vec2AudioEncoder(
            model_name=audio_model_name,
            freeze_feature_extractor=freeze_feature_extractor
        )
        audio_out_dim = self.audio_encoder.hidden_size  # wav2vec2-large-960h => 1024

        # Project to cross_hidden_dim
        self.proj_audio = nn.Linear(audio_out_dim, cross_hidden_dim)
        self.norm_audio = nn.LayerNorm(cross_hidden_dim)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(cross_hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_values):
        # Audio
        audio_embed = self.audio_encoder(input_values)  # (B, T_a, hidden_size)

        # Linear projection
        A = self.proj_audio(audio_embed)  # (B, T_a, cross_hidden_dim)

        # Mean pooling
        A_pool = A.mean(dim=1)  # (B, cross_hidden_dim)
        A_pool = self.norm_audio(A_pool)

        # Classification
        logits = self.classifier(A_pool)
        return logits


# -----------------------------------------------------
# 4) wave_to_wav2vec2_input
# -----------------------------------------------------
processor = Wav2Vec2Processor.from_pretrained(
    "facebook/wav2vec2-large-960h",
    return_attention_mask=False
)

def wave_to_wav2vec2_input(wave_batch, sr_batch, target_sr=16000):
    """
    wave_batch : (B, 1, T)
    sr_batch   : (B,)
    """
    B = wave_batch.size(0)
    input_values_list = []
    for i in range(B):
        wave_i = wave_batch[i, 0, :].cpu().numpy()
        sr_i = sr_batch[i].item()
        if sr_i != target_sr:
            wave_i = torchaudio.functional.resample(
                torch.from_numpy(wave_i), sr_i, target_sr
            ).numpy()
        out = processor(
            wave_i,
            sampling_rate=target_sr,
            return_tensors="pt",
            padding=True,
            truncation=False,
            return_attention_mask=False
        )
        input_values_list.append(out.input_values[0])

    # Pad if lengths differ
    padded_input_values = pad_sequence(
        input_values_list,
        batch_first=True,
        padding_value=0.0
    )
    return padded_input_values  # (B, T_wav2vec)


# -----------------------------------------------------
# 5-A) Segment-level Evaluation
# -----------------------------------------------------
def evaluate_one_epoch_segment_level(model, data_loader, criterion, device, print_confusion=True):
    """
    Evaluate at the segment level (traditional approach).
    """
    model.eval()
    total_loss = 0.0
    total_samples = 0
    correct = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            if batch is None:  # Skip if everything is None
                continue

            sample_indices, wave_batch, sr_batch, lbl_batch, participant_ids = batch

            wave_batch = wave_batch.to(device)
            sr_batch   = sr_batch.to(device)
            lbl_batch  = lbl_batch.to(device)

            inputs = wave_to_wav2vec2_input(wave_batch, sr_batch, target_sr=16000).to(device)
            logits = model(inputs)

            loss_val = criterion(logits, lbl_batch.long())

            bs = wave_batch.size(0)
            total_loss += loss_val.item() * bs
            total_samples += bs

            preds = torch.argmax(logits, dim=1)
            correct += (preds == lbl_batch.long()).sum().item()

            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(lbl_batch.cpu().numpy().tolist())

    if total_samples > 0:
        avg_loss = total_loss / total_samples
        accuracy = correct / total_samples
    else:
        avg_loss = 0.0
        accuracy = 0.0

    recall_1 = 0.0
    if print_confusion and len(all_labels) > 0:
        target_names = ["not depressed", "depressed"]
        cr_str = classification_report(all_labels, all_preds, target_names=target_names, digits=6)
        print("=== [Segment-level] Classification Report ===")
        print(cr_str)

        cr_dict = classification_report(all_labels, all_preds, target_names=target_names,
                                        digits=6, output_dict=True)
        if "depressed" in cr_dict:
            recall_1 = cr_dict["depressed"]["recall"]

        cm = confusion_matrix(all_labels, all_preds, labels=[0,1])
        print("[Segment-level] Confusion Matrix ( [0,1] ):\n", cm)

    return avg_loss, accuracy, recall_1


# -----------------------------------------------------
# 5-B) Participant-level Evaluation
# -----------------------------------------------------
def evaluate_one_epoch_participant_level(model, data_loader, criterion, device, print_confusion=True):
    """
    Average the logits of all segments for each participant and make a final prediction based on the mean.
    """
    model.eval()
    # While we could calculate segment-level loss here,
    # this function primarily focuses on averaging loss across all segments.
    total_loss = 0.0
    total_segment_count = 0

    # participant_id => [list of logits], {label}
    participant_logits_map = defaultdict(list)
    participant_label_map = {}

    with torch.no_grad():
        for batch in data_loader:
            if batch is None:
                continue

            sample_indices, wave_batch, sr_batch, lbl_batch, participant_ids = batch

            wave_batch = wave_batch.to(device)
            sr_batch   = sr_batch.to(device)
            lbl_batch  = lbl_batch.to(device)

            inputs = wave_to_wav2vec2_input(wave_batch, sr_batch, target_sr=16000).to(device)
            logits = model(inputs)  # (B, num_classes)

            # Segment-level loss (for reference)
            loss_val = criterion(logits, lbl_batch.long())
            bs = wave_batch.size(0)
            total_loss += loss_val.item() * bs
            total_segment_count += bs

            # Store logits for each participant_id
            for i in range(bs):
                pid = participant_ids[i]
                participant_logits_map[pid].append(logits[i].cpu().numpy())  # shape=(num_classes,)
                # The label should be the same for all segments of a participant, so storing it once is sufficient (overwriting is fine).
                participant_label_map[pid] = int(lbl_batch[i].item())

    if total_segment_count > 0:
        avg_loss = total_loss / total_segment_count
    else:
        avg_loss = 0.0

    # Make final predictions at the participant level
    all_participant_preds = []
    all_participant_labels = []
    for pid, list_of_logits in participant_logits_map.items():
        # e.g., list_of_logits has shape [num_segments, 2]
        avg_logit = np.mean(list_of_logits, axis=0)  # => (2,)
        pred_label = np.argmax(avg_logit)
        true_label = participant_label_map[pid]
        all_participant_preds.append(pred_label)
        all_participant_labels.append(true_label)

    # Evaluation metrics
    participant_count = len(all_participant_labels)
    if participant_count > 0:
        correct_count = sum(
            1 for p, t in zip(all_participant_preds, all_participant_labels) if p == t
        )
        participant_level_acc = correct_count / participant_count
    else:
        participant_level_acc = 0.0

    recall_1 = 0.0
    if print_confusion and participant_count > 0:
        target_names = ["not depressed", "depressed"]
        cr_str = classification_report(all_participant_labels, all_participant_preds,
                                       target_names=target_names, digits=6)
        print("=== [Participant-level] Classification Report ===")
        print(cr_str)

        cr_dict = classification_report(all_participant_labels, all_participant_preds,
                                        target_names=target_names,
                                        digits=6, output_dict=True)
        if "depressed" in cr_dict:
            recall_1 = cr_dict["depressed"]["recall"]

        cm = confusion_matrix(all_participant_labels, all_participant_preds, labels=[0,1])
        print("[Participant-level] Confusion Matrix ( [0,1] ):\n", cm)

    return avg_loss, participant_level_acc, recall_1


# -----------------------------------------------------
# 6) Main
# -----------------------------------------------------
if __name__ == "__main__":
    # === Settings ===
    CSV_PATH = "/content/drive/MyDrive/DAIC-WOZ/dataset_info_all_text.csv"

    BATCH_SIZE = 2
    accumulation_steps = 16

    # Number of epochs
    EPOCHS = 15

    LEARNING_RATE = 1e-5

    # Class weights (example: [1.0, 2.5] for "not depressed", "depressed")
    class_weights = torch.FloatTensor([1.0, 2.5])

    # Load CSV
    df_all = pd.read_csv(CSV_PATH)
    print(f"Loaded CSV total rows: {len(df_all)}")

    # ===================================================
    # Split into train/dev/test by participant_id
    #   - 60% train, 20% dev, 20% test
    # ===================================================
    unique_ids = df_all["participant_id"].unique()

    # Split1: train (60%) vs. remain (40%)
    train_ids, remain_ids = train_test_split(
        unique_ids, test_size=0.40, random_state=SEED, shuffle=True
    )
    # Split2: dev (20%) vs. test (20%) => half each of the remaining 40%
    dev_ids, test_ids = train_test_split(
        remain_ids, test_size=0.5, random_state=SEED, shuffle=True
    )

    print(f"Unique IDs total: {len(unique_ids)}")
    print(f"Train IDs: {len(train_ids)}, Dev IDs: {len(dev_ids)}, Test IDs: {len(test_ids)}")

    # Filter rows
    train_df_raw = df_all[df_all["participant_id"].isin(train_ids)].copy()
    dev_df_raw   = df_all[df_all["participant_id"].isin(dev_ids)].copy()
    test_df_raw  = df_all[df_all["participant_id"].isin(test_ids)].copy()

    print(f"Train data rows: {len(train_df_raw)}, Dev rows: {len(dev_df_raw)}, Test rows: {len(test_df_raw)}")

    # Create Dataset
    train_dataset = AudioDataset(train_df_raw, target_sr=16000, verbose=False)
    dev_dataset   = AudioDataset(dev_df_raw,   target_sr=16000, verbose=False)
    test_dataset  = AudioDataset(test_df_raw,  target_sr=16000, verbose=False)

    # DataLoader
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn_audio)
    dev_loader   = DataLoader(dev_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_audio)
    test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_audio)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device={device}")

    # Model
    model = AudioOnlyModel(
        audio_model_name="facebook/wav2vec2-large-960h",
        freeze_feature_extractor=True,
        cross_hidden_dim=384,
        num_classes=2
    ).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    global_step = 0

    # For logging metrics and saving
    epoch_stats = []

    for epoch in range(EPOCHS):
        # -------------------- Training --------------------
        model.train()
        total_loss_train = 0.0
        total_samples_train = 0

        optimizer.zero_grad()

        for step, batch in enumerate(train_loader):
            if batch is None:
                continue

            sample_indices, wave_batch, sr_batch, lbl_batch, participant_ids = batch

            wave_batch = wave_batch.to(device)
            sr_batch   = sr_batch.to(device)
            lbl_batch  = lbl_batch.to(device)

            inputs = wave_to_wav2vec2_input(wave_batch, sr_batch, 16000).to(device)
            logits = model(inputs)
            loss = criterion(logits, lbl_batch.long())

            bs = wave_batch.size(0)
            total_loss_train += loss.item() * bs
            total_samples_train += bs

            # Gradient accumulation
            loss = loss / accumulation_steps
            loss.backward()

            if (step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            global_step += 1

            # Simple log every 50 steps
            if (global_step % 50) == 0:
                current_lr = optimizer.param_groups[0]["lr"]
                print(f"[Train] epoch={epoch+1}, global_step={global_step}, loss={loss.item():.4f}, lr={current_lr}")

        # Average train loss
        if total_samples_train > 0:
            epoch_train_loss = total_loss_train / total_samples_train
        else:
            epoch_train_loss = 0.0

        print(f"[Epoch {epoch+1}] train loss = {epoch_train_loss:.4f}")

        # -------------------- Dev Evaluation (Segment-level) --------------------
        print(f"\n=== [Epoch {epoch+1}] Dev Evaluation (Segment-level) ===")
        dev_loss_seg, dev_acc_seg, dev_recall_seg = evaluate_one_epoch_segment_level(
            model, dev_loader, criterion, device, print_confusion=False
        )
        print(f"Dev (segment-level): loss={dev_loss_seg:.4f}, acc={dev_acc_seg:.4f}, recall(depressed)={dev_recall_seg:.4f}")

        # -------------------- Dev Evaluation (Participant-level) --------------------
        print(f"\n=== [Epoch {epoch+1}] Dev Evaluation (Participant-level) ===")
        dev_loss_part, dev_acc_part, dev_recall_part = evaluate_one_epoch_participant_level(
            model, dev_loader, criterion, device, print_confusion=False
        )
        print(f"Dev (participant-level): loss={dev_loss_part:.4f}, acc={dev_acc_part:.4f}, recall(depressed)={dev_recall_part:.4f}")

        # -------------------- Test Evaluation (Segment-level) -------------------
        print(f"\n=== [Epoch {epoch+1}] Test Evaluation (Segment-level) ===")
        test_loss_seg, test_acc_seg, test_recall_seg = evaluate_one_epoch_segment_level(
            model, test_loader, criterion, device, print_confusion=True
        )
        print(f"Test (segment-level): loss={test_loss_seg:.4f}, acc={test_acc_seg:.4f}, recall(depressed)={test_recall_seg:.4f}")

        # -------------------- Test Evaluation (Participant-level) -------------------
        print(f"\n=== [Epoch {epoch+1}] Test Evaluation (Participant-level) ===")
        test_loss_part, test_acc_part, test_recall_part = evaluate_one_epoch_participant_level(
            model, test_loader, criterion, device, print_confusion=True
        )
        print(f"Test (participant-level): loss={test_loss_part:.4f}, acc={test_acc_part:.4f}, recall(depressed)={test_recall_part:.4f}")

        # -------------------- Save metrics in memory -------------------
        epoch_stats.append({
            "epoch": epoch+1,
            "train_loss": epoch_train_loss,
            "dev_loss_segment_level": dev_loss_seg,
            "dev_acc_segment_level": dev_acc_seg,
            "dev_recall_depressed_segment_level": dev_recall_seg,
            "dev_loss_participant_level": dev_loss_part,
            "dev_acc_participant_level": dev_acc_part,
            "dev_recall_depressed_participant_level": dev_recall_part,
            "test_loss_segment_level": test_loss_seg,
            "test_acc_segment_level": test_acc_seg,
            "test_recall_depressed_segment_level": test_recall_seg,
            "test_loss_participant_level": test_loss_part,
            "test_acc_participant_level": test_acc_part,
            "test_recall_depressed_participant_level": test_recall_part
        })

        # Optional: Save CSV report periodically (e.g., every 5 epochs)
        if (epoch + 1) % 5 == 0:
            df_report = pd.DataFrame(epoch_stats)
            save_path = f"/content/drive/MyDrive/wav2vec_epoch_report_up_to_{epoch+1}.csv"
            df_report.to_csv(save_path, index=False)
            print(f"[Info] Saved epoch report at: {save_path}")

        print(f"--- End of epoch {epoch+1} ---\n")
        gc.collect()
        torch.cuda.empty_cache()

    print("Training completed.")

    # Optionally, save final report as well
    df_final = pd.DataFrame(epoch_stats)
    df_final.to_csv("/content/drive/MyDrive/wav2vec_final_epoch_report.csv", index=False)
    print("[Info] Saved final report at /content/drive/MyDrive/wav2vec_final_epoch_report.csv")