In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np

import torchvision.utils as vutils
import os

import librosa
import numpy as np
from glob import glob
from torch.utils.data import DataLoader, TensorDataset  

In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
manualSeed = 6789
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(manualSeed)

# Generator shared layers
class GeneratorSharedLayers(nn.Module):
    def __init__(self, ngf, nc):
        super(GeneratorSharedLayers, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Generator with unique input layer and shared layers
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc, shared_layers, mel_bins, time_frames):
        super(Generator, self).__init__()
        self.ngf = ngf
        self.mel_bins = mel_bins
        self.time_frames = time_frames
        self.input_layer = nn.Sequential(
            nn.Linear(nz, ngf * 8 * (mel_bins // 8) * (time_frames // 8)),
            nn.BatchNorm1d(ngf * 8 * (mel_bins // 8) * (time_frames // 8)),
            nn.ReLU(True)
        )
        self.shared_layers = shared_layers

    def forward(self, input):
        x = self.input_layer(input)
        x = x.view(-1, self.ngf * 8, self.mel_bins // 8, self.time_frames // 8)
        x = self.shared_layers(x)
        return x

# Discriminator shared layers
class DiscriminatorSharedLayers(nn.Module):
    def __init__(self, ndf, nc):
        super(DiscriminatorSharedLayers, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, input):
        return self.main(input)

# Discriminator with shared layers and unique output layers
class Discriminator(nn.Module):
    def __init__(self, ndf, nc, shared_layers, num_gens, mel_bins, time_frames):
        super(Discriminator, self).__init__()
        self.shared_layers = shared_layers
        self.output_bin = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, (mel_bins // 16, time_frames // 16), 1, 0, bias=False),
            nn.Sigmoid()
        )
        self.output_mul = nn.Sequential(
            nn.Conv2d(ndf * 8, num_gens, (mel_bins // 16, time_frames // 16), 1, 0, bias=False)
        )

    def forward(self, input):
        x = self.shared_layers(input)
        output_bin = self.output_bin(x).view(-1, 1).squeeze(1)
        output_mul = self.output_mul(x).squeeze()
        return output_bin, output_mul

# MGAN class encapsulating the training loop
class MGAN:
    def __init__(self, num_z, beta, num_gens, batch_size, z_prior, learning_rate,
                 num_epochs, img_size, num_gen_feature_maps, num_dis_feature_maps,
                 sample_dir, device):
        self.num_z = num_z
        self.beta = beta
        self.num_gens = num_gens
        self.batch_size = batch_size
        self.z_prior = z_prior
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.img_size = img_size
        self.ngf = num_gen_feature_maps
        self.ndf = num_dis_feature_maps
        self.sample_dir = sample_dir
        self.device = device

        self.mel_bins, self.num_frames, self.num_channels = img_size
        self.history = {'d_loss': [], 'g_loss': []}

        self._build_model()

    def _build_model(self):
        mel_bins, time_frames, num_channels = self.img_size

        self.shared_gen_layers = GeneratorSharedLayers(self.ngf, num_channels).to(self.device)
        self.generators = nn.ModuleList([
            Generator(self.num_z, self.ngf, num_channels, self.shared_gen_layers, mel_bins, time_frames).to(self.device)
            for _ in range(self.num_gens)
        ])
        self.shared_dis_layers = DiscriminatorSharedLayers(self.ndf, num_channels).to(self.device)
        self.discriminator = Discriminator(self.ndf, num_channels, self.shared_dis_layers, self.num_gens, mel_bins, time_frames).to(self.device)

        self.optimizerD = optim.Adam(
            list(self.discriminator.parameters()) + list(self.shared_dis_layers.parameters()),
            lr=self.learning_rate, betas=(0.5, 0.999)
        )
        gen_params = [param for gen in self.generators for param in gen.input_layer.parameters()]
        gen_params += list(self.shared_gen_layers.parameters())
        self.optimizerG = optim.Adam(gen_params, lr=self.learning_rate, betas=(0.5, 0.999))

        self.criterion_bin = nn.BCELoss()
        self.criterion_mul = nn.CrossEntropyLoss()

    def fit(self, trainloader):
        fixed_noise = self._sample_z(self.num_gens * 16).to(self.device)

        real_label = 1.0
        fake_label = 0.0

        for epoch in range(self.num_epochs):
            for i, data in enumerate(trainloader):
                real_images = data[0].to(self.device)
                b_size = real_images.size(0)
                label_real = torch.full((b_size,), real_label, device=self.device)
                label_fake = torch.full((b_size,), fake_label, device=self.device)

                output_bin_real, _ = self.discriminator(real_images)
                d_bin_real_loss = self.criterion_bin(output_bin_real, label_real)

                # Generate fake images and labels
                fake_images = []
                gen_labels = []
                for idx, gen in enumerate(self.generators):
                    z = self._sample_z(b_size // self.num_gens).to(self.device)
                    fake_imgs = gen(z)
                    fake_images.append(fake_imgs)
                    gen_labels.append(torch.full((fake_imgs.size(0),), idx, dtype=torch.long, device=self.device))

                fake_images = torch.cat(fake_images, 0)
                gen_labels = torch.cat(gen_labels, 0)

                # Get discriminator output for fake images
                output_bin_fake, output_mul_fake = self.discriminator(fake_images.detach())

                # Create labels with the correct size
                output_size = output_bin_fake.size()  # Match the shape of the output
                label_real = torch.full(output_size, real_label, device=self.device, dtype=torch.float)
                label_fake = torch.full(output_size, fake_label, device=self.device, dtype=torch.float)

                # Compute losses
                d_bin_fake_loss = self.criterion_bin(output_bin_fake, label_fake)
                d_mul_loss = self.criterion_mul(output_mul_fake.view(-1, self.num_gens), gen_labels)
                self.optimizerD.step()

                for gen in self.generators:
                    gen.zero_grad()
                self.shared_gen_layers.zero_grad()

                label_real = torch.full((fake_images.size(0),), real_label, device=self.device)
                output_bin_fake, output_mul_fake = self.discriminator(fake_images)
                g_bin_loss = self.criterion_bin(output_bin_fake, label_real)
                g_mul_loss = self.criterion_mul(output_mul_fake.view(-1, self.num_gens), gen_labels) * self.beta

                g_loss = g_bin_loss + g_mul_loss
                g_loss.backward()
                self.optimizerG.step()

                self.history['d_loss'].append(d_loss.item())
                self.history['g_loss'].append(g_loss.item())

            print(f"[{epoch+1}/{self.num_epochs}] d_loss: {d_loss.item():.4f} | g_loss: {g_loss.item():.4f}")
            if (epoch+1) % 5 == 0:
                self._save_samples(epoch+1, fixed_noise)

        self._plot_history()

    def _sample_z(self, size):
        if self.z_prior == "uniform":
            return torch.rand(size, self.num_z) * 2 - 1
        return torch.randn(size, self.num_z)

    def _save_samples(self, epoch, fixed_noise):
        with torch.no_grad():
            fake_images = []
            for idx, gen in enumerate(self.generators):
                noise = fixed_noise[idx * 16:(idx + 1) * 16].to(self.device)
                gen.eval()
                fake_imgs = gen(noise)
                gen.train()
                fake_images.append(fake_imgs)

            fake_images = torch.cat(fake_images, 0)
            fake_images = (fake_images + 1) / 2.0
            os.makedirs(self.sample_dir, exist_ok=True)
            sample_path = os.path.join(self.sample_dir, f"epoch_{epoch:04d}.png")
            vutils.save_image(fake_images, sample_path, nrow=16, padding=2, normalize=True)
            print(f"Saved samples to {sample_path}")

    def _plot_history(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.history['d_loss'], label="D Loss")
        plt.plot(self.history['g_loss'], label="G Loss")
        plt.title("Loss During Training")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()


In [7]:
def librispeech_to_mel(root_dir, target_sr=16000, n_mels=64, n_fft=1024, hop_length=512, target_length=128):
    """
    Convert raw LibriSpeech audio files to Mel spectrograms and return a TensorDataset.
    """
    from glob import glob

    flac_files = glob(os.path.join(root_dir, '**', '*.flac'), recursive=True)
    mel_spectrograms = []  # Initialize list to store Mel spectrograms

    if len(flac_files) == 0:
        raise ValueError(f"No FLAC files found in directory: {root_dir}")

    for file in flac_files:
        try:
            # Load the audio file
            audio, sr = librosa.load(file, sr=target_sr)
            # Normalize audio to range [-1, 1]
            audio = audio / np.max(np.abs(audio))
            
            # Convert to Mel spectrogram
            mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
            
            # Convert to dB scale
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
            
            # Ensure the Mel spectrogram has a fixed length
            if mel_spec_db.shape[1] < target_length:
                # Pad if shorter
                pad_width = target_length - mel_spec_db.shape[1]
                mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, pad_width)), mode='constant')
            else:
                # Trim if longer
                mel_spec_db = mel_spec_db[:, :target_length]

            # Add channel dimension and append to list
            mel_tensor = torch.tensor(mel_spec_db, dtype=torch.float32).unsqueeze(0)  # Add channel dim
            mel_spectrograms.append(mel_tensor)
        
        except Exception as e:
            print(f"Error processing file {file}: {e}")

    if len(mel_spectrograms) == 0:
        raise ValueError("No valid audio files were processed into Mel spectrograms.")

    # Stack into a single tensor
    mel_spectrograms = torch.stack(mel_spectrograms)  # Shape: [num_samples, 1, n_mels, target_length]
    return TensorDataset(mel_spectrograms)


In [8]:
import torch
from torch.utils.data import DataLoader
from mgan_model import MGAN  # Import the MGAN class
from your_dataset_loader import librispeech_to_mel  # Your dataset preprocessing function

def main():
    # Hyperparameters
    num_z = 100
    beta = 0.5
    num_gens = 10
    batch_size = 32
    z_prior = "gaussian"
    learning_rate = 0.0002
    num_epochs = 50

    # Spectrogram dimensions
    mel_bins = 64
    num_frames = 128
    num_channels = 1
    img_size = (mel_bins, num_frames, num_channels)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Dataset preprocessing
    data_dir = r'data\LibriSpeech\LibriSpeech\dev-clean'
    print("Preprocessing dataset into Mel spectrograms...")
    mel_dataset = librispeech_to_mel(data_dir, target_sr=16000, n_mels=mel_bins, target_length=num_frames)
    print(f"Dataset size: {len(mel_dataset)}")

    # DataLoader
    dataloader = DataLoader(mel_dataset, batch_size=batch_size, shuffle=True)

    # Initialize MGAN
    mgan_model = MGAN(
        num_z=num_z,
        beta=beta,
        num_gens=num_gens,
        batch_size=batch_size,
        z_prior=z_prior,
        learning_rate=learning_rate,
        num_epochs=num_epochs,
        img_size=img_size,
        num_gen_feature_maps=64,
        num_dis_feature_maps=64,
        sample_dir="samples",
        device=device
    )

    # Train the model
    print("Starting training...")
    mgan_model.fit(dataloader)
    print("Training completed.")

if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'mgan_model'