In [1]:
# Cell 1: Imports
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, GRU, Dense, Dropout, BatchNormalization, LayerNormalization, Reshape, Permute, Bidirectional, Add, Attention, Flatten, TimeDistributed, Conv2DTranspose, Conv2D, Layer, Concatenate, Multiply, AdditiveAttention # Added Multiply, AdditiveAttention for potential SelfAttention implementation
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam, RMSprop # Added RMSprop as an option for WGAN
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, Callback
from tensorflow.keras import backend as K
from sklearn.metrics import confusion_matrix, f1_score, roc_curve # Moved confusion_matrix here
import librosa
import soundfile as sf
# import noisereduce as nr # Consider if still needed/effective
import matplotlib.pyplot as plt
from scipy.signal import butter, sosfilt
import logging
from tqdm.notebook import tqdm # Use tqdm.notebook for better Jupyter integration
import random
import seaborn as sns # Added for confusion matrix plotting
from scipy.interpolate import interp1d # Added for EER/t-DCF

# WGAN-GP specific
from tensorflow import GradientTape

# Create directory for saving figures
FIGURES_DIR = 'training_figures_wgan_sa' # Changed dir name
os.makedirs(FIGURES_DIR, exist_ok=True)

# Configure logging
logging.basicConfig(filename='audio_errors_wgan_sa.log', level=logging.ERROR,
                    format='%(asctime)s - %(levelname)s - %(message)s')

print(f"TensorFlow version: {tf.__version__}")

2025-04-03 22:22:08.138790: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-03 22:22:08.149608: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743699128.163480  120890 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743699128.167670  120890 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743699128.178689  120890 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

TensorFlow version: 2.19.0


In [2]:
# Cell 2: Force GPU usage
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    try:
        tf.config.set_visible_devices(physical_devices[0], 'GPU')
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        print("GPU is available and configured.")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(f"Error configuring GPU: {e}")
else:
    print("No GPU devices found. Falling back to CPU.")

# Set mixed precision policy if desired (can speed up training on compatible GPUs)
#from tensorflow.keras import mixed_precision
#policy = mixed_precision.Policy('mixed_float16')
#mixed_precision.set_global_policy(policy)
#print('Mixed precision enabled')

GPU is available and configured.


In [3]:
# Cell 3: Audio Loading and Preprocessing Function
def load_and_preprocess_audio(file_path, sr=16000, duration=4):
    try:
        # Load audio, potentially reducing noise first if beneficial
        audio, current_sr = librosa.load(file_path, sr=None, duration=duration) # Load native SR first

        # Optional: Noise Reduction (experiment if needed)
        # audio = nr.reduce_noise(y=audio, sr=current_sr)

        # Resample if necessary
        if current_sr != sr:
            audio = librosa.resample(audio, orig_sr=current_sr, target_sr=sr)

        # Pad or truncate to fixed duration *before* augmentation/normalization
        target_len = sr * duration
        if len(audio) < target_len:
            audio = np.pad(audio, (0, target_len - len(audio)), mode='constant')
        else:
            audio = audio[:target_len]

        # Data Augmentation (applied *before* normalization)
        if np.random.random() < 0.5: # 50% chance
            augmentation_type = np.random.choice(['noise', 'pitch', 'speed'])
            if augmentation_type == 'noise':
                noise_amp = 0.005 * np.random.uniform(0.1, 1.0) * np.max(np.abs(audio)) # Scale noise relative to audio
                noise = np.random.randn(len(audio)) * noise_amp
                audio = audio + noise
            elif augmentation_type == 'pitch':
                pitch_shift_steps = np.random.uniform(-2.5, 2.5)
                audio = librosa.effects.pitch_shift(audio, sr=sr, n_steps=pitch_shift_steps)
            else: # speed (time stretch)
                speed_rate = np.random.uniform(0.85, 1.15)
                audio = librosa.effects.time_stretch(audio, rate=speed_rate)
                 # Time stretching changes length, re-pad/truncate
                if len(audio) < target_len:
                    audio = np.pad(audio, (0, target_len - len(audio)), mode='constant')
                else:
                    audio = audio[:target_len]

        # Normalize audio (peak normalization) - crucial for consistency
        max_amp = np.max(np.abs(audio))
        if max_amp > 1e-6: # Avoid division by zero
             audio = audio / max_amp
        # Optional: RMS normalization instead
        # rms = np.sqrt(np.mean(audio**2))
        # if rms > 1e-6:
        #    audio = audio / rms * 0.5 # Scale to a target RMS

        return audio
    except Exception as e:
        logging.error(f"Error loading/preprocessing {file_path}: {e}")
        print(f"Error loading/preprocessing {file_path}: {e}")
        return None

In [4]:
# Cell 4: Feature Extraction Function
def extract_features(audio, sr=16000, n_mels=80, n_fft=2048, hop_length=512):
    if audio is None:
        return None
    try:
        # Extract mel spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=sr,
            n_mels=n_mels,
            n_fft=n_fft,
            hop_length=hop_length
        )
        # Convert to log scale (dB)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        # Normalize features (per spectrogram) - Standard Scaling
        mean = np.mean(log_mel_spec)
        std = np.std(log_mel_spec)
        if std > 1e-6: # Avoid division by zero
            log_mel_spec = (log_mel_spec - mean) / std
        else:
            log_mel_spec = log_mel_spec - mean # Just center if std is near zero

        # Ensure the shape is consistent (should be handled by fixed duration loading)
        # Expected frames: int(np.ceil(target_len / hop_length)) -> int(ceil(16000*4/512)) = 125? Check calculation
        # target_len = 16000 * 4 = 64000
        # expected_frames = 64000 // hop_length + 1 if 64000 % hop_length != 0 else 64000 // hop_length
        # expected_frames = 64000 / 512 = 125. Check librosa padding. It often adds a frame. Let's stick to TARGET_FRAMES=126 based on previous findings.

        return log_mel_spec # Shape (n_mels, n_frames) e.g. (80, 126)

    except Exception as e:
        logging.error(f"Error extracting features: {e}")
        print(f"Error extracting features: {e}")
        return None

In [5]:
# Cell 5: Class Distribution Analysis Function
def analyze_class_distribution(data_path):
    try:
        real_dir = os.path.join(data_path, 'real')
        fake_dir = os.path.join(data_path, 'fake')
        real_count = len([f for f in os.listdir(real_dir) if f.endswith('.wav')]) if os.path.exists(real_dir) else 0
        fake_count = len([f for f in os.listdir(fake_dir) if f.endswith('.wav')]) if os.path.exists(fake_dir) else 0
        total = real_count + fake_count
        if total == 0:
            print(f"\nNo .wav files found in {data_path}")
            return {'real': 0, 'fake': 0}
        print(f"\nClass Distribution for {data_path}:")
        print(f"Real: {real_count} ({real_count/total*100:.2f}%)")
        print(f"Fake: {fake_count} ({fake_count/total*100:.2f}%)")
        return {'real': real_count, 'fake': fake_count}
    except FileNotFoundError:
        print(f"\nError: Data path not found - {data_path}")
        return {'real': 0, 'fake': 0}
    except Exception as e:
        logging.error(f"Error analyzing distribution for {data_path}: {e}")
        print(f"Error analyzing distribution for {data_path}: {e}")
        return {'real': 0, 'fake': 0}

In [6]:
# Cell 6: Data Generators

# Define the fixed number of frames for GAN/Classifier consistency
TARGET_FRAMES = 126 # Recalculate based on sr=16000, duration=4, hop_length=512 if needed

# Data generator for STANDALONE CLASSIFIER training (Yields X, y, sample_weights)
def data_generator_classifier(data_path, batch_size=32, shuffle=True, target_frames=TARGET_FRAMES, sr=16000, duration=4, n_mels=80, n_fft=2048, hop_length=512):
    real_files = []
    fake_files = []
    try:
        real_dir = os.path.join(data_path, 'real')
        fake_dir = os.path.join(data_path, 'fake')
        if os.path.exists(real_dir):
             real_files = [os.path.join(real_dir, f) for f in os.listdir(real_dir) if f.endswith('.wav')]
        if os.path.exists(fake_dir):
             fake_files = [os.path.join(fake_dir, f) for f in os.listdir(fake_dir) if f.endswith('.wav')]
    except FileNotFoundError as e:
        print(f"Error finding directories in {data_path}: {e}")
        # Optionally raise the error or yield nothing
        # raise e
        return # Stop the generator

    all_files = real_files + fake_files
    labels = [1] * len(real_files) + [0] * len(fake_files) # Real=1, Fake=0

    if not all_files:
        print(f"No WAV files found in {data_path}. Generator stopping.")
        return

    total_samples = len(all_files)
    # Calculate class weights (inverse frequency) - useful for imbalanced datasets
    class_weights = {
        1: total_samples / (2 * len(real_files)) if len(real_files) > 0 else 1.0, # Handle division by zero
        0: total_samples / (2 * len(fake_files)) if len(fake_files) > 0 else 1.0,
    }
    print(f"Using class weights: {class_weights} for path {data_path}")


    indices = np.arange(total_samples)
    while True:
        if shuffle:
            np.random.shuffle(indices)

        for i in range(0, total_samples, batch_size):
            batch_indices = indices[i:i+batch_size]
            batch_files = [all_files[k] for k in batch_indices]
            batch_labels = [labels[k] for k in batch_indices]

            batch_x = []
            batch_y = []
            batch_sample_weights = []

            for file_path, label in zip(batch_files, batch_labels):
                audio = load_and_preprocess_audio(file_path, sr=sr, duration=duration)
                if audio is None: continue # Skip failed loads

                features = extract_features(audio, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) # Shape (n_mels, n_frames)
                if features is None: continue # Skip failed feature extraction

                # Pad or truncate features to target_frames (redundant if load_and_preprocess handles length)
                current_frames = features.shape[1]
                if current_frames != target_frames:
                     # This indicates an issue in load_and_preprocess or feature extraction settings
                     # print(f"Warning: Feature frames mismatch for {file_path}. Expected {target_frames}, got {current_frames}. Padding/Truncating.") # Log this if it happens often
                     if current_frames < target_frames:
                        pad_width = target_frames - current_frames
                        padded_features = np.pad(features, ((0, 0), (0, pad_width)), mode='constant')
                     else:
                        padded_features = features[:, :target_frames]
                else:
                    padded_features = features

                batch_x.append(padded_features)
                batch_y.append(label)
                batch_sample_weights.append(class_weights[label])

            if batch_x: # Only yield if batch is not empty
                # Add channel dimension for Conv2D models
                batch_x_4d = np.expand_dims(np.array(batch_x), axis=-1).astype(np.float32)
                batch_y_arr = np.array(batch_y).astype(np.float32)
                batch_weights_arr = np.array(batch_sample_weights).astype(np.float32)
                yield batch_x_4d, batch_y_arr, batch_weights_arr


# Data generator for WGAN training (Yields real fake samples X only)
def data_generator_gan(data_path, batch_size=32, shuffle=True, target_frames=TARGET_FRAMES, sr=16000, duration=4, n_mels=80, n_fft=2048, hop_length=512):
    fake_files = []
    try:
        fake_dir = os.path.join(data_path, 'fake')
        if os.path.exists(fake_dir):
             fake_files = [os.path.join(fake_dir, f) for f in os.listdir(fake_dir) if f.endswith('.wav')]
    except FileNotFoundError as e:
        print(f"Error finding fake directory in {data_path}: {e}")
        return

    if not fake_files:
        print(f"No fake WAV files found in {data_path}. GAN generator stopping.")
        return

    total_samples = len(fake_files)
    indices = np.arange(total_samples)

    while True:
        if shuffle:
            np.random.shuffle(indices)

        for i in range(0, total_samples, batch_size):
            batch_indices = indices[i:i+batch_size]
            batch_files = [fake_files[k] for k in batch_indices]

            batch_x = []

            for file_path in batch_files:
                audio = load_and_preprocess_audio(file_path, sr=sr, duration=duration)
                if audio is None: continue

                features = extract_features(audio, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
                if features is None: continue

                current_frames = features.shape[1]
                if current_frames != target_frames:
                    # print(f"Warning: Feature frames mismatch for {file_path} in GAN generator. Expected {target_frames}, got {current_frames}. Padding/Truncating.")
                    if current_frames < target_frames:
                        pad_width = target_frames - current_frames
                        padded_features = np.pad(features, ((0, 0), (0, pad_width)), mode='constant')
                    else:
                        padded_features = features[:, :target_frames]
                else:
                     padded_features = features

                batch_x.append(padded_features)

            if batch_x:
                # Add channel dimension
                batch_x_4d = np.expand_dims(np.array(batch_x), axis=-1).astype(np.float32)
                yield batch_x_4d # Yield only the features (4D array)

In [7]:
# Cell 7: Self-Attention Layer (SAGAN style)

class SelfAttention(Layer):
    """
    Self-attention layer based on SAGAN.
    Input shape: (batch, height, width, channels)
    Output shape: (batch, height, width, channels_out) where channels_out is typically channels
    """
    def __init__(self, channels_out=None, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.channels_out = channels_out

    def build(self, input_shape):
        self.input_channels = input_shape[-1]
        if self.channels_out is None:
            self.channels_out = self.input_channels

        # Convolution layers for query, key, value
        # Use 1x1 convolutions to reduce/transform channels
        self.f = Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_f') # Query
        self.g = Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_g') # Key
        self.h = Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_h')        # Value

        # Final 1x1 convolution
        self.out_conv = Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_out')

        # Learnable scale parameter
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)

        super(SelfAttention, self).build(input_shape)

    def call(self, x):
        batch_size, height, width, num_channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        location_num = height * width
        downsampled_num = location_num

        # Query (f), Key (g), Value (h) projections
        f_proj = self.f(x) # Shape: (batch, h, w, c/8)
        g_proj = self.g(x) # Shape: (batch, h, w, c/8)
        h_proj = self.h(x) # Shape: (batch, h, w, c_out)

        # Reshape for matrix multiplication
        f_flat = tf.reshape(f_proj, shape=(batch_size, location_num, self.input_channels // 8)) # (batch, h*w, c/8)
        g_flat = tf.reshape(g_proj, shape=(batch_size, location_num, self.input_channels // 8)) # (batch, h*w, c/8)
        h_flat = tf.reshape(h_proj, shape=(batch_size, location_num, self.channels_out))       # (batch, h*w, c_out)

        # Attention map calculation
        # Transpose g for matmul: (batch, c/8, h*w)
        g_flat_t = tf.transpose(g_flat, perm=[0, 2, 1])
        # Attention score: (batch, h*w, c/8) x (batch, c/8, h*w) -> (batch, h*w, h*w)
        attention_score = tf.matmul(f_flat, g_flat_t)
        attention_prob = tf.nn.softmax(attention_score, axis=-1) # Apply softmax across locations

        # Apply attention map to value projection
        # (batch, h*w, h*w) x (batch, h*w, c_out) -> (batch, h*w, c_out)
        attention_output = tf.matmul(attention_prob, h_flat)

        # Reshape back to image format
        attention_output_reshaped = tf.reshape(attention_output, shape=(batch_size, height, width, self.channels_out))

        # Apply final 1x1 convolution and scale by gamma
        o = self.out_conv(attention_output_reshaped)
        y = self.gamma * o + x # Additive skip connection

        return y

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.channels_out,)

In [8]:
# Cell 8: Generator Model (with Self-Attention)

def create_generator(latent_dim, output_shape): # output_shape (n_mels, n_frames) e.g., (80, 126)
    """Creates the generator model with Self-Attention."""
    n_mels, n_frames = output_shape
    init_h, init_w = n_mels // 8, n_frames // 8 # Calculate initial dimensions based on 3 upsamples (2*2*2=8)
    init_c = 128 # Initial channels

    if init_h * 8 != n_mels or init_w * 8 != n_frames:
         print(f"Warning: Output shape {output_shape} might not be perfectly reached with 3 strides of 2. Adjusting initial size or layers.")
         # Adjust init_w slightly if needed, e.g. target 128 -> 16, target 126 -> 16 (trim later)
         init_w = (n_frames + 7) // 8 # Ceiling division equivalent for width

    nodes = init_h * init_w * init_c

    model = Sequential(name='generator')
    model.add(Input(shape=(latent_dim,)))

    # Dense layer and reshape
    model.add(Dense(nodes))
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2)) # Use negative_slope
    model.add(Reshape((init_h, init_w, init_c))) # e.g., (10, 16, 128)

    # Upsample 1: (10, 16, 128) -> (20, 32, 64)
    model.add(Conv2DTranspose(init_c // 2, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False)) # Typically use_bias=False with Norm
    model.add(LayerNormalization()) # Using LayerNorm instead of BatchNorm
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))

    # Upsample 2: (20, 32, 64) -> (40, 64, 32)
    model.add(Conv2DTranspose(init_c // 4, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(LayerNormalization())
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))

    # Add Self-Attention Layer Here (applied to 40x64 feature map)
    # Note: Attention can be computationally expensive. Apply strategically.
    model.add(SelfAttention(channels_out=init_c // 4)) # Keep channels the same
    # model.add(LayerNormalization()) # Optional normalization after attention
    # model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2)) # Optional activation after attention

    # Upsample 3: (40, 64, 32) -> (80, 128, 1)
    model.add(Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), padding='same', activation='tanh'))

    # Final adjustment layer if needed (e.g., width 128 -> 126)
    current_width = init_w * 8
    if current_width != n_frames:
        print(f"Generator adding final Conv2D to adjust width from {current_width} to {n_frames}")
        # Calculate kernel size needed for 'valid' padding: K = W_in - W_out + 1
        kernel_w = current_width - n_frames + 1
        if kernel_w > 0:
             model.add(Conv2D(1, kernel_size=(1, kernel_w), padding='valid', activation='tanh'))
        else:
            # This case shouldn't happen with the init_w calculation but handle just in case
             print(f"Warning: Could not adjust width with Conv2D. Current {current_width}, Target {n_frames}")
             # May need padding='same' and cropping layer if kernel_w is not positive

    # Ensure final output shape is correct
    # model.add(Reshape((n_mels, n_frames, 1))) # Add reshape just to be certain, though last Conv should handle it

    return model

In [9]:
# Cell 9: Critic (Discriminator) Model (with Self-Attention)

def create_critic(input_shape): # input_shape (n_mels, n_frames) e.g., (80, 126)
    """Creates the Critic model for WGAN-GP with Self-Attention."""
    n_mels, n_frames = input_shape
    model_input_shape = (n_mels, n_frames, 1) # Expects (80, 126, 1)

    model = Sequential(name='critic')
    model.add(Input(shape=model_input_shape))

    # Layer 1
    model.add(Conv2D(64, kernel_size=(4, 4), strides=(2, 2), padding='same')) # Increased filters
    # No Batch/Layer Norm typically in WGAN-GP critic input layer
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    model.add(Dropout(0.25))

    # Layer 2
    model.add(Conv2D(128, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    # WGAN-GP often avoids norm layers in critic, but LayerNorm can sometimes help stability
    model.add(LayerNormalization())
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    model.add(Dropout(0.25))

    # Add Self-Attention Layer Here (applied to 20x32 feature map if input is 80x128->40x64->20x32)
    # Check actual size based on input_shape and strides
    # Example placement after layer 2
    model.add(SelfAttention(channels_out=128)) # Keep channels the same
    # model.add(LayerNormalization()) # Optional
    # model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2)) # Optional

    # Layer 3
    model.add(Conv2D(256, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False)) # Increased filters
    model.add(LayerNormalization())
    model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    model.add(Dropout(0.25))

    # Layer 4 (Optional additional conv layer)
    # model.add(Conv2D(512, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    # model.add(LayerNormalization())
    # model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    # model.add(Dropout(0.25))

    # Flatten and Output Score (Linear Activation for WGAN Critic)
    model.add(Flatten())
    # model.add(Dense(512)) # Optional dense layer before output
    # model.add(LayerNormalization())
    # model.add(tf.keras.layers.LeakyReLU(negative_slope=0.2))
    model.add(Dense(1)) # Output a single scalar score (no activation)

    return model

In [10]:
# Cell 10: Define the GAN Model (Not used for WGAN-GP training loop)
# def create_gan(generator, discriminator, latent_dim):
#     """Creates the combined GAN model."""
#     # Make discriminator non-trainable
#     discriminator.trainable = False
#
#     # Stack generator and discriminator
#     gan_input = Input(shape=(latent_dim,))
#     gan_output = discriminator(generator(gan_input))
#     gan = Model(gan_input, gan_output)
#
#     return gan

In [11]:
# Cell 11: Data Path and Parameters

# Data Paths
train_data_path = 'datasetNEW/train'
dev_data_path = 'datasetNEW/dev'
eval_data_path = 'datasetNEW/eval'

# Define the fixed number of frames
TARGET_FRAMES = 126
N_MELS = 80
mel_spectrogram_shape = (N_MELS, TARGET_FRAMES)

# WGAN-GP specific parameters
latent_dim = 100
n_critic = 5         # Train critic 5 times per generator update
gp_weight = 10.0     # Gradient penalty weight
gan_epochs = 75      # WGAN often needs more epochs, adjust as needed
gan_batch_size = 8  # Adjust based on GPU memory (WGAN-GP can be memory intensive)

# Optimizers (Typical WGAN settings: lower LR, betas=(0.5, 0.9))
critic_optimizer = Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.9)
generator_optimizer = Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.9)
# critic_optimizer = RMSprop(learning_rate=0.00005) # Alternative optimizer
# generator_optimizer = RMSprop(learning_rate=0.00005)

# Create instances
generator = create_generator(latent_dim, mel_spectrogram_shape)
critic = create_critic(mel_spectrogram_shape)

# Diagnostic Code: Verify Output Shape
test_noise = tf.random.normal((1, latent_dim))
generated_image = generator(test_noise, training=False) # Use tf.random and call model directly
critic_output = critic(generated_image, training=False)
print("Shape of generated image (Generator Output):", generated_image.shape)
print("Shape of critic output:", critic_output.shape)

critic_input_shape_expected = (mel_spectrogram_shape[0], mel_spectrogram_shape[1], 1)
print("Expected shape for Critic Input:", critic_input_shape_expected)
assert generated_image.shape[1:] == critic_input_shape_expected, "Generator output shape mismatch!"
assert len(critic_output.shape) == 2 and critic_output.shape[1] == 1, "Critic output shape mismatch!"


# Report the models
print("\n--- Generator Summary ---")
generator.summary()
print("\n--- Critic Summary ---")
critic.summary()

# Parameters for standalone classifier training (can be different)
classifier_batch_size = 8 # Keep smaller for classifier fine-tuning?
classifier_epochs = 50

I0000 00:00:1743699130.652512  120890 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1904 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Ti Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


Generator adding final Conv2D to adjust width from 128 to 126


I0000 00:00:1743699131.351937  120890 cuda_dnn.cc:529] Loaded cuDNN version 90300


Shape of generated image (Generator Output): (1, 80, 126, 1)
Shape of critic output: (1, 1)
Expected shape for Critic Input: (80, 126, 1)

--- Generator Summary ---



--- Critic Summary ---


In [None]:
# Cell 12: WGAN-GP Training Loop (Corrected for Mixed Precision Compatibility)

# --- Loss Functions ---
def critic_loss(real_output, fake_output):
    """Wasserstein loss for the critic."""
    # Ensure outputs are float32 for stable loss calculation
    real_output = tf.cast(real_output, tf.float32)
    fake_output = tf.cast(fake_output, tf.float32)
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    """Wasserstein loss for the generator."""
    # Ensure output is float32
    fake_output = tf.cast(fake_output, tf.float32)
    return -tf.reduce_mean(fake_output)

# --- Gradient Penalty Function (Corrected) ---
def gradient_penalty(batch_size, real_images, fake_images):
    """ Calculates the gradient penalty loss for WGAN GP, handling mixed precision. """
    # Ensure both images have the same dtype before interpolation
    # Typically, fake_images might be float16 if mixed precision is on.
    # Cast real_images to match fake_images' dtype.
    if real_images.dtype != fake_images.dtype:
        # print(f"Casting real_images from {real_images.dtype} to {fake_images.dtype} in gradient_penalty") # Optional debug print
        real_images = tf.cast(real_images, fake_images.dtype)

    # Generate interpolation alpha with the correct dtype
    # Ensure alpha shape matches rank of images for broadcasting
    alpha_shape = [tf.shape(real_images)[0]] + [1] * (len(real_images.shape) - 1)
    alpha = tf.random.uniform(shape=alpha_shape, minval=0., maxval=1., dtype=real_images.dtype)

    # Interpolate images
    interpolated = real_images + alpha * (fake_images - real_images)

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        # 1. Get the critic output for this interpolated image.
        pred = critic(interpolated, training=True)
        # If using mixed precision, the prediction itself might be float16.
        # It's often safer to compute the final GP loss in float32.
        pred = tf.cast(pred, tf.float32)

    # 2. Calculate the gradients w.r.t to this interpolated image.
    # Ensure pred is connected to interpolated in the graph
    grads = gp_tape.gradient(pred, [interpolated]) # grads is a list
    if grads is None or grads[0] is None:
        # Handle case where gradients are None (e.g., if interpolated is not connected to pred)
        logging.warning("Gradients are None in gradient_penalty. Returning 0 penalty.")
        print("Warning: Gradients are None in gradient_penalty. Returning 0 penalty.")
        return tf.constant(0.0, dtype=tf.float32)
    grads = grads[0] # Get the tensor from the list

    # Cast gradients to float32 before norm calculation for stability
    grads = tf.cast(grads, tf.float32)

    # 3. Calculate the norm of the gradients. Sum over H, W, C axes.
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=tf.range(1, tf.rank(grads))))
    gp = tf.reduce_mean((norm - 1.0) ** 2)

    return gp * gp_weight # gp_weight should be float


# --- Data Generator and Steps Calculation ---
# Create GAN data generator instance
train_gen_gan = data_generator_gan(train_data_path, batch_size=gan_batch_size)

# Calculate steps per epoch for GAN training
fake_files_count = 0
try:
    fake_dir = os.path.join(train_data_path, 'fake')
    if os.path.exists(fake_dir):
        fake_files_count = len([f for f in os.listdir(fake_dir) if f.endswith('.wav')])
except FileNotFoundError:
    fake_files_count = 0

if gan_batch_size <= 0:
    raise ValueError("GAN Batch size must be positive.")
if fake_files_count == 0:
    print("Warning: No fake training files found for GAN. Setting GAN steps to 0.")
    gan_steps_per_epoch = 0
else:
    # Use ceiling division
    gan_steps_per_epoch = int(np.ceil(fake_files_count / float(gan_batch_size)))
    print(f"Calculated {gan_steps_per_epoch} GAN steps per epoch.")


# --- Training Step Function (Decorated with tf.function) ---
@tf.function
def train_step(real_images):
    current_batch_size = tf.shape(real_images)[0]
    noise = tf.random.normal([current_batch_size, latent_dim])

    # Train Critic (n_critic times)
    # Use a loop *inside* tf.function (AutoGraph will handle it)
    # We need to track the loss from the *last* critic iteration for reporting
    total_crit_loss = tf.constant(0.0, dtype=tf.float32)
    for _ in tf.range(n_critic): # Use tf.range inside tf.function
        with tf.GradientTape() as crit_tape:
            # Generate fake images (potentially float16)
            fake_images = generator(noise, training=True)
            # Get scores for real (float32 input) and fake images
            real_output = critic(real_images, training=True) # Critic output might be float16 or float32
            fake_output = critic(fake_images, training=True) # Critic output might be float16 or float32

            # Calculate base critic loss (handles casting inside)
            crit_loss = critic_loss(real_output, fake_output)
            # Calculate gradient penalty (handles casting inside)
            gp = gradient_penalty(current_batch_size, real_images, fake_images)
            # Total critic loss (ensure consistent dtype, should be float32 due to loss functions)
            total_crit_loss = crit_loss + gp # Store loss from this iteration

        # Calculate gradients for critic and apply them
        crit_gradients = crit_tape.gradient(total_crit_loss, critic.trainable_variables)

        # Filter out None gradients before applying
        valid_grads_and_vars = [
            (g, v) for g, v in zip(crit_gradients, critic.trainable_variables) if g is not None
        ]
        if len(valid_grads_and_vars) < len(critic.trainable_variables):
             tf.print("Warning: Some critic gradients are None.") # Use tf.print inside tf.function
        if len(valid_grads_and_vars) > 0: # Check if there are any valid gradients
            critic_optimizer.apply_gradients(valid_grads_and_vars)

    # Train Generator (once per train_step call, after n_critic updates)
    with tf.GradientTape() as gen_tape:
        # Generate fake images again (within generator's tape)
        fake_images_gen = generator(noise, training=True) # Use a different variable name if needed
        # Get critic's score on these fake images
        fake_output_gen = critic(fake_images_gen, training=True)
        # Calculate generator loss (handles casting inside)
        gen_loss = generator_loss(fake_output_gen)

    # Calculate gradients for generator and apply them
    gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)

    # Filter out None gradients before applying
    valid_grads_and_vars_gen = [
        (g, v) for g, v in zip(gen_gradients, generator.trainable_variables) if g is not None
    ]
    if len(valid_grads_and_vars_gen) < len(generator.trainable_variables):
         tf.print("Warning: Some generator gradients are None.")
    if len(valid_grads_and_vars_gen) > 0:
        generator_optimizer.apply_gradients(valid_grads_and_vars_gen)

    # Return the critic loss from the LAST iteration and the generator loss
    return total_crit_loss, gen_loss


# --- WGAN-GP Training Loop ---
print("\nBegin WGAN-GP training!")

# Check if generator exists before proceeding
if train_gen_gan is None or gan_steps_per_epoch == 0:
    print("Skipping WGAN-GP training: Generator not initialized or no steps per epoch.")
else:
    c_loss_history = []
    g_loss_history = []

    for epoch in range(gan_epochs):
        print(f"\nEpoch {epoch+1}/{gan_epochs}")
        epoch_pbar = tqdm(range(gan_steps_per_epoch), desc=f"Epoch {epoch+1}")

        epoch_c_loss = 0.0
        epoch_g_loss = 0.0
        batches_processed = 0

        for batch_idx in epoch_pbar:
            try:
                # Get REAL fake samples (samples from dataset's fake folder)
                # Ensure the generator yields tf.float32 or handle conversion
                real_spoof_samples_np = next(train_gen_gan)
                real_spoof_samples = tf.convert_to_tensor(real_spoof_samples_np, dtype=tf.float32) # Explicitly ensure float32 input to train_step

                if tf.shape(real_spoof_samples)[0] == 0:
                    print(f"Warning: Skipped empty batch at step {batch_idx}")
                    continue

                # Run the training step
                c_loss, g_loss = train_step(real_spoof_samples)

                # Accumulate losses (use .numpy() only for display/logging outside tf.function)
                epoch_c_loss += c_loss.numpy()
                epoch_g_loss += g_loss.numpy()
                batches_processed += 1

                # Update tqdm postfix with current batch losses
                epoch_pbar.set_postfix({"C Loss": f"{c_loss.numpy():.4f}", "G Loss": f"{g_loss.numpy():.4f}"})

            except StopIteration:
                print(f"\nGAN Generator exhausted prematurely at batch {batch_idx}. Moving to next epoch.")
                # Re-initialize generator for the next epoch cycle
                train_gen_gan = data_generator_gan(train_data_path, batch_size=gan_batch_size)
                break # Exit the inner loop for this epoch

            except Exception as e:
                logging.error(f"Error during WGAN training epoch {epoch+1}, batch {batch_idx}: {e}", exc_info=True)
                print(f"\nError during WGAN training epoch {epoch+1}, batch {batch_idx}: {e}")
                # Decide whether to continue or break
                continue # Skip this problematic batch

        # --- End-of-epoch actions ---
        if batches_processed > 0:
             avg_c_loss = epoch_c_loss / batches_processed
             avg_g_loss = epoch_g_loss / batches_processed
             c_loss_history.append(avg_c_loss)
             g_loss_history.append(avg_g_loss)
             print(f"Epoch {epoch+1} finished. Avg C Loss: {avg_c_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}")
        else:
             print(f"Epoch {epoch+1} finished. No batches processed.")
             # Append NaN or handle appropriately if needed for plotting
             c_loss_history.append(np.nan)
             g_loss_history.append(np.nan)


        # Save models periodically
        if (epoch + 1) % 5 == 0 or (epoch + 1) == gan_epochs: # Save every 5 epochs and at the end
             try:
                 generator.save(f'generator_wgan_sa_epoch_{epoch+1}.keras')
                 critic.save(f'critic_wgan_sa_epoch_{epoch+1}.keras')
                 print(f"Saved generator and critic models for epoch {epoch+1}")
             except Exception as e:
                 print(f"Error saving models at epoch {epoch+1}: {e}")

        # Optional: Plot losses at the end of each epoch
        # plt.figure()
        # plt.plot(c_loss_history, label='Critic Loss')
        # plt.plot(g_loss_history, label='Generator Loss')
        # plt.title(f'Losses up to Epoch {epoch+1}')
        # plt.xlabel('Epoch')
        # plt.ylabel('Loss')
        # plt.legend()
        # plt.savefig(os.path.join(FIGURES_DIR, f'gan_loss_epoch_{epoch+1}.png'))
        # plt.close()


    print("\nWGAN-GP training finished.")

    # Final loss plot
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(c_loss_history) + 1), c_loss_history, label='Avg Critic Loss per Epoch')
    plt.plot(range(1, len(g_loss_history) + 1), g_loss_history, label='Avg Generator Loss per Epoch')
    plt.title('WGAN-GP Training Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(FIGURES_DIR, 'gan_loss_final.png'))
    plt.show()

Calculated 2850 GAN steps per epoch.

Begin WGAN-GP training!

Epoch 1/75


Epoch 1:   0%|          | 0/2850 [00:00<?, ?it/s]

E0000 00:00:1743699138.234367  120890 meta_optimizer.cc:967] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inwhile/body/_1/while/critic_1/dropout_1/stateless_dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


In [None]:
# Cell 13: Train Standalone Classifier (using the trained Critic)

print("\nSetting up Standalone Classifier training...")

# --- Create Data Generators for Classifier ---
train_gen_clf = data_generator_classifier(train_data_path, batch_size=classifier_batch_size)
dev_gen_clf = data_generator_classifier(dev_data_path, batch_size=classifier_batch_size, shuffle=False) # No shuffle for dev

# --- Calculate steps for Classifier Training ---
def count_total_files(path):
    # Uses the function defined in Cell 5 context if run sequentially, otherwise redefine:
    real_dir = os.path.join(path, 'real')
    fake_dir = os.path.join(path, 'fake')
    real_count = len([f for f in os.listdir(real_dir) if f.endswith('.wav')]) if os.path.exists(real_dir) else 0
    fake_count = len([f for f in os.listdir(fake_dir) if f.endswith('.wav')]) if os.path.exists(fake_dir) else 0
    return real_count + fake_count

train_samples_count = count_total_files(train_data_path)
dev_samples_count = count_total_files(dev_data_path)

if classifier_batch_size <= 0:
    raise ValueError("Classifier batch size must be positive.")

clf_steps_per_epoch = int(np.ceil(train_samples_count / float(classifier_batch_size))) if train_samples_count > 0 else 0
clf_validation_steps = int(np.ceil(dev_samples_count / float(classifier_batch_size))) if dev_samples_count > 0 else 0

print(f"Classifier Train Steps/Epoch: {clf_steps_per_epoch}, Validation Steps: {clf_validation_steps}")


# --- Build the Spoof Detector Model from the Critic ---
# Load the latest critic weights if desired, otherwise it uses weights from end of training
# latest_critic_ckpt = tf.train.latest_checkpoint(os.path.dirname('critic_wgan_sa_epoch_X.keras')) # Or specify path
# if latest_critic_ckpt:
#    print(f"Loading critic weights from {latest_critic_ckpt}")
#    critic.load_weights(latest_critic_ckpt)

# Create the final classifier by adding a sigmoid layer to the critic's base
critic.trainable = True # Unfreeze critic layers for fine-tuning

spoof_detector = Sequential(name='spoof_detector')
for layer in critic.layers[:-1]: # Add all layers except the last Dense(1) output layer
    spoof_detector.add(layer)

# Add the final classification layer
spoof_detector.add(Dense(1, activation='sigmoid', name='classifier_output'))

print("\n--- Spoof Detector (Classifier) Summary ---")
spoof_detector.summary()

# --- Callbacks for Classifier Training ---
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss', factor=0.2, patience=5, min_lr=1e-7, verbose=1 # Lower min_lr
)
early_stopping = EarlyStopping(
    monitor='val_loss', patience=10, restore_best_weights=True, verbose=1
)

# Checkpoint directory for the classifier
checkpoint_dir_clf = './training_checkpoints_spoof_detector_wgan_sa'
os.makedirs(checkpoint_dir_clf, exist_ok=True)
checkpoint_prefix_clf = os.path.join(checkpoint_dir_clf, "ckpt_clf_{epoch:02d}.weights.h5")

checkpoint_callback_clf = ModelCheckpoint(
    filepath=checkpoint_prefix_clf,
    save_weights_only=True,
    monitor='val_auc', # Monitor validation AUC (or accuracy/loss)
    mode='max',
    save_best_only=True,
    verbose=1
)

# Plotting callback (ensure class is defined - should be from original Cell 13)
plot_training_callback = PlotTrainingHistory(model_name='spoof_detector_wgan_sa')


# --- Compile and Train the Classifier ---
classifier_optimizer = Adam(learning_rate=1e-5) # Use a smaller LR for fine-tuning
spoof_detector.compile(
    optimizer=classifier_optimizer,
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

print("\nStarting Standalone Classifier training...")

if clf_steps_per_epoch > 0 and clf_validation_steps > 0:
    history = spoof_detector.fit(
        train_gen_clf,
        steps_per_epoch=clf_steps_per_epoch,
        epochs=classifier_epochs,
        validation_data=dev_gen_clf,
        validation_steps=clf_validation_steps,
        callbacks=[reduce_lr, early_stopping, checkpoint_callback_clf, plot_training_callback],
        # Use sample_weight argument if generator yields weights
        # sample_weight= # Not directly supported by fit, need custom loop or weighted metrics if using sample_weights yield
        # Note: Keras `fit` uses the third element of the generator yield as sample_weight if present
        # Make sure data_generator_classifier yields (batch_x, batch_y, batch_weights)
    )

    # Save the final trained classifier model
    spoof_detector.save('spoof_detector_final_wgan_sa.keras')
    print("\nClassifier training complete. Model saved as spoof_detector_final_wgan_sa.keras")

    # --- Optional: Load best weights based on checkpoint ---
    best_checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir_clf)
    if best_checkpoint_path:
        print(f"Loading best classifier weights from: {best_checkpoint_path}")
        spoof_detector.load_weights(best_checkpoint_path)
        spoof_detector.save('spoof_detector_best_val_auc_wgan_sa.keras')
        print("Saved classifier with best validation AUC weights.")
    else:
        print("Could not find best checkpoint weights to load. Final weights saved.")

else:
    print("Skipping classifier training because steps_per_epoch or validation_steps is zero.")

In [None]:
# Cell 14: Evaluation

print("\nEvaluating the final Spoof Detector...")

# --- Load the best or final classifier model ---
try:
    # Try loading the model saved with best weights first
    detector_model_path = 'spoof_detector_best_val_auc_wgan_sa.keras'
    if not os.path.exists(detector_model_path):
         # Fallback to the final model if best wasn't saved or found
         detector_model_path = 'spoof_detector_final_wgan_sa.keras'

    print(f"Loading model for evaluation: {detector_model_path}")
    # When loading custom objects like SelfAttention, provide them in custom_objects
    custom_objects = {'SelfAttention': SelfAttention}
    spoof_detector_eval = tf.keras.models.load_model(
        detector_model_path,
        custom_objects=custom_objects
    )
    # Re-compile is often needed after loading, especially if optimizer state isn't saved
    spoof_detector_eval.compile(
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )
except Exception as e:
    print(f"Error loading spoof detector model: {e}")
    print("Evaluation cannot proceed.")
    spoof_detector_eval = None


if spoof_detector_eval:
    # --- Create Evaluation Generator ---
    eval_gen_clf = data_generator_classifier(eval_data_path, batch_size=classifier_batch_size, shuffle=False) # No shuffle

    # --- Calculate Evaluation Steps ---
    eval_samples_count = count_total_files(eval_data_path)
    if classifier_batch_size <= 0:
        raise ValueError("Classifier batch size must be positive.")
    eval_steps = int(np.ceil(eval_samples_count / float(classifier_batch_size))) if eval_samples_count > 0 else 0
    print(f"Using {eval_steps} steps for evaluation.")

    # --- Evaluate ---
    if eval_steps > 0:
        results = spoof_detector_eval.evaluate(
            eval_gen_clf,
            steps=eval_steps,
            verbose=1
        )

        print("\n--- Evaluation Results ---")
        if len(results) >= 3:
             print(f"Loss: {results[0]:.4f}")
             print(f"Accuracy: {results[1]:.4f}")
             print(f"AUC: {results[2]:.4f}")
        else:
            print(f"Results: {results}") # Print raw results if format is unexpected
        print("--------------------------")
    else:
        print("Skipping evaluation because eval_steps is zero.")

In [None]:
# Cell 15: Reporting (Confusion Matrix, EER, t-DCF)

print("\nGenerating Final Reports (F1, CM, EER, t-DCF)...")

# Ensure the evaluation model is loaded from Cell 14
if 'spoof_detector_eval' not in locals() or spoof_detector_eval is None:
    print("Spoof detector model not loaded. Cannot generate reports.")
else:
    # --- Parameters for t-DCF ---
    p_target = 0.05  # Prior probability of target (real=1) - Adjust based on ASVspoof challenge or use case
    c_miss = 1       # Cost of miss (classifying real as fake - FN)
    c_false_alarm = 1 # Cost of false alarm (classifying fake as real - FP)

    # --- Regenerate Predictions ---
    # Reset the eval generator
    eval_gen_report = data_generator_classifier(eval_data_path, batch_size=classifier_batch_size, shuffle=False)

    # Recalculate eval_steps if needed (should match Cell 14)
    eval_samples_count_report = count_total_files(eval_data_path)
    eval_steps_report = int(np.ceil(eval_samples_count_report / float(classifier_batch_size))) if eval_samples_count_report > 0 else 0

    y_pred_scores = []
    y_true_labels = []

    if eval_steps_report > 0:
        print(f"Generating predictions using {eval_steps_report} steps...")
        for _ in tqdm(range(eval_steps_report), desc="Predicting"):
            try:
                # Generator yields (batch_x, batch_y, batch_weights)
                batch_x, batch_y, _ = next(eval_gen_report)
                if batch_x.size == 0: continue # Skip empty batches if they occur

                # Use predict, not predict_on_batch for potentially better performance over many batches
                batch_pred = spoof_detector_eval.predict(batch_x, verbose=0)
                y_pred_scores.extend(batch_pred.flatten())
                y_true_labels.extend(batch_y)
            except StopIteration:
                print("Evaluation generator stopped.")
                break
            except Exception as e:
                print(f"Error during prediction generation: {e}")
                continue

        y_pred_scores = np.array(y_pred_scores).astype(np.float32)
        y_true_labels = np.array(y_true_labels).astype(np.int32)

        # Ensure lengths match (might be off if last batch was incomplete and generator didn't handle it perfectly)
        min_len = min(len(y_pred_scores), len(y_true_labels))
        if min_len == 0:
            print("No predictions or labels collected. Cannot generate reports.")
        else:
            y_pred_scores = y_pred_scores[:min_len]
            y_true_labels = y_true_labels[:min_len]

            # --- Calculations ---
            # Binary predictions for F1 and CM
            y_pred_binary = (y_pred_scores > 0.5).astype(int)

            # F1 Score
            f1 = f1_score(y_true_labels, y_pred_binary)
            print(f"\nF1 Score (threshold 0.5): {f1:.4f}")

            # Confusion Matrix
            cm = confusion_matrix(y_true_labels, y_pred_binary)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Fake (0)', 'Real (1)'], yticklabels=['Fake (0)', 'Real (1)'])
            plt.title('Confusion Matrix (Counts)')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.savefig(os.path.join(FIGURES_DIR, 'confusion_matrix_counts_wgan_sa.png'))
            plt.show()

            # Confusion Matrix (Percentages)
            cm_sum = np.sum(cm, axis=1, keepdims=True)
            cm_perc = cm / cm_sum.astype(float) * 100
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm_perc, annot=True, fmt='.2f', cmap='Greens', xticklabels=['Fake (0)', 'Real (1)'], yticklabels=['Fake (0)', 'Real (1)'])
            plt.title('Confusion Matrix (Row Percentages)')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.savefig(os.path.join(FIGURES_DIR, 'confusion_matrix_perc_wgan_sa.png'))
            plt.show()


            # EER Calculation
            fpr, tpr, thresholds_roc = roc_curve(y_true_labels, y_pred_scores, pos_label=1)
            fnr = 1 - tpr
            eer_index = np.nanargmin(np.abs(fnr - fpr))
            eer_threshold = thresholds_roc[eer_index]
            eer = fpr[eer_index] # Or use (fpr[eer_index] + fnr[eer_index]) / 2
            print(f"EER: {eer:.4f} at threshold {eer_threshold:.4f}")

            # Plot ROC Curve
            plt.figure(figsize=(8, 6))
            plt.plot(fpr, tpr, label=f'ROC curve (AUC = {results[2]:.4f})') # Use AUC from evaluate
            plt.plot(fpr, fnr, label='FN Rate')
            plt.plot([0, 1], [0, 1], 'k--') # Random guess line
            plt.scatter(fpr[eer_index], tpr[eer_index], color='red', zorder=5, label=f'EER = {eer:.4f}')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate (FPR)')
            plt.ylabel('True Positive Rate (TPR)')
            plt.title('Receiver Operating Characteristic (ROC)')
            plt.legend(loc="lower right")
            plt.grid(True)
            plt.savefig(os.path.join(FIGURES_DIR, 'roc_curve_wgan_sa.png'))
            plt.show()


            # t-DCF Calculation
            # Define function (can be moved to utils if used often)
            def calculate_t_dcf(y_true, y_scores, p_target, c_miss, c_fa, thresholds):
                """Calculates normalized t-DCF for a range of thresholds."""
                num_thresholds = len(thresholds)
                dcf_values = np.zeros(num_thresholds)

                num_real = np.sum(y_true == 1)
                num_fake = np.sum(y_true == 0)

                if num_real == 0 or num_fake == 0:
                     print("Warning: Cannot calculate t-DCF with zero samples in a class.")
                     return np.inf # Return infinity or handle as error

                for i, thr in enumerate(thresholds):
                    y_pred_binary = (y_scores >= thr).astype(int)
                    tn, fp, fn, tp = confusion_matrix(y_true, y_pred_binary, labels=[0, 1]).ravel()

                    p_miss = fn / num_real if num_real > 0 else 0 # Miss Rate (FN / Total Real)
                    p_fa = fp / num_fake if num_fake > 0 else 0   # False Alarm Rate (FP / Total Fake)

                    cost = (c_miss * p_miss * p_target) + (c_fa * p_fa * (1 - p_target))

                    # Normalize the cost
                    dcf_norm = min(c_miss * p_target, c_fa * (1 - p_target))
                    dcf_values[i] = cost / dcf_norm if dcf_norm > 0 else cost # Avoid division by zero

                # Find the minimum t-DCF
                min_dcf_index = np.argmin(dcf_values)
                min_dcf = dcf_values[min_dcf_index]
                min_dcf_threshold = thresholds[min_dcf_index]

                return min_dcf, min_dcf_threshold

            # Calculate min t-DCF over a range of thresholds (more robust than just EER threshold)
            # Use ROC thresholds, but filter out +/- inf if present
            valid_thresholds = thresholds_roc[np.isfinite(thresholds_roc)]
            min_tdcf, min_tdcf_thresh = calculate_t_dcf(y_true_labels, y_pred_scores, p_target, c_miss, c_false_alarm, valid_thresholds)

            print(f"Minimum t-DCF: {min_tdcf:.4f} at threshold {min_tdcf_thresh:.4f}")

    else:
         print("Skipping report generation because eval_steps_report is zero.")