# Install PyTorch and related libraries 

In [None]:
!pip install torch torchvision torchaudio

#  Setup and Imports

- Sets a CUDA memory config (`expandable_segments:True`) to prevent out-of-memory errors during dynamic allocation in PyTorch.
- Imports essential libraries for:
  -  **Audio processing**: `librosa`, `torchaudio`, `soundfile`
  -  **Visualization**: `matplotlib`, `librosa.display`
  -  **Deep learning**: `torch`, `torch.nn`, `torch.optim`, `amp` for mixed-precision training
  -  **Model evaluation**: `mean_squared_error`, `cosine` distance
  -  **Data management**: `Dataset`, `DataLoader`, `train_test_split`
  -  **Others**: `tqdm` for progress bars, `PIL` for image operations, `warnings` to suppress logs


In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import math
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.cuda.amp as amp
import soundfile as sf
from sklearn.metrics import mean_squared_error
from scipy.spatial.distance import cosine
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

 # GeluGatedLayer

In [None]:
class GeluGatedLayer(nn.Module):
    def __init__(self, input_dim, output_dim, bias=True):
        super(GeluGatedLayer, self).__init__()
        self.input_linear = nn.Linear(input_dim, output_dim, bias=bias)
        self.activation = nn.GELU()

    def forward(self, src):
        output = self.activation(self.input_linear(src))
        return output

# AudioProcessor: Audio Feature Utility Class

This helper class handles key audio preprocessing and visualization tasks.

**Core Methods:**
- `mel_spectrogram_to_audio`:  
  🔁 Converts a Mel spectrogram (in dB) back into waveform audio using:
  - `librosa.db_to_power()` → converts dB to power
  - `mel_to_stft()` → approximates original STFT
  - `griffinlim()` → reconstructs waveform via iterative phase estimation

- `plot_mel_spectrogram`:  
  📊 Plots the given Mel spectrogram using `librosa.display.specshow()`  
  - Adds axes, color bar, and title for easy interpretation

In [None]:
# AudioProcessor class
class AudioProcessor:
    def __init__(self, sr=22050, n_mels=128, n_fft=2048, hop_length=512):
        self.sr = sr
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def mel_spectrogram_to_audio(self, mel_spec_db):
        mel_spec = librosa.db_to_power(mel_spec_db)
        stft = librosa.feature.inverse.mel_to_stft(mel_spec, sr=self.sr, n_fft=self.n_fft)
        audio = librosa.griffinlim(stft, hop_length=self.hop_length)
        return audio

    def plot_mel_spectrogram(self, mel_spec_db, title="Mel Spectrogram"):
        plt.figure(figsize=(12, 6))
        librosa.display.specshow(
            mel_spec_db, sr=self.sr, hop_length=self.hop_length, x_axis='time', y_axis='mel'
        )
        plt.colorbar(format='%+2.0f dB')
        plt.title(title)
        plt.tight_layout()
        plt.show()

# AudioDataset: PyTorch Dataset Wrapper for Mel Spectrograms

In [None]:
# AudioDataset class
class AudioDataset(Dataset):
    def __init__(self, mel_spectrograms):
        self.mel_spectrograms = mel_spectrograms

    def __len__(self):
        return len(self.mel_spectrograms)

    def __getitem__(self, idx):
        return torch.FloatTensor(self.mel_spectrograms[idx])

# ImprovedAudioVAE: Audio VAE with GELU-Gated Layers

This version of `ImprovedAudioVAE` enhances the encoder with **GELU-gated layers** (via `GeluGatedLayer`)
---

###  Initialization (`__init__`)
- Initializes VAE parameters and builds:
  - `Encoder` with GELU gating
  - `Bottleneck`
  - `Decoder` using transposed convolutions

In [None]:
class ImprovedAudioVAE(nn.Module):
    def __init__(self, input_shape, conv_filters=(16, 32, 64, 64), conv_kernels=(3, 3, 3, 3),
                 conv_strides=(1, 2, 2, 2), latent_dim=128, dropout_rate=0.2):
        super(ImprovedAudioVAE, self).__init__()

        self.input_shape = input_shape
        self.conv_filters = conv_filters
        self.conv_kernels = conv_kernels
        self.conv_strides = conv_strides
        self.latent_dim = latent_dim
        self.dropout_rate = dropout_rate
        self.num_conv_layers = len(conv_filters)
        self.shape_before_bottleneck = None

        self.encoder = self._build_encoder()
        self.decoder = self._build_decoder()
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear, GeluGatedLayer)):
            if hasattr(module, 'weight') and module.weight is not None:
                torch.nn.init.xavier_uniform_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.constant_(module.bias, 0)

    def _build_encoder(self):
        layers = []
        self.gelu_positions = []
        self.gelu_layers = nn.ModuleList()
        in_channels = 1

        for i, (filters, kernel, stride) in enumerate(zip(self.conv_filters, self.conv_kernels, self.conv_strides)):
            layers.extend([
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=filters,
                    kernel_size=kernel,
                    stride=stride,
                    padding=kernel // 2,
                ),
                nn.ReLU(),
                nn.BatchNorm2d(filters),
                nn.Dropout(self.dropout_rate)
            ])
            self.gelu_positions.append(len(layers) - 1)
            self.gelu_layers.append(GeluGatedLayer(filters, filters))
            in_channels = filters

        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, *self.input_shape)
            x = dummy_input
            gelu_idx = 0
            for i, layer in enumerate(layers):
                x = layer(x)
                if i in self.gelu_positions:
                    b, c, h, w = x.shape
                    x_flat = x.permute(0, 2, 3, 1).reshape(-1, c)
                    x_flat = self.gelu_layers[gelu_idx](x_flat)
                    x = x_flat.reshape(b, h, w, c).permute(0, 3, 1, 2)
                    gelu_idx += 1
            self.shape_before_bottleneck = x.shape[1:]

        flat_dim = np.prod(self.shape_before_bottleneck)
        layers.append(nn.Flatten())
        self.mu = nn.Linear(flat_dim, self.latent_dim)
        self.logvar = nn.Linear(flat_dim, self.latent_dim)

        self.encoder_core = nn.Sequential(*layers)
        return self.encoder_core

    def _build_decoder(self):
        layers = []
        num_neurons = np.prod(self.shape_before_bottleneck)
        layers.extend([
            nn.Linear(self.latent_dim, num_neurons),
            nn.LayerNorm(num_neurons),
            nn.Unflatten(1, self.shape_before_bottleneck)
        ])

        in_channels = self.conv_filters[-1]
        for i in reversed(range(1, self.num_conv_layers)):
            layers.extend([
                nn.ConvTranspose2d(
                    in_channels=in_channels,
                    out_channels=self.conv_filters[i-1],
                    kernel_size=self.conv_kernels[i],
                    stride=self.conv_strides[i],
                    padding=self.conv_kernels[i] // 2,
                    output_padding=0
                ),
                nn.ReLU(),
                nn.BatchNorm2d(self.conv_filters[i-1]),
                nn.Dropout(self.dropout_rate)
            ])
            in_channels = self.conv_filters[i-1]

        layers.extend([
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=1,
                kernel_size=self.conv_kernels[0],
                stride=self.conv_strides[0],
                padding=self.conv_kernels[0] // 2,
                output_padding=0
            ),
            nn.Sigmoid()
        ])

        return nn.Sequential(*layers)

    def encode(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        elif x.dim() == 2:
            x = x.unsqueeze(0).unsqueeze(0)
        h = checkpoint.checkpoint_sequential(self.encoder, segments=2, input=x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        return mu, logvar

    def decode(self, z):
        x = self.decoder(z)
        if x.dim() == 3:
            x = x.unsqueeze(1)
        recon_x = torch.nn.functional.interpolate(x, size=self.input_shape, mode='bilinear', align_corners=False)
        return recon_x

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        elif x.dim() == 2:
            x = x.unsqueeze(0).unsqueeze(0)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

#  AudioDeepfakeGenerator: End-to-End Deepfake Audio Generator Using VAE


This class orchestrates the full workflow for generating deepfake audio using a **Variational Autoencoder (VAE)** trained on Mel spectrograms.

---

### Initialization (`__init__`)
- Sets up:
  - `AudioProcessor` (Mel spectrogram utils)
  - VAE model placeholder
  - `GradScaler` for mixed-precision training
- Target input shape defaults to 128x128

---

### Preprocessing (`preprocess_mel_spectrograms`)
- Pads or crops Mel spectrograms to target shape
- Normalizes values between 0–1 using robust min-max scaling
- Converts valid spectrograms into a 4D NumPy array for training

---

### VAE Training (`train_vae`)
- Initializes and trains an `ImprovedAudioVAE` model:
  - Customizable epochs, learning rate, batch size, KL-β weight
  - Uses gradient accumulation + mixed precision
  - Saves best + periodic model checkpoints
  - Applies early stopping based on validation loss

---

### VAE Loss (`improved_vae_loss`)
- Combines:
  - MSE reconstruction loss
  - KL divergence loss (scaled by `β`)

---

### Visualization (`plot_training_curve`)
- Plots training loss curve over epochs

---

### Deepfake Generation (`generate_deepfake`)
- Uses a trained VAE to reconstruct + modify original Mel spectrograms
- Adds controllable Gaussian noise in latent space
- Returns reconstructed Mel spectrograms for each noise level

---

### Evaluation (`evaluate_model`)
- Computes average loss on a test set using same VAE loss
- Useful for monitoring model generalization

---

### Comparison (`compare_audio`)
- Compares original and fake Mel spectrograms:
  - MSE (Mean Squared Error)
  - Cosine similarity
  - Pearson correlation

---

### Visualization (`visualize_comparison`)
- Plots:
  - Original Mel
  - Generated Mel
  - Absolute difference between them

---

### Audio Output (`save_audio`)
- Converts a Mel spectrogram back to waveform
- Saves output using `soundfile.write`

---


In [None]:

# AudioDeepfakeGenerator class
class AudioDeepfakeGenerator:
    def __init__(self, height=128, width=128):
        self.processor = AudioProcessor(n_mels=height, sr=22050, n_fft=2048, hop_length=512)
        self.vae = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        self.target_shape = (height, width)
        self.scaler = amp.GradScaler(enabled=self.device.type == 'cuda')

    def preprocess_mel_spectrograms(self, file_data):
        mel_spectrograms = []
        successful_files = 0
        print(f"Processing {len(file_data)} mel spectrogram image files...")

        for file_path, mel_spec in file_data:
            try:
                if mel_spec.shape[0] > self.target_shape[0]:
                    mel_spec = mel_spec[:self.target_shape[0], :]
                elif mel_spec.shape[0] < self.target_shape[0]:
                    pad_height = self.target_shape[0] - mel_spec.shape[0]
                    mel_spec = np.pad(mel_spec, ((0, pad_height), (0, 0)), 'constant', constant_values=mel_spec.min())

                if mel_spec.shape[1] > self.target_shape[1]:
                    mel_spec = mel_spec[:, :self.target_shape[1]]
                elif mel_spec.shape[1] < self.target_shape[1]:
                    pad_width = self.target_shape[1] - mel_spec.shape[1]
                    mel_spec = np.pad(mel_spec, ((0, 0), (0, pad_width)), 'constant', constant_values=mel_spec.min())

                if mel_spec.shape != self.target_shape:
                    print(f"Warning: Shape mismatch for {os.path.basename(file_path)}: "
                          f"got {mel_spec.shape}, expected {self.target_shape}. Skipping.")
                    continue

                mel_min, mel_max = np.percentile(mel_spec, [1, 99])
                if mel_max <= mel_min:
                    print(f"Warning: Invalid mel spectrogram range for {os.path.basename(file_path)}. Skipping.")
                    continue

                mel_norm = np.clip((mel_spec - mel_min) / (mel_max - mel_min + 1e-8), 0, 1)
                mel_spectrograms.append(mel_norm)
                successful_files += 1

            except Exception as e:
                print(f"Error processing {file_path}: {e}")
                continue

        if len(mel_spectrograms) > 0:
            mel_spectrograms = np.stack(mel_spectrograms)[:, None, :, :]
            print(f"Successfully processed {len(mel_spectrograms)} files")
            print(f"Final data shape: {mel_spectrograms.shape}")
            if mel_spectrograms.shape[2:] != self.target_shape:
                raise ValueError(f"Stacked spectrograms have incorrect shape: {mel_spectrograms.shape[2:]}, expected {self.target_shape}")
            return mel_spectrograms
        else:
            raise ValueError("No mel spectrogram image files could be processed successfully!")

    def train_vae(self, mel_spectrograms, epochs=5, batch_size=1, learning_rate=1e-4, beta=0.1, early_stop_patience=60, accum_steps=2):
        torch.cuda.empty_cache()
        input_shape_vae = mel_spectrograms.shape[2:]
        print(f"VAE input shape (H, W): {input_shape_vae}")

        # Create models directory early
        os.makedirs("models", exist_ok=True)

        self.vae = ImprovedAudioVAE(
            input_shape=input_shape_vae,
            conv_filters=(16, 32, 64, 256),
            conv_kernels=(3, 3, 3, 3),
            conv_strides=(1, 2, 2, 2),
            latent_dim=128,
            dropout_rate=0.2
        ).to(self.device)

        dataset = AudioDataset(mel_spectrograms)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

        optimizer = optim.AdamW(self.vae.parameters(), lr=learning_rate, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
        best_loss = float('inf')
        patience_counter = 0

        self.vae.train()
        train_losses = []

        for epoch in range(epochs):
            total_loss = 0
            total_recon_loss = 0
            total_kl_loss = 0
            progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}')
            optimizer.zero_grad(set_to_none=True)

            for batch_idx, data in enumerate(progress_bar):
                data = data.to(self.device, non_blocking=True)

                with amp.autocast(enabled=self.device.type == 'cuda'):
                    recon_batch, mu, logvar = self.vae(data)
                    recon_loss, kl_loss = self.improved_vae_loss(recon_batch, data, mu, logvar, beta=beta)
                    loss = (recon_loss + kl_loss) / accum_steps

                self.scaler.scale(loss).backward()

                if (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1) == len(dataloader):
                    self.scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
                    self.scaler.step(optimizer)
                    self.scaler.update()
                    optimizer.zero_grad(set_to_none=True)

                total_loss += loss.item() * accum_steps
                total_recon_loss += recon_loss.item()
                total_kl_loss += kl_loss.item()

                progress_bar.set_postfix({
                    'loss': f'{loss.item() * accum_steps:.4f}',
                    'recon': f'{recon_loss.item():.4f}',
                    'kl': f'{kl_loss.item():.4f}'
                })

                del data, recon_batch, mu, logvar, loss
                torch.cuda.empty_cache()

            avg_loss = total_loss / len(dataloader)
            avg_recon_loss = total_recon_loss / len(dataloader)
            avg_kl_loss = total_kl_loss / len(dataloader)

            train_losses.append(avg_loss)
            scheduler.step()

            # Log loss details to diagnose issues
            print(f'Epoch {epoch+1:3d} | Avg Loss: {avg_loss:.4f} | '
                  f'Recon Loss: {avg_recon_loss:.4f} | KL Loss: {avg_kl_loss:.4f} | '
                  f'Best Loss: {best_loss:.4f}')

            # Save checkpoint if loss improves
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
                try:
                    torch.save(self.vae.state_dict(), 'models/best_vae_model.pth')
                    print(f"Saved best model checkpoint at epoch {epoch+1}")
                except Exception as e:
                    print(f"Error saving checkpoint: {e}")
            else:
                patience_counter += 1

            # Save periodic checkpoint every 2 epochs
            if (epoch + 1) % 2 == 0:
                try:
                    torch.save(self.vae.state_dict(), f'models/vae_model_epoch_{epoch+1}.pth')
                    print(f"Saved periodic checkpoint at epoch {epoch+1}")
                except Exception as e:
                    print(f"Error saving periodic checkpoint: {e}")

            if epoch % 2 == 0 or patience_counter == 0:
                print(f'Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | '
                      f'Recon: {avg_recon_loss:.4f} | KL: {avg_kl_loss:.4f} | '
                      f'LR: {optimizer.param_groups[0]["lr"]:.6f} | '
                      f'Patience: {patience_counter}/{early_stop_patience}')

            if patience_counter >= early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

            torch.cuda.empty_cache()

        # Save final model if no checkpoint was saved
        if not os.path.exists('models/best_vae_model.pth'):
            try:
                torch.save(self.vae.state_dict(), 'models/best_vae_model.pth')
                print("Saved final model checkpoint")
            except Exception as e:
                print(f"Error saving final checkpoint: {e}")

        # Load best model if it exists, otherwise continue with current state
        if os.path.exists('models/best_vae_model.pth'):
            try:
                self.vae.load_state_dict(torch.load('models/best_vae_model.pth'))
                print("Best VAE model loaded.")
            except Exception as e:
                print(f"Error loading checkpoint: {e}. Continuing with current model state.")
        else:
            print("No checkpoint found. Using current model state.")

        print("VAE training completed!")
        self.plot_training_curve(train_losses)
        torch.cuda.empty_cache()

    def improved_vae_loss(self, recon_x, x, mu, logvar, beta):
        recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum') / x.size(0)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
        return recon_loss, beta * kl_loss

    def plot_training_curve(self, train_losses):
        plt.figure(figsize=(10, 6))
        plt.plot(train_losses, label='Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('VAE Training Progress')
        plt.legend()
        plt.grid(True)
        plt.close()

    def generate_deepfake(self, original_mel_spec, noise_levels=[0.0, 0.1]):
        if self.vae is None:
            print("Error: VAE model not trained. Please train the model first.")
            return None

        self.vae.eval()
        deepfake_results = []

        with torch.no_grad():
            if original_mel_spec.shape != self.target_shape:
                print(f"Warning: Input mel shape {original_mel_spec.shape} does not match VAE target shape {self.target_shape}. Resizing.")
                original_mel_spec = np.resize(original_mel_spec, self.target_shape)

            mel_min, mel_max = np.percentile(original_mel_spec, [1, 99])
            if mel_max <= mel_min:
                print("Error: Invalid mel spectrogram range for generation.")
                return original_mel_spec

            mel_norm = np.clip((original_mel_spec - mel_min) / (mel_max - mel_min + 1e-8), 0, 1)
            input_tensor = torch.FloatTensor(mel_norm[None, None, :, :]).to(self.device, non_blocking=True)

            mu, logvar = self.vae.encode(input_tensor)
            for noise_level in noise_levels:
                noise = torch.randn_like(mu) * noise_level
                z_modified = mu + noise
                fake_mel_norm = self.vae.decode(z_modified).cpu().numpy().squeeze()
                fake_mel_norm = np.clip(fake_mel_norm, 0, 1)
                fake_mel = fake_mel_norm * (mel_max - mel_min) + mel_min
                deepfake_results.append((fake_mel, noise_level))
                del noise, z_modified, fake_mel_norm
                torch.cuda.empty_cache()

            del input_tensor, mu, logvar
            torch.cuda.empty_cache()

        return deepfake_results

    def evaluate_model(self, test_spectrograms, beta=0.1):
        if self.vae is None:
            print("Error: VAE model not trained for evaluation.")
            return None

        self.vae.eval()
        total_loss = 0
        total_samples = 0
        with torch.no_grad():
            dataloader = DataLoader(AudioDataset(test_spectrograms), batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
            for batch_idx, batch in enumerate(dataloader):
                batch_tensor = batch.to(self.device, non_blocking=True)
                with amp.autocast(enabled=self.device.type == 'cuda'):
                    recon_batch, mu, logvar = self.vae(batch_tensor)
                    recon_loss, kl_loss = self.improved_vae_loss(recon_batch, batch_tensor, mu, logvar, beta=beta)
                total_loss += (recon_loss + kl_loss).item() * len(batch)
                total_samples += len(batch)
                del batch_tensor, recon_batch, mu, logvar
                torch.cuda.empty_cache()

        if total_samples > 0:
            avg_loss = total_loss / total_samples
            print(f"Evaluation Loss: {avg_loss:.4f}")
            return avg_loss
        else:
            print("No test samples available for evaluation.")
            return None

    def compare_audio(self, original_mel, fake_mel):
        min_height = min(original_mel.shape[0], fake_mel.shape[0])
        min_width = min(original_mel.shape[1], fake_mel.shape[1])
        original_mel_clipped = original_mel[:min_height, :min_width]
        fake_mel_clipped = fake_mel[:min_height, :min_width]

        orig_flat = original_mel_clipped.flatten()
        fake_flat = fake_mel_clipped.flatten()

        mse = mean_squared_error(orig_flat, fake_flat)
        if np.all(orig_flat == orig_flat[0]) or np.all(fake_flat == fake_flat[0]):
            cosine_sim = np.nan
            correlation = np.nan
        else:
            cosine_sim = 1 - cosine(orig_flat, fake_flat)
            correlation = np.corrcoef(orig_flat, fake_flat)[0, 1]

        print("=== Audio Comparison Results ===")
        print(f"Mean Squared Error: {mse:.4f}")
        print(f"Cosine Similarity: {cosine_sim:.4f}")
        print(f"Correlation: {correlation:.4f}")
        return {'mse': mse, 'cosine_similarity': cosine_sim, 'correlation': correlation}

    def visualize_comparison(self, original_mel, fake_mel):
            min_height = min(original_mel.shape[0], fake_mel.shape[0])
            min_width = min(original_mel.shape[1], fake_mel.shape[1])
            original_mel_clipped = original_mel[:min_height, :min_width]
            fake_mel_clipped = fake_mel[:min_height, :min_width]
    
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            im1 = axes[0].imshow(original_mel_clipped, aspect='auto', origin='lower', cmap='viridis')
            axes[0].set_title('Original Mel Spectrogram')
            axes[0].set_xlabel('Time')
            axes[0].set_ylabel('Mel Frequency')
            plt.colorbar(im1, ax=axes[0])
            im2 = axes[1].imshow(fake_mel_clipped, aspect='auto', origin='lower', cmap='viridis')
            axes[1].set_title('Generated Deepfake Mel Spectrogram')
            axes[1].set_xlabel('Time')
            axes[1].set_ylabel('Mel Frequency')
            plt.colorbar(im2, ax=axes[1])
            diff = np.abs(original_mel_clipped - fake_mel_clipped)
            im3 = axes[2].imshow(diff, aspect='auto', origin='lower', cmap='hot')
            axes[2].set_title('Absolute Difference')
            axes[2].set_xlabel('Time')
            axes[2].set_ylabel('Mel Frequency')
            plt.colorbar(im3, ax=axes[2])
            plt.tight_layout()
            plt.close()

    def save_audio(self, mel_spec, filename, sr=22050):
        audio = self.processor.mel_spectrogram_to_audio(mel_spec)
        sf.write(filename, audio, sr)
        print(f"Audio saved as: {filename}")
        
    def generate_deepfake(self, original_mel_spec, noise_levels=[0.0, 0.1]):
        if self.vae is None:
            print("Error: VAE model not trained. Please train the model first.")
            return None

        self.vae.eval()
        deepfake_results = []

        with torch.no_grad():
            if original_mel_spec.shape != self.target_shape:
                print(f"Warning: Input mel shape {original_mel_spec.shape} does not match VAE target shape {self.target_shape}. Resizing.")
                original_mel_spec = np.resize(original_mel_spec, self.target_shape)

            mel_min, mel_max = np.percentile(original_mel_spec, [1, 99])
            if mel_max <= mel_min:
                print("Error: Invalid mel spectrogram range for generation.")
                return original_mel_spec

            mel_norm = np.clip((original_mel_spec - mel_min) / (mel_max - mel_min + 1e-8), 0, 1)
            input_tensor = torch.FloatTensor(mel_norm[None, None, :, :]).to(self.device, non_blocking=True)

            mu, logvar = self.vae.encode(input_tensor)
            for noise_level in noise_levels:
                noise = torch.randn_like(mu) * noise_level
                z_modified = mu + noise
                fake_mel_norm = self.vae.decode(z_modified).cpu().numpy().squeeze()
                fake_mel_norm = np.clip(fake_mel_norm, 0, 1)
                fake_mel = fake_mel_norm * (mel_max - mel_min) + mel_min
                deepfake_results.append((fake_mel, noise_level))
                del noise, z_modified, fake_mel_norm
                torch.cuda.empty_cache()

            del input_tensor, mu, logvar
            torch.cuda.empty_cache()

        return deepfake_results

# Main Function

In [None]:
def get_image_files(directory):
    image_files = []
    valid_extensions = ('.png', '.jpg', '.jpeg')
    if not os.path.exists(directory):
        print(f"Directory {directory} does not exist")
        return image_files

    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(valid_extensions):
                file_path = os.path.join(root, file)
                try:
                    img = Image.open(file_path).convert('L')
                    img_array = np.array(img, dtype=np.float32)
                    image_files.append((file_path, img_array))
                except Exception as e:
                    print(f"Error loading image {file_path}: {e}")
    return image_files

def main():
    torch.cuda.empty_cache()
    generator = AudioDeepfakeGenerator(height=128, width=128)
    base_path = "/kaggle/input/vae128-ctrsdd"
    train_real_path = os.path.join(base_path, "spec_128")
    
    train_files = get_image_files(train_real_path)
    if not train_files:
        print("No image files found in the training directory.")
        return

    train_files, test_files = train_test_split(train_files, test_size=0.2, random_state=42)
    val_files = test_files

    print(f"Found {len(train_files)} training files")
    print(f"Found {len(val_files)} validation files")
    print(f"Found {len(test_files)} testing files")

    print("\nStep 1: Processing training mel spectrograms...")
    try:
        train_spectrograms = generator.preprocess_mel_spectrograms(train_files)
    except ValueError as e:
        print(f"Error: {e}")
        return

    print("\nStep 2: Processing validation mel spectrograms...")
    try:
        val_spectrograms = generator.preprocess_mel_spectrograms(val_files)
    except ValueError as e:
        print(f"Error: {e}")
        return

    print("\nStep 3: Processing testing mel spectrograms...")
    try:
        test_spectrograms = generator.preprocess_mel_spectrograms(test_files)
    except ValueError as e:
        print(f"Error: {e}")
        return

    print(f"Training samples: {len(train_spectrograms)}")
    print(f"Validation samples: {len(val_spectrograms)}")
    print(f"Testing samples: {len(test_spectrograms)}")

    print("\nStep 4: Training convolutional VAE...")
    generator.train_vae(
        mel_spectrograms=train_spectrograms,
        epochs=250,
        batch_size=32,
        learning_rate=1e-4,
        beta=0.1,
        early_stop_patience=60,
        accum_steps=2
    )

    print("\nStep 5: Evaluating model on validation set...")
    generator.evaluate_model(val_spectrograms, beta=0.1)

    print("\nStep 6: Evaluating model on test set...")
    generator.evaluate_model(test_spectrograms, beta=0.1)

    print("\nStep 7: Preparing reference for deepfake generation...")
    for j in range(11):
        reference_file, reference_mel = test_files[j]
        if reference_mel.shape != generator.target_shape:
            if reference_mel.shape[0] > generator.target_shape[0]:
                reference_mel = reference_mel[:generator.target_shape[0], :]
            elif reference_mel.shape[0] < self.target_shape[0]:
                pad_height = generator.target_shape[0] - reference_mel.shape[0]
                reference_mel = np.pad(reference_mel, ((0, pad_height), (0, 0)), 'constant', constant_values=reference_mel.min())
    
            if reference_mel.shape[1] > generator.target_shape[1]:
                reference_mel = reference_mel[:, :generator.target_shape[1]]
            elif reference_mel.shape[1] < generator.target_shape[1]:
                pad_width = generator.target_shape[1] - reference_mel.shape[1]
                reference_mel = np.pad(reference_mel, ((0, 0), (0, pad_width)), 'constant', constant_values=reference_mel.min())
    
        print(f"Reference file: {os.path.basename(reference_file)}, shape: {reference_mel.shape}")
        generator.processor.plot_mel_spectrogram(reference_mel, "Original Reference Mel Spectrogram")
    
        print("\nStep 8: Generating high-quality deepfakes...")
        noise_levels = [0.0]
        deepfake_results = generator.generate_deepfake(reference_mel, noise_levels=noise_levels)
    
        for i, (fake_mel, noise_level) in enumerate(deepfake_results):
            print(f"Generated deepfake variation {i+1}/{len(deepfake_results)} (noise={noise_level})...")
            generator.save_audio(fake_mel, f"high_quality_deepfake_{j+1}_noise_{noise_level}.wav")
            results = generator.compare_audio(reference_mel, fake_mel)
            print(f"  MSE: {results['mse']:.4f}, Cosine Sim: {results['cosine_similarity']:.4f}")
    
        best_idx = 0
        best_fake_mel, best_noise = deepfake_results[best_idx]
        generator.processor.plot_mel_spectrogram(best_fake_mel, "fake Reference Mel Spectrogram")
        print(f"\nStep 9: Detailed analysis of best deepfake (noise={best_noise})...")
        detailed_results = generator.compare_audio(reference_mel, best_fake_mel)
    
        print("\nStep 10: Visualizing results...")
        generator.visualize_comparison(reference_mel, best_fake_mel)
        generator.save_audio(reference_mel, "original_reference_improved.wav")

    print("\nImproved deepfake generation completed!")
    print(f"Generated {len(deepfake_results)} high-quality variations")
    print("Files saved:")
    print("- original_reference_improved.wav")
    for i, (_, noise) in enumerate(deepfake_results):
        print(f"- high_quality_deepfake_{i+1}_noise_{noise}.wav")

    print(f"\nPerformance Summary:")
    print(f"Best deepfake metrics:")
    print(f"  MSE: {detailed_results['mse']:.6f}")
    print(f"  Cosine Similarity: {detailed_results['cosine_similarity']:.6f}")
    print(f"  Correlation: {detailed_results['correlation']:.6f}")

    torch.cuda.empty_cache()

if __name__ == "__main__":
    main()