In [2]:
pip install kaggle numpy pandas librosa torchaudio torch torchvision scikit-learn tqdm matplotlib seaborn

Note: you may need to restart the kernel to use updated packages.


# Directories clean-up

In [3]:
import os
import shutil

# Directory to clean
output_dir = "/kaggle/working/preprocessed_t2_test_chunks"

# Check if the directory exists and delete its contents
if os.path.exists(output_dir):
    print(f"Clearing existing files in {output_dir}...")
    shutil.rmtree(output_dir)
    print(f"Deleted {output_dir} and all its contents.")
else:
    print(f"Directory {output_dir} does not exist yet.")

# Recreate the directory
os.makedirs(output_dir, exist_ok=True)
print(f"Recreated empty directory: {output_dir}")

# Verification
files = os.listdir(output_dir)
print(f"Files in {output_dir} after clearing: {files}")

Directory /kaggle/working/preprocessed_t2_test_chunks does not exist yet.
Recreated empty directory: /kaggle/working/preprocessed_t2_test_chunks
Files in /kaggle/working/preprocessed_t2_test_chunks after clearing: []


# Data Preprocessing: MFCC, LFCC, Chroma-STFT

In [6]:
import os
import numpy as np
import torch
import torchaudio
from torchaudio.transforms import MFCC, LFCC, Spectrogram
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, GPU: {torch.cuda.get_device_name(0)}")

# Paths
DATASET_PATHS = {
    "fake_or_real": "/kaggle/input/the-fake-or-real-dataset/",
    "wild_deepfake": "/kaggle/input/in-the-wild-audio-deepfake/"
}

# Function to validate if a .wav file can be loaded
def validate_wav_file(file_path):
    try:
        waveform, sr = torchaudio.load(file_path)
        if waveform.size(1) == 0:
            print(f"Skipping empty file: {file_path}")
            return False
        return True
    except Exception as e:
        print(f"Skipping invalid file: {file_path}, Error: {e}")
        return False

# Function to collect audio files and assign labels
def get_audio_files_and_labels():
    audio_files = []
    labels = []
    counts = {"fake": 0, "real": 0}

    # Process In-the-Wild dataset
    wild_path = DATASET_PATHS["wild_deepfake"]
    wild_real_path = os.path.join(wild_path, "release_in_the_wild", "real")
    wild_fake_path = os.path.join(wild_path, "release_in_the_wild", "fake")

    for file in os.listdir(wild_real_path):
        if file.endswith(".wav"):
            file_path = os.path.join(wild_real_path, file)
            if validate_wav_file(file_path):
                audio_files.append(file_path)
                labels.append(0)  # Real
                counts["real"] += 1

    for file in os.listdir(wild_fake_path):
        if file.endswith(".wav"):
            file_path = os.path.join(wild_fake_path, file)
            if validate_wav_file(file_path):
                audio_files.append(file_path)
                labels.append(1)  # Fake
                counts["fake"] += 1

    print(f"In-the-Wild Counts: Real: {counts['real']}, Fake: {counts['fake']}")

    # Process Fake-or-Real dataset
    for_counts = {"fake": 0, "real": 0}
    for_base_path = DATASET_PATHS["fake_or_real"]
    for_folders = ["for-2sec", "for-norm", "for-original", "for-rerec"]
    for_subfolders = {
        "for-2sec": "for-2seconds",
        "for-norm": "for-norm",
        "for-original": "for-original",
        "for-rerec": "for-rerecorded"
    }
    splits = ["training", "validation", "testing"]

    for folder in for_folders:
        subfolder = for_subfolders[folder]
        for split in splits:
            real_path = os.path.join(for_base_path, folder, subfolder, split, "real")
            if os.path.exists(real_path):
                for file in os.listdir(real_path):
                    if file.endswith(".wav"):
                        file_path = os.path.join(real_path, file)
                        if validate_wav_file(file_path):
                            audio_files.append(file_path)
                            labels.append(0)  # Real
                            for_counts["real"] += 1

            fake_path = os.path.join(for_base_path, folder, subfolder, split, "fake")
            if os.path.exists(fake_path):
                for file in os.listdir(fake_path):
                    if file.endswith(".wav"):
                        file_path = os.path.join(fake_path, file)
                        if validate_wav_file(file_path):
                            audio_files.append(file_path)
                            labels.append(1)  # Fake
                            for_counts["fake"] += 1

    print(f"Fake-or-Real Counts: Real: {for_counts['real']}, Fake: {for_counts['fake']}")
    total_counts = {"real": counts['real'] + for_counts['real'], "fake": counts['fake'] + for_counts['fake']}
    print(f"Total Counts: Real: {total_counts['real']}, Fake: {total_counts['fake']}")
    return audio_files, labels

# Collect files and labels
all_files, all_labels = get_audio_files_and_labels()
print(f"\nTotal files before processing: {len(all_files)}")
print(f"Label Distribution before processing: Real (0): {sum(1 for label in all_labels if label == 0)} ({sum(1 for label in all_labels if label == 0)/len(all_files)*100:.2f}%), Fake (1): {sum(1 for label in all_labels if label == 1)} ({sum(1 for label in all_labels if label == 1)/len(all_files)*100:.2f}%)")

# Custom Chroma-STFT implementation using torchaudio
def compute_chroma_stft(waveforms, sample_rate=16000, n_fft=2048, hop_length=512, n_chroma=12):
    # Compute spectrogram
    spectrogram_transform = Spectrogram(
        n_fft=n_fft,
        hop_length=hop_length,
        power=2.0
    ).to(device)
    spec = spectrogram_transform(waveforms)  # (batch, channels, freq, time)
    
    # Squeeze the channel dimension (since mono audio)
    spec = spec.squeeze(1)  # (batch, freq, time)
    
    # Create chroma filter bank
    freqs = torch.linspace(0, sample_rate / 2, steps=spec.shape[1]).to(device)  # Frequency bins
    chroma_freqs = torch.tensor([31.25 * (2 ** (i / 12)) for i in range(12 * 4)], device=device)  # MIDI 21 to 68 (A0 to C5)
    chroma_bins = torch.zeros((n_chroma, spec.shape[1]), device=device)
    
    for i in range(n_chroma):
        center = chroma_freqs[i::12]  # All octaves of this note
        for cf in center:
            mask = (freqs >= cf / 1.06) & (freqs <= cf * 1.06)  # Approximate triangular filter
            chroma_bins[i] += mask.float()
    
    # Normalize filter bank
    chroma_bins /= chroma_bins.sum(dim=1, keepdim=True).clamp(min=1e-10)
    
    # Apply filter bank to spectrogram
    chroma = torch.einsum('cf,bft->bct', chroma_bins, spec)  # (batch, n_chroma, time)
    return chroma

# Feature Extraction
def extract_features_batch(file_paths, labels, batch_size=64, max_length=16000):
    mfcc_results = []
    lfcc_results = []
    chroma_results = []
    valid_labels = []
    
    # Define feature extractors
    mfcc_transform = MFCC(
        sample_rate=16000,
        n_mfcc=40,
        melkwargs={"n_fft": 2048, "hop_length": 512, "n_mels": 128}
    ).to(device)
    
    lfcc_transform = LFCC(
        sample_rate=16000,
        n_lfcc=40,
        f_min=0,
        f_max=8000,
        n_filter=128,
        speckwargs={"n_fft": 2048, "hop_length": 512}
    ).to(device)
    
    total_batches = (len(file_paths) + batch_size - 1) // batch_size
    log_interval = max(1, total_batches // 100)  # Logging every 1% progress to check for errors
    
    for i in tqdm(range(0, len(file_paths), batch_size), desc="Processing Batches", total=total_batches):
        if i % (log_interval * batch_size) == 0:
            print(f"Processed {i // batch_size}/{total_batches} batches ({(i / len(file_paths)) * 100:.1f}%)")
            
        batch_files = file_paths[i:i + batch_size]
        batch_labels = labels[i:i + batch_size]
        waveforms = []
        valid_indices = []
        
        # Load and preprocess waveforms
        for idx, (file_path, label) in enumerate(zip(batch_files, batch_labels)):
            try:
                waveform, sr = torchaudio.load(file_path)
                # Convert to mono by averaging channels
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)  # (channels, samples) -> (1, samples)
                if sr != 16000:
                    waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
                
                if waveform.size(1) > max_length:
                    waveform = waveform[:, :max_length]
                elif waveform.size(1) < max_length:
                    pad_size = max_length - waveform.size(1)
                    waveform = torch.nn.functional.pad(waveform, (0, pad_size))
                
                waveforms.append(waveform)
                valid_indices.append(idx)
            except Exception as e:
                continue
        
        if not waveforms:
            continue
        
        try:
            waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True).to(device)
            
            # Extract MFCC
            mfccs = mfcc_transform(waveforms)  # (batch, n_mfcc, time)
            
            # Extract LFCC
            lfccs = lfcc_transform(waveforms)  # (batch, n_lfcc, time)
            
            # Extract Chroma-STFT
            chromas = compute_chroma_stft(
                waveforms,
                sample_rate=16000,
                n_fft=2048,
                hop_length=512,
                n_chroma=12
            )  # (batch, n_chroma, time)
            
            # Standardize time dimension (trim/pad to 32 frames for consistency)
            target_frames = 32
            for idx in valid_indices:
                # MFCC
                mfcc = mfccs[idx]  # (n_mfcc, time)
                if mfcc.shape[1] > target_frames:
                    mfcc = mfcc[:, :target_frames]
                elif mfcc.shape[1] < target_frames:
                    mfcc = torch.nn.functional.pad(mfcc, (0, target_frames - mfcc.shape[1]))
                mfcc_results.append(mfcc.cpu().numpy())
                
                # LFCC
                lfcc = lfccs[idx]  # (n_lfcc, time)
                if lfcc.shape[1] > target_frames:
                    lfcc = lfcc[:, :target_frames]
                elif lfcc.shape[1] < target_frames:
                    lfcc = torch.nn.functional.pad(lfcc, (0, target_frames - lfcc.shape[1]))
                lfcc_results.append(lfcc.cpu().numpy())
                
                # Chroma-STFT
                chroma = chromas[idx]  # (n_chroma, time)
                if chroma.shape[1] > target_frames:
                    chroma = chroma[:, :target_frames]
                elif chroma.shape[1] < target_frames:
                    chroma = torch.nn.functional.pad(chroma, (0, target_frames - chroma.shape[1]))
                chroma_results.append(chroma.cpu().numpy())
                
                valid_labels.append(batch_labels[idx])
        
        except Exception as e:
            print(f"Error processing batch {i // batch_size}: {e}")
            continue
    
    return mfcc_results, lfcc_results, chroma_results, valid_labels

# Process files in batches
mfcc_features, lfcc_features, chroma_features, y = extract_features_batch(all_files, all_labels, batch_size=64, max_length=16000)

# Convert to numpy arrays
mfcc_features = np.array(mfcc_features)  # (num_samples, n_mfcc, time)
lfcc_features = np.array(lfcc_features)  # (num_samples, n_lfcc, time)
chroma_features = np.array(chroma_features)  # (num_samples, n_chroma, time)
y = np.array(y)

# Validate alignment
if not (len(mfcc_features) == len(lfcc_features) == len(chroma_features) == len(y)):
    raise ValueError(f"Mismatch between features and labels: "
                     f"MFCC: {len(mfcc_features)}, LFCC: {len(lfcc_features)}, "
                     f"Chroma: {len(chroma_features)}, y: {len(y)}")

# Verify total samples
total_samples = len(y)
print(f"Total samples after feature extraction: {total_samples}")
if total_samples != 173128:
    print(f"Warning: Expected 173,128 samples, but got {total_samples} samples. Some files may have been skipped due to errors.")

# Save in chunks (268 chunks of 646 samples each)
num_chunks = 268
chunk_size = total_samples // num_chunks  # Should be 646
total_samples_used = num_chunks * chunk_size
print(f"Total samples: {total_samples}, Number of chunks: {num_chunks}, Chunk size: {chunk_size}")
print(f"Total samples used: {total_samples_used}")

# Adjustments if total_samples is not exactly divisible by chunk_size
if total_samples != total_samples_used:  # Shouldnt be, dont waste data
    print(f"Warning: {total_samples - total_samples_used} samples will be dropped due to chunking.")

output_dir = "/kaggle/working/preprocessed_chunks_268"
os.makedirs(output_dir, exist_ok=True)

# Clear existing chunks
for f in os.listdir(output_dir):
    os.remove(os.path.join(output_dir, f))

# Save the data in chunks
for i in range(0, total_samples_used, chunk_size):
    chunk_idx = i // chunk_size
    mfcc_chunk = mfcc_features[i:i + chunk_size]  # (chunk_size, n_mfcc, time)
    lfcc_chunk = lfcc_features[i:i + chunk_size]  # (chunk_size, n_lfcc, time)
    chroma_chunk = chroma_features[i:i + chunk_size]  # (chunk_size, n_chroma, time)
    y_chunk = y[i:i + chunk_size]  # (chunk_size,)
    
    if not (len(mfcc_chunk) == len(lfcc_chunk) == len(chroma_chunk) == len(y_chunk)):
        raise ValueError(f"Mismatch in chunk {chunk_idx}: "
                         f"MFCC: {len(mfcc_chunk)}, LFCC: {len(lfcc_chunk)}, "
                         f"Chroma: {len(chroma_chunk)}, y: {len(y_chunk)}")
    
    np.save(os.path.join(output_dir, f"mfcc_chunk_{chunk_idx}.npy"), mfcc_chunk)
    np.save(os.path.join(output_dir, f"lfcc_chunk_{chunk_idx}.npy"), lfcc_chunk)
    np.save(os.path.join(output_dir, f"chroma_chunk_{chunk_idx}.npy"), chroma_chunk)
    np.save(os.path.join(output_dir, f"y_chunk_{chunk_idx}.npy"), y_chunk)
    print(f"Saved chunk {chunk_idx}: "
          f"MFCC shape {mfcc_chunk.shape}, LFCC shape {lfcc_chunk.shape}, "
          f"Chroma shape {chroma_chunk.shape}, y shape {y_chunk.shape}")

print(f"Processed and saved {total_samples_used} audio samples in {output_dir}!")

# Verify saved files
print("\nVerifying all saved files...")
chunk_files = [f for f in os.listdir(output_dir) if f.endswith('.npy')]
print(f"Found {len(chunk_files)} .npy files in {output_dir}:")
for chunk_file in sorted(chunk_files):
    file_path = os.path.join(output_dir, chunk_file)
    size_mb = os.path.getsize(file_path) / (1024 * 1024)
    print(f"{chunk_file}: {size_mb:.2f} MB")

Using device: cuda, GPU: Tesla T4
In-the-Wild Counts: Real: 19963, Fake: 11816
Skipping invalid file: /kaggle/input/the-fake-or-real-dataset/for-norm/for-norm/training/real/file15440.wav_16k.wav_norm.wav_mono.wav_silence.wav, Error: Failed to decode audio.
Skipping invalid file: /kaggle/input/the-fake-or-real-dataset/for-norm/for-norm/training/real/file11064.wav_16k.wav_norm.wav_mono.wav_silence.wav, Error: Failed to decode audio.
Fake-or-Real Counts: Real: 84756, Fake: 56593
Total Counts: Real: 104719, Fake: 68409

Total files before processing: 173128
Label Distribution before processing: Real (0): 104719 (60.49%), Fake (1): 68409 (39.51%)


Processing Batches:   0%|          | 0/2706 [00:00<?, ?it/s]

Processed 0/2706 batches (0.0%)


Processing Batches:   1%|          | 27/2706 [00:07<11:10,  3.99it/s]

Processed 27/2706 batches (1.0%)


Processing Batches:   2%|▏         | 54/2706 [00:14<11:37,  3.80it/s]

Processed 54/2706 batches (2.0%)


Processing Batches:   3%|▎         | 81/2706 [00:21<10:36,  4.12it/s]

Processed 81/2706 batches (3.0%)


Processing Batches:   4%|▍         | 108/2706 [00:27<10:59,  3.94it/s]

Processed 108/2706 batches (4.0%)


Processing Batches:   5%|▍         | 135/2706 [00:34<10:30,  4.08it/s]

Processed 135/2706 batches (5.0%)


Processing Batches:   6%|▌         | 162/2706 [00:40<10:03,  4.21it/s]

Processed 162/2706 batches (6.0%)


Processing Batches:   7%|▋         | 189/2706 [00:53<27:16,  1.54it/s]

Processed 189/2706 batches (7.0%)


Processing Batches:   8%|▊         | 216/2706 [01:11<29:06,  1.43it/s]

Processed 216/2706 batches (8.0%)


Processing Batches:   9%|▉         | 243/2706 [01:30<28:22,  1.45it/s]

Processed 243/2706 batches (9.0%)


Processing Batches:  10%|▉         | 270/2706 [01:47<25:11,  1.61it/s]

Processed 270/2706 batches (10.0%)


Processing Batches:  11%|█         | 297/2706 [02:03<24:59,  1.61it/s]

Processed 297/2706 batches (11.0%)


Processing Batches:  12%|█▏        | 324/2706 [02:21<27:13,  1.46it/s]

Processed 324/2706 batches (12.0%)


Processing Batches:  13%|█▎        | 351/2706 [02:40<28:19,  1.39it/s]

Processed 351/2706 batches (13.0%)


Processing Batches:  14%|█▍        | 378/2706 [02:58<26:18,  1.48it/s]

Processed 378/2706 batches (14.0%)


Processing Batches:  15%|█▍        | 405/2706 [03:17<28:02,  1.37it/s]

Processed 405/2706 batches (15.0%)


Processing Batches:  16%|█▌        | 432/2706 [03:36<25:57,  1.46it/s]

Processed 432/2706 batches (16.0%)


Processing Batches:  17%|█▋        | 459/2706 [03:55<24:57,  1.50it/s]

Processed 459/2706 batches (17.0%)


Processing Batches:  18%|█▊        | 486/2706 [04:13<26:27,  1.40it/s]

Processed 486/2706 batches (18.0%)


Processing Batches:  19%|█▉        | 513/2706 [04:29<18:01,  2.03it/s]

Processed 513/2706 batches (19.0%)


Processing Batches:  20%|█▉        | 540/2706 [04:43<17:54,  2.02it/s]

Processed 540/2706 batches (20.0%)


Processing Batches:  21%|██        | 567/2706 [04:56<17:34,  2.03it/s]

Processed 567/2706 batches (21.0%)


Processing Batches:  22%|██▏       | 594/2706 [05:10<18:13,  1.93it/s]

Processed 594/2706 batches (22.0%)


Processing Batches:  23%|██▎       | 621/2706 [05:23<17:14,  2.02it/s]

Processed 621/2706 batches (23.0%)


Processing Batches:  24%|██▍       | 648/2706 [05:37<16:59,  2.02it/s]

Processed 648/2706 batches (24.0%)


Processing Batches:  25%|██▍       | 675/2706 [05:51<17:44,  1.91it/s]

Processed 675/2706 batches (25.0%)


Processing Batches:  26%|██▌       | 702/2706 [06:05<18:45,  1.78it/s]

Processed 702/2706 batches (26.0%)


Processing Batches:  27%|██▋       | 729/2706 [06:20<17:29,  1.88it/s]

Processed 729/2706 batches (26.9%)


Processing Batches:  28%|██▊       | 756/2706 [06:34<17:45,  1.83it/s]

Processed 756/2706 batches (27.9%)


Processing Batches:  29%|██▉       | 783/2706 [06:50<27:44,  1.16it/s]

Processed 783/2706 batches (28.9%)


Processing Batches:  30%|██▉       | 810/2706 [07:16<29:47,  1.06it/s]

Processed 810/2706 batches (29.9%)


Processing Batches:  31%|███       | 837/2706 [07:41<29:23,  1.06it/s]

Processed 837/2706 batches (30.9%)


Processing Batches:  32%|███▏      | 864/2706 [08:07<29:19,  1.05it/s]

Processed 864/2706 batches (31.9%)


Processing Batches:  33%|███▎      | 891/2706 [08:32<28:27,  1.06it/s]

Processed 891/2706 batches (32.9%)


Processing Batches:  34%|███▍      | 918/2706 [08:57<27:32,  1.08it/s]

Processed 918/2706 batches (33.9%)


Processing Batches:  35%|███▍      | 945/2706 [09:23<27:40,  1.06it/s]

Processed 945/2706 batches (34.9%)


Processing Batches:  36%|███▌      | 972/2706 [09:48<26:58,  1.07it/s]

Processed 972/2706 batches (35.9%)


Processing Batches:  37%|███▋      | 999/2706 [10:13<27:07,  1.05it/s]

Processed 999/2706 batches (36.9%)


Processing Batches:  38%|███▊      | 1026/2706 [10:38<25:19,  1.11it/s]

Processed 1026/2706 batches (37.9%)


Processing Batches:  39%|███▉      | 1053/2706 [11:03<25:34,  1.08it/s]

Processed 1053/2706 batches (38.9%)


Processing Batches:  40%|███▉      | 1080/2706 [11:28<24:40,  1.10it/s]

Processed 1080/2706 batches (39.9%)


Processing Batches:  41%|████      | 1107/2706 [11:53<24:15,  1.10it/s]

Processed 1107/2706 batches (40.9%)


Processing Batches:  42%|████▏     | 1134/2706 [12:17<23:26,  1.12it/s]

Processed 1134/2706 batches (41.9%)


Processing Batches:  43%|████▎     | 1161/2706 [12:42<23:07,  1.11it/s]

Processed 1161/2706 batches (42.9%)


Processing Batches:  44%|████▍     | 1188/2706 [13:07<24:02,  1.05it/s]

Processed 1188/2706 batches (43.9%)


Processing Batches:  45%|████▍     | 1215/2706 [13:25<14:43,  1.69it/s]

Processed 1215/2706 batches (44.9%)


Processing Batches:  46%|████▌     | 1242/2706 [13:40<13:33,  1.80it/s]

Processed 1242/2706 batches (45.9%)


Processing Batches:  47%|████▋     | 1269/2706 [13:55<13:05,  1.83it/s]

Processed 1269/2706 batches (46.9%)


Processing Batches:  48%|████▊     | 1296/2706 [14:10<13:05,  1.79it/s]

Processed 1296/2706 batches (47.9%)


Processing Batches:  49%|████▉     | 1323/2706 [14:25<13:46,  1.67it/s]

Processed 1323/2706 batches (48.9%)


Processing Batches:  50%|████▉     | 1350/2706 [14:40<12:23,  1.82it/s]

Processed 1350/2706 batches (49.9%)


Processing Batches:  51%|█████     | 1377/2706 [14:55<12:29,  1.77it/s]

Processed 1377/2706 batches (50.9%)


Processing Batches:  52%|█████▏    | 1404/2706 [15:11<11:41,  1.86it/s]

Processed 1404/2706 batches (51.9%)


Processing Batches:  53%|█████▎    | 1431/2706 [15:26<12:03,  1.76it/s]

Processed 1431/2706 batches (52.9%)


Processing Batches:  54%|█████▍    | 1458/2706 [15:41<11:11,  1.86it/s]

Processed 1458/2706 batches (53.9%)


Processing Batches:  55%|█████▍    | 1485/2706 [15:56<11:38,  1.75it/s]

Processed 1485/2706 batches (54.9%)


Processing Batches:  56%|█████▌    | 1512/2706 [16:12<11:32,  1.72it/s]

Processed 1512/2706 batches (55.9%)


Processing Batches:  57%|█████▋    | 1539/2706 [16:28<11:33,  1.68it/s]

Processed 1539/2706 batches (56.9%)


Processing Batches:  58%|█████▊    | 1566/2706 [16:44<10:23,  1.83it/s]

Processed 1566/2706 batches (57.9%)


Processing Batches:  59%|█████▉    | 1593/2706 [16:59<10:46,  1.72it/s]

Processed 1593/2706 batches (58.9%)


Processing Batches:  60%|█████▉    | 1620/2706 [17:16<14:31,  1.25it/s]

Processed 1620/2706 batches (59.9%)


Processing Batches:  61%|██████    | 1647/2706 [17:42<16:50,  1.05it/s]

Processed 1647/2706 batches (60.9%)


Processing Batches:  62%|██████▏   | 1674/2706 [18:08<17:23,  1.01s/it]

Processed 1674/2706 batches (61.9%)


Processing Batches:  63%|██████▎   | 1701/2706 [18:33<15:51,  1.06it/s]

Processed 1701/2706 batches (62.9%)


Processing Batches:  64%|██████▍   | 1728/2706 [18:48<08:49,  1.85it/s]

Processed 1728/2706 batches (63.9%)


Processing Batches:  65%|██████▍   | 1755/2706 [19:03<09:09,  1.73it/s]

Processed 1755/2706 batches (64.9%)


Processing Batches:  66%|██████▌   | 1782/2706 [19:19<08:33,  1.80it/s]

Processed 1782/2706 batches (65.9%)


Processing Batches:  67%|██████▋   | 1809/2706 [19:38<11:20,  1.32it/s]

Processed 1809/2706 batches (66.9%)


Processing Batches:  68%|██████▊   | 1836/2706 [19:55<07:46,  1.86it/s]

Processed 1836/2706 batches (67.9%)


Processing Batches:  69%|██████▉   | 1863/2706 [20:14<17:40,  1.26s/it]

Processed 1863/2706 batches (68.9%)


Processing Batches:  70%|██████▉   | 1890/2706 [20:49<17:23,  1.28s/it]

Processed 1890/2706 batches (69.9%)


Processing Batches:  71%|███████   | 1917/2706 [21:23<16:56,  1.29s/it]

Processed 1917/2706 batches (70.9%)


Processing Batches:  72%|███████▏  | 1944/2706 [21:57<15:42,  1.24s/it]

Processed 1944/2706 batches (71.9%)


Processing Batches:  73%|███████▎  | 1971/2706 [22:32<16:20,  1.33s/it]

Processed 1971/2706 batches (72.9%)


Processing Batches:  74%|███████▍  | 1998/2706 [23:07<15:30,  1.31s/it]

Processed 1998/2706 batches (73.9%)


Processing Batches:  75%|███████▍  | 2025/2706 [23:42<14:13,  1.25s/it]

Processed 2025/2706 batches (74.9%)


Processing Batches:  76%|███████▌  | 2052/2706 [24:17<14:00,  1.29s/it]

Processed 2052/2706 batches (75.9%)


Processing Batches:  77%|███████▋  | 2079/2706 [24:52<13:15,  1.27s/it]

Processed 2079/2706 batches (76.9%)


Processing Batches:  78%|███████▊  | 2106/2706 [25:27<12:36,  1.26s/it]

Processed 2106/2706 batches (77.9%)


Processing Batches:  79%|███████▉  | 2133/2706 [26:01<12:26,  1.30s/it]

Processed 2133/2706 batches (78.9%)


Processing Batches:  80%|███████▉  | 2160/2706 [26:37<12:38,  1.39s/it]

Processed 2160/2706 batches (79.8%)


Processing Batches:  81%|████████  | 2187/2706 [27:13<11:54,  1.38s/it]

Processed 2187/2706 batches (80.8%)


Processing Batches:  82%|████████▏ | 2214/2706 [27:48<10:41,  1.30s/it]

Processed 2214/2706 batches (81.8%)


Processing Batches:  83%|████████▎ | 2241/2706 [28:26<11:36,  1.50s/it]

Processed 2241/2706 batches (82.8%)


Processing Batches:  84%|████████▍ | 2268/2706 [29:04<10:36,  1.45s/it]

Processed 2268/2706 batches (83.8%)


Processing Batches:  85%|████████▍ | 2295/2706 [29:36<06:37,  1.03it/s]

Processed 2295/2706 batches (84.8%)


Processing Batches:  86%|████████▌ | 2322/2706 [30:02<06:07,  1.05it/s]

Processed 2322/2706 batches (85.8%)


Processing Batches:  87%|████████▋ | 2349/2706 [30:35<07:58,  1.34s/it]

Processed 2349/2706 batches (86.8%)


Processing Batches:  88%|████████▊ | 2376/2706 [31:11<07:42,  1.40s/it]

Processed 2376/2706 batches (87.8%)


Processing Batches:  89%|████████▉ | 2403/2706 [31:46<06:19,  1.25s/it]

Processed 2403/2706 batches (88.8%)


Processing Batches:  90%|████████▉ | 2430/2706 [32:15<03:57,  1.16it/s]

Processed 2430/2706 batches (89.8%)


Processing Batches:  91%|█████████ | 2457/2706 [32:38<03:24,  1.21it/s]

Processed 2457/2706 batches (90.8%)


Processing Batches:  92%|█████████▏| 2484/2706 [33:03<03:33,  1.04it/s]

Processed 2484/2706 batches (91.8%)


Processing Batches:  93%|█████████▎| 2511/2706 [33:23<01:22,  2.36it/s]

Processed 2511/2706 batches (92.8%)


Processing Batches:  94%|█████████▍| 2538/2706 [33:34<01:12,  2.33it/s]

Processed 2538/2706 batches (93.8%)


Processing Batches:  95%|█████████▍| 2565/2706 [33:46<00:57,  2.44it/s]

Processed 2565/2706 batches (94.8%)


Processing Batches:  96%|█████████▌| 2592/2706 [33:57<00:46,  2.45it/s]

Processed 2592/2706 batches (95.8%)


Processing Batches:  97%|█████████▋| 2619/2706 [34:11<00:48,  1.79it/s]

Processed 2619/2706 batches (96.8%)


Processing Batches:  98%|█████████▊| 2646/2706 [34:27<00:33,  1.77it/s]

Processed 2646/2706 batches (97.8%)


Processing Batches:  99%|█████████▉| 2673/2706 [34:42<00:18,  1.75it/s]

Processed 2673/2706 batches (98.8%)


Processing Batches: 100%|█████████▉| 2700/2706 [34:58<00:03,  1.73it/s]

Processed 2700/2706 batches (99.8%)


Processing Batches: 100%|██████████| 2706/2706 [35:01<00:00,  1.29it/s]


Total samples after feature extraction: 173128
Total samples: 173128, Number of chunks: 268, Chunk size: 646
Total samples used: 173128
Saved chunk 0: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 12, 32), y shape (646,)
Saved chunk 1: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 12, 32), y shape (646,)
Saved chunk 2: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 12, 32), y shape (646,)
Saved chunk 3: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 12, 32), y shape (646,)
Saved chunk 4: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 12, 32), y shape (646,)
Saved chunk 5: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 12, 32), y shape (646,)
Saved chunk 6: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 12, 32), y shape (646,)
Saved chunk 7: MFCC shape (646, 1, 32, 32), LFCC sha

# Model Training and Testing - Datasets aggregated

In [8]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, roc_curve
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import psutil

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

epochs = 20  # Keep epochs with early stopping

# Dataset Class
class AudioDataset(Dataset):
    def __init__(self, chunk_dir, chunk_indices, dataset_type="train"):
        self.chunk_dir = chunk_dir
        self.dataset_type = dataset_type
        
        # Verify that the directory exists
        if not os.path.exists(chunk_dir):
            raise FileNotFoundError(f"Chunk directory {chunk_dir} does not exist.")
        
        # Look for MFCC chunk files with the correct names
        self.chunk_files = sorted([f for f in os.listdir(chunk_dir) if f.startswith("mfcc_chunk_") and f.endswith(".npy")])
        
        # Check if any chunk files were found; debugging for previous errors
        if not self.chunk_files:
            raise FileNotFoundError(f"No MFCC chunk files found in {chunk_dir}. Expected files like 'mfcc_chunk_0.npy'.")
        
        print(f"Found {len(self.chunk_files)} chunk files in {chunk_dir}. First few: {self.chunk_files[:5]}")
        
        self.chunk_size = 646  # Fixed chunk size from preprocessing
        self.num_chunks = len(self.chunk_files)
        self.total_samples = self.num_chunks * self.chunk_size
        
        # Filter indices to ensure they are within bounds
        self.indices = [idx for idx in chunk_indices if idx < self.total_samples]
        
        # Check if indices are empty
        if not self.indices:
            raise ValueError(f"No valid indices for {self.dataset_type} dataset. "
                            f"Total samples: {self.total_samples}, but chunk_indices range is {min(chunk_indices)} to {max(chunk_indices)}.")
        
        print(f"Dataset ({self.dataset_type}) initialized with {len(self.indices)} samples. "
              f"Index range: {self.indices[0]} to {self.indices[-1]}")
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        if idx >= len(self.indices):
            raise IndexError(f"Index {idx} is out of bounds for dataset with length {len(self.indices)}")
        
        # Map global index to chunk and local index
        global_idx = self.indices[idx]
        chunk_idx = global_idx // self.chunk_size
        local_idx = global_idx % self.chunk_size
        
        # Load the chunks with the correct file names
        mfcc_chunk = np.load(os.path.join(self.chunk_dir, f"mfcc_chunk_{chunk_idx}.npy"), mmap_mode='r')
        lfcc_chunk = np.load(os.path.join(self.chunk_dir, f"lfcc_chunk_{chunk_idx}.npy"), mmap_mode='r')
        chroma_chunk = np.load(os.path.join(self.chunk_dir, f"chroma_chunk_{chunk_idx}.npy"), mmap_mode='r')
        y_chunk = np.load(os.path.join(self.chunk_dir, f"y_chunk_{chunk_idx}.npy"), mmap_mode='r')
        
        # Extract the sample
        mfcc = torch.FloatTensor(mfcc_chunk[local_idx])  # Shape: (1, 32, 32)
        lfcc = torch.FloatTensor(lfcc_chunk[local_idx])  # Shape: (1, 32, 32)
        chroma = torch.FloatTensor(chroma_chunk[local_idx]).unsqueeze(0)  # Shape: (12, 32) -> (1, 12, 32)
        y = torch.LongTensor([y_chunk[local_idx]])[0]  # Single label (0 or 1)
        
        # Normalize features (zero mean, unit variance)
        mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-8)
        lfcc = (lfcc - lfcc.mean()) / (lfcc.std() + 1e-8)
        chroma = (chroma - chroma.mean()) / (chroma.std() + 1e-8)
        
        # Add noise to features during training for data augmentation
        if self.dataset_type == "train":
            noise_factor = 0.1
            mfcc += torch.randn_like(mfcc) * noise_factor
            lfcc += torch.randn_like(lfcc) * noise_factor
            chroma += torch.randn_like(chroma) * noise_factor
        
        return mfcc, lfcc, chroma, y

# MFAAN Model with 1D Convolutions
class MFAAN(nn.Module):
    def __init__(self):
        super(MFAAN, self).__init__()
        # MFCC Path: Input shape (batch, 32, 32)
        self.path_mfcc = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),  # Reduced to 16
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),  # (batch, 16, 16)
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),  # (batch, 32, 16)
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)  # (batch, 32, 8)
        )
        
        # LFCC Path: Input shape (batch, 32, 32)
        self.path_lfcc = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),  # Reduced to 16
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),  # (batch, 16, 16)
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),  # (batch, 32, 16)
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)  # (batch, 32, 8)
        )
        
        # Chroma Path: Input shape (batch, 12, 32)
        self.path_chroma = nn.Sequential(
            nn.Conv1d(in_channels=12, out_channels=16, kernel_size=3, stride=1, padding=1),  # Reduced to 16
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),  # (batch, 16, 16)
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),  # (batch, 32, 16)
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)  # (batch, 32, 8)
        )
        
        # Fully connected layers after fusion
        # Each path outputs (batch, 32, 8), so after flattening and concatenation: 32 * 8 * 3 = 768
        self.fc1 = nn.Linear(32 * 8 * 3, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 2)  # Binary classification (real vs fake)
    
    def forward(self, mfcc, lfcc, chroma):
        mfcc = mfcc.squeeze(1)  # (batch, 1, 32, 32) -> (batch, 32, 32)
        lfcc = lfcc.squeeze(1)  # (batch, 1, 32, 32) -> (batch, 32, 32)
        chroma = chroma.squeeze(1)  # (batch, 1, 12, 32) -> (batch, 12, 32)
        
        # Process each feature through their path
        mfcc_out = self.path_mfcc(mfcc)  # (batch, 32, 8)
        lfcc_out = self.path_lfcc(lfcc)  # (batch, 32, 8)
        chroma_out = self.path_chroma(chroma)  # (batch, 32, 8)
        
        # Flatten and concatenate - finally
        mfcc_flat = mfcc_out.view(mfcc_out.size(0), -1)  # (batch, 32 * 8)
        lfcc_flat = lfcc_out.view(lfcc_out.size(0), -1)  # (batch, 32 * 8)
        chroma_flat = chroma_out.view(chroma_out.size(0), -1)  # (batch, 32 * 8)
        fused = torch.cat((mfcc_flat, lfcc_flat, chroma_flat), dim=1)  # (batch, 32 * 8 * 3)
        
        # Pass through fully connected layers
        x = F.relu(self.fc1(fused))
        x = self.dropout(x)
        output = self.fc2(x)
        return output

# Function to compute Equal Error Rate (EER)
def compute_eer(labels, scores):
    fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.absolute(fnr - fpr))]
    eer = fpr[np.nanargmin(np.absolute(fnr - fpr))]
    return eer

chunk_dir = "/kaggle/working/preprocessed_chunks_268"
total_samples = 173128

# For stratified split
all_labels = []
for chunk_idx in range(268):
    y_chunk = np.load(os.path.join(chunk_dir, f"y_chunk_{chunk_idx}.npy"))
    all_labels.extend(y_chunk)
all_labels = np.array(all_labels)

indices = list(range(len(all_labels)))

# Stratified split: train (80%), val (10%), test (10%)
train_indices, temp_indices, train_labels, temp_labels = train_test_split(
    indices, all_labels, test_size=0.2, stratify=all_labels, random_state=42
)
val_indices, test_indices, val_labels, test_labels = train_test_split(
    temp_indices, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42
)

print(f"Train samples: {len(train_indices)}, Val samples: {len(val_indices)}, Test samples: {len(test_indices)}")
print("Train class distribution:", np.bincount(train_labels))
print("Val class distribution:", np.bincount(val_labels))
print("Test class distribution:", np.bincount(test_labels))

train_dataset = AudioDataset(chunk_dir, train_indices, dataset_type="train")
val_dataset = AudioDataset(chunk_dir, val_indices, dataset_type="val")
test_dataset = AudioDataset(chunk_dir, test_indices, dataset_type="test")

# Debug - print dataset lengths
print(f"Train dataset length: {len(train_dataset)}")
print(f"Val dataset length: {len(val_dataset)}")
print(f"Test dataset length: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=0)

# Check for class imbalance and apply class weights
class_counts = np.bincount(all_labels)
class_weights = 1.0 / class_counts
class_weights = torch.FloatTensor(class_weights / class_weights.sum()).to(device)
print(f"Class weights: {class_weights}")

# Model, Optimizer, and Loss
model = MFAAN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-3)
criterion = nn.CrossEntropyLoss(weight=class_weights)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1, verbose=True)

# Training loop - early stopping based on eer
def train_model(epochs=epochs, patience=5):
    best_val_eer = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for batch_idx, (mfcc, lfcc, chroma, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            mfcc, lfcc, chroma, labels = mfcc.to(device), lfcc.to(device), chroma.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(mfcc, lfcc, chroma)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            running_loss += loss.item()
        
        # Validation
        model.eval()
        val_preds = []
        val_scores = []
        val_labels = []
        with torch.no_grad():
            for mfcc, lfcc, chroma, labels in val_loader:
                mfcc, lfcc, chroma, labels = mfcc.to(device), lfcc.to(device), chroma.to(device), labels.to(device)
                outputs = model(mfcc, lfcc, chroma)
                probs = F.softmax(outputs, dim=1)[:, 1]  # Probability of class 1 (fake)
                _, predicted = torch.max(outputs, 1)
                val_preds.extend(predicted.cpu().numpy())
                val_scores.extend(probs.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        
        val_acc = accuracy_score(val_labels, val_preds)
        val_eer = compute_eer(val_labels, val_scores)
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, "
              f"Validation Accuracy: {val_acc * 100:.2f}%, Validation EER: {val_eer * 100:.2f}%")
        
        # Test after first epoch - see how it is
        if epoch == 0:
            test_preds = []
            test_scores = []
            test_labels = []
            with torch.no_grad():
                for batch_idx, (mfcc, lfcc, chroma, labels) in enumerate(test_loader):
                    mfcc, lfcc, chroma, labels = mfcc.to(device), lfcc.to(device), chroma.to(device), labels.to(device)
                    outputs = model(mfcc, lfcc, chroma)
                    probs = F.softmax(outputs, dim=1)[:, 1]
                    _, predicted = torch.max(outputs, 1)
                    test_preds.extend(predicted.cpu().numpy())
                    test_scores.extend(probs.cpu().numpy())
                    test_labels.extend(labels.cpu().numpy())
            test_acc = accuracy_score(test_labels, test_preds)
            test_eer = compute_eer(test_labels, test_scores)
            print(f"Test Accuracy after Epoch 1: {test_acc * 100:.2f}%, Test EER: {test_eer * 100:.2f}%")
        
        # Early Stopping based on eer
        if val_eer < best_val_eer:
            best_val_eer = val_eer
            patience_counter = 0
            torch.save(model.state_dict(), "/kaggle/working/mfaan_best.pth")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered!")
                break
        
        scheduler.step(val_eer)

train_model(epochs=epochs, patience=5)

process = psutil.Process()
mem_info = process.memory_info()
print(f"Memory usage after training: {mem_info.rss / 1024**2:.2f} MB")

# Final Test Evaluation with Per-Class Accuracy
model.load_state_dict(torch.load("/kaggle/working/mfaan_best.pth"))
model.eval()
test_preds = []
test_scores = []
test_labels = []
with torch.no_grad():
    for mfcc, lfcc, chroma, labels in test_loader:
        mfcc, lfcc, chroma, labels = mfcc.to(device), lfcc.to(device), chroma.to(device), labels.to(device)
        outputs = model(mfcc, lfcc, chroma)
        probs = F.softmax(outputs, dim=1)[:, 1]
        _, predicted = torch.max(outputs, 1)
        test_preds.extend(predicted.cpu().numpy())
        test_scores.extend(probs.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(test_labels, test_preds)
test_eer = compute_eer(test_labels, test_scores)

# Compute per-class accuracy - fake and real
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(test_labels, test_preds)
per_class_acc = cm.diagonal() / cm.sum(axis=1)
print(f"\nFinal Test Accuracy: {test_acc * 100:.2f}%, Final Test EER: {test_eer * 100:.2f}%")
print(f"Per-class accuracy (real, fake): {per_class_acc * 100}")

Using device: cuda
Train samples: 138502, Val samples: 17313, Test samples: 17313
Train class distribution: [83775 54727]
Val class distribution: [10472  6841]
Test class distribution: [10472  6841]
Found 268 chunk files in /kaggle/working/preprocessed_chunks_268. First few: ['mfcc_chunk_0.npy', 'mfcc_chunk_1.npy', 'mfcc_chunk_10.npy', 'mfcc_chunk_100.npy', 'mfcc_chunk_101.npy']
Dataset (train) initialized with 138502 samples. Index range: 125083 to 68149
Found 268 chunk files in /kaggle/working/preprocessed_chunks_268. First few: ['mfcc_chunk_0.npy', 'mfcc_chunk_1.npy', 'mfcc_chunk_10.npy', 'mfcc_chunk_100.npy', 'mfcc_chunk_101.npy']
Dataset (val) initialized with 17313 samples. Index range: 112480 to 130832
Found 268 chunk files in /kaggle/working/preprocessed_chunks_268. First few: ['mfcc_chunk_0.npy', 'mfcc_chunk_1.npy', 'mfcc_chunk_10.npy', 'mfcc_chunk_100.npy', 'mfcc_chunk_101.npy']
Dataset (test) initialized with 17313 samples. Index range: 168466 to 60314
Train dataset length: 

Epoch 1/20: 100%|██████████| 4329/4329 [02:52<00:00, 25.17it/s]


Epoch 1/20, Loss: 0.2952, Validation Accuracy: 94.73%, Validation EER: 4.22%
Test Accuracy after Epoch 1: 94.70%, Test EER: 4.27%


Epoch 2/20: 100%|██████████| 4329/4329 [02:57<00:00, 24.44it/s]


Epoch 2/20, Loss: 0.1641, Validation Accuracy: 96.52%, Validation EER: 3.28%


Epoch 3/20: 100%|██████████| 4329/4329 [02:53<00:00, 24.89it/s]


Epoch 3/20, Loss: 0.1368, Validation Accuracy: 97.05%, Validation EER: 2.44%


Epoch 4/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.82it/s]


Epoch 4/20, Loss: 0.1264, Validation Accuracy: 97.23%, Validation EER: 1.87%


Epoch 5/20: 100%|██████████| 4329/4329 [02:55<00:00, 24.70it/s]


Epoch 5/20, Loss: 0.1215, Validation Accuracy: 98.09%, Validation EER: 1.87%


Epoch 6/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.79it/s]


Epoch 6/20, Loss: 0.1158, Validation Accuracy: 98.30%, Validation EER: 1.75%


Epoch 7/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.79it/s]


Epoch 7/20, Loss: 0.1104, Validation Accuracy: 98.23%, Validation EER: 1.77%


Epoch 8/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.87it/s]


Epoch 8/20, Loss: 0.1062, Validation Accuracy: 98.46%, Validation EER: 1.55%


Epoch 9/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.83it/s]


Epoch 9/20, Loss: 0.1023, Validation Accuracy: 98.32%, Validation EER: 1.70%


Epoch 10/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.78it/s]


Epoch 10/20, Loss: 0.1039, Validation Accuracy: 98.24%, Validation EER: 1.63%


Epoch 11/20: 100%|██████████| 4329/4329 [02:53<00:00, 24.91it/s]


Epoch 11/20, Loss: 0.0905, Validation Accuracy: 98.63%, Validation EER: 1.37%


Epoch 12/20: 100%|██████████| 4329/4329 [02:55<00:00, 24.66it/s]


Epoch 12/20, Loss: 0.0877, Validation Accuracy: 98.82%, Validation EER: 1.27%


Epoch 13/20: 100%|██████████| 4329/4329 [02:55<00:00, 24.65it/s]


Epoch 13/20, Loss: 0.0882, Validation Accuracy: 98.78%, Validation EER: 1.23%


Epoch 14/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.84it/s]


Epoch 14/20, Loss: 0.0885, Validation Accuracy: 98.90%, Validation EER: 1.19%


Epoch 15/20: 100%|██████████| 4329/4329 [02:55<00:00, 24.68it/s]


Epoch 15/20, Loss: 0.0901, Validation Accuracy: 98.53%, Validation EER: 1.33%


Epoch 16/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.88it/s]


Epoch 16/20, Loss: 0.0891, Validation Accuracy: 98.67%, Validation EER: 1.28%


Epoch 17/20: 100%|██████████| 4329/4329 [02:54<00:00, 24.85it/s]


Epoch 17/20, Loss: 0.0887, Validation Accuracy: 98.80%, Validation EER: 1.21%


Epoch 18/20: 100%|██████████| 4329/4329 [02:53<00:00, 24.93it/s]


Epoch 18/20, Loss: 0.0883, Validation Accuracy: 98.86%, Validation EER: 1.17%


Epoch 19/20: 100%|██████████| 4329/4329 [02:52<00:00, 25.04it/s]


Epoch 19/20, Loss: 0.0898, Validation Accuracy: 98.73%, Validation EER: 1.27%


Epoch 20/20: 100%|██████████| 4329/4329 [02:53<00:00, 24.89it/s]


Epoch 20/20, Loss: 0.0883, Validation Accuracy: 98.79%, Validation EER: 1.22%
Memory usage after training: 4603.22 MB

Final Test Accuracy: 98.99%, Final Test EER: 1.04%
Per-class accuracy (real, fake): [99.26470588 98.5674609 ]


In [None]:
import numpy as np

#to check for calss imbalance - it exists - need to use stratified
train_labels = [train_dataset[i][-1].item() for i in range(len(train_dataset))]
val_labels = [val_dataset[i][-1].item() for i in range(len(val_dataset))]
test_labels = [test_dataset[i][-1].item() for i in range(len(test_dataset))]
print("Train class distribution:", np.bincount(train_labels))
print("Val class distribution:", np.bincount(val_labels))
print("Test class distribution:", np.bincount(test_labels))

# Multi-Lingual Audio Deepfake Detection Corpus (MLADDC):
Dataset with halftruths. 

Initially, testing FoR & InTheWild trained model on T2 - international languages - dataset.

In [9]:
import os
import numpy as np
import torch
import torchaudio
from torchaudio.transforms import MFCC, LFCC, Spectrogram
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# T2 dataset path
T2_PATH = "/kaggle/input/mladdc-t2/T2/"

# Function to validate if a .wav file can be loaded
def validate_wav_file(file_path):
    try:
        waveform, sr = torchaudio.load(file_path)
        if waveform.size(1) == 0:
            print(f"Skipping empty file: {file_path}")
            return False
        return True
    except Exception as e:
        print(f"Skipping invalid file: {file_path}, Error: {e}")
        return False

# Function to collect T2 test audio files and labels
def get_t2_audio_files_and_labels(data_split="test"):
    audio_files = []
    labels = []
    counts = {"real": 0, "fake": 0}

    real_path = os.path.join(T2_PATH, data_split, "real")
    fake_path = os.path.join(T2_PATH, data_split, "deepfake")

    if not os.path.exists(real_path):
        raise FileNotFoundError(f"Real directory not found at {real_path}")
    if not os.path.exists(fake_path):
        raise FileNotFoundError(f"Fake directory not found at {fake_path}")

    for file in os.listdir(real_path):
        if file.endswith(".wav"):
            file_path = os.path.join(real_path, file)
            if validate_wav_file(file_path):
                audio_files.append(file_path)
                labels.append(0)  # Real
                counts["real"] += 1

    for file in os.listdir(fake_path):
        if file.endswith(".wav"):
            file_path = os.path.join(fake_path, file)
            if validate_wav_file(file_path):
                audio_files.append(file_path)
                labels.append(1)  # Fake
                counts["fake"] += 1

    print(f"T2 {data_split} Counts: Real: {counts['real']}, Fake: {counts['fake']}")
    return audio_files, labels

# Custom Chroma-STFT implementation
def compute_chroma_stft(waveforms, sample_rate=16000, n_fft=2048, hop_length=512, n_chroma=12):
    spectrogram_transform = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2.0).to(device)
    spec = spectrogram_transform(waveforms)
    spec = spec.squeeze(1)
    freqs = torch.linspace(0, sample_rate / 2, steps=spec.shape[1]).to(device)
    chroma_freqs = torch.tensor([31.25 * (2 ** (i / 12)) for i in range(12 * 4)], device=device)
    chroma_bins = torch.zeros((n_chroma, spec.shape[1]), device=device)
    for i in range(n_chroma):
        center = chroma_freqs[i::12]
        for cf in center:
            mask = (freqs >= cf / 1.06) & (freqs <= cf * 1.06)
            chroma_bins[i] += mask.float()
    chroma_bins /= chroma_bins.sum(dim=1, keepdim=True).clamp(min=1e-10)
    chroma = torch.einsum('cf,bft->bct', chroma_bins, spec)
    return chroma

# Feature Extraction for T2 test set
def extract_t2_features_batch(file_paths, labels, batch_size=64, max_length=16000):
    mfcc_results = []
    lfcc_results = []
    chroma_results = []
    valid_labels = []

    # Use 32 coefficients for MFCC and LFCC
    mfcc_transform = MFCC(
        sample_rate=16000,
        n_mfcc=32,
        melkwargs={"n_fft": 2048, "hop_length": 512, "n_mels": 128}
    ).to(device)
    lfcc_transform = LFCC(
        sample_rate=16000,
        n_lfcc=32,
        f_min=0,
        f_max=8000,
        n_filter=128,
        speckwargs={"n_fft": 2048, "hop_length": 512}
    ).to(device)

    total_batches = (len(file_paths) + batch_size - 1) // batch_size
    log_interval = max(1, total_batches // 100)

    for i in tqdm(range(0, len(file_paths), batch_size), desc="Processing T2 Test Batches", total=total_batches):
        if i % (log_interval * batch_size) == 0:
            print(f"Processed {i // batch_size}/{total_batches} batches ({(i / len(file_paths)) * 100:.1f}%)")

        batch_files = file_paths[i:i + batch_size]
        batch_labels = labels[i:i + batch_size]
        waveforms = []
        valid_indices = []

        for idx, (file_path, label) in enumerate(zip(batch_files, batch_labels)):
            try:
                waveform, sr = torchaudio.load(file_path)
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                if sr != 16000:
                    waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
                if waveform.size(1) > max_length:
                    waveform = waveform[:, :max_length]
                elif waveform.size(1) < max_length:
                    pad_size = max_length - waveform.size(1)
                    waveform = torch.nn.functional.pad(waveform, (0, pad_size))
                waveforms.append(waveform)
                valid_indices.append(idx)
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                continue

        if not waveforms:
            continue

        waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True).to(device)
        mfccs = mfcc_transform(waveforms)  # Expected (batch, 32, time)
        lfccs = lfcc_transform(waveforms)  # Expected (batch, 32, time)
        chromas = compute_chroma_stft(waveforms, sample_rate=16000, n_fft=2048, hop_length=512, n_chroma=12)  # Expected (batch, 12, time)

        target_frames = 32
        for idx in valid_indices:
            # Process MFCC
            mfcc = mfccs[idx]  # Expected (32, time)
            if len(mfcc.shape) == 3:  # If (1, 32, time), remove extra dim
                mfcc = mfcc.squeeze(0)
            if mfcc.shape[1] > target_frames:
                mfcc = mfcc[:, :target_frames]
            elif mfcc.shape[1] < target_frames:
                mfcc = torch.nn.functional.pad(mfcc, (0, target_frames - mfcc.shape[1]))
            if mfcc.shape != (32, 32):
                print(f"Warning: MFCC shape after resize at index {idx}: {mfcc.shape}")
                continue
            mfcc = mfcc.unsqueeze(0)  # (32, 32) -> (1, 32, 32)
            if mfcc.shape != (1, 32, 32):
                print(f"Warning: Final MFCC shape mismatch at index {idx}: {mfcc.shape}")
                continue
            mfcc_results.append(mfcc.cpu().numpy())

            # Process LFCC
            lfcc = lfccs[idx]
            if len(lfcc.shape) == 3:
                lfcc = lfcc.squeeze(0)
            if lfcc.shape[1] > target_frames:
                lfcc = lfcc[:, :target_frames]
            elif lfcc.shape[1] < target_frames:
                lfcc = torch.nn.functional.pad(lfcc, (0, target_frames - lfcc.shape[1]))
            if lfcc.shape != (32, 32):
                print(f"Warning: LFCC shape after resize at index {idx}: {lfcc.shape}")
                continue
            lfcc = lfcc.unsqueeze(0)  # (32, 32) -> (1, 32, 32)
            if lfcc.shape != (1, 32, 32):
                print(f"Warning: Final LFCC shape mismatch at index {idx}: {lfcc.shape}")
                continue
            lfcc_results.append(lfcc.cpu().numpy())

            # Process Chroma
            chroma = chromas[idx]
            if len(chroma.shape) == 3:
                chroma = chroma.squeeze(0)
            if chroma.shape[1] > target_frames:
                chroma = chroma[:, :target_frames]
            elif chroma.shape[1] < target_frames:
                chroma = torch.nn.functional.pad(chroma, (0, target_frames - chroma.shape[1]))
            if chroma.shape != (12, 32):
                print(f"Warning: Chroma shape after resize at index {idx}: {chroma.shape}")
                continue
            chroma = chroma.unsqueeze(0)  # (12, 32) -> (1, 12, 32)
            if chroma.shape != (1, 12, 32):
                print(f"Warning: Final Chroma shape mismatch at index {idx}: {chroma.shape}")
                continue
            chroma_results.append(chroma.cpu().numpy())

            valid_labels.append(batch_labels[idx])

    return mfcc_results, lfcc_results, chroma_results, valid_labels

# Preprocess T2 test files
t2_test_files, t2_test_labels = get_t2_audio_files_and_labels("test")
print(f"\nTotal T2 test files before processing: {len(t2_test_files)}")
print(f"Label Distribution: Real (0): {sum(1 for label in t2_test_labels if label == 0)} "
      f"({sum(1 for label in t2_test_labels if label == 0)/len(t2_test_files)*100:.2f}%), "
      f"Fake (1): {sum(1 for label in t2_test_labels if label == 1)} "
      f"({sum(1 for label in t2_test_labels if label == 1)/len(t2_test_files)*100:.2f}%)")

# Extract features
t2_test_mfcc_features, t2_test_lfcc_features, t2_test_chroma_features, t2_test_y = extract_t2_features_batch(
    t2_test_files, t2_test_labels, batch_size=64, max_length=16000
)

# Convert to numpy arrays
t2_test_mfcc_features = np.array(t2_test_mfcc_features)  # (num_samples, 1, 32, 32)
t2_test_lfcc_features = np.array(t2_test_lfcc_features)  # (num_samples, 1, 32, 32)
t2_test_chroma_features = np.array(t2_test_chroma_features)  # (num_samples, 1, 12, 32)
t2_test_y = np.array(t2_test_y)  # (num_samples,)

# Debugging - validate shapes 
print(f"MFCC features shape: {t2_test_mfcc_features.shape}")
print(f"LFCC features shape: {t2_test_lfcc_features.shape}")
print(f"Chroma features shape: {t2_test_chroma_features.shape}")
print(f"Labels shape: {t2_test_y.shape}")

# Validate alignment
if not (len(t2_test_mfcc_features) == len(t2_test_lfcc_features) == len(t2_test_chroma_features) == len(t2_test_y)):
    raise ValueError(f"Mismatch between T2 test features and labels: "
                     f"MFCC: {len(t2_test_mfcc_features)}, LFCC: {len(t2_test_lfcc_features)}, "
                     f"Chroma: {len(t2_test_chroma_features)}, y: {len(t2_test_y)}")

# Verify total samples
total_t2_test_samples = len(t2_test_y)
print(f"Total T2 test samples after feature extraction: {total_t2_test_samples}")

# Save T2 test chunks
output_dir_t2_test = "/kaggle/working/preprocessed_t2_test_chunks"
os.makedirs(output_dir_t2_test, exist_ok=True)

# Clear existing T2 test chunks
for f in os.listdir(output_dir_t2_test):
    os.remove(os.path.join(output_dir_t2_test, f))

# Use chunk size of 646 to match training data
chunk_size_t2_test = 646
num_chunks_t2_test = (total_t2_test_samples + chunk_size_t2_test - 1) // chunk_size_t2_test
total_t2_test_samples_used = num_chunks_t2_test * chunk_size_t2_test

print(f"Total T2 test samples: {total_t2_test_samples}, Number of T2 test chunks: {num_chunks_t2_test}, "
      f"T2 Test Chunk size: {chunk_size_t2_test}")
if total_t2_test_samples != total_t2_test_samples_used:
    print(f"Warning: {total_t2_test_samples_used - total_t2_test_samples} samples will be padded due to chunking.")

# Pad features and labels if needed
if total_t2_test_samples < total_t2_test_samples_used:
    pad_size = total_t2_test_samples_used - total_t2_test_samples
    mfcc_pad = np.zeros((pad_size, 1, 32, 32), dtype=np.float32)
    lfcc_pad = np.zeros((pad_size, 1, 32, 32), dtype=np.float32)
    chroma_pad = np.zeros((pad_size, 1, 12, 32), dtype=np.float32)
    y_pad = np.zeros(pad_size, dtype=np.int64)  # Pad with real labels (0)

    t2_test_mfcc_features = np.concatenate([t2_test_mfcc_features, mfcc_pad], axis=0)
    t2_test_lfcc_features = np.concatenate([t2_test_lfcc_features, lfcc_pad], axis=0)
    t2_test_chroma_features = np.concatenate([t2_test_chroma_features, chroma_pad], axis=0)
    t2_test_y = np.concatenate([t2_test_y, y_pad], axis=0)

# Validate padded shapes
print(f"Padded MFCC features shape: {t2_test_mfcc_features.shape}")
print(f"Padded LFCC features shape: {t2_test_lfcc_features.shape}")
print(f"Padded Chroma features shape: {t2_test_chroma_features.shape}")
print(f"Padded Labels shape: {t2_test_y.shape}")

# Save chunks
for i in range(0, total_t2_test_samples_used, chunk_size_t2_test):
    chunk_idx = i // chunk_size_t2_test
    t2_test_mfcc_chunk = t2_test_mfcc_features[i:i + chunk_size_t2_test]
    t2_test_lfcc_chunk = t2_test_lfcc_features[i:i + chunk_size_t2_test]
    t2_test_chroma_chunk = t2_test_chroma_features[i:i + chunk_size_t2_test]
    t2_test_y_chunk = t2_test_y[i:i + chunk_size_t2_test]

    if not (len(t2_test_mfcc_chunk) == len(t2_test_lfcc_chunk) == len(t2_test_chroma_chunk) == len(t2_test_y_chunk)):
        raise ValueError(f"Mismatch in T2 test chunk {chunk_idx}: "
                         f"MFCC: {len(t2_test_mfcc_chunk)}, LFCC: {len(t2_test_lfcc_chunk)}, "
                         f"Chroma: {len(t2_test_chroma_chunk)}, y: {len(t2_test_y_chunk)}")

    np.save(os.path.join(output_dir_t2_test, f"t2_test_mfcc_chunk_{chunk_idx}.npy"), t2_test_mfcc_chunk)
    np.save(os.path.join(output_dir_t2_test, f"t2_test_lfcc_chunk_{chunk_idx}.npy"), t2_test_lfcc_chunk)
    np.save(os.path.join(output_dir_t2_test, f"t2_test_chroma_chunk_{chunk_idx}.npy"), t2_test_chroma_chunk)
    np.save(os.path.join(output_dir_t2_test, f"t2_test_y_chunk_{chunk_idx}.npy"), t2_test_y_chunk)
    print(f"Saved T2 test chunk {chunk_idx}: "
          f"MFCC shape {t2_test_mfcc_chunk.shape}, LFCC shape {t2_test_lfcc_chunk.shape}, "
          f"Chroma shape {t2_test_chroma_chunk.shape}, y shape {t2_test_y_chunk.shape}")

print(f"Processed and saved {total_t2_test_samples_used} T2 test audio samples in {output_dir_t2_test}!")

Using device: cuda, GPU: Tesla T4
T2 test Counts: Real: 5600, Fake: 11200

Total T2 test files before processing: 16800
Label Distribution: Real (0): 5600 (33.33%), Fake (1): 11200 (66.67%)


Processing T2 Test Batches:   0%|          | 0/263 [00:00<?, ?it/s]

Processed 0/263 batches (0.0%)


Processing T2 Test Batches:   1%|          | 2/263 [00:01<03:24,  1.27it/s]

Processed 2/263 batches (0.8%)


Processing T2 Test Batches:   2%|▏         | 4/263 [00:03<03:22,  1.28it/s]

Processed 4/263 batches (1.5%)


Processing T2 Test Batches:   2%|▏         | 6/263 [00:04<03:25,  1.25it/s]

Processed 6/263 batches (2.3%)


Processing T2 Test Batches:   3%|▎         | 8/263 [00:06<03:22,  1.26it/s]

Processed 8/263 batches (3.0%)


Processing T2 Test Batches:   4%|▍         | 10/263 [00:07<03:24,  1.23it/s]

Processed 10/263 batches (3.8%)


Processing T2 Test Batches:   5%|▍         | 12/263 [00:09<03:17,  1.27it/s]

Processed 12/263 batches (4.6%)


Processing T2 Test Batches:   5%|▌         | 14/263 [00:11<03:19,  1.25it/s]

Processed 14/263 batches (5.3%)


Processing T2 Test Batches:   6%|▌         | 16/263 [00:12<03:15,  1.27it/s]

Processed 16/263 batches (6.1%)


Processing T2 Test Batches:   7%|▋         | 18/263 [00:14<03:11,  1.28it/s]

Processed 18/263 batches (6.9%)


Processing T2 Test Batches:   8%|▊         | 20/263 [00:15<03:09,  1.29it/s]

Processed 20/263 batches (7.6%)


Processing T2 Test Batches:   8%|▊         | 22/263 [00:17<03:07,  1.28it/s]

Processed 22/263 batches (8.4%)


Processing T2 Test Batches:   9%|▉         | 24/263 [00:18<03:05,  1.29it/s]

Processed 24/263 batches (9.1%)


Processing T2 Test Batches:  10%|▉         | 26/263 [00:20<03:04,  1.28it/s]

Processed 26/263 batches (9.9%)


Processing T2 Test Batches:  11%|█         | 28/263 [00:22<03:00,  1.30it/s]

Processed 28/263 batches (10.7%)


Processing T2 Test Batches:  11%|█▏        | 30/263 [00:23<03:01,  1.29it/s]

Processed 30/263 batches (11.4%)


Processing T2 Test Batches:  12%|█▏        | 32/263 [00:25<03:01,  1.28it/s]

Processed 32/263 batches (12.2%)


Processing T2 Test Batches:  13%|█▎        | 34/263 [00:26<02:58,  1.28it/s]

Processed 34/263 batches (13.0%)


Processing T2 Test Batches:  14%|█▎        | 36/263 [00:28<02:56,  1.29it/s]

Processed 36/263 batches (13.7%)


Processing T2 Test Batches:  14%|█▍        | 38/263 [00:29<02:56,  1.28it/s]

Processed 38/263 batches (14.5%)


Processing T2 Test Batches:  15%|█▌        | 40/263 [00:31<02:55,  1.27it/s]

Processed 40/263 batches (15.2%)


Processing T2 Test Batches:  16%|█▌        | 42/263 [00:32<02:51,  1.29it/s]

Processed 42/263 batches (16.0%)


Processing T2 Test Batches:  17%|█▋        | 44/263 [00:34<02:48,  1.30it/s]

Processed 44/263 batches (16.8%)


Processing T2 Test Batches:  17%|█▋        | 46/263 [00:36<02:50,  1.27it/s]

Processed 46/263 batches (17.5%)


Processing T2 Test Batches:  18%|█▊        | 48/263 [00:37<02:48,  1.28it/s]

Processed 48/263 batches (18.3%)


Processing T2 Test Batches:  19%|█▉        | 50/263 [00:39<02:47,  1.27it/s]

Processed 50/263 batches (19.0%)


Processing T2 Test Batches:  20%|█▉        | 52/263 [00:40<02:48,  1.25it/s]

Processed 52/263 batches (19.8%)


Processing T2 Test Batches:  21%|██        | 54/263 [00:42<02:45,  1.26it/s]

Processed 54/263 batches (20.6%)


Processing T2 Test Batches:  21%|██▏       | 56/263 [00:43<02:40,  1.29it/s]

Processed 56/263 batches (21.3%)


Processing T2 Test Batches:  22%|██▏       | 58/263 [00:45<02:40,  1.27it/s]

Processed 58/263 batches (22.1%)


Processing T2 Test Batches:  23%|██▎       | 60/263 [00:47<02:35,  1.30it/s]

Processed 60/263 batches (22.9%)


Processing T2 Test Batches:  24%|██▎       | 62/263 [00:48<02:33,  1.31it/s]

Processed 62/263 batches (23.6%)


Processing T2 Test Batches:  24%|██▍       | 64/263 [00:50<02:32,  1.30it/s]

Processed 64/263 batches (24.4%)


Processing T2 Test Batches:  25%|██▌       | 66/263 [00:51<02:30,  1.31it/s]

Processed 66/263 batches (25.1%)


Processing T2 Test Batches:  26%|██▌       | 68/263 [00:53<02:28,  1.31it/s]

Processed 68/263 batches (25.9%)


Processing T2 Test Batches:  27%|██▋       | 70/263 [00:54<02:28,  1.30it/s]

Processed 70/263 batches (26.7%)


Processing T2 Test Batches:  27%|██▋       | 72/263 [00:56<02:26,  1.31it/s]

Processed 72/263 batches (27.4%)


Processing T2 Test Batches:  28%|██▊       | 74/263 [00:57<02:26,  1.29it/s]

Processed 74/263 batches (28.2%)


Processing T2 Test Batches:  29%|██▉       | 76/263 [00:59<02:25,  1.28it/s]

Processed 76/263 batches (29.0%)


Processing T2 Test Batches:  30%|██▉       | 78/263 [01:00<02:26,  1.27it/s]

Processed 78/263 batches (29.7%)


Processing T2 Test Batches:  30%|███       | 80/263 [01:02<02:25,  1.26it/s]

Processed 80/263 batches (30.5%)


Processing T2 Test Batches:  31%|███       | 82/263 [01:04<02:22,  1.27it/s]

Processed 82/263 batches (31.2%)


Processing T2 Test Batches:  32%|███▏      | 84/263 [01:05<02:23,  1.25it/s]

Processed 84/263 batches (32.0%)


Processing T2 Test Batches:  33%|███▎      | 86/263 [01:07<02:19,  1.27it/s]

Processed 86/263 batches (32.8%)


Processing T2 Test Batches:  33%|███▎      | 88/263 [01:08<02:17,  1.27it/s]

Processed 88/263 batches (33.5%)


Processing T2 Test Batches:  34%|███▍      | 90/263 [01:10<02:20,  1.23it/s]

Processed 90/263 batches (34.3%)


Processing T2 Test Batches:  35%|███▍      | 92/263 [01:12<02:17,  1.24it/s]

Processed 92/263 batches (35.0%)


Processing T2 Test Batches:  36%|███▌      | 94/263 [01:13<02:15,  1.25it/s]

Processed 94/263 batches (35.8%)


Processing T2 Test Batches:  37%|███▋      | 96/263 [01:15<02:11,  1.27it/s]

Processed 96/263 batches (36.6%)


Processing T2 Test Batches:  37%|███▋      | 98/263 [01:16<02:09,  1.27it/s]

Processed 98/263 batches (37.3%)


Processing T2 Test Batches:  38%|███▊      | 100/263 [01:18<02:09,  1.26it/s]

Processed 100/263 batches (38.1%)


Processing T2 Test Batches:  39%|███▉      | 102/263 [01:20<02:09,  1.24it/s]

Processed 102/263 batches (38.9%)


Processing T2 Test Batches:  40%|███▉      | 104/263 [01:21<02:06,  1.25it/s]

Processed 104/263 batches (39.6%)


Processing T2 Test Batches:  40%|████      | 106/263 [01:23<02:03,  1.27it/s]

Processed 106/263 batches (40.4%)


Processing T2 Test Batches:  41%|████      | 108/263 [01:24<02:02,  1.26it/s]

Processed 108/263 batches (41.1%)


Processing T2 Test Batches:  42%|████▏     | 110/263 [01:26<02:00,  1.27it/s]

Processed 110/263 batches (41.9%)


Processing T2 Test Batches:  43%|████▎     | 112/263 [01:27<01:57,  1.29it/s]

Processed 112/263 batches (42.7%)


Processing T2 Test Batches:  43%|████▎     | 114/263 [01:29<01:57,  1.27it/s]

Processed 114/263 batches (43.4%)


Processing T2 Test Batches:  44%|████▍     | 116/263 [01:31<01:55,  1.28it/s]

Processed 116/263 batches (44.2%)


Processing T2 Test Batches:  45%|████▍     | 118/263 [01:32<01:51,  1.30it/s]

Processed 118/263 batches (45.0%)


Processing T2 Test Batches:  46%|████▌     | 120/263 [01:34<01:49,  1.30it/s]

Processed 120/263 batches (45.7%)


Processing T2 Test Batches:  46%|████▋     | 122/263 [01:35<01:47,  1.31it/s]

Processed 122/263 batches (46.5%)


Processing T2 Test Batches:  47%|████▋     | 124/263 [01:37<01:46,  1.30it/s]

Processed 124/263 batches (47.2%)


Processing T2 Test Batches:  48%|████▊     | 126/263 [01:38<01:46,  1.29it/s]

Processed 126/263 batches (48.0%)


Processing T2 Test Batches:  49%|████▊     | 128/263 [01:40<01:44,  1.29it/s]

Processed 128/263 batches (48.8%)


Processing T2 Test Batches:  49%|████▉     | 130/263 [01:41<01:44,  1.27it/s]

Processed 130/263 batches (49.5%)


Processing T2 Test Batches:  50%|█████     | 132/263 [01:43<01:42,  1.27it/s]

Processed 132/263 batches (50.3%)


Processing T2 Test Batches:  51%|█████     | 134/263 [01:44<01:40,  1.28it/s]

Processed 134/263 batches (51.0%)


Processing T2 Test Batches:  52%|█████▏    | 136/263 [01:46<01:37,  1.30it/s]

Processed 136/263 batches (51.8%)


Processing T2 Test Batches:  52%|█████▏    | 138/263 [01:48<01:36,  1.30it/s]

Processed 138/263 batches (52.6%)


Processing T2 Test Batches:  53%|█████▎    | 140/263 [01:49<01:34,  1.31it/s]

Processed 140/263 batches (53.3%)


Processing T2 Test Batches:  54%|█████▍    | 142/263 [01:51<01:32,  1.31it/s]

Processed 142/263 batches (54.1%)


Processing T2 Test Batches:  55%|█████▍    | 144/263 [01:52<01:29,  1.32it/s]

Processed 144/263 batches (54.9%)


Processing T2 Test Batches:  56%|█████▌    | 146/263 [01:54<01:29,  1.31it/s]

Processed 146/263 batches (55.6%)


Processing T2 Test Batches:  56%|█████▋    | 148/263 [01:55<01:28,  1.30it/s]

Processed 148/263 batches (56.4%)


Processing T2 Test Batches:  57%|█████▋    | 150/263 [01:57<01:26,  1.30it/s]

Processed 150/263 batches (57.1%)


Processing T2 Test Batches:  58%|█████▊    | 152/263 [01:58<01:27,  1.27it/s]

Processed 152/263 batches (57.9%)


Processing T2 Test Batches:  59%|█████▊    | 154/263 [02:00<01:25,  1.27it/s]

Processed 154/263 batches (58.7%)


Processing T2 Test Batches:  59%|█████▉    | 156/263 [02:01<01:23,  1.28it/s]

Processed 156/263 batches (59.4%)


Processing T2 Test Batches:  60%|██████    | 158/263 [02:03<01:21,  1.28it/s]

Processed 158/263 batches (60.2%)


Processing T2 Test Batches:  61%|██████    | 160/263 [02:05<01:20,  1.28it/s]

Processed 160/263 batches (61.0%)


Processing T2 Test Batches:  62%|██████▏   | 162/263 [02:06<01:18,  1.28it/s]

Processed 162/263 batches (61.7%)


Processing T2 Test Batches:  62%|██████▏   | 164/263 [02:08<01:17,  1.28it/s]

Processed 164/263 batches (62.5%)


Processing T2 Test Batches:  63%|██████▎   | 166/263 [02:09<01:14,  1.30it/s]

Processed 166/263 batches (63.2%)


Processing T2 Test Batches:  64%|██████▍   | 168/263 [02:11<01:14,  1.28it/s]

Processed 168/263 batches (64.0%)


Processing T2 Test Batches:  65%|██████▍   | 170/263 [02:12<01:12,  1.28it/s]

Processed 170/263 batches (64.8%)


Processing T2 Test Batches:  65%|██████▌   | 172/263 [02:14<01:21,  1.12it/s]

Processed 172/263 batches (65.5%)


Processing T2 Test Batches:  66%|██████▌   | 174/263 [02:16<01:14,  1.20it/s]

Processed 174/263 batches (66.3%)


Processing T2 Test Batches:  67%|██████▋   | 176/263 [02:17<01:10,  1.24it/s]

Processed 176/263 batches (67.0%)


Processing T2 Test Batches:  68%|██████▊   | 178/263 [02:19<01:06,  1.28it/s]

Processed 178/263 batches (67.8%)


Processing T2 Test Batches:  68%|██████▊   | 180/263 [02:21<01:04,  1.28it/s]

Processed 180/263 batches (68.6%)


Processing T2 Test Batches:  69%|██████▉   | 182/263 [02:22<01:02,  1.29it/s]

Processed 182/263 batches (69.3%)


Processing T2 Test Batches:  70%|██████▉   | 184/263 [02:24<01:00,  1.30it/s]

Processed 184/263 batches (70.1%)


Processing T2 Test Batches:  71%|███████   | 186/263 [02:25<00:59,  1.30it/s]

Processed 186/263 batches (70.9%)


Processing T2 Test Batches:  71%|███████▏  | 188/263 [02:27<00:58,  1.29it/s]

Processed 188/263 batches (71.6%)


Processing T2 Test Batches:  72%|███████▏  | 190/263 [02:28<00:56,  1.29it/s]

Processed 190/263 batches (72.4%)


Processing T2 Test Batches:  73%|███████▎  | 192/263 [02:30<00:54,  1.30it/s]

Processed 192/263 batches (73.1%)


Processing T2 Test Batches:  74%|███████▍  | 194/263 [02:31<00:53,  1.30it/s]

Processed 194/263 batches (73.9%)


Processing T2 Test Batches:  75%|███████▍  | 196/263 [02:33<00:50,  1.32it/s]

Processed 196/263 batches (74.7%)


Processing T2 Test Batches:  75%|███████▌  | 198/263 [02:34<00:48,  1.34it/s]

Processed 198/263 batches (75.4%)


Processing T2 Test Batches:  76%|███████▌  | 200/263 [02:36<00:47,  1.33it/s]

Processed 200/263 batches (76.2%)


Processing T2 Test Batches:  77%|███████▋  | 202/263 [02:37<00:45,  1.34it/s]

Processed 202/263 batches (77.0%)


Processing T2 Test Batches:  78%|███████▊  | 204/263 [02:39<00:43,  1.36it/s]

Processed 204/263 batches (77.7%)


Processing T2 Test Batches:  78%|███████▊  | 206/263 [02:40<00:42,  1.36it/s]

Processed 206/263 batches (78.5%)


Processing T2 Test Batches:  79%|███████▉  | 208/263 [02:42<00:40,  1.36it/s]

Processed 208/263 batches (79.2%)


Processing T2 Test Batches:  80%|███████▉  | 210/263 [02:43<00:39,  1.34it/s]

Processed 210/263 batches (80.0%)


Processing T2 Test Batches:  81%|████████  | 212/263 [02:45<00:38,  1.34it/s]

Processed 212/263 batches (80.8%)


Processing T2 Test Batches:  81%|████████▏ | 214/263 [02:46<00:36,  1.33it/s]

Processed 214/263 batches (81.5%)


Processing T2 Test Batches:  82%|████████▏ | 216/263 [02:48<00:35,  1.34it/s]

Processed 216/263 batches (82.3%)


Processing T2 Test Batches:  83%|████████▎ | 218/263 [02:49<00:33,  1.33it/s]

Processed 218/263 batches (83.0%)


Processing T2 Test Batches:  84%|████████▎ | 220/263 [02:51<00:32,  1.31it/s]

Processed 220/263 batches (83.8%)


Processing T2 Test Batches:  84%|████████▍ | 222/263 [02:52<00:30,  1.33it/s]

Processed 222/263 batches (84.6%)


Processing T2 Test Batches:  85%|████████▌ | 224/263 [02:54<00:29,  1.33it/s]

Processed 224/263 batches (85.3%)


Processing T2 Test Batches:  86%|████████▌ | 226/263 [02:55<00:27,  1.33it/s]

Processed 226/263 batches (86.1%)


Processing T2 Test Batches:  87%|████████▋ | 228/263 [02:57<00:26,  1.33it/s]

Processed 228/263 batches (86.9%)


Processing T2 Test Batches:  87%|████████▋ | 230/263 [02:58<00:24,  1.34it/s]

Processed 230/263 batches (87.6%)


Processing T2 Test Batches:  88%|████████▊ | 232/263 [03:00<00:23,  1.34it/s]

Processed 232/263 batches (88.4%)


Processing T2 Test Batches:  89%|████████▉ | 234/263 [03:01<00:21,  1.33it/s]

Processed 234/263 batches (89.1%)


Processing T2 Test Batches:  90%|████████▉ | 236/263 [03:03<00:20,  1.33it/s]

Processed 236/263 batches (89.9%)


Processing T2 Test Batches:  90%|█████████ | 238/263 [03:04<00:18,  1.35it/s]

Processed 238/263 batches (90.7%)


Processing T2 Test Batches:  91%|█████████▏| 240/263 [03:06<00:17,  1.33it/s]

Processed 240/263 batches (91.4%)


Processing T2 Test Batches:  92%|█████████▏| 242/263 [03:07<00:16,  1.27it/s]

Processed 242/263 batches (92.2%)


Processing T2 Test Batches:  93%|█████████▎| 244/263 [03:09<00:14,  1.28it/s]

Processed 244/263 batches (93.0%)


Processing T2 Test Batches:  94%|█████████▎| 246/263 [03:11<00:13,  1.29it/s]

Processed 246/263 batches (93.7%)


Processing T2 Test Batches:  94%|█████████▍| 248/263 [03:12<00:11,  1.32it/s]

Processed 248/263 batches (94.5%)


Processing T2 Test Batches:  95%|█████████▌| 250/263 [03:14<00:09,  1.31it/s]

Processed 250/263 batches (95.2%)


Processing T2 Test Batches:  96%|█████████▌| 252/263 [03:15<00:08,  1.34it/s]

Processed 252/263 batches (96.0%)


Processing T2 Test Batches:  97%|█████████▋| 254/263 [03:17<00:06,  1.31it/s]

Processed 254/263 batches (96.8%)


Processing T2 Test Batches:  97%|█████████▋| 256/263 [03:18<00:05,  1.33it/s]

Processed 256/263 batches (97.5%)


Processing T2 Test Batches:  98%|█████████▊| 258/263 [03:20<00:03,  1.32it/s]

Processed 258/263 batches (98.3%)


Processing T2 Test Batches:  99%|█████████▉| 260/263 [03:21<00:02,  1.35it/s]

Processed 260/263 batches (99.0%)


Processing T2 Test Batches: 100%|█████████▉| 262/263 [03:23<00:00,  1.35it/s]

Processed 262/263 batches (99.8%)


Processing T2 Test Batches: 100%|██████████| 263/263 [03:23<00:00,  1.29it/s]


MFCC features shape: (16800, 1, 32, 32)
LFCC features shape: (16800, 1, 32, 32)
Chroma features shape: (16800, 1, 12, 32)
Labels shape: (16800,)
Total T2 test samples after feature extraction: 16800
Total T2 test samples: 16800, Number of T2 test chunks: 27, T2 Test Chunk size: 646
Padded MFCC features shape: (17442, 1, 32, 32)
Padded LFCC features shape: (17442, 1, 32, 32)
Padded Chroma features shape: (17442, 1, 12, 32)
Padded Labels shape: (17442,)
Saved T2 test chunk 0: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 1, 12, 32), y shape (646,)
Saved T2 test chunk 1: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 1, 12, 32), y shape (646,)
Saved T2 test chunk 2: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 1, 12, 32), y shape (646,)
Saved T2 test chunk 3: MFCC shape (646, 1, 32, 32), LFCC shape (646, 1, 32, 32), Chroma shape (646, 1, 12, 32), y shape (646,)
Saved T2 test chunk 4: MFCC shape (6

## Preprocessing for Training and Validation folders

In [10]:

import os
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import MFCC, LFCC, Spectrogram
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# T2 dataset path
T2_PATH = "/kaggle/input/mladdc-t2/T2/"

# Function to validate if a .wav file can be loaded
def validate_wav_file(file_path):
    try:
        waveform, sr = torchaudio.load(file_path)
        if waveform.size(1) == 0:
            print(f"Skipping empty file: {file_path}")
            return False
        return True
    except Exception as e:
        print(f"Skipping invalid file: {file_path}, Error: {e}")
        return False

# Function to collect T2 audio files and labels
def get_t2_audio_files_and_labels(data_split="train"):
    audio_files = []
    labels = []
    counts = {"real": 0, "fake": 0}

    real_path = os.path.join(T2_PATH, data_split, "real")
    fake_path = os.path.join(T2_PATH, data_split, "deepfake")

    if not os.path.exists(real_path):
        raise FileNotFoundError(f"Real directory not found at {real_path}")
    if not os.path.exists(fake_path):
        raise FileNotFoundError(f"Fake directory not found at {fake_path}")

    for file in os.listdir(real_path):
        if file.endswith(".wav"):
            file_path = os.path.join(real_path, file)
            if validate_wav_file(file_path):
                audio_files.append(file_path)
                labels.append(0)  # Real
                counts["real"] += 1

    for file in os.listdir(fake_path):
        if file.endswith(".wav"):
            file_path = os.path.join(fake_path, file)
            if validate_wav_file(file_path):
                audio_files.append(file_path)
                labels.append(1)  # Fake
                counts["fake"] += 1

    print(f"T2 {data_split} Counts: Real: {counts['real']}, Fake: {counts['fake']}")
    return audio_files, labels

# Custom Chroma-STFT implementation
def compute_chroma_stft(waveforms, sample_rate=16000, n_fft=2048, hop_length=512, n_chroma=12):
    spectrogram_transform = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2.0).to(device)
    spec = spectrogram_transform(waveforms)
    spec = spec.squeeze(1)
    freqs = torch.linspace(0, sample_rate / 2, steps=spec.shape[1]).to(device)
    chroma_freqs = torch.tensor([31.25 * (2 ** (i / 12)) for i in range(12 * 4)], device=device)
    chroma_bins = torch.zeros((n_chroma, spec.shape[1]), device=device)
    for i in range(n_chroma):
        center = chroma_freqs[i::12]
        for cf in center:
            mask = (freqs >= cf / 1.06) & (freqs <= cf * 1.06)
            chroma_bins[i] += mask.float()
    chroma_bins /= chroma_bins.sum(dim=1, keepdim=True).clamp(min=1e-10)
    chroma = torch.einsum('cf,bft->bct', chroma_bins, spec)
    return chroma

# Feature Extraction for T2 train/val set
def extract_t2_features_batch(file_paths, labels, batch_size=64, max_length=16000, data_split="train"):
    mfcc_results = []
    lfcc_results = []
    chroma_results = []
    valid_labels = []

    mfcc_transform = MFCC(
        sample_rate=16000,
        n_mfcc=32,
        melkwargs={"n_fft": 2048, "hop_length": 512, "n_mels": 128}
    ).to(device)
    lfcc_transform = LFCC(
        sample_rate=16000,
        n_lfcc=32,
        f_min=0,
        f_max=8000,
        n_filter=128,
        speckwargs={"n_fft": 2048, "hop_length": 512}
    ).to(device)

    total_batches = (len(file_paths) + batch_size - 1) // batch_size
    log_interval = max(1, total_batches // 100)

    for i in tqdm(range(0, len(file_paths), batch_size), desc=f"Processing T2 {data_split} Batches", total=total_batches):
        if i % (log_interval * batch_size) == 0:
            print(f"Processed {i // batch_size}/{total_batches} batches ({(i / len(file_paths)) * 100:.1f}%)")

        batch_files = file_paths[i:i + batch_size]
        batch_labels = labels[i:i + batch_size]
        waveforms = []
        valid_indices = []

        for idx, (file_path, label) in enumerate(zip(batch_files, batch_labels)):
            try:
                waveform, sr = torchaudio.load(file_path)
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                if sr != 16000:
                    waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
                if waveform.size(1) > max_length:
                    waveform = waveform[:, :max_length]
                elif waveform.size(1) < max_length:
                    pad_size = max_length - waveform.size(1)
                    waveform = torch.nn.functional.pad(waveform, (0, pad_size))
                waveforms.append(waveform)
                valid_indices.append(idx)
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                continue

        if not waveforms:
            continue

        waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True).to(device)
        mfccs = mfcc_transform(waveforms)
        lfccs = lfcc_transform(waveforms)
        chromas = compute_chroma_stft(waveforms, sample_rate=16000, n_fft=2048, hop_length=512, n_chroma=12)

        # Debugging - recurring shape mismatch issues
        target_frames = 32
        for idx in valid_indices:
            mfcc = mfccs[idx]
            if len(mfcc.shape) == 3:
                mfcc = mfcc.squeeze(0)
            if mfcc.shape[1] > target_frames:
                mfcc = mfcc[:, :target_frames]
            elif mfcc.shape[1] < target_frames:
                mfcc = torch.nn.functional.pad(mfcc, (0, target_frames - mfcc.shape[1]))
            if mfcc.shape != (32, 32):
                print(f"Warning: MFCC shape after resize at index {idx}: {mfcc.shape}")
                continue
            mfcc = mfcc.unsqueeze(0)
            if mfcc.shape != (1, 32, 32):
                print(f"Warning: Final MFCC shape mismatch at index {idx}: {mfcc.shape}")
                continue
            mfcc_results.append(mfcc.cpu().numpy())

            lfcc = lfccs[idx]
            if len(lfcc.shape) == 3:
                lfcc = lfcc.squeeze(0)
            if lfcc.shape[1] > target_frames:
                lfcc = lfcc[:, :target_frames]
            elif lfcc.shape[1] < target_frames:
                lfcc = torch.nn.functional.pad(lfcc, (0, target_frames - lfcc.shape[1]))
            if lfcc.shape != (32, 32):
                print(f"Warning: LFCC shape after resize at index {idx}: {lfcc.shape}")
                continue
            lfcc = lfcc.unsqueeze(0)
            if lfcc.shape != (1, 32, 32):
                print(f"Warning: Final LFCC shape mismatch at index {idx}: {lfcc.shape}")
                continue
            lfcc_results.append(lfcc.cpu().numpy())

            chroma = chromas[idx]
            if len(chroma.shape) == 3:
                chroma = chroma.squeeze(0)
            if chroma.shape[1] > target_frames:
                chroma = chroma[:, :target_frames]
            elif chroma.shape[1] < target_frames:
                chroma = torch.nn.functional.pad(chroma, (0, target_frames - chroma.shape[1]))
            if chroma.shape != (12, 32):
                print(f"Warning: Chroma shape after resize at index {idx}: {chroma.shape}")
                continue
            chroma = chroma.unsqueeze(0)
            if chroma.shape != (1, 12, 32):
                print(f"Warning: Final Chroma shape mismatch at index {idx}: {chroma.shape}")
                continue
            chroma_results.append(chroma.cpu().numpy())

            valid_labels.append(batch_labels[idx])

    return mfcc_results, lfcc_results, chroma_results, valid_labels

# MFAAN Model
class MFAAN(nn.Module):
    def __init__(self):
        super(MFAAN, self).__init__()
        self.path_mfcc = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.path_lfcc = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.path_chroma = nn.Sequential(
            nn.Conv1d(in_channels=12, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(32 * 8 * 3, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, mfcc, lfcc, chroma):
        mfcc = mfcc.squeeze(1)
        lfcc = lfcc.squeeze(1)
        chroma = chroma.squeeze(1)
        mfcc_out = self.path_mfcc(mfcc)
        lfcc_out = self.path_lfcc(lfcc)
        chroma_out = self.path_chroma(chroma)
        mfcc_flat = mfcc_out.view(mfcc_out.size(0), -1)
        lfcc_flat = lfcc_out.view(lfcc_out.size(0), -1)
        chroma_flat = chroma_out.view(chroma_out.size(0), -1)
        fused = torch.cat((mfcc_flat, lfcc_flat, chroma_flat), dim=1)
        x = F.relu(self.fc1(fused))
        x = self.dropout(x)
        output = self.fc2(x)
        return output

# Calculate and print model size
model = MFAAN()
total_params = sum(p.numel() for p in model.parameters())
param_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
param_size_mb = param_size_bytes / (1024 ** 2)
print(f"\nMFAAN Model Size:")
print(f"Total Parameters: {total_params:,}")
print(f"Parameter Size: {param_size_mb:.2f} MB")

# Preprocess T2 train and val sets
for data_split in ["train", "val"]:
    print(f"\nProcessing T2 {data_split} split...")
    files, labels = get_t2_audio_files_and_labels(data_split)
    print(f"Total T2 {data_split} files: {len(files)}")
    print(f"Label Distribution: Real (0): {sum(1 for label in labels if label == 0)} "
          f"({sum(1 for label in labels if label == 0)/len(files)*100:.2f}%), "
          f"Fake (1): {sum(1 for label in labels if label == 1)} "
          f"({sum(1 for label in labels if label == 1)/len(files)*100:.2f}%)")

    mfcc_features, lfcc_features, chroma_features, labels = extract_t2_features_batch(
        files, labels, batch_size=64, max_length=16000, data_split=data_split
    )
    mfcc_features = np.array(mfcc_features)
    lfcc_features = np.array(lfcc_features)
    chroma_features = np.array(chroma_features)
    labels = np.array(labels)

    print(f"T2 {data_split} MFCC shape: {mfcc_features.shape}")
    print(f"T2 {data_split} LFCC shape: {lfcc_features.shape}")
    print(f"T2 {data_split} Chroma shape: {chroma_features.shape}")
    print(f"T2 {data_split} Labels shape: {labels.shape}")

    # Validate alignment
    if not (len(mfcc_features) == len(lfcc_features) == len(chroma_features) == len(labels)):
        raise ValueError(f"Mismatch in T2 {data_split} features and labels: "
                         f"MFCC: {len(mfcc_features)}, LFCC: {len(lfcc_features)}, "
                         f"Chroma: {len(chroma_features)}, Labels: {len(labels)}")

    # Save preprocessed data
    output_dir = f"/kaggle/working/preprocessed_t2_{data_split}"
    os.makedirs(output_dir, exist_ok=True)
    for f in os.listdir(output_dir):
        os.remove(os.path.join(output_dir, f))  # Clear existing files
    np.save(os.path.join(output_dir, f"t2_{data_split}_mfcc.npy"), mfcc_features)
    np.save(os.path.join(output_dir, f"t2_{data_split}_lfcc.npy"), lfcc_features)
    np.save(os.path.join(output_dir, f"t2_{data_split}_chroma.npy"), chroma_features)
    np.save(os.path.join(output_dir, f"t2_{data_split}_labels.npy"), labels)
    print(f"Saved preprocessed T2 {data_split} data to {output_dir}")

Using device: cuda, GPU: Tesla T4

MFAAN Model Size:
Total Parameters: 206,066
Parameter Size: 0.79 MB

Processing T2 train split...
T2 train Counts: Real: 44800, Fake: 89600
Total T2 train files: 134400
Label Distribution: Real (0): 44800 (33.33%), Fake (1): 89600 (66.67%)


Processing T2 train Batches:   0%|          | 0/2100 [00:00<?, ?it/s]

Processed 0/2100 batches (0.0%)


Processing T2 train Batches:   1%|          | 21/2100 [00:24<40:15,  1.16s/it]

Processed 21/2100 batches (1.0%)


Processing T2 train Batches:   2%|▏         | 42/2100 [00:48<39:23,  1.15s/it]

Processed 42/2100 batches (2.0%)


Processing T2 train Batches:   3%|▎         | 63/2100 [01:13<39:08,  1.15s/it]

Processed 63/2100 batches (3.0%)


Processing T2 train Batches:   4%|▍         | 84/2100 [01:36<37:11,  1.11s/it]

Processed 84/2100 batches (4.0%)


Processing T2 train Batches:   5%|▌         | 105/2100 [02:01<36:53,  1.11s/it]

Processed 105/2100 batches (5.0%)


Processing T2 train Batches:   6%|▌         | 126/2100 [02:25<36:35,  1.11s/it]

Processed 126/2100 batches (6.0%)


Processing T2 train Batches:   7%|▋         | 147/2100 [02:49<37:12,  1.14s/it]

Processed 147/2100 batches (7.0%)


Processing T2 train Batches:   8%|▊         | 168/2100 [03:13<37:00,  1.15s/it]

Processed 168/2100 batches (8.0%)


Processing T2 train Batches:   9%|▉         | 189/2100 [03:36<34:37,  1.09s/it]

Processed 189/2100 batches (9.0%)


Processing T2 train Batches:  10%|█         | 210/2100 [04:00<36:05,  1.15s/it]

Processed 210/2100 batches (10.0%)


Processing T2 train Batches:  11%|█         | 231/2100 [04:24<36:14,  1.16s/it]

Processed 231/2100 batches (11.0%)


Processing T2 train Batches:  12%|█▏        | 252/2100 [04:48<35:54,  1.17s/it]

Processed 252/2100 batches (12.0%)


Processing T2 train Batches:  13%|█▎        | 273/2100 [05:13<34:40,  1.14s/it]

Processed 273/2100 batches (13.0%)


Processing T2 train Batches:  14%|█▍        | 294/2100 [05:37<34:47,  1.16s/it]

Processed 294/2100 batches (14.0%)


Processing T2 train Batches:  15%|█▌        | 315/2100 [06:02<34:10,  1.15s/it]

Processed 315/2100 batches (15.0%)


Processing T2 train Batches:  16%|█▌        | 336/2100 [06:27<33:14,  1.13s/it]

Processed 336/2100 batches (16.0%)


Processing T2 train Batches:  17%|█▋        | 357/2100 [06:51<35:24,  1.22s/it]

Processed 357/2100 batches (17.0%)


Processing T2 train Batches:  18%|█▊        | 378/2100 [07:15<33:16,  1.16s/it]

Processed 378/2100 batches (18.0%)


Processing T2 train Batches:  19%|█▉        | 399/2100 [07:39<32:08,  1.13s/it]

Processed 399/2100 batches (19.0%)


Processing T2 train Batches:  20%|██        | 420/2100 [08:04<32:35,  1.16s/it]

Processed 420/2100 batches (20.0%)


Processing T2 train Batches:  21%|██        | 441/2100 [08:28<30:59,  1.12s/it]

Processed 441/2100 batches (21.0%)


Processing T2 train Batches:  22%|██▏       | 462/2100 [08:52<32:50,  1.20s/it]

Processed 462/2100 batches (22.0%)


Processing T2 train Batches:  23%|██▎       | 483/2100 [09:16<29:23,  1.09s/it]

Processed 483/2100 batches (23.0%)


Processing T2 train Batches:  24%|██▍       | 504/2100 [09:40<30:50,  1.16s/it]

Processed 504/2100 batches (24.0%)


Processing T2 train Batches:  25%|██▌       | 525/2100 [10:05<31:57,  1.22s/it]

Processed 525/2100 batches (25.0%)


Processing T2 train Batches:  26%|██▌       | 546/2100 [10:30<29:08,  1.13s/it]

Processed 546/2100 batches (26.0%)


Processing T2 train Batches:  27%|██▋       | 567/2100 [10:54<29:39,  1.16s/it]

Processed 567/2100 batches (27.0%)


Processing T2 train Batches:  28%|██▊       | 588/2100 [11:18<29:17,  1.16s/it]

Processed 588/2100 batches (28.0%)


Processing T2 train Batches:  29%|██▉       | 609/2100 [11:42<28:47,  1.16s/it]

Processed 609/2100 batches (29.0%)


Processing T2 train Batches:  30%|███       | 630/2100 [12:07<28:01,  1.14s/it]

Processed 630/2100 batches (30.0%)


Processing T2 train Batches:  31%|███       | 651/2100 [12:31<29:20,  1.21s/it]

Processed 651/2100 batches (31.0%)


Processing T2 train Batches:  32%|███▏      | 672/2100 [12:56<27:47,  1.17s/it]

Processed 672/2100 batches (32.0%)


Processing T2 train Batches:  33%|███▎      | 693/2100 [13:20<26:14,  1.12s/it]

Processed 693/2100 batches (33.0%)


Processing T2 train Batches:  34%|███▍      | 714/2100 [13:44<27:11,  1.18s/it]

Processed 714/2100 batches (34.0%)


Processing T2 train Batches:  35%|███▌      | 735/2100 [14:09<27:12,  1.20s/it]

Processed 735/2100 batches (35.0%)


Processing T2 train Batches:  36%|███▌      | 756/2100 [14:33<25:25,  1.13s/it]

Processed 756/2100 batches (36.0%)


Processing T2 train Batches:  37%|███▋      | 777/2100 [14:58<25:48,  1.17s/it]

Processed 777/2100 batches (37.0%)


Processing T2 train Batches:  38%|███▊      | 798/2100 [15:22<25:23,  1.17s/it]

Processed 798/2100 batches (38.0%)


Processing T2 train Batches:  39%|███▉      | 819/2100 [15:47<25:26,  1.19s/it]

Processed 819/2100 batches (39.0%)


Processing T2 train Batches:  40%|████      | 840/2100 [16:12<25:27,  1.21s/it]

Processed 840/2100 batches (40.0%)


Processing T2 train Batches:  41%|████      | 861/2100 [16:37<24:33,  1.19s/it]

Processed 861/2100 batches (41.0%)


Processing T2 train Batches:  42%|████▏     | 882/2100 [17:02<23:33,  1.16s/it]

Processed 882/2100 batches (42.0%)


Processing T2 train Batches:  43%|████▎     | 903/2100 [17:26<23:01,  1.15s/it]

Processed 903/2100 batches (43.0%)


Processing T2 train Batches:  44%|████▍     | 924/2100 [17:50<23:33,  1.20s/it]

Processed 924/2100 batches (44.0%)


Processing T2 train Batches:  45%|████▌     | 945/2100 [18:14<21:22,  1.11s/it]

Processed 945/2100 batches (45.0%)


Processing T2 train Batches:  46%|████▌     | 966/2100 [18:38<21:01,  1.11s/it]

Processed 966/2100 batches (46.0%)


Processing T2 train Batches:  47%|████▋     | 987/2100 [19:02<20:36,  1.11s/it]

Processed 987/2100 batches (47.0%)


Processing T2 train Batches:  48%|████▊     | 1008/2100 [19:26<20:05,  1.10s/it]

Processed 1008/2100 batches (48.0%)


Processing T2 train Batches:  49%|████▉     | 1029/2100 [19:49<20:15,  1.14s/it]

Processed 1029/2100 batches (49.0%)


Processing T2 train Batches:  50%|█████     | 1050/2100 [20:13<19:58,  1.14s/it]

Processed 1050/2100 batches (50.0%)


Processing T2 train Batches:  51%|█████     | 1071/2100 [20:37<19:16,  1.12s/it]

Processed 1071/2100 batches (51.0%)


Processing T2 train Batches:  52%|█████▏    | 1092/2100 [21:01<19:34,  1.17s/it]

Processed 1092/2100 batches (52.0%)


Processing T2 train Batches:  53%|█████▎    | 1113/2100 [21:25<18:26,  1.12s/it]

Processed 1113/2100 batches (53.0%)


Processing T2 train Batches:  54%|█████▍    | 1134/2100 [21:49<18:45,  1.17s/it]

Processed 1134/2100 batches (54.0%)


Processing T2 train Batches:  55%|█████▌    | 1155/2100 [22:12<17:45,  1.13s/it]

Processed 1155/2100 batches (55.0%)


Processing T2 train Batches:  56%|█████▌    | 1176/2100 [22:36<17:11,  1.12s/it]

Processed 1176/2100 batches (56.0%)


Processing T2 train Batches:  57%|█████▋    | 1197/2100 [23:00<17:23,  1.16s/it]

Processed 1197/2100 batches (57.0%)


Processing T2 train Batches:  58%|█████▊    | 1218/2100 [23:24<16:35,  1.13s/it]

Processed 1218/2100 batches (58.0%)


Processing T2 train Batches:  59%|█████▉    | 1239/2100 [23:47<16:16,  1.13s/it]

Processed 1239/2100 batches (59.0%)


Processing T2 train Batches:  60%|██████    | 1260/2100 [24:11<16:42,  1.19s/it]

Processed 1260/2100 batches (60.0%)


Processing T2 train Batches:  61%|██████    | 1281/2100 [24:35<15:09,  1.11s/it]

Processed 1281/2100 batches (61.0%)


Processing T2 train Batches:  62%|██████▏   | 1302/2100 [24:59<15:05,  1.13s/it]

Processed 1302/2100 batches (62.0%)


Processing T2 train Batches:  63%|██████▎   | 1323/2100 [25:23<14:52,  1.15s/it]

Processed 1323/2100 batches (63.0%)


Processing T2 train Batches:  64%|██████▍   | 1344/2100 [25:47<14:15,  1.13s/it]

Processed 1344/2100 batches (64.0%)


Processing T2 train Batches:  65%|██████▌   | 1365/2100 [26:10<13:49,  1.13s/it]

Processed 1365/2100 batches (65.0%)


Processing T2 train Batches:  66%|██████▌   | 1386/2100 [26:33<12:44,  1.07s/it]

Processed 1386/2100 batches (66.0%)


Processing T2 train Batches:  67%|██████▋   | 1407/2100 [26:57<12:47,  1.11s/it]

Processed 1407/2100 batches (67.0%)


Processing T2 train Batches:  68%|██████▊   | 1428/2100 [27:20<12:25,  1.11s/it]

Processed 1428/2100 batches (68.0%)


Processing T2 train Batches:  69%|██████▉   | 1449/2100 [27:44<12:07,  1.12s/it]

Processed 1449/2100 batches (69.0%)


Processing T2 train Batches:  70%|███████   | 1470/2100 [28:08<11:53,  1.13s/it]

Processed 1470/2100 batches (70.0%)


Processing T2 train Batches:  71%|███████   | 1491/2100 [28:31<10:48,  1.07s/it]

Processed 1491/2100 batches (71.0%)


Processing T2 train Batches:  72%|███████▏  | 1512/2100 [28:54<11:17,  1.15s/it]

Processed 1512/2100 batches (72.0%)


Processing T2 train Batches:  73%|███████▎  | 1533/2100 [29:18<10:16,  1.09s/it]

Processed 1533/2100 batches (73.0%)


Processing T2 train Batches:  74%|███████▍  | 1554/2100 [29:41<09:43,  1.07s/it]

Processed 1554/2100 batches (74.0%)


Processing T2 train Batches:  75%|███████▌  | 1575/2100 [30:05<09:55,  1.13s/it]

Processed 1575/2100 batches (75.0%)


Processing T2 train Batches:  76%|███████▌  | 1596/2100 [30:28<09:12,  1.10s/it]

Processed 1596/2100 batches (76.0%)


Processing T2 train Batches:  77%|███████▋  | 1617/2100 [30:52<09:34,  1.19s/it]

Processed 1617/2100 batches (77.0%)


Processing T2 train Batches:  78%|███████▊  | 1638/2100 [31:16<08:42,  1.13s/it]

Processed 1638/2100 batches (78.0%)


Processing T2 train Batches:  79%|███████▉  | 1659/2100 [31:39<08:47,  1.20s/it]

Processed 1659/2100 batches (79.0%)


Processing T2 train Batches:  80%|████████  | 1680/2100 [32:04<08:24,  1.20s/it]

Processed 1680/2100 batches (80.0%)


Processing T2 train Batches:  81%|████████  | 1701/2100 [32:27<07:06,  1.07s/it]

Processed 1701/2100 batches (81.0%)


Processing T2 train Batches:  82%|████████▏ | 1722/2100 [32:50<07:13,  1.15s/it]

Processed 1722/2100 batches (82.0%)


Processing T2 train Batches:  83%|████████▎ | 1743/2100 [33:14<06:25,  1.08s/it]

Processed 1743/2100 batches (83.0%)


Processing T2 train Batches:  84%|████████▍ | 1764/2100 [33:36<05:57,  1.07s/it]

Processed 1764/2100 batches (84.0%)


Processing T2 train Batches:  85%|████████▌ | 1785/2100 [33:59<05:51,  1.12s/it]

Processed 1785/2100 batches (85.0%)


Processing T2 train Batches:  86%|████████▌ | 1806/2100 [34:23<05:14,  1.07s/it]

Processed 1806/2100 batches (86.0%)


Processing T2 train Batches:  87%|████████▋ | 1827/2100 [34:46<05:20,  1.17s/it]

Processed 1827/2100 batches (87.0%)


Processing T2 train Batches:  88%|████████▊ | 1848/2100 [35:10<04:45,  1.13s/it]

Processed 1848/2100 batches (88.0%)


Processing T2 train Batches:  89%|████████▉ | 1869/2100 [35:32<04:06,  1.07s/it]

Processed 1869/2100 batches (89.0%)


Processing T2 train Batches:  90%|█████████ | 1890/2100 [35:56<04:09,  1.19s/it]

Processed 1890/2100 batches (90.0%)


Processing T2 train Batches:  91%|█████████ | 1911/2100 [36:19<03:31,  1.12s/it]

Processed 1911/2100 batches (91.0%)


Processing T2 train Batches:  92%|█████████▏| 1932/2100 [36:43<03:06,  1.11s/it]

Processed 1932/2100 batches (92.0%)


Processing T2 train Batches:  93%|█████████▎| 1953/2100 [37:06<02:44,  1.12s/it]

Processed 1953/2100 batches (93.0%)


Processing T2 train Batches:  94%|█████████▍| 1974/2100 [37:30<02:18,  1.10s/it]

Processed 1974/2100 batches (94.0%)


Processing T2 train Batches:  95%|█████████▌| 1995/2100 [37:53<02:00,  1.15s/it]

Processed 1995/2100 batches (95.0%)


Processing T2 train Batches:  96%|█████████▌| 2016/2100 [38:17<01:31,  1.09s/it]

Processed 2016/2100 batches (96.0%)


Processing T2 train Batches:  97%|█████████▋| 2037/2100 [38:39<01:09,  1.10s/it]

Processed 2037/2100 batches (97.0%)


Processing T2 train Batches:  98%|█████████▊| 2058/2100 [39:03<00:47,  1.13s/it]

Processed 2058/2100 batches (98.0%)


Processing T2 train Batches:  99%|█████████▉| 2079/2100 [39:26<00:22,  1.08s/it]

Processed 2079/2100 batches (99.0%)


Processing T2 train Batches: 100%|██████████| 2100/2100 [39:49<00:00,  1.14s/it]


T2 train MFCC shape: (134400, 1, 32, 32)
T2 train LFCC shape: (134400, 1, 32, 32)
T2 train Chroma shape: (134400, 1, 12, 32)
T2 train Labels shape: (134400,)
Saved preprocessed T2 train data to /kaggle/working/preprocessed_t2_train

Processing T2 val split...
T2 val Counts: Real: 5600, Fake: 11200
Total T2 val files: 16800
Label Distribution: Real (0): 5600 (33.33%), Fake (1): 11200 (66.67%)


Processing T2 val Batches:   0%|          | 0/263 [00:00<?, ?it/s]

Processed 0/263 batches (0.0%)


Processing T2 val Batches:   1%|          | 2/263 [00:02<05:00,  1.15s/it]

Processed 2/263 batches (0.8%)


Processing T2 val Batches:   2%|▏         | 4/263 [00:04<05:00,  1.16s/it]

Processed 4/263 batches (1.5%)


Processing T2 val Batches:   2%|▏         | 6/263 [00:07<05:06,  1.19s/it]

Processed 6/263 batches (2.3%)


Processing T2 val Batches:   3%|▎         | 8/263 [00:09<04:53,  1.15s/it]

Processed 8/263 batches (3.0%)


Processing T2 val Batches:   4%|▍         | 10/263 [00:11<05:05,  1.21s/it]

Processed 10/263 batches (3.8%)


Processing T2 val Batches:   5%|▍         | 12/263 [00:14<05:02,  1.20s/it]

Processed 12/263 batches (4.6%)


Processing T2 val Batches:   5%|▌         | 14/263 [00:16<04:48,  1.16s/it]

Processed 14/263 batches (5.3%)


Processing T2 val Batches:   6%|▌         | 16/263 [00:18<04:44,  1.15s/it]

Processed 16/263 batches (6.1%)


Processing T2 val Batches:   7%|▋         | 18/263 [00:20<04:37,  1.13s/it]

Processed 18/263 batches (6.9%)


Processing T2 val Batches:   8%|▊         | 20/263 [00:23<04:27,  1.10s/it]

Processed 20/263 batches (7.6%)


Processing T2 val Batches:   8%|▊         | 22/263 [00:25<04:27,  1.11s/it]

Processed 22/263 batches (8.4%)


Processing T2 val Batches:   9%|▉         | 24/263 [00:27<04:36,  1.16s/it]

Processed 24/263 batches (9.1%)


Processing T2 val Batches:  10%|▉         | 26/263 [00:29<04:26,  1.13s/it]

Processed 26/263 batches (9.9%)


Processing T2 val Batches:  11%|█         | 28/263 [00:32<04:18,  1.10s/it]

Processed 28/263 batches (10.7%)


Processing T2 val Batches:  11%|█▏        | 30/263 [00:34<04:23,  1.13s/it]

Processed 30/263 batches (11.4%)


Processing T2 val Batches:  12%|█▏        | 32/263 [00:36<04:18,  1.12s/it]

Processed 32/263 batches (12.2%)


Processing T2 val Batches:  13%|█▎        | 34/263 [00:39<04:22,  1.15s/it]

Processed 34/263 batches (13.0%)


Processing T2 val Batches:  14%|█▎        | 36/263 [00:41<04:15,  1.12s/it]

Processed 36/263 batches (13.7%)


Processing T2 val Batches:  14%|█▍        | 38/263 [00:43<04:05,  1.09s/it]

Processed 38/263 batches (14.5%)


Processing T2 val Batches:  15%|█▌        | 40/263 [00:45<04:12,  1.13s/it]

Processed 40/263 batches (15.2%)


Processing T2 val Batches:  16%|█▌        | 42/263 [00:48<04:12,  1.14s/it]

Processed 42/263 batches (16.0%)


Processing T2 val Batches:  17%|█▋        | 44/263 [00:50<04:10,  1.14s/it]

Processed 44/263 batches (16.8%)


Processing T2 val Batches:  17%|█▋        | 46/263 [00:52<04:04,  1.12s/it]

Processed 46/263 batches (17.5%)


Processing T2 val Batches:  18%|█▊        | 48/263 [00:54<04:01,  1.12s/it]

Processed 48/263 batches (18.3%)


Processing T2 val Batches:  19%|█▉        | 50/263 [00:57<03:58,  1.12s/it]

Processed 50/263 batches (19.0%)


Processing T2 val Batches:  20%|█▉        | 52/263 [00:59<03:58,  1.13s/it]

Processed 52/263 batches (19.8%)


Processing T2 val Batches:  21%|██        | 54/263 [01:01<03:59,  1.15s/it]

Processed 54/263 batches (20.6%)


Processing T2 val Batches:  21%|██▏       | 56/263 [01:03<03:56,  1.14s/it]

Processed 56/263 batches (21.3%)


Processing T2 val Batches:  22%|██▏       | 58/263 [01:06<03:52,  1.13s/it]

Processed 58/263 batches (22.1%)


Processing T2 val Batches:  23%|██▎       | 60/263 [01:08<03:53,  1.15s/it]

Processed 60/263 batches (22.9%)


Processing T2 val Batches:  24%|██▎       | 62/263 [01:10<03:56,  1.17s/it]

Processed 62/263 batches (23.6%)


Processing T2 val Batches:  24%|██▍       | 64/263 [01:13<03:56,  1.19s/it]

Processed 64/263 batches (24.4%)


Processing T2 val Batches:  25%|██▌       | 66/263 [01:15<03:52,  1.18s/it]

Processed 66/263 batches (25.1%)


Processing T2 val Batches:  26%|██▌       | 68/263 [01:17<03:47,  1.17s/it]

Processed 68/263 batches (25.9%)


Processing T2 val Batches:  27%|██▋       | 70/263 [01:20<03:35,  1.12s/it]

Processed 70/263 batches (26.7%)


Processing T2 val Batches:  27%|██▋       | 72/263 [01:22<03:32,  1.11s/it]

Processed 72/263 batches (27.4%)


Processing T2 val Batches:  28%|██▊       | 74/263 [01:24<03:29,  1.11s/it]

Processed 74/263 batches (28.2%)


Processing T2 val Batches:  29%|██▉       | 76/263 [01:26<03:28,  1.12s/it]

Processed 76/263 batches (29.0%)


Processing T2 val Batches:  30%|██▉       | 78/263 [01:28<03:23,  1.10s/it]

Processed 78/263 batches (29.7%)


Processing T2 val Batches:  30%|███       | 80/263 [01:31<03:16,  1.08s/it]

Processed 80/263 batches (30.5%)


Processing T2 val Batches:  31%|███       | 82/263 [01:33<03:13,  1.07s/it]

Processed 82/263 batches (31.2%)


Processing T2 val Batches:  32%|███▏      | 84/263 [01:35<03:14,  1.09s/it]

Processed 84/263 batches (32.0%)


Processing T2 val Batches:  33%|███▎      | 86/263 [01:37<03:09,  1.07s/it]

Processed 86/263 batches (32.8%)


Processing T2 val Batches:  33%|███▎      | 88/263 [01:39<03:07,  1.07s/it]

Processed 88/263 batches (33.5%)


Processing T2 val Batches:  34%|███▍      | 90/263 [01:41<03:14,  1.13s/it]

Processed 90/263 batches (34.3%)


Processing T2 val Batches:  35%|███▍      | 92/263 [01:44<03:04,  1.08s/it]

Processed 92/263 batches (35.0%)


Processing T2 val Batches:  36%|███▌      | 94/263 [01:46<02:59,  1.06s/it]

Processed 94/263 batches (35.8%)


Processing T2 val Batches:  37%|███▋      | 96/263 [01:48<02:56,  1.06s/it]

Processed 96/263 batches (36.6%)


Processing T2 val Batches:  37%|███▋      | 98/263 [01:50<02:54,  1.06s/it]

Processed 98/263 batches (37.3%)


Processing T2 val Batches:  38%|███▊      | 100/263 [01:52<02:52,  1.06s/it]

Processed 100/263 batches (38.1%)


Processing T2 val Batches:  39%|███▉      | 102/263 [01:54<02:49,  1.05s/it]

Processed 102/263 batches (38.9%)


Processing T2 val Batches:  40%|███▉      | 104/263 [01:56<02:47,  1.05s/it]

Processed 104/263 batches (39.6%)


Processing T2 val Batches:  40%|████      | 106/263 [01:58<02:49,  1.08s/it]

Processed 106/263 batches (40.4%)


Processing T2 val Batches:  41%|████      | 108/263 [02:01<02:45,  1.07s/it]

Processed 108/263 batches (41.1%)


Processing T2 val Batches:  42%|████▏     | 110/263 [02:03<02:46,  1.09s/it]

Processed 110/263 batches (41.9%)


Processing T2 val Batches:  43%|████▎     | 112/263 [02:05<02:47,  1.11s/it]

Processed 112/263 batches (42.7%)


Processing T2 val Batches:  43%|████▎     | 114/263 [02:07<02:44,  1.10s/it]

Processed 114/263 batches (43.4%)


Processing T2 val Batches:  44%|████▍     | 116/263 [02:09<02:41,  1.10s/it]

Processed 116/263 batches (44.2%)


Processing T2 val Batches:  45%|████▍     | 118/263 [02:12<02:44,  1.13s/it]

Processed 118/263 batches (45.0%)


Processing T2 val Batches:  46%|████▌     | 120/263 [02:14<02:45,  1.16s/it]

Processed 120/263 batches (45.7%)


Processing T2 val Batches:  46%|████▋     | 122/263 [02:16<02:39,  1.13s/it]

Processed 122/263 batches (46.5%)


Processing T2 val Batches:  47%|████▋     | 124/263 [02:19<02:37,  1.13s/it]

Processed 124/263 batches (47.2%)


Processing T2 val Batches:  48%|████▊     | 126/263 [02:21<02:34,  1.12s/it]

Processed 126/263 batches (48.0%)


Processing T2 val Batches:  49%|████▊     | 128/263 [02:23<02:34,  1.15s/it]

Processed 128/263 batches (48.8%)


Processing T2 val Batches:  49%|████▉     | 130/263 [02:25<02:28,  1.12s/it]

Processed 130/263 batches (49.5%)


Processing T2 val Batches:  50%|█████     | 132/263 [02:28<02:26,  1.12s/it]

Processed 132/263 batches (50.3%)


Processing T2 val Batches:  51%|█████     | 134/263 [02:30<02:21,  1.10s/it]

Processed 134/263 batches (51.0%)


Processing T2 val Batches:  52%|█████▏    | 136/263 [02:32<02:18,  1.09s/it]

Processed 136/263 batches (51.8%)


Processing T2 val Batches:  52%|█████▏    | 138/263 [02:34<02:15,  1.09s/it]

Processed 138/263 batches (52.6%)


Processing T2 val Batches:  53%|█████▎    | 140/263 [02:36<02:14,  1.10s/it]

Processed 140/263 batches (53.3%)


Processing T2 val Batches:  54%|█████▍    | 142/263 [02:38<02:10,  1.08s/it]

Processed 142/263 batches (54.1%)


Processing T2 val Batches:  55%|█████▍    | 144/263 [02:41<02:10,  1.10s/it]

Processed 144/263 batches (54.9%)


Processing T2 val Batches:  56%|█████▌    | 146/263 [02:43<02:05,  1.07s/it]

Processed 146/263 batches (55.6%)


Processing T2 val Batches:  56%|█████▋    | 148/263 [02:45<02:02,  1.06s/it]

Processed 148/263 batches (56.4%)


Processing T2 val Batches:  57%|█████▋    | 150/263 [02:47<01:58,  1.05s/it]

Processed 150/263 batches (57.1%)


Processing T2 val Batches:  58%|█████▊    | 152/263 [02:49<01:57,  1.06s/it]

Processed 152/263 batches (57.9%)


Processing T2 val Batches:  59%|█████▊    | 154/263 [02:51<01:55,  1.06s/it]

Processed 154/263 batches (58.7%)


Processing T2 val Batches:  59%|█████▉    | 156/263 [02:53<01:53,  1.06s/it]

Processed 156/263 batches (59.4%)


Processing T2 val Batches:  60%|██████    | 158/263 [02:55<01:52,  1.07s/it]

Processed 158/263 batches (60.2%)


Processing T2 val Batches:  61%|██████    | 160/263 [02:58<01:50,  1.07s/it]

Processed 160/263 batches (61.0%)


Processing T2 val Batches:  62%|██████▏   | 162/263 [03:00<01:47,  1.07s/it]

Processed 162/263 batches (61.7%)


Processing T2 val Batches:  62%|██████▏   | 164/263 [03:02<01:46,  1.08s/it]

Processed 164/263 batches (62.5%)


Processing T2 val Batches:  63%|██████▎   | 166/263 [03:04<01:46,  1.10s/it]

Processed 166/263 batches (63.2%)


Processing T2 val Batches:  64%|██████▍   | 168/263 [03:07<01:56,  1.22s/it]

Processed 168/263 batches (64.0%)


Processing T2 val Batches:  65%|██████▍   | 170/263 [03:09<01:48,  1.16s/it]

Processed 170/263 batches (64.8%)


Processing T2 val Batches:  65%|██████▌   | 172/263 [03:11<01:44,  1.15s/it]

Processed 172/263 batches (65.5%)


Processing T2 val Batches:  66%|██████▌   | 174/263 [03:13<01:38,  1.11s/it]

Processed 174/263 batches (66.3%)


Processing T2 val Batches:  67%|██████▋   | 176/263 [03:16<01:37,  1.13s/it]

Processed 176/263 batches (67.0%)


Processing T2 val Batches:  68%|██████▊   | 178/263 [03:18<01:37,  1.15s/it]

Processed 178/263 batches (67.8%)


Processing T2 val Batches:  68%|██████▊   | 180/263 [03:20<01:32,  1.12s/it]

Processed 180/263 batches (68.6%)


Processing T2 val Batches:  69%|██████▉   | 182/263 [03:22<01:30,  1.12s/it]

Processed 182/263 batches (69.3%)


Processing T2 val Batches:  70%|██████▉   | 184/263 [03:25<01:26,  1.09s/it]

Processed 184/263 batches (70.1%)


Processing T2 val Batches:  71%|███████   | 186/263 [03:27<01:24,  1.10s/it]

Processed 186/263 batches (70.9%)


Processing T2 val Batches:  71%|███████▏  | 188/263 [03:29<01:21,  1.09s/it]

Processed 188/263 batches (71.6%)


Processing T2 val Batches:  72%|███████▏  | 190/263 [03:31<01:17,  1.07s/it]

Processed 190/263 batches (72.4%)


Processing T2 val Batches:  73%|███████▎  | 192/263 [03:33<01:14,  1.06s/it]

Processed 192/263 batches (73.1%)


Processing T2 val Batches:  74%|███████▍  | 194/263 [03:35<01:13,  1.06s/it]

Processed 194/263 batches (73.9%)


Processing T2 val Batches:  75%|███████▍  | 196/263 [03:37<01:11,  1.06s/it]

Processed 196/263 batches (74.7%)


Processing T2 val Batches:  75%|███████▌  | 198/263 [03:40<01:08,  1.05s/it]

Processed 198/263 batches (75.4%)


Processing T2 val Batches:  76%|███████▌  | 200/263 [03:42<01:07,  1.08s/it]

Processed 200/263 batches (76.2%)


Processing T2 val Batches:  77%|███████▋  | 202/263 [03:44<01:04,  1.06s/it]

Processed 202/263 batches (77.0%)


Processing T2 val Batches:  78%|███████▊  | 204/263 [03:46<01:01,  1.04s/it]

Processed 204/263 batches (77.7%)


Processing T2 val Batches:  78%|███████▊  | 206/263 [03:48<01:01,  1.09s/it]

Processed 206/263 batches (78.5%)


Processing T2 val Batches:  79%|███████▉  | 208/263 [03:50<00:58,  1.07s/it]

Processed 208/263 batches (79.2%)


Processing T2 val Batches:  80%|███████▉  | 210/263 [03:52<00:57,  1.08s/it]

Processed 210/263 batches (80.0%)


Processing T2 val Batches:  81%|████████  | 212/263 [03:55<00:55,  1.08s/it]

Processed 212/263 batches (80.8%)


Processing T2 val Batches:  81%|████████▏ | 214/263 [03:57<00:52,  1.08s/it]

Processed 214/263 batches (81.5%)


Processing T2 val Batches:  82%|████████▏ | 216/263 [03:59<00:49,  1.06s/it]

Processed 216/263 batches (82.3%)


Processing T2 val Batches:  83%|████████▎ | 218/263 [04:01<00:48,  1.07s/it]

Processed 218/263 batches (83.0%)


Processing T2 val Batches:  84%|████████▎ | 220/263 [04:03<00:47,  1.09s/it]

Processed 220/263 batches (83.8%)


Processing T2 val Batches:  84%|████████▍ | 222/263 [04:05<00:44,  1.09s/it]

Processed 222/263 batches (84.6%)


Processing T2 val Batches:  85%|████████▌ | 224/263 [04:08<00:43,  1.10s/it]

Processed 224/263 batches (85.3%)


Processing T2 val Batches:  86%|████████▌ | 226/263 [04:10<00:40,  1.10s/it]

Processed 226/263 batches (86.1%)


Processing T2 val Batches:  87%|████████▋ | 228/263 [04:12<00:38,  1.11s/it]

Processed 228/263 batches (86.9%)


Processing T2 val Batches:  87%|████████▋ | 230/263 [04:14<00:35,  1.08s/it]

Processed 230/263 batches (87.6%)


Processing T2 val Batches:  88%|████████▊ | 232/263 [04:16<00:33,  1.08s/it]

Processed 232/263 batches (88.4%)


Processing T2 val Batches:  89%|████████▉ | 234/263 [04:18<00:31,  1.09s/it]

Processed 234/263 batches (89.1%)


Processing T2 val Batches:  90%|████████▉ | 236/263 [04:21<00:30,  1.11s/it]

Processed 236/263 batches (89.9%)


Processing T2 val Batches:  90%|█████████ | 238/263 [04:23<00:27,  1.08s/it]

Processed 238/263 batches (90.7%)


Processing T2 val Batches:  91%|█████████▏| 240/263 [04:25<00:25,  1.09s/it]

Processed 240/263 batches (91.4%)


Processing T2 val Batches:  92%|█████████▏| 242/263 [04:27<00:23,  1.10s/it]

Processed 242/263 batches (92.2%)


Processing T2 val Batches:  93%|█████████▎| 244/263 [04:29<00:20,  1.09s/it]

Processed 244/263 batches (93.0%)


Processing T2 val Batches:  94%|█████████▎| 246/263 [04:32<00:18,  1.07s/it]

Processed 246/263 batches (93.7%)


Processing T2 val Batches:  94%|█████████▍| 248/263 [04:34<00:16,  1.07s/it]

Processed 248/263 batches (94.5%)


Processing T2 val Batches:  95%|█████████▌| 250/263 [04:36<00:13,  1.06s/it]

Processed 250/263 batches (95.2%)


Processing T2 val Batches:  96%|█████████▌| 252/263 [04:38<00:11,  1.06s/it]

Processed 252/263 batches (96.0%)


Processing T2 val Batches:  97%|█████████▋| 254/263 [04:40<00:09,  1.05s/it]

Processed 254/263 batches (96.8%)


Processing T2 val Batches:  97%|█████████▋| 256/263 [04:42<00:07,  1.05s/it]

Processed 256/263 batches (97.5%)


Processing T2 val Batches:  98%|█████████▊| 258/263 [04:44<00:05,  1.05s/it]

Processed 258/263 batches (98.3%)


Processing T2 val Batches:  99%|█████████▉| 260/263 [04:46<00:03,  1.05s/it]

Processed 260/263 batches (99.0%)


Processing T2 val Batches: 100%|█████████▉| 262/263 [04:48<00:01,  1.06s/it]

Processed 262/263 batches (99.8%)


Processing T2 val Batches: 100%|██████████| 263/263 [04:49<00:00,  1.10s/it]


T2 val MFCC shape: (16800, 1, 32, 32)
T2 val LFCC shape: (16800, 1, 32, 32)
T2 val Chroma shape: (16800, 1, 12, 32)
T2 val Labels shape: (16800,)
Saved preprocessed T2 val data to /kaggle/working/preprocessed_t2_val


In [14]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, roc_curve, confusion_matrix
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# MFAAN Model
class MFAAN(nn.Module):
    def __init__(self):
        super(MFAAN, self).__init__()
        self.path_mfcc = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.path_lfcc = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.path_chroma = nn.Sequential(
            nn.Conv1d(in_channels=12, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(32 * 8 * 3, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, mfcc, lfcc, chroma):
        mfcc = mfcc.squeeze(1)
        lfcc = lfcc.squeeze(1)
        chroma = chroma.squeeze(1)
        mfcc_out = self.path_mfcc(mfcc)
        lfcc_out = self.path_lfcc(lfcc)
        chroma_out = self.path_chroma(chroma)
        mfcc_flat = mfcc_out.view(mfcc_out.size(0), -1)
        lfcc_flat = lfcc_out.view(lfcc_out.size(0), -1)
        chroma_flat = chroma_out.view(chroma_out.size(0), -1)
        fused = torch.cat((mfcc_flat, lfcc_flat, chroma_flat), dim=1)
        x = F.relu(self.fc1(fused))
        x = self.dropout(x)
        output = self.fc2(x)
        return output

model = MFAAN()
total_params = sum(p.numel() for p in model.parameters())
param_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
param_size_mb = param_size_bytes / (1024 ** 2)
print(f"\nMFAAN Model Size:")
print(f"Total Parameters: {total_params:,}")
print(f"Parameter Size: {param_size_mb:.2f} MB")

# Custom Dataset Class for Train/Val
class T2AudioDataset(Dataset):
    def __init__(self, mfcc_features, lfcc_features, chroma_features, labels):
        self.mfcc_features = mfcc_features
        self.lfcc_features = lfcc_features
        self.chroma_features = chroma_features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        mfcc = torch.FloatTensor(self.mfcc_features[idx])
        lfcc = torch.FloatTensor(self.lfcc_features[idx])
        chroma = torch.FloatTensor(self.chroma_features[idx])
        label = torch.LongTensor([self.labels[idx]])[0]

        mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-8)
        lfcc = (lfcc - lfcc.mean()) / (lfcc.std() + 1e-8)
        chroma = (chroma - chroma.mean()) / (chroma.std() + 1e-8)

        return mfcc, lfcc, chroma, label

# Custom Dataset Class for Test
class AudioDataset(Dataset):
    def __init__(self, chunk_dir, chunk_indices, dataset_type="test"):
        self.chunk_dir = chunk_dir
        self.dataset_type = dataset_type
        if not os.path.exists(chunk_dir):
            raise FileNotFoundError(f"Chunk directory {chunk_dir} does not exist.")
        
        self.chunk_files = sorted([f for f in os.listdir(chunk_dir) if f.startswith("t2_test_mfcc_chunk_") and f.endswith(".npy")])
        if not self.chunk_files:
            raise FileNotFoundError(f"No MFCC chunk files found in {chunk_dir}.")
        
        print(f"Found {len(self.chunk_files)} chunk files in {chunk_dir}. First few: {self.chunk_files[:5]}")
        
        self.chunk_size = 646
        self.num_chunks = len(self.chunk_files)
        self.total_samples = self.num_chunks * self.chunk_size
        self.indices = [idx for idx in chunk_indices if idx < self.total_samples]
        
        if not self.indices:
            raise ValueError(f"No valid indices for {self.dataset_type} dataset. Total samples: {self.total_samples}.")
        
        print(f"Dataset ({self.dataset_type}) initialized with {len(self.indices)} samples.")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        if idx >= len(self.indices):
            raise IndexError(f"Index {idx} is out of bounds for dataset with length {len(self.indices)}")
        
        global_idx = self.indices[idx]
        chunk_idx = global_idx // self.chunk_size
        local_idx = global_idx % self.chunk_size
        
        mfcc_chunk = np.load(os.path.join(self.chunk_dir, f"t2_test_mfcc_chunk_{chunk_idx}.npy"), mmap_mode='r')
        lfcc_chunk = np.load(os.path.join(self.chunk_dir, f"t2_test_lfcc_chunk_{chunk_idx}.npy"), mmap_mode='r')
        chroma_chunk = np.load(os.path.join(self.chunk_dir, f"t2_test_chroma_chunk_{chunk_idx}.npy"), mmap_mode='r')
        y_chunk = np.load(os.path.join(self.chunk_dir, f"t2_test_y_chunk_{chunk_idx}.npy"), mmap_mode='r')
        
        mfcc = torch.FloatTensor(mfcc_chunk[local_idx])
        lfcc = torch.FloatTensor(lfcc_chunk[local_idx])
        chroma = torch.FloatTensor(chroma_chunk[local_idx])
        y = torch.LongTensor([y_chunk[local_idx]])[0]
        
        mfcc = (mfcc - mfcc.mean()) / (mfcc.std() + 1e-8)
        lfcc = (lfcc - lfcc.mean()) / (lfcc.std() + 1e-8)
        chroma = (chroma - chroma.mean()) / (chroma.std() + 1e-8)
        
        return mfcc, lfcc, chroma, y

# Compute eer
def compute_eer(labels, scores):
    fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    eer = fpr[np.nanargmin(np.absolute(fnr - fpr))]
    return eer

# Load preprocessed train and val data
try:
    train_dataset = T2AudioDataset(
        np.load("/kaggle/working/preprocessed_t2_train/t2_train_mfcc.npy"),
        np.load("/kaggle/working/preprocessed_t2_train/t2_train_lfcc.npy"),
        np.load("/kaggle/working/preprocessed_t2_train/t2_train_chroma.npy"),
        np.load("/kaggle/working/preprocessed_t2_train/t2_train_labels.npy")
    )
    val_dataset = T2AudioDataset(
        np.load("/kaggle/working/preprocessed_t2_val/t2_val_mfcc.npy"),
        np.load("/kaggle/working/preprocessed_t2_val/t2_val_lfcc.npy"),
        np.load("/kaggle/working/preprocessed_t2_val/t2_val_chroma.npy"),
        np.load("/kaggle/working/preprocessed_t2_val/t2_val_labels.npy")
    )
except FileNotFoundError as e:
    raise FileNotFoundError(f"Preprocessed train/val data not found. Run the preprocessing cell first: {e}")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
print(f"Train dataset size: {len(train_dataset)}, Val dataset size: {len(val_dataset)}")

# Compute class weights for imbalanced data
train_labels = np.load("/kaggle/working/preprocessed_t2_train/t2_train_labels.npy")
class_counts = np.bincount(train_labels)
class_weights = torch.FloatTensor([1.0 / class_counts[i] for i in range(2)]).to(device)
class_weights = class_weights / class_weights.sum() * 2  # Normalize
print(f"Class weights (Real, Fake): {class_weights.tolist()}")

model = MFAAN().to(device)
try:
    model.load_state_dict(torch.load("/kaggle/working/mfaan_best.pth"))  # Load pre-trained weights
except FileNotFoundError:
    print("Pre-trained model not found. Starting fine-tuning from scratch.")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # Low lr for fine-tuning
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Fine-tuning loop
num_epochs = 20
best_val_eer = float("inf")
best_model_path = "/kaggle/working/mfaan_t2_finetuned.pth"

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_preds = []
    train_labels = []
    for mfcc, lfcc, chroma, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
        mfcc, lfcc, chroma, labels = mfcc.to(device), lfcc.to(device), chroma.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(mfcc, lfcc, chroma)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_preds.extend(predicted.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

    train_acc = accuracy_score(train_labels, train_preds)
    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0.0
    val_preds = []
    val_scores = []
    val_labels = []
    with torch.no_grad():
        for mfcc, lfcc, chroma, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
            mfcc, lfcc, chroma, labels = mfcc.to(device), lfcc.to(device), chroma.to(device), labels.to(device)
            outputs = model(mfcc, lfcc, chroma)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            probs = F.softmax(outputs, dim=1)[:, 1]
            _, predicted = torch.max(outputs, 1)
            val_preds.extend(predicted.cpu().numpy())
            val_scores.extend(probs.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_acc = accuracy_score(val_labels, val_preds)
    val_eer = compute_eer(val_labels, val_scores)
    val_loss /= len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%, Val EER: {val_eer*100:.2f}%")

    if val_eer < best_val_eer:
        best_val_eer = val_eer
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved best model with Val EER: {best_val_eer*100:.2f}%")

print(f"Fine-tuning complete. Best model saved to {best_model_path}")

# Test the fine-tuned model
model.load_state_dict(torch.load(best_model_path))
model.eval()
print("\nLoaded fine-tuned model for testing.")

# Create test dataset and loader
chunk_dir = "/kaggle/working/preprocessed_t2_test_chunks"
total_t2_test_samples = 16800  # Actual samples, excluding padding
test_indices = list(range(total_t2_test_samples))
t2_test_dataset = AudioDataset(chunk_dir, test_indices, dataset_type="test")
t2_test_loader = DataLoader(t2_test_dataset, batch_size=32, shuffle=False, num_workers=0)

# Evaluate on T2 test set
test_preds = []
test_scores = []
test_labels = []
with torch.no_grad():
    for mfcc, lfcc, chroma, labels in tqdm(t2_test_loader, desc="Evaluating T2 Test Set"):
        mfcc, lfcc, chroma, labels = mfcc.to(device), lfcc.to(device), chroma.to(device), labels.to(device)
        outputs = model(mfcc, lfcc, chroma)
        probs = F.softmax(outputs, dim=1)[:, 1]
        _, predicted = torch.max(outputs, 1)
        test_preds.extend(predicted.cpu().numpy())
        test_scores.extend(probs.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(test_labels, test_preds)
test_eer = compute_eer(test_labels, test_scores)
cm = confusion_matrix(test_labels, test_preds)
per_class_acc = cm.diagonal() / cm.sum(axis=1)

# Print results
print(f"\nFinal T2 Test Accuracy: {test_acc * 100:.2f}%")
print(f"Final T2 Test EER: {test_eer * 100:.2f}%")
print(f"Per-class accuracy (Real, Fake): {per_class_acc * 100}")

# Compare with baseline
baseline_accuracy = 68.44
baseline_eer = 40.9
print(f"\nBaseline Accuracy (T2): {baseline_accuracy:.2f}%, Baseline EER: {baseline_eer:.2f}%")
print("Model vs. Baseline:")
print(f"Accuracy: {'Outperforms' if test_acc * 100 > baseline_accuracy else 'Does not outperform'} "
      f"(Model: {test_acc * 100:.2f}%, Baseline: {baseline_accuracy:.2f}%)")
print(f"EER: {'Outperforms' if test_eer * 100 < baseline_eer else 'Does not outperform'} "
      f"(Model: {test_eer * 100:.2f}%, Baseline: {baseline_eer:.2f}%)")

Using device: cuda, GPU: Tesla T4

MFAAN Model Size:
Total Parameters: 206,066
Parameter Size: 0.79 MB
Train dataset size: 134400, Val dataset size: 16800
Class weights (Real, Fake): [1.3333333730697632, 0.6666666865348816]


Epoch 1/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 97.09it/s] 
Epoch 1/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 150.88it/s]


Epoch 1/20:
Train Loss: 0.2703, Train Acc: 88.65%
Val Loss: 0.3135, Val Acc: 86.97%, Val EER: 10.20%
Saved best model with Val EER: 10.20%


Epoch 2/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.64it/s] 
Epoch 2/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 153.03it/s]


Epoch 2/20:
Train Loss: 0.1972, Train Acc: 92.13%
Val Loss: 0.2405, Val Acc: 90.53%, Val EER: 8.34%
Saved best model with Val EER: 8.34%


Epoch 3/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.04it/s]
Epoch 3/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 146.94it/s]


Epoch 3/20:
Train Loss: 0.1824, Train Acc: 92.63%
Val Loss: 0.2455, Val Acc: 90.49%, Val EER: 8.48%


Epoch 4/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.44it/s]
Epoch 4/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 150.50it/s]


Epoch 4/20:
Train Loss: 0.1722, Train Acc: 93.12%
Val Loss: 0.2355, Val Acc: 90.26%, Val EER: 7.64%
Saved best model with Val EER: 7.64%


Epoch 5/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.65it/s] 
Epoch 5/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 147.81it/s]


Epoch 5/20:
Train Loss: 0.1676, Train Acc: 93.28%
Val Loss: 0.2216, Val Acc: 91.61%, Val EER: 7.52%
Saved best model with Val EER: 7.52%


Epoch 6/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.96it/s] 
Epoch 6/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 148.86it/s]


Epoch 6/20:
Train Loss: 0.1626, Train Acc: 93.54%
Val Loss: 0.2127, Val Acc: 91.79%, Val EER: 7.57%


Epoch 7/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.19it/s] 
Epoch 7/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 151.03it/s]


Epoch 7/20:
Train Loss: 0.1586, Train Acc: 93.77%
Val Loss: 0.2026, Val Acc: 92.47%, Val EER: 7.16%
Saved best model with Val EER: 7.16%


Epoch 8/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.83it/s]
Epoch 8/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 147.26it/s]


Epoch 8/20:
Train Loss: 0.1547, Train Acc: 93.86%
Val Loss: 0.2948, Val Acc: 87.99%, Val EER: 8.52%


Epoch 9/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.48it/s] 
Epoch 9/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 149.77it/s]


Epoch 9/20:
Train Loss: 0.1508, Train Acc: 94.07%
Val Loss: 0.2307, Val Acc: 91.11%, Val EER: 7.68%


Epoch 10/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.24it/s] 
Epoch 10/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 147.90it/s]


Epoch 10/20:
Train Loss: 0.1499, Train Acc: 94.01%
Val Loss: 0.2355, Val Acc: 91.02%, Val EER: 7.75%


Epoch 11/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.79it/s] 
Epoch 11/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 151.83it/s]


Epoch 11/20:
Train Loss: 0.1459, Train Acc: 94.19%
Val Loss: 0.2140, Val Acc: 91.46%, Val EER: 7.14%
Saved best model with Val EER: 7.14%


Epoch 12/20 [Train]: 100%|██████████| 4200/4200 [00:42<00:00, 97.89it/s] 
Epoch 12/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 154.46it/s]


Epoch 12/20:
Train Loss: 0.1450, Train Acc: 94.26%
Val Loss: 0.2470, Val Acc: 90.02%, Val EER: 7.70%


Epoch 13/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.78it/s] 
Epoch 13/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 151.99it/s]


Epoch 13/20:
Train Loss: 0.1436, Train Acc: 94.30%
Val Loss: 0.2180, Val Acc: 91.54%, Val EER: 7.86%


Epoch 14/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 97.31it/s] 
Epoch 14/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 150.07it/s]


Epoch 14/20:
Train Loss: 0.1409, Train Acc: 94.34%
Val Loss: 0.2358, Val Acc: 90.49%, Val EER: 7.96%


Epoch 15/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 97.26it/s] 
Epoch 15/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 154.69it/s]


Epoch 15/20:
Train Loss: 0.1389, Train Acc: 94.42%
Val Loss: 0.2510, Val Acc: 90.04%, Val EER: 8.07%


Epoch 16/20 [Train]: 100%|██████████| 4200/4200 [00:42<00:00, 97.75it/s] 
Epoch 16/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 151.87it/s]


Epoch 16/20:
Train Loss: 0.1387, Train Acc: 94.48%
Val Loss: 0.2437, Val Acc: 90.27%, Val EER: 7.29%


Epoch 17/20 [Train]: 100%|██████████| 4200/4200 [00:42<00:00, 97.72it/s] 
Epoch 17/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 150.48it/s]


Epoch 17/20:
Train Loss: 0.1376, Train Acc: 94.48%
Val Loss: 0.2496, Val Acc: 89.43%, Val EER: 7.14%


Epoch 18/20 [Train]: 100%|██████████| 4200/4200 [00:42<00:00, 97.95it/s] 
Epoch 18/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 152.49it/s]


Epoch 18/20:
Train Loss: 0.1368, Train Acc: 94.51%
Val Loss: 0.2224, Val Acc: 90.85%, Val EER: 7.18%


Epoch 19/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.90it/s] 
Epoch 19/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 149.15it/s]


Epoch 19/20:
Train Loss: 0.1348, Train Acc: 94.60%
Val Loss: 0.2188, Val Acc: 91.32%, Val EER: 7.52%


Epoch 20/20 [Train]: 100%|██████████| 4200/4200 [00:43<00:00, 96.81it/s] 
Epoch 20/20 [Val]: 100%|██████████| 525/525 [00:03<00:00, 150.69it/s]


Epoch 20/20:
Train Loss: 0.1358, Train Acc: 94.55%
Val Loss: 0.2610, Val Acc: 89.70%, Val EER: 8.45%
Fine-tuning complete. Best model saved to /kaggle/working/mfaan_t2_finetuned.pth

Loaded fine-tuned model for testing.
Found 27 chunk files in /kaggle/working/preprocessed_t2_test_chunks. First few: ['t2_test_mfcc_chunk_0.npy', 't2_test_mfcc_chunk_1.npy', 't2_test_mfcc_chunk_10.npy', 't2_test_mfcc_chunk_11.npy', 't2_test_mfcc_chunk_12.npy']
Dataset (test) initialized with 16800 samples.


Evaluating T2 Test Set: 100%|██████████| 525/525 [00:16<00:00, 31.04it/s]


Final T2 Test Accuracy: 91.76%
Final T2 Test EER: 6.89%
Per-class accuracy (Real, Fake): [95.85714286 89.71428571]

Baseline Accuracy (T2): 68.44%, Baseline EER: 40.90%
Model vs. Baseline:
Accuracy: Outperforms (Model: 91.76%, Baseline: 68.44%)
EER: Outperforms (Model: 6.89%, Baseline: 40.90%)



