In [1]:
# Cell 1: Imports
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, GRU, Dense, Dropout, BatchNormalization, LayerNormalization, Reshape, Permute, Bidirectional, Add, Attention, Flatten, TimeDistributed, Conv2DTranspose, Conv2D, Layer, Concatenate, Multiply, AdditiveAttention # Added Multiply, AdditiveAttention for potential SelfAttention implementation
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam, RMSprop # Added RMSprop as an option for WGAN
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, Callback
from tensorflow.keras import backend as K
from sklearn.metrics import confusion_matrix, f1_score, roc_curve # Moved confusion_matrix here
import librosa
import soundfile as sf
# import noisereduce as nr # Consider if still needed/effective
import matplotlib.pyplot as plt
from scipy.signal import butter, sosfilt
import logging
from tqdm.notebook import tqdm # Use tqdm.notebook for better Jupyter integration
import random
import seaborn as sns # Added for confusion matrix plotting
tf.config.optimizer.set_jit(False)
print("XLA JIT compilation disabled.")

# WGAN-GP specific
from tensorflow import GradientTape

# Create directory for saving figures
FIGURES_DIR = 'training_figures_wgan_sa' # Changed dir name
os.makedirs(FIGURES_DIR, exist_ok=True)

# Configure logging
logging.basicConfig(filename='audio_errors_wgan_sa.log', level=logging.ERROR,
                    format='%(asctime)s - %(levelname)s - %(message)s')

print(f"TensorFlow version: {tf.__version__}")

2025-04-04 17:45:15.033435: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-04 17:45:15.045755: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743768915.060522  640800 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743768915.064943  640800 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743768915.077012  640800 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

XLA JIT compilation disabled.
TensorFlow version: 2.19.0


In [2]:
# Cell 2: Force GPU usage
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    try:
        tf.config.set_visible_devices(physical_devices[0], 'GPU')
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        print("GPU is available and configured.")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(f"Error configuring GPU: {e}")
else:
    print("No GPU devices found. Falling back to CPU.")

# Set mixed precision policy if desired (can speed up training on compatible GPUs)
#from tensorflow.keras import mixed_precision
#policy = mixed_precision.Policy('mixed_float16')
#mixed_precision.set_global_policy(policy)
#print('Mixed precision enabled')

GPU is available and configured.


In [3]:
# Cell 3: Audio Loading and Preprocessing Function
def load_and_preprocess_audio(file_path, sr=16000, duration=4):
    try:
        # Load audio, potentially reducing noise first if beneficial
        audio, current_sr = librosa.load(file_path, sr=None, duration=duration) # Load native SR first

        # Optional: Noise Reduction (experiment if needed)
        # audio = nr.reduce_noise(y=audio, sr=current_sr)

        # Resample if necessary
        if current_sr != sr:
            audio = librosa.resample(audio, orig_sr=current_sr, target_sr=sr)

        # Pad or truncate to fixed duration *before* augmentation/normalization
        target_len = sr * duration
        if len(audio) < target_len:
            audio = np.pad(audio, (0, target_len - len(audio)), mode='constant')
        else:
            audio = audio[:target_len]

        # Data Augmentation (applied *before* normalization)
        if np.random.random() < 0.5: # 50% chance
            augmentation_type = np.random.choice(['noise', 'pitch', 'speed'])
            if augmentation_type == 'noise':
                noise_amp = 0.005 * np.random.uniform(0.1, 1.0) * np.max(np.abs(audio)) # Scale noise relative to audio
                noise = np.random.randn(len(audio)) * noise_amp
                audio = audio + noise
            elif augmentation_type == 'pitch':
                pitch_shift_steps = np.random.uniform(-2.5, 2.5)
                audio = librosa.effects.pitch_shift(audio, sr=sr, n_steps=pitch_shift_steps)
            else: # speed (time stretch)
                speed_rate = np.random.uniform(0.85, 1.15)
                audio = librosa.effects.time_stretch(audio, rate=speed_rate)
                 # Time stretching changes length, re-pad/truncate
                if len(audio) < target_len:
                    audio = np.pad(audio, (0, target_len - len(audio)), mode='constant')
                else:
                    audio = audio[:target_len]

        # Normalize audio (peak normalization) - crucial for consistency
        max_amp = np.max(np.abs(audio))
        if max_amp > 1e-6: # Avoid division by zero
             audio = audio / max_amp
        # Optional: RMS normalization instead
        # rms = np.sqrt(np.mean(audio**2))
        # if rms > 1e-6:
        #    audio = audio / rms * 0.5 # Scale to a target RMS

        return audio
    except Exception as e:
        logging.error(f"Error loading/preprocessing {file_path}: {e}")
        print(f"Error loading/preprocessing {file_path}: {e}")
        return None

In [4]:
# Cell 4: Feature Extraction Function
def extract_features(audio, sr=16000, n_mels=80, n_fft=2048, hop_length=512):
    if audio is None:
        return None
    try:
        # Extract mel spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=sr,
            n_mels=n_mels,
            n_fft=n_fft,
            hop_length=hop_length
        )
        # Convert to log scale (dB)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        # Normalize features (per spectrogram) - Standard Scaling
        mean = np.mean(log_mel_spec)
        std = np.std(log_mel_spec)
        if std > 1e-6: # Avoid division by zero
            log_mel_spec = (log_mel_spec - mean) / std
        else:
            log_mel_spec = log_mel_spec - mean # Just center if std is near zero

        # Ensure the shape is consistent (should be handled by fixed duration loading)
        # Expected frames: int(np.ceil(target_len / hop_length)) -> int(ceil(16000*4/512)) = 125? Check calculation
        # target_len = 16000 * 4 = 64000
        # expected_frames = 64000 // hop_length + 1 if 64000 % hop_length != 0 else 64000 // hop_length
        # expected_frames = 64000 / 512 = 125. Check librosa padding. It often adds a frame. Let's stick to TARGET_FRAMES=126 based on previous findings.

        return log_mel_spec # Shape (n_mels, n_frames) e.g. (80, 126)

    except Exception as e:
        logging.error(f"Error extracting features: {e}")
        print(f"Error extracting features: {e}")
        return None

In [5]:
# Cell 5: Class Distribution Analysis Function
def analyze_class_distribution(data_path):
    try:
        real_dir = os.path.join(data_path, 'real')
        fake_dir = os.path.join(data_path, 'fake')
        real_count = len([f for f in os.listdir(real_dir) if f.endswith('.wav')]) if os.path.exists(real_dir) else 0
        fake_count = len([f for f in os.listdir(fake_dir) if f.endswith('.wav')]) if os.path.exists(fake_dir) else 0
        total = real_count + fake_count
        if total == 0:
            print(f"\nNo .wav files found in {data_path}")
            return {'real': 0, 'fake': 0}
        print(f"\nClass Distribution for {data_path}:")
        print(f"Real: {real_count} ({real_count/total*100:.2f}%)")
        print(f"Fake: {fake_count} ({fake_count/total*100:.2f}%)")
        return {'real': real_count, 'fake': fake_count}
    except FileNotFoundError:
        print(f"\nError: Data path not found - {data_path}")
        return {'real': 0, 'fake': 0}
    except Exception as e:
        logging.error(f"Error analyzing distribution for {data_path}: {e}")
        print(f"Error analyzing distribution for {data_path}: {e}")
        return {'real': 0, 'fake': 0}

In [6]:
# Cell 6: Data Generators

# Define the fixed number of frames for GAN/Classifier consistency
TARGET_FRAMES = 126 # Recalculate based on sr=16000, duration=4, hop_length=512 if needed

# Data generator for STANDALONE CLASSIFIER training (Yields X, y, sample_weights)
# --- (data_generator_classifier function remains the same as corrected before) ---
def data_generator_classifier(data_path, batch_size=32, shuffle=True, target_frames=TARGET_FRAMES, sr=16000, duration=4, n_mels=80, n_fft=2048, hop_length=512):
    real_files = []
    fake_files = []
    try:
        real_dir = os.path.join(data_path, 'real')
        fake_dir = os.path.join(data_path, 'fake')
        if os.path.exists(real_dir):
             real_files = [os.path.join(real_dir, f) for f in os.listdir(real_dir) if f.endswith('.wav')]
        if os.path.exists(fake_dir):
             fake_files = [os.path.join(fake_dir, f) for f in os.listdir(fake_dir) if f.endswith('.wav')]
    except FileNotFoundError as e:
        print(f"Error finding directories in {data_path}: {e}")
        return

    all_files = real_files + fake_files
    labels = [1] * len(real_files) + [0] * len(fake_files) # Real=1, Fake=0

    if not all_files:
        print(f"No WAV files found in {data_path}. Classifier generator stopping.")
        return

    total_samples = len(all_files)
    class_weights = {
        1: total_samples / (2 * len(real_files)) if len(real_files) > 0 else 1.0,
        0: total_samples / (2 * len(fake_files)) if len(fake_files) > 0 else 1.0,
    }
    print(f"Using class weights: {class_weights} for path {data_path}")


    indices = np.arange(total_samples)
    while True:
        if shuffle:
            np.random.shuffle(indices)

        for i in range(0, total_samples, batch_size):
            batch_indices = indices[i:i+batch_size]
            batch_files = [all_files[k] for k in batch_indices]
            batch_labels = [labels[k] for k in batch_indices]

            batch_x = []
            batch_y = []
            batch_sample_weights = []

            for file_path, label in zip(batch_files, batch_labels):
                audio = load_and_preprocess_audio(file_path, sr=sr, duration=duration)
                if audio is None: continue

                features = extract_features(audio, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
                if features is None: continue

                current_frames = features.shape[1]
                if current_frames != target_frames:
                     if current_frames < target_frames:
                        pad_width = target_frames - current_frames
                        padded_features = np.pad(features, ((0, 0), (0, pad_width)), mode='constant')
                     else:
                        padded_features = features[:, :target_frames]
                else:
                    padded_features = features

                batch_x.append(padded_features)
                batch_y.append(label)
                batch_sample_weights.append(class_weights[label])

            # --- Check for classifier generator ---
            if not batch_x:
                print(f"Warning: Skipping empty batch yield in data_generator_classifier for path {data_path}")
                continue

            batch_x_4d = np.expand_dims(np.array(batch_x), axis=-1).astype(np.float32)
            batch_y_arr = np.array(batch_y).astype(np.float32)
            batch_weights_arr = np.array(batch_sample_weights).astype(np.float32)
            yield batch_x_4d, batch_y_arr, batch_weights_arr


# Data generator for WGAN training (Yields real fake samples X only)
# --- (MODIFIED data_generator_gan function) ---
def data_generator_gan(data_path, batch_size=32, shuffle=True, target_frames=TARGET_FRAMES, sr=16000, duration=4, n_mels=80, n_fft=2048, hop_length=512):
    fake_files = []
    try:
        fake_dir = os.path.join(data_path, 'fake')
        if os.path.exists(fake_dir):
             fake_files = [os.path.join(fake_dir, f) for f in os.listdir(fake_dir) if f.endswith('.wav')]
    except FileNotFoundError as e:
        print(f"Error finding fake directory in {data_path}: {e}")
        return

    if not fake_files:
        print(f"No fake WAV files found in {data_path}. GAN generator stopping.")
        return

    total_samples = len(fake_files)
    indices = np.arange(total_samples)

    while True:
        if shuffle:
            np.random.shuffle(indices)

        for i in range(0, total_samples, batch_size):
            batch_indices = indices[i:i+batch_size]
            batch_files = [fake_files[k] for k in batch_indices]

            batch_x = []

            for file_path in batch_files:
                audio = load_and_preprocess_audio(file_path, sr=sr, duration=duration)
                if audio is None: continue

                features = extract_features(audio, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
                if features is None: continue

                current_frames = features.shape[1]
                if current_frames != target_frames:
                    if current_frames < target_frames:
                        pad_width = target_frames - current_frames
                        padded_features = np.pad(features, ((0, 0), (0, pad_width)), mode='constant')
                    else:
                        padded_features = features[:, :target_frames]
                else:
                     padded_features = features

                batch_x.append(padded_features)

            # ---> ADD THIS CHECK for GAN generator <---
            if not batch_x:
                print(f"Warning: Skipping empty batch yield in data_generator_gan for path {data_path}")
                continue # Skip yield if batch ended up empty

            # If batch is not empty, yield the data
            batch_x_4d = np.expand_dims(np.array(batch_x), axis=-1).astype(np.float32)
            yield batch_x_4d # Yield only the features (4D array)

In [7]:
# Cell 7: Self-Attention Layer (SAGAN style)

class SelfAttention(Layer):
    """
    Self-attention layer based on SAGAN.
    Input shape: (batch, height, width, channels)
    Output shape: (batch, height, width, channels_out) where channels_out is typically channels
    """
    def __init__(self, channels_out=None, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.channels_out = channels_out

    def build(self, input_shape):
        self.input_channels = input_shape[-1]
        if self.channels_out is None:
            self.channels_out = self.input_channels

        # Convolution layers for query, key, value
        # Use 1x1 convolutions to reduce/transform channels
        self.f = Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_f') # Query
        self.g = Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_g') # Key
        self.h = Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_h')        # Value

        # Final 1x1 convolution
        self.out_conv = Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_out')

        # Learnable scale parameter
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)

        super(SelfAttention, self).build(input_shape)

    def call(self, x):
        batch_size, height, width, num_channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        location_num = height * width
        downsampled_num = location_num

        # Query (f), Key (g), Value (h) projections
        f_proj = self.f(x) # Shape: (batch, h, w, c/8)
        g_proj = self.g(x) # Shape: (batch, h, w, c/8)
        h_proj = self.h(x) # Shape: (batch, h, w, c_out)

        # Reshape for matrix multiplication
        f_flat = tf.reshape(f_proj, shape=(batch_size, location_num, self.input_channels // 8)) # (batch, h*w, c/8)
        g_flat = tf.reshape(g_proj, shape=(batch_size, location_num, self.input_channels // 8)) # (batch, h*w, c/8)
        h_flat = tf.reshape(h_proj, shape=(batch_size, location_num, self.channels_out))       # (batch, h*w, c_out)

        # Attention map calculation
        # Transpose g for matmul: (batch, c/8, h*w)
        g_flat_t = tf.transpose(g_flat, perm=[0, 2, 1])
        # Attention score: (batch, h*w, c/8) x (batch, c/8, h*w) -> (batch, h*w, h*w)
        attention_score = tf.matmul(f_flat, g_flat_t)
        attention_prob = tf.nn.softmax(attention_score, axis=-1) # Apply softmax across locations

        # Apply attention map to value projection
        # (batch, h*w, h*w) x (batch, h*w, c_out) -> (batch, h*w, c_out)
        attention_output = tf.matmul(attention_prob, h_flat)

        # Reshape back to image format
        attention_output_reshaped = tf.reshape(attention_output, shape=(batch_size, height, width, self.channels_out))

        # Apply final 1x1 convolution and scale by gamma
        o = self.out_conv(attention_output_reshaped)
        y = self.gamma * o + x # Additive skip connection

        return y

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.channels_out,)

In [8]:
# Cell 8: Generator Model (with Self-Attention)

def create_generator(latent_dim, output_shape): # output_shape (n_mels, n_frames) e.g., (80, 126)
    """Creates the generator model with Self-Attention."""
    n_mels, n_frames = output_shape
    init_h, init_w = n_mels // 8, n_frames // 8 # Calculate initial dimensions based on 3 upsamples (2*2*2=8)
    init_c = 128 # Initial channels

    if init_h * 8 != n_mels or init_w * 8 != n_frames:
         print(f"Warning: Output shape {output_shape} might not be perfectly reached with 3 strides of 2. Adjusting initial size or layers.")
         # Adjust init_w slightly if needed, e.g. target 128 -> 16, target 126 -> 16 (trim later)
         init_w = (n_frames + 7) // 8 # Ceiling division equivalent for width

    nodes = init_h * init_w * init_c

    model = Sequential(name='generator')
    model.add(Input(shape=(latent_dim,)))

    # Dense layer and reshape
    model.add(Dense(nodes))
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2)) # Use negative_slope
    model.add(Reshape((init_h, init_w, init_c))) # e.g., (10, 16, 128)

    # Upsample 1: (10, 16, 128) -> (20, 32, 64)
    model.add(Conv2DTranspose(init_c // 2, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False)) # Typically use_bias=False with Norm
    model.add(LayerNormalization()) # Using LayerNorm instead of BatchNorm
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))

    # Upsample 2: (20, 32, 64) -> (40, 64, 32)
    model.add(Conv2DTranspose(init_c // 4, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(LayerNormalization())
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))

    # Add Self-Attention Layer Here (applied to 40x64 feature map)
    # Note: Attention can be computationally expensive. Apply strategically.
    model.add(SelfAttention(channels_out=init_c // 4)) # Keep channels the same
    # model.add(LayerNormalization()) # Optional normalization after attention
    # model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2)) # Optional activation after attention

    # Upsample 3: (40, 64, 32) -> (80, 128, 1)
    model.add(Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), padding='same', activation='tanh'))

    # Final adjustment layer if needed (e.g., width 128 -> 126)
    current_width = init_w * 8
    if current_width != n_frames:
        print(f"Generator adding final Conv2D to adjust width from {current_width} to {n_frames}")
        # Calculate kernel size needed for 'valid' padding: K = W_in - W_out + 1
        kernel_w = current_width - n_frames + 1
        if kernel_w > 0:
             model.add(Conv2D(1, kernel_size=(1, kernel_w), padding='valid', activation='tanh'))
        else:
            # This case shouldn't happen with the init_w calculation but handle just in case
             print(f"Warning: Could not adjust width with Conv2D. Current {current_width}, Target {n_frames}")
             # May need padding='same' and cropping layer if kernel_w is not positive

    # Ensure final output shape is correct
    # model.add(Reshape((n_mels, n_frames, 1))) # Add reshape just to be certain, though last Conv should handle it

    return model

In [9]:
# Cell 9: Critic (Discriminator) Model (with Self-Attention) - NO DROPOUT

def create_critic(input_shape): # input_shape (n_mels, n_frames) e.g., (80, 126)
    """Creates the Critic model for WGAN-GP with Self-Attention. NO DROPOUT."""
    n_mels, n_frames = input_shape
    model_input_shape = (n_mels, n_frames, 1) # Expects (80, 126, 1)

    model = Sequential(name='critic')
    model.add(Input(shape=model_input_shape))

    # Layer 1
    model.add(Conv2D(64, kernel_size=(4, 4), strides=(2, 2), padding='same'))
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    # model.add(Dropout(0.25)) # REMOVED

    # Layer 2
    model.add(Conv2D(128, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(LayerNormalization())
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    # model.add(Dropout(0.25)) # REMOVED

    # Add Self-Attention Layer Here
    model.add(SelfAttention(channels_out=128))

    # Layer 3
    model.add(Conv2D(256, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(LayerNormalization())
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    # model.add(Dropout(0.25)) # REMOVED (If you had one here)

    # Flatten and Output Score
    model.add(Flatten())
    model.add(Dense(1))

    return model

In [10]:
# Cell 10: Define the GAN Model (Not used for WGAN-GP training loop)
# def create_gan(generator, discriminator, latent_dim):
#     """Creates the combined GAN model."""
#     # Make discriminator non-trainable
#     discriminator.trainable = False
#
#     # Stack generator and discriminator
#     gan_input = Input(shape=(latent_dim,))
#     gan_output = discriminator(generator(gan_input))
#     gan = Model(gan_input, gan_output)
#
#     return gan

In [11]:
# Cell 11: Data Path and Parameters

# Data Paths
train_data_path = 'datasetNEW/train'
dev_data_path = 'datasetNEW/dev'
eval_data_path = 'datasetNEW/eval'

# Define the fixed number of frames
TARGET_FRAMES = 126
N_MELS = 80
mel_spectrogram_shape = (N_MELS, TARGET_FRAMES)

# WGAN-GP specific parameters
latent_dim = 100
n_critic = 5         # Train critic 5 times per generator update
gp_weight = 10.0     # Gradient penalty weight
gan_epochs = 40      # WGAN often needs more epochs, adjust as needed
gan_batch_size = 8  # Adjust based on GPU memory (WGAN-GP can be memory intensive)

# Optimizers (Typical WGAN settings: lower LR, betas=(0.5, 0.9))
critic_optimizer = Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.9)
generator_optimizer = Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.9)
# critic_optimizer = RMSprop(learning_rate=0.00005) # Alternative optimizer
# generator_optimizer = RMSprop(learning_rate=0.00005)

# Create instances
generator = create_generator(latent_dim, mel_spectrogram_shape)
critic = create_critic(mel_spectrogram_shape)

# Diagnostic Code: Verify Output Shape
test_noise = tf.random.normal((1, latent_dim))
generated_image = generator(test_noise, training=False) # Use tf.random and call model directly
critic_output = critic(generated_image, training=False)
print("Shape of generated image (Generator Output):", generated_image.shape)
print("Shape of critic output:", critic_output.shape)

critic_input_shape_expected = (mel_spectrogram_shape[0], mel_spectrogram_shape[1], 1)
print("Expected shape for Critic Input:", critic_input_shape_expected)
assert generated_image.shape[1:] == critic_input_shape_expected, "Generator output shape mismatch!"
assert len(critic_output.shape) == 2 and critic_output.shape[1] == 1, "Critic output shape mismatch!"


# Report the models
print("\n--- Generator Summary ---")
generator.summary()
print("\n--- Critic Summary ---")
critic.summary()

# Parameters for standalone classifier training (can be different)
classifier_batch_size = 8 # Keep smaller for classifier fine-tuning?
classifier_epochs = 60

I0000 00:00:1743768917.534606  640800 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1904 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Ti Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


Generator adding final Conv2D to adjust width from 128 to 126


I0000 00:00:1743768918.184631  640800 cuda_dnn.cc:529] Loaded cuDNN version 90300


Shape of generated image (Generator Output): (1, 80, 126, 1)
Shape of critic output: (1, 1)
Expected shape for Critic Input: (80, 126, 1)

--- Generator Summary ---



--- Critic Summary ---


In [None]:
# Cell 12: WGAN-GP Training Loop (Saving Critic Weights Only)

# --- Imports ---
import os
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm # Use tqdm.notebook for better Jupyter integration
import matplotlib.pyplot as plt
import logging

# Assuming necessary variables (train_data_path, gan_batch_size, latent_dim, n_critic,
# gp_weight, gan_epochs, critic_optimizer, generator_optimizer, generator, critic,
# FIGURES_DIR) and functions (data_generator_gan) are defined and accessible
# from previous cells.

# --- Loss Functions ---
def critic_loss(real_output, fake_output):
    """Wasserstein loss for the critic."""
    # Ensure outputs are float32 for stable loss calculation
    real_output = tf.cast(real_output, tf.float32)
    fake_output = tf.cast(fake_output, tf.float32)
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    """Wasserstein loss for the generator."""
    # Ensure output is float32
    fake_output = tf.cast(fake_output, tf.float32)
    return -tf.reduce_mean(fake_output)

# --- Gradient Penalty Function (Corrected) ---
def gradient_penalty(batch_size, real_images, fake_images):
    """ Calculates the gradient penalty loss for WGAN GP, handling mixed precision. """
    # Ensure both images have the same dtype before interpolation
    if real_images.dtype != fake_images.dtype:
        real_images = tf.cast(real_images, fake_images.dtype)

    # Generate interpolation alpha with the correct dtype
    alpha_shape = [tf.shape(real_images)[0]] + [1] * (len(real_images.shape) - 1)
    alpha = tf.random.uniform(shape=alpha_shape, minval=0., maxval=1., dtype=real_images.dtype)

    # Interpolate images
    interpolated = real_images + alpha * (fake_images - real_images)

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        pred = critic(interpolated, training=True)
        pred = tf.cast(pred, tf.float32) # Cast prediction to float32 for stable GP calculation

    grads = gp_tape.gradient(pred, [interpolated])
    if grads is None or grads[0] is None:
        logging.warning("Gradients are None in gradient_penalty. Returning 0 penalty.")
        # print("Warning: Gradients are None in gradient_penalty. Returning 0 penalty.") # Optional print
        return tf.constant(0.0, dtype=tf.float32)
    grads = grads[0]
    grads = tf.cast(grads, tf.float32) # Cast gradients to float32 before norm

    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=tf.range(1, tf.rank(grads))))
    gp = tf.reduce_mean((norm - 1.0) ** 2)

    return gp * gp_weight # gp_weight should be float


# --- Data Generator and Steps Calculation ---
train_gen_gan = data_generator_gan(train_data_path, batch_size=gan_batch_size)

fake_files_count = 0
try:
    fake_dir = os.path.join(train_data_path, 'fake')
    if os.path.exists(fake_dir):
        fake_files_count = len([f for f in os.listdir(fake_dir) if f.endswith('.wav')])
except FileNotFoundError:
    fake_files_count = 0

if gan_batch_size <= 0:
    raise ValueError("GAN Batch size must be positive.")
if fake_files_count == 0:
    print("Warning: No fake training files found for GAN. Setting GAN steps to 0.")
    gan_steps_per_epoch = 0
else:
    gan_steps_per_epoch = int(np.ceil(fake_files_count / float(gan_batch_size)))
    print(f"Calculated {gan_steps_per_epoch} GAN steps per epoch.")


# --- Training Step Function (Decorated with tf.function) ---
@tf.function
def train_step(real_images):
    current_batch_size = tf.shape(real_images)[0]
    # Use float32 for noise if mixed precision is enabled, otherwise default might be okay
    noise = tf.random.normal([current_batch_size, latent_dim], dtype=tf.float32)

    # Train Critic (n_critic times)
    total_crit_loss = tf.constant(0.0, dtype=tf.float32)
    for _ in tf.range(n_critic):
        with tf.GradientTape() as crit_tape:
            fake_images = generator(noise, training=True)
            real_output = critic(real_images, training=True)
            fake_output = critic(fake_images, training=True)
            crit_loss = critic_loss(real_output, fake_output)
            gp = gradient_penalty(current_batch_size, real_images, fake_images)
            total_crit_loss = crit_loss + gp

        crit_gradients = crit_tape.gradient(total_crit_loss, critic.trainable_variables)
        valid_grads_and_vars = [(g, v) for g, v in zip(crit_gradients, critic.trainable_variables) if g is not None]
        if len(valid_grads_and_vars) < len(critic.trainable_variables):
             tf.print("Warning: Some critic gradients are None.")
        if len(valid_grads_and_vars) > 0:
            critic_optimizer.apply_gradients(valid_grads_and_vars)

    # Train Generator
    with tf.GradientTape() as gen_tape:
        fake_images_gen = generator(noise, training=True)
        fake_output_gen = critic(fake_images_gen, training=True)
        gen_loss = generator_loss(fake_output_gen)

    gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
    valid_grads_and_vars_gen = [(g, v) for g, v in zip(gen_gradients, generator.trainable_variables) if g is not None]
    if len(valid_grads_and_vars_gen) < len(generator.trainable_variables):
         tf.print("Warning: Some generator gradients are None.")
    if len(valid_grads_and_vars_gen) > 0:
        generator_optimizer.apply_gradients(valid_grads_and_vars_gen)

    return total_crit_loss, gen_loss


# --- WGAN-GP Training Loop ---
print("\nBegin WGAN-GP training!")

if train_gen_gan is None or gan_steps_per_epoch == 0:
    print("Skipping WGAN-GP training: Generator not initialized or no steps per epoch.")
else:
    c_loss_history = []
    g_loss_history = []

    for epoch in range(gan_epochs): # Use gan_epochs defined in Cell 11
        print(f"\nEpoch {epoch+1}/{gan_epochs}")
        epoch_pbar = tqdm(range(gan_steps_per_epoch), desc=f"Epoch {epoch+1}")

        epoch_c_loss = 0.0
        epoch_g_loss = 0.0
        batches_processed = 0

        for batch_idx in epoch_pbar:
            try:
                real_spoof_samples_np = next(train_gen_gan)
                # Ensure input to train_step is float32
                real_spoof_samples = tf.convert_to_tensor(real_spoof_samples_np, dtype=tf.float32)

                if tf.shape(real_spoof_samples)[0] == 0:
                    # print(f"Warning: Skipped empty batch at step {batch_idx}") # Optional print
                    continue

                c_loss, g_loss = train_step(real_spoof_samples)

                # Check for NaN or Inf losses (important for stability)
                if np.isnan(c_loss.numpy()) or np.isinf(c_loss.numpy()) or \
                   np.isnan(g_loss.numpy()) or np.isinf(g_loss.numpy()):
                    print(f"\nError: NaN or Inf loss detected at epoch {epoch+1}, batch {batch_idx}. Stopping training.")
                    logging.error(f"NaN/Inf loss detected: C Loss={c_loss.numpy()}, G Loss={g_loss.numpy()}. Epoch {epoch+1}, Batch {batch_idx}")
                    # Optional: Save current state before breaking if needed
                    # generator.save_weights('generator_nan_inf.weights.h5')
                    # critic.save_weights('critic_nan_inf.weights.h5')
                    raise ValueError("NaN or Inf loss detected, stopping training.") # Stop execution

                epoch_c_loss += c_loss.numpy()
                epoch_g_loss += g_loss.numpy()
                batches_processed += 1

                epoch_pbar.set_postfix({"C Loss": f"{c_loss.numpy():.4f}", "G Loss": f"{g_loss.numpy():.4f}"})

            except StopIteration:
                print(f"\nGAN Generator exhausted prematurely at batch {batch_idx}. Moving to next epoch.")
                train_gen_gan = data_generator_gan(train_data_path, batch_size=gan_batch_size)
                break

            except Exception as e:
                # Catch NaN/Inf error from the check above or other errors
                logging.error(f"Error during WGAN training epoch {epoch+1}, batch {batch_idx}: {e}", exc_info=True)
                print(f"\nError during WGAN training epoch {epoch+1}, batch {batch_idx}: {e}")
                # Decide whether to continue or break based on the error type
                if isinstance(e, ValueError) and "NaN or Inf loss" in str(e):
                     break # Stop the outer loop for NaN/Inf
                continue # Skip other problematic batches


        # Check if loop was broken due to NaN/Inf
        if np.isnan(epoch_c_loss) or np.isinf(epoch_c_loss) or np.isnan(epoch_g_loss) or np.isinf(epoch_g_loss):
             print("Training stopped due to NaN/Inf loss.")
             break # Exit the epoch loop

        # --- End-of-epoch actions ---
        if batches_processed > 0:
             avg_c_loss = epoch_c_loss / batches_processed
             avg_g_loss = epoch_g_loss / batches_processed
             c_loss_history.append(avg_c_loss)
             g_loss_history.append(avg_g_loss)
             print(f"Epoch {epoch+1} finished. Avg C Loss: {avg_c_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}")
        else:
             print(f"Epoch {epoch+1} finished. No batches processed.")
             c_loss_history.append(np.nan)
             g_loss_history.append(np.nan)


        # --- Save Models/Weights Periodically ---
        # Save every 5 epochs and at the very last epoch
        if (epoch + 1) % 5 == 0 or (epoch + 1) == gan_epochs:
             try:
                 # Save generator as full model (assuming it doesn't have loading issues)
                 generator.save(f'generator_wgan_sa_epoch_{epoch+1}.keras')

                 # Save critic weights ONLY
                 critic_weights_filename = f'critic_wgan_sa_epoch_{epoch+1}.weights.h5' # Define filename
                 critic.save_weights(critic_weights_filename) # Use save_weights

                 print(f"Saved generator model and critic weights ({critic_weights_filename}) for epoch {epoch+1}")
             except Exception as e:
                 logging.error(f"Error saving models/weights at epoch {epoch+1}: {e}", exc_info=True)
                 print(f"Error saving models/weights at epoch {epoch+1}: {e}")


    print("\nWGAN-GP training finished (or stopped early).")

    # Final loss plot
    if c_loss_history and g_loss_history: # Plot only if history exists
         plt.figure(figsize=(10, 5))
         # Filter out NaN values for plotting if training stopped early
         epochs_ran = range(1, len([loss for loss in c_loss_history if not np.isnan(loss)]) + 1)
         plt.plot(epochs_ran, [loss for loss in c_loss_history if not np.isnan(loss)], label='Avg Critic Loss per Epoch')
         plt.plot(epochs_ran, [loss for loss in g_loss_history if not np.isnan(loss)], label='Avg Generator Loss per Epoch')
         plt.title('WGAN-GP Training Losses')
         plt.xlabel('Epoch')
         plt.ylabel('Average Loss')
         plt.legend()
         plt.grid(True)
         plot_filename = os.path.join(FIGURES_DIR, 'gan_loss_final.png')
         try:
              plt.savefig(plot_filename)
              print(f"Saved final loss plot to {plot_filename}")
         except Exception as e:
              logging.error(f"Failed to save final loss plot: {e}", exc_info=True)
              print(f"Failed to save final loss plot: {e}")
         plt.show()
    else:
         print("No loss history recorded, skipping final plot.")

Calculated 2850 GAN steps per epoch.

Begin WGAN-GP training!

Epoch 1/40


Epoch 1:   0%|          | 0/2850 [00:00<?, ?it/s]

In [None]:
# Cell 13: Train Standalone Classifier (using Critic Weights)

# --- Imports needed specifically for this cell's logic ---
import fnmatch # For finding model files
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.models import Model # Need Model for functional API approach
from tensorflow.keras.layers import Dense, Input # Need Dense and Input
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import logging

# Assuming other necessary variables (train_data_path, dev_data_path, classifier_batch_size,
# mel_spectrogram_shape=(N_MELS, TARGET_FRAMES), classifier_epochs, FIGURES_DIR) and functions
# (data_generator_classifier, count_total_files, SelfAttention, create_critic) are defined and
# accessible from previous cells. create_critic is now essential.


# --- Custom Callback for Plotting Training History ---
# ... (PlotTrainingHistory class definition - KEEP AS IS from previous version) ...
class PlotTrainingHistory(Callback):
    def __init__(self, model_name='model'):
        super().__init__()
        self.model_name = model_name
        self.acc = []
        self.val_acc = []
        self.loss = []
        self.val_loss = []
        self.auc = [] # Added to track AUC
        self.val_auc = [] # Added to track validation AUC

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.loss.append(logs.get('loss'))
        self.val_loss.append(logs.get('val_loss'))
        self.acc.append(logs.get('accuracy'))
        self.val_acc.append(logs.get('val_accuracy'))
        self.auc.append(logs.get('auc')) # Get AUC
        self.val_auc.append(logs.get('val_auc')) # Get val_AUC

        # Plot training history
        plt.figure(figsize=(18, 5))

        # Accuracy Plot
        plt.subplot(1, 3, 1)
        if any(v is not None for v in self.acc):
            plt.plot(self.acc, label='Training Accuracy')
        if any(v is not None for v in self.val_acc):
            plt.plot(self.val_acc, label='Validation Accuracy')
        plt.title('Model Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        if any(v is not None for v in self.acc) or any(v is not None for v in self.val_acc):
            plt.legend()

        # Loss Plot
        plt.subplot(1, 3, 2)
        if any(v is not None for v in self.loss):
            plt.plot(self.loss, label='Training Loss')
        if any(v is not None for v in self.val_loss):
             plt.plot(self.val_loss, label='Validation Loss')
        plt.title('Model Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        if any(v is not None for v in self.loss) or any(v is not None for v in self.val_loss):
            plt.legend()

        # AUC Plot
        plt.subplot(1, 3, 3)
        if any(v is not None for v in self.auc):
             plt.plot(self.auc, label='Training AUC')
        if any(v is not None for v in self.val_auc):
             plt.plot(self.val_auc, label='Validation AUC')
        plt.title('Model AUC')
        plt.xlabel('Epoch')
        plt.ylabel('AUC')
        if any(v is not None for v in self.auc) or any(v is not None for v in self.val_auc):
            plt.legend()

        plt.tight_layout()
        filepath = os.path.join(FIGURES_DIR, f'{self.model_name}_epoch_{epoch+1}.png')
        try:
             plt.savefig(filepath)
             plt.close()
        except Exception as e:
             logging.error(f"Error saving plot to {filepath}: {e}")
             print(f"\nError saving plot to {filepath}: {e}")
             plt.close()
# --- END OF CLASS DEFINITION ---


print("\nSetting up Standalone Classifier training...")

# --- Create Data Generators for Classifier ---
train_gen_clf = data_generator_classifier(train_data_path, batch_size=classifier_batch_size)
dev_gen_clf = data_generator_classifier(dev_data_path, batch_size=classifier_batch_size, shuffle=False)

# --- Calculate steps for Classifier Training ---
train_samples_count = count_total_files(train_data_path)
dev_samples_count = count_total_files(dev_data_path)

if classifier_batch_size <= 0:
    raise ValueError("Classifier batch size must be positive.")

clf_steps_per_epoch = int(np.ceil(train_samples_count / float(classifier_batch_size))) if train_samples_count > 0 else 0
clf_validation_steps = int(np.ceil(dev_samples_count / float(classifier_batch_size))) if dev_samples_count > 0 else 0

print(f"Classifier Train Steps/Epoch: {clf_steps_per_epoch}, Validation Steps: {clf_validation_steps}")


# --- Recreate Critic Structure and Load Weights ---
critic_weights_path = None
critic_for_classifier = None # Use a new variable name

# 1. Recreate the critic structure
try:
    # Ensure mel_spectrogram_shape is defined, e.g., (80, 126)
    print(f"Recreating critic structure using shape: {mel_spectrogram_shape}")
    # create_critic function must be available from Cell 9
    critic_for_classifier = create_critic(mel_spectrogram_shape)
    print("Critic structure recreated successfully.")
except NameError as ne:
     logging.error(f"Failed to recreate critic structure: {ne}. 'create_critic' or 'mel_spectrogram_shape' not defined?", exc_info=True)
     raise NameError(f"Failed to recreate critic structure: {ne}. Ensure 'create_critic' and 'mel_spectrogram_shape' are defined.")
except Exception as ce:
     logging.error(f"Failed to recreate critic structure: {ce}", exc_info=True)
     raise RuntimeError(f"Failed to recreate critic structure: {ce}")


# 2. Find the latest critic weights file
try:
    critic_weights_pattern = 'critic_wgan_sa_epoch_*.weights.h5'
    list_of_critic_weights_files = [f for f in os.listdir('.') if fnmatch.fnmatch(f, critic_weights_pattern)]
    if list_of_critic_weights_files:
        latest_weights_file = max(list_of_critic_weights_files, key=os.path.getctime)
        critic_weights_path = latest_weights_file
        print(f"Found latest critic weights file: {critic_weights_path}")
    else:
        print("Warning: No critic weights file found matching pattern.")
        print("Proceeding with randomly initialized critic weights for the classifier.")
except Exception as e:
    logging.error(f"Error finding critic weights files: {e}", exc_info=True)
    print(f"Error finding critic weights files: {e}")
    print("Proceeding with randomly initialized critic weights for the classifier.")


# 3. Load weights if found
if critic_weights_path:
    try:
        print(f"Loading weights from {critic_weights_path} into recreated critic structure...")
        critic_for_classifier.load_weights(critic_weights_path)
        print("Successfully loaded critic weights.")
    except Exception as e:
        logging.error(f"Error loading weights from {critic_weights_path}: {e}", exc_info=True)
        print(f"Error loading weights from {critic_weights_path}: {e}")
        print("Classifier training will proceed with initialized weights.")
# If weights not found or failed to load, critic_for_classifier will just have its initial random weights


# --- Build the Spoof Detector Model using the Critic with Loaded Weights ---
print("Building classifier using critic layers...")

# Make the critic layers trainable for fine-tuning
critic_for_classifier.trainable = True

# 1. Get the expected input shape
try:
    if len(mel_spectrogram_shape) == 2:
        critic_input_shape = (mel_spectrogram_shape[0], mel_spectrogram_shape[1], 1)
    else:
        critic_input_shape = mel_spectrogram_shape
    print(f"Using input shape for new model: {critic_input_shape}")
except NameError:
    logging.error("Variable 'mel_spectrogram_shape' is not defined.")
    raise NameError("Variable 'mel_spectrogram_shape' is not defined. Cannot determine input shape.")

# 2. Create a new Input layer
new_input = Input(shape=critic_input_shape, name="classifier_input")

# 3. Sequentially connect the layers from the recreated critic (excluding the original output layer)
x = new_input
processed_layers_count = 0
for layer in critic_for_classifier.layers[:-1]: # Iterate through all layers EXCEPT the last one
    try:
        layer.trainable = True # Ensure layers are trainable in the new model
        x = layer(x)
        processed_layers_count += 1
    except Exception as layer_e:
        logging.error(f"Error connecting layer '{layer.name}': {layer_e}", exc_info=True)
        print(f"Error connecting layer '{layer.name}': {layer_e}")
        raise RuntimeError(f"Failed to rebuild model at layer: {layer.name}")

print(f"Connected {processed_layers_count} layers from the critic.")
if processed_layers_count == 0:
    raise ValueError("No layers were processed from the critic. Cannot build classifier.")

# 4. Add the final binary classification layer
classifier_output = Dense(1, activation='sigmoid', name='classifier_output')(x)

# 5. Create the final Functional API model
spoof_detector = Model(inputs=new_input, outputs=classifier_output, name='spoof_detector')


print("\n--- Spoof Detector (Classifier) Summary ---")
spoof_detector.summary()


# --- Callbacks for Classifier Training ---
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss', factor=0.2, patience=5, min_lr=1e-7, verbose=1
)
early_stopping = EarlyStopping(
    monitor='val_loss', patience=10, restore_best_weights=True, verbose=1
)

checkpoint_dir_clf = './training_checkpoints_spoof_detector_wgan_sa'
os.makedirs(checkpoint_dir_clf, exist_ok=True)
checkpoint_prefix_clf = os.path.join(checkpoint_dir_clf, "ckpt_clf_{epoch:02d}.weights.h5")

# Checkpoint callback saves only weights
checkpoint_callback_clf = ModelCheckpoint(
    filepath=checkpoint_prefix_clf,
    save_weights_only=True,
    monitor='val_auc',
    mode='max',
    save_best_only=True,
    verbose=1
)

plot_training_callback = PlotTrainingHistory(model_name='spoof_detector_wgan_sa')


# --- Compile and Train the Classifier ---
classifier_optimizer = Adam(learning_rate=1e-5)
spoof_detector.compile(
    optimizer=classifier_optimizer,
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

print("\nStarting Standalone Classifier training...")

history = None # Initialize history
if clf_steps_per_epoch > 0 and clf_validation_steps > 0:
    try:
        history = spoof_detector.fit(
            train_gen_clf,
            steps_per_epoch=clf_steps_per_epoch,
            epochs=classifier_epochs,
            validation_data=dev_gen_clf,
            validation_steps=clf_validation_steps,
            callbacks=[reduce_lr, early_stopping, checkpoint_callback_clf, plot_training_callback],
        )
    except Exception as fit_e:
        logging.error(f"Error during classifier training (model.fit): {fit_e}", exc_info=True)
        print(f"\n--- ERROR DURING CLASSIFIER TRAINING ---")
        print(f"Error: {fit_e}")
        print("----------------------------------------")


    # --- Save the Final Trained Classifier Model ---
    # Since loading the full model was problematic, saving only weights might be safer here too,
    # although saving the full model *might* work now that it's built differently.
    # Let's try saving the full model first, with a fallback.
    if history is not None:
        final_model_path = 'spoof_detector_final_wgan_sa.keras'
        try:
            # Need custom_objects because SelfAttention is still part of the model layers
            custom_objects = {'SelfAttention': SelfAttention}
            spoof_detector.save(final_model_path, custom_objects=custom_objects)
            print(f"\nClassifier training complete or stopped. Final model saved as {final_model_path}")
        except Exception as e:
             logging.error(f"Error saving final classifier model (full): {e}", exc_info=True)
             print(f"\nError saving final classifier model (full): {e}. Saving weights only as fallback.")
             spoof_detector.save_weights('spoof_detector_final_wgan_sa.weights.h5')

        # --- Optional: Load Best Weights and Save Best Model ---
        try:
            best_checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir_clf)
            if best_checkpoint_path:
                print(f"Loading best classifier weights from checkpoint: {best_checkpoint_path}")
                # Load the best weights into the current model structure
                spoof_detector.load_weights(best_checkpoint_path) # load_weights works reliably

                # Try saving the full model with best weights
                best_model_path = 'spoof_detector_best_val_auc_wgan_sa.keras'
                custom_objects = {'SelfAttention': SelfAttention}
                spoof_detector.save(best_model_path, custom_objects=custom_objects)
                print(f"Saved classifier with best validation AUC weights as {best_model_path}")
            else:
                print("Could not find best checkpoint weights. The final model saved corresponds to the last epoch.")
        except Exception as e:
            logging.error(f"Error loading best weights or saving best model: {e}", exc_info=True)
            print(f"Error loading best weights or saving best model: {e}")
            # Optionally save just the best weights if saving the full model failed
            if best_checkpoint_path:
                 try:
                      spoof_detector.save_weights('spoof_detector_best_val_auc_wgan_sa.weights.h5')
                      print("Saved best weights only as spoof_detector_best_val_auc_wgan_sa.weights.h5")
                 except Exception as e_ws:
                      print(f"Failed to save best weights only: {e_ws}")
    else:
        print("\nClassifier training did not run or failed. No final model saved.")

else:
    print("Skipping classifier training because steps_per_epoch or validation_steps is zero.")

In [None]:
# Cell 14: Evaluation (Improved Error Handling)

import os
import tensorflow as tf
import numpy as np # Needed for count_total_files if redefined
from tensorflow.keras.models import load_model # Ensure load_model is imported

# --- Ensure SelfAttention Class Definition is Available ---
# Copy the class definition from Cell 7 here for robustness,
# especially if the kernel might have been restarted.

class SelfAttention(tf.keras.layers.Layer):
    """
    Self-attention layer based on SAGAN.
    Input shape: (batch, height, width, channels)
    Output shape: (batch, height, width, channels_out) where channels_out is typically channels
    """
    def __init__(self, channels_out=None, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.channels_out = channels_out

    def build(self, input_shape):
        self.input_channels = input_shape[-1]
        if self.channels_out is None:
            self.channels_out = self.input_channels

        # Convolution layers for query, key, value
        self.f = tf.keras.layers.Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_f') # Query
        self.g = tf.keras.layers.Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_g') # Key
        self.h = tf.keras.layers.Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_h')        # Value
        self.out_conv = tf.keras.layers.Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_out')
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)
        super(SelfAttention, self).build(input_shape)

    def call(self, x):
        batch_size, height, width, num_channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        location_num = height * width
        f_proj = self.f(x)
        g_proj = self.g(x)
        h_proj = self.h(x)
        f_flat = tf.reshape(f_proj, shape=(batch_size, location_num, self.input_channels // 8))
        g_flat = tf.reshape(g_proj, shape=(batch_size, location_num, self.input_channels // 8))
        h_flat = tf.reshape(h_proj, shape=(batch_size, location_num, self.channels_out))
        g_flat_t = tf.transpose(g_flat, perm=[0, 2, 1])
        attention_score = tf.matmul(f_flat, g_flat_t)
        attention_prob = tf.nn.softmax(attention_score, axis=-1)
        attention_output = tf.matmul(attention_prob, h_flat)
        attention_output_reshaped = tf.reshape(attention_output, shape=(batch_size, height, width, self.channels_out))
        o = self.out_conv(attention_output_reshaped)
        y = self.gamma * o + x
        return y

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.channels_out,)

    # Add get_config for proper saving/loading if needed (optional for this fix but good practice)
    def get_config(self):
        config = super(SelfAttention, self).get_config()
        config.update({"channels_out": self.channels_out})
        return config

# --- END OF SelfAttention DEFINITION ---


print("\nEvaluating the final Spoof Detector...")

# Define expected model paths
best_model_filename = 'spoof_detector_best_val_auc_wgan_sa.keras'
final_model_filename = 'spoof_detector_final_wgan_sa.keras'
detector_model_path = None
spoof_detector_eval = None # Initialize to None

# --- Check for Model Files and Attempt Loading ---
print("Checking for model files...")
if os.path.exists(best_model_filename):
    print(f"Found best model file: {best_model_filename}")
    detector_model_path = best_model_filename
elif os.path.exists(final_model_filename):
    print(f"Best model not found. Found final model file: {final_model_filename}")
    detector_model_path = final_model_filename
else:
    print(f"Error: Could not find '{best_model_filename}' or '{final_model_filename}' in the current directory.")
    print("Please ensure Cell 13 completed successfully and saved the model.")

# Proceed only if a model file was found
if detector_model_path:
    try:
        print(f"\nAttempting to load model: {detector_model_path}")
        custom_objects = {'SelfAttention': SelfAttention}
        spoof_detector_eval = load_model(
            detector_model_path,
            custom_objects=custom_objects
        )
        print("Model loaded successfully.")

        # Re-compile the model after loading (important for evaluation)
        # Use the same optimizer settings as in Cell 13 if continuing training,
        # but for evaluation, just compiling with loss/metrics is usually enough.
        print("Compiling loaded model...")
        spoof_detector_eval.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # Use a dummy optimizer or same as training
            loss='binary_crossentropy',
            metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
        )
        print("Model compiled successfully.")

    except Exception as e:
        print(f"\n--- ERROR DURING MODEL LOADING OR COMPILATION ---")
        print(f"Error details: {e}")
        import traceback
        traceback.print_exc() # Print detailed traceback
        print("----------------------------------------------------")
        print("Evaluation cannot proceed due to loading/compilation error.")
        spoof_detector_eval = None # Ensure it's None on error

# --- Perform Evaluation if Model Loaded ---
if spoof_detector_eval:
    print("\nProceeding with evaluation...")
    # --- Create Evaluation Generator ---
    # Ensure data_generator_classifier and classifier_batch_size are defined
    eval_gen_clf = data_generator_classifier(eval_data_path, batch_size=classifier_batch_size, shuffle=False)

    # --- Calculate Evaluation Steps ---
    # Ensure count_total_files and eval_data_path are defined
    eval_samples_count = count_total_files(eval_data_path)
    if classifier_batch_size <= 0:
        raise ValueError("Classifier batch size must be positive.")
    eval_steps = int(np.ceil(eval_samples_count / float(classifier_batch_size))) if eval_samples_count > 0 else 0
    print(f"Using {eval_steps} steps for evaluation.")

    # --- Evaluate ---
    if eval_steps > 0:
        print("Running model.evaluate...")
        results = spoof_detector_eval.evaluate(
            eval_gen_clf,
            steps=eval_steps,
            verbose=1
        )

        print("\n--- Evaluation Results ---")
        # Assuming results format [loss, accuracy, auc] based on compile metrics
        if isinstance(results, list) and len(results) >= 3:
             print(f"Loss: {results[0]:.4f}")
             print(f"Accuracy: {results[1]:.4f}")
             print(f"AUC: {results[2]:.4f}")
        else:
             print(f"Raw Results: {results}") # Print raw results if format is unexpected
        print("--------------------------")
    else:
        print("Skipping evaluation because eval_steps is zero (no evaluation data found?).")

else:
    # This message will now appear if file not found OR if loading failed
    print("\nSkipping evaluation because the spoof detector model could not be loaded.")

# --- End of Cell 14 ---

In [None]:
# Cell 15: Reporting (Confusion Matrix, EER, t-DCF)

print("\nGenerating Final Reports (F1, CM, EER, t-DCF)...")

# Ensure the evaluation model is loaded from Cell 14
if 'spoof_detector_eval' not in locals() or spoof_detector_eval is None:
    print("Spoof detector model not loaded. Cannot generate reports.")
else:
    # --- Parameters for t-DCF ---
    p_target = 0.05  # Prior probability of target (real=1) - Adjust based on ASVspoof challenge or use case
    c_miss = 1       # Cost of miss (classifying real as fake - FN)
    c_false_alarm = 1 # Cost of false alarm (classifying fake as real - FP)

    # --- Regenerate Predictions ---
    # Reset the eval generator
    eval_gen_report = data_generator_classifier(eval_data_path, batch_size=classifier_batch_size, shuffle=False)

    # Recalculate eval_steps if needed (should match Cell 14)
    eval_samples_count_report = count_total_files(eval_data_path)
    eval_steps_report = int(np.ceil(eval_samples_count_report / float(classifier_batch_size))) if eval_samples_count_report > 0 else 0

    y_pred_scores = []
    y_true_labels = []

    if eval_steps_report > 0:
        print(f"Generating predictions using {eval_steps_report} steps...")
        for _ in tqdm(range(eval_steps_report), desc="Predicting"):
            try:
                # Generator yields (batch_x, batch_y, batch_weights)
                batch_x, batch_y, _ = next(eval_gen_report)
                if batch_x.size == 0: continue # Skip empty batches if they occur

                # Use predict, not predict_on_batch for potentially better performance over many batches
                batch_pred = spoof_detector_eval.predict(batch_x, verbose=0)
                y_pred_scores.extend(batch_pred.flatten())
                y_true_labels.extend(batch_y)
            except StopIteration:
                print("Evaluation generator stopped.")
                break
            except Exception as e:
                print(f"Error during prediction generation: {e}")
                continue

        y_pred_scores = np.array(y_pred_scores).astype(np.float32)
        y_true_labels = np.array(y_true_labels).astype(np.int32)

        # Ensure lengths match (might be off if last batch was incomplete and generator didn't handle it perfectly)
        min_len = min(len(y_pred_scores), len(y_true_labels))
        if min_len == 0:
            print("No predictions or labels collected. Cannot generate reports.")
        else:
            y_pred_scores = y_pred_scores[:min_len]
            y_true_labels = y_true_labels[:min_len]

            # --- Calculations ---
            # Binary predictions for F1 and CM
            y_pred_binary = (y_pred_scores > 0.5).astype(int)

            # F1 Score
            f1 = f1_score(y_true_labels, y_pred_binary)
            print(f"\nF1 Score (threshold 0.5): {f1:.4f}")

            # Confusion Matrix
            cm = confusion_matrix(y_true_labels, y_pred_binary)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Fake (0)', 'Real (1)'], yticklabels=['Fake (0)', 'Real (1)'])
            plt.title('Confusion Matrix (Counts)')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.savefig(os.path.join(FIGURES_DIR, 'confusion_matrix_counts_wgan_sa.png'))
            plt.show()

            # Confusion Matrix (Percentages)
            cm_sum = np.sum(cm, axis=1, keepdims=True)
            cm_perc = cm / cm_sum.astype(float) * 100
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm_perc, annot=True, fmt='.2f', cmap='Greens', xticklabels=['Fake (0)', 'Real (1)'], yticklabels=['Fake (0)', 'Real (1)'])
            plt.title('Confusion Matrix (Row Percentages)')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.savefig(os.path.join(FIGURES_DIR, 'confusion_matrix_perc_wgan_sa.png'))
            plt.show()


            # EER Calculation
            fpr, tpr, thresholds_roc = roc_curve(y_true_labels, y_pred_scores, pos_label=1)
            fnr = 1 - tpr
            eer_index = np.nanargmin(np.abs(fnr - fpr))
            eer_threshold = thresholds_roc[eer_index]
            eer = fpr[eer_index] # Or use (fpr[eer_index] + fnr[eer_index]) / 2
            print(f"EER: {eer:.4f} at threshold {eer_threshold:.4f}")

            # Plot ROC Curve
            plt.figure(figsize=(8, 6))
            plt.plot(fpr, tpr, label=f'ROC curve (AUC = {results[2]:.4f})') # Use AUC from evaluate
            plt.plot(fpr, fnr, label='FN Rate')
            plt.plot([0, 1], [0, 1], 'k--') # Random guess line
            plt.scatter(fpr[eer_index], tpr[eer_index], color='red', zorder=5, label=f'EER = {eer:.4f}')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate (FPR)')
            plt.ylabel('True Positive Rate (TPR)')
            plt.title('Receiver Operating Characteristic (ROC)')
            plt.legend(loc="lower right")
            plt.grid(True)
            plt.savefig(os.path.join(FIGURES_DIR, 'roc_curve_wgan_sa.png'))
            plt.show()


            # t-DCF Calculation
            # Define function (can be moved to utils if used often)
            def calculate_t_dcf(y_true, y_scores, p_target, c_miss, c_fa, thresholds):
                """Calculates normalized t-DCF for a range of thresholds."""
                num_thresholds = len(thresholds)
                dcf_values = np.zeros(num_thresholds)

                num_real = np.sum(y_true == 1)
                num_fake = np.sum(y_true == 0)

                if num_real == 0 or num_fake == 0:
                     print("Warning: Cannot calculate t-DCF with zero samples in a class.")
                     return np.inf # Return infinity or handle as error

                for i, thr in enumerate(thresholds):
                    y_pred_binary = (y_scores >= thr).astype(int)
                    tn, fp, fn, tp = confusion_matrix(y_true, y_pred_binary, labels=[0, 1]).ravel()

                    p_miss = fn / num_real if num_real > 0 else 0 # Miss Rate (FN / Total Real)
                    p_fa = fp / num_fake if num_fake > 0 else 0   # False Alarm Rate (FP / Total Fake)

                    cost = (c_miss * p_miss * p_target) + (c_fa * p_fa * (1 - p_target))

                    # Normalize the cost
                    dcf_norm = min(c_miss * p_target, c_fa * (1 - p_target))
                    dcf_values[i] = cost / dcf_norm if dcf_norm > 0 else cost # Avoid division by zero

                # Find the minimum t-DCF
                min_dcf_index = np.argmin(dcf_values)
                min_dcf = dcf_values[min_dcf_index]
                min_dcf_threshold = thresholds[min_dcf_index]

                return min_dcf, min_dcf_threshold

            # Calculate min t-DCF over a range of thresholds (more robust than just EER threshold)
            # Use ROC thresholds, but filter out +/- inf if present
            valid_thresholds = thresholds_roc[np.isfinite(thresholds_roc)]
            min_tdcf, min_tdcf_thresh = calculate_t_dcf(y_true_labels, y_pred_scores, p_target, c_miss, c_false_alarm, valid_thresholds)

            print(f"Minimum t-DCF: {min_tdcf:.4f} at threshold {min_tdcf_thresh:.4f}")

    else:
         print("Skipping report generation because eval_steps_report is zero.")