In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from model import MGAN  # Import the MGAN class
from dataset_load import librispeech_to_mel  # Your dataset preprocessing function

def librispeech_to_mel(root_dir, target_sr=16000, n_mels=64, n_fft=1024, hop_length=512, target_length=128):
    """
    Convert raw LibriSpeech audio files to Mel spectrograms and return a TensorDataset.
    """
    flac_files = glob.glob(os.path.join(root_dir, '**', '*.flac'), recursive=True)
    mel_spectrograms = []  # Initialize list to store Mel spectrograms

    if len(flac_files) == 0:
        raise ValueError(f"No FLAC files found in directory: {root_dir}")

    for file in flac_files:
        try:
            # Load the audio file
            audio, sr = librosa.load(file, sr=target_sr)
            # Normalize audio to range [-1, 1]
            audio = audio / np.max(np.abs(audio))
            
            # Convert to Mel spectrogram
            mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
            
            # Convert to dB scale
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
            
            # Ensure the Mel spectrogram has a fixed length
            if mel_spec_db.shape[1] < target_length:
                # Pad if shorter
                pad_width = target_length - mel_spec_db.shape[1]
                mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, pad_width)), mode='constant')
            else:
                # Trim if longer
                mel_spec_db = mel_spec_db[:, :target_length]

            # Add channel dimension and append to list
            mel_tensor = torch.tensor(mel_spec_db, dtype=torch.float32).unsqueeze(0)  # Add channel dim
            mel_spectrograms.append(mel_tensor)
        
        except Exception as e:
            print(f"Error processing file {file}: {e}")

    if len(mel_spectrograms) == 0:
        raise ValueError("No valid audio files were processed into Mel spectrograms.")

    # Stack into a single tensor
    mel_spectrograms = torch.stack(mel_spectrograms)  # Shape: [num_samples, 1, n_mels, target_length]
    return TensorDataset(mel_spectrograms)


In [None]:
# Hyperparameters
num_z = 100
beta = 0.5
num_gens = 10
batch_size = 16
z_prior = "gaussian"
learning_rate = 0.0002
num_epochs = 50

# Spectrogram dimensions
mel_bins = 64
num_frames = 128
num_channels = 1
img_size = (mel_bins, num_frames, num_channels)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset preprocessing
data_dir = os.path.join("data", "LibriSpeech", "LibriSpeech", "dev-clean")
print("Preprocessing dataset into Mel spectrograms...")
try:
    mel_dataset = librispeech_to_mel(data_dir, target_sr=16000, n_mels=mel_bins, target_length=num_frames)
    print(f"Dataset size: {len(mel_dataset)}")
except Exception as e:
    print(f"Error during dataset preprocessing: {e}")
    return

# Check if dataset is empty
if len(mel_dataset) == 0:
    raise ValueError("The dataset is empty. Please check your data directory and preprocessing function.")

# Save visualizations of the first 5 Mel spectrograms
visualization_dir = "mel_spectrograms"
print("Saving Mel spectrogram visualizations...")
save_mel_spectrograms(mel_dataset, visualization_dir, num_to_visualize=5)

# DataLoader
dataloader = DataLoader(mel_dataset, batch_size=batch_size, shuffle=True)

# Initialize MGAN
sample_dir = "samples"
os.makedirs(sample_dir, exist_ok=True)
mgan_model = MGAN(
    num_z=num_z,
    beta=beta,
    num_gens=num_gens,
    batch_size=batch_size,
    z_prior=z_prior,
    learning_rate=learning_rate,
    num_epochs=num_epochs,
    img_size=img_size,
    num_gen_feature_maps=64,
    num_dis_feature_maps=64,
    sample_dir=sample_dir,
    device=device
)

# Load pre-trained model if available
model_path = "mgan_model.pth"
if os.path.exists(model_path):
    try:
        mgan_model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"Model loaded from {model_path}")
    except Exception as e:
        print(f"Error loading model: {e}")
        return

# Train the model
print("Starting training...")
try:
    mgan_model.fit(dataloader)
    print("Training completed.")
except Exception as e:
    print(f"Error during training: {e}")
    return

# Save the trained model
try:
    torch.save(mgan_model.state_dict(), model_path)
    print(f"Model saved to {model_path}")
except Exception as e:
    print(f"Error saving model: {e}")