#  orthogonal loss 
This model work well

In [29]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import librosa
import soundfile as sf
from glob import glob
from phonemizer import phonemize
from torch.nn.utils import spectral_norm

In [30]:
# Set random seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [31]:
# Load a FLAC audio file and resample it to the target sample rate
def load_flac(file_path, target_sr=16000):
    audio, sr = librosa.load(file_path, sr=target_sr)
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = audio / max_val  # Normalize to [-1, 1]
    return audio

# Pad or trim audio to the target length
def pad_or_trim(audio, target_length=64000):
    if len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)))
    else:
        audio = audio[:target_length]
    return audio

In [32]:
# Retrieve the transcription for a given audio file from the corresponding .trans.txt file
def get_transcription(file_path):
    dir_path = os.path.dirname(file_path)
    base_name = os.path.basename(file_path)
    file_id = os.path.splitext(base_name)[0]

    # Locate the transcription file
    transcription_file = None
    for file in os.listdir(dir_path):
        if file.endswith(".trans.txt"):
            transcription_file = os.path.join(dir_path, file)
            break

    if not transcription_file:
        raise FileNotFoundError(f"No transcription file found in {dir_path}")

    # Read the transcription file and find the transcription for the current audio file
    with open(transcription_file, "r") as f:
        for line in f:
            parts = line.strip().split(" ", 1)
            if parts[0] == file_id:
                transcription = parts[1]
                return transcription

    raise ValueError(f"No transcription found for file {file_id}")

# Extract phonetic features from transcription
def get_phonetic_features(transcription, max_length=100):
    phonemes = phonemize(transcription, backend="espeak", language="en-us")
    phoneme_to_id = {char: idx for idx, char in enumerate(sorted(set(phonemes)))}
    phonetic_features = [phoneme_to_id[p] for p in phonemes]

    # Convert to tensor and pad/truncate
    phonetic_features = torch.tensor(phonetic_features, dtype=torch.float32)
    if len(phonetic_features) < max_length:
        phonetic_features = nn.functional.pad(phonetic_features, (0, max_length - len(phonetic_features)))
    else:
        phonetic_features = phonetic_features[:max_length]
    return phonetic_features


In [33]:
# Preprocess the dataset and create a TensorDataset
def preprocess_dataset(root_dir, target_sr=16000, target_length=64000, feature_length=100):
    flac_files = glob(os.path.join(root_dir, '**', '*.flac'), recursive=True)
    print(f"Found {len(flac_files)} .flac files in {root_dir}.")
    if len(flac_files) == 0:
        print("No .flac files found. Please check the root_dir path.")
    audio_dataset = []
    feature_dataset = []
    for file in flac_files:
        try:
            audio = load_flac(file, target_sr)
            audio = pad_or_trim(audio, target_length)
            transcription = get_transcription(file)
            phonetic_features = get_phonetic_features(transcription, max_length=feature_length)
            audio_dataset.append(audio)
            feature_dataset.append(phonetic_features)
        except Exception as e:
            print(f"Error processing file {file}: {e}")
    audio_dataset = torch.tensor(audio_dataset, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
    feature_dataset = torch.stack(feature_dataset)  # Stack tensors
    return TensorDataset(audio_dataset, feature_dataset)

In [34]:
# Save a waveform to an audio file
def save_waveform_to_audio(waveform, sample_rate, filename):
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.detach().cpu().numpy()
    waveform = np.squeeze(waveform)
    max_val = np.max(np.abs(waveform))
    if max_val > 0:
        waveform = waveform / max_val
    sf.write(filename, waveform, sample_rate)

# Verify waveform-to-audio conversion using preprocessed dataset
def verify_waveform_to_audio(root_dir, sample_rate=16000, target_length=64000, output_dir="verified_audio"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    dataset = preprocess_dataset(root_dir, sample_rate, target_length)
    num_samples_to_verify = min(5, len(dataset))
    for idx in range(num_samples_to_verify):
        waveform = dataset[idx][0]  # Access audio data
        filename = os.path.join(output_dir, f"example_waveform_{idx+1}.wav")
        save_waveform_to_audio(waveform, sample_rate, filename)
        print(f"Waveform saved to {filename}")
        # Plot the waveform
        plt.figure(figsize=(12, 4))
        plt.plot(waveform.numpy().squeeze())
        plt.title(f"Waveform {idx+1}")
        plt.xlabel("Sample Index")
        plt.ylabel("Amplitude")
        plt.show()

In [35]:
# Generate noise for the generator
def generate_noise(batch_size, z_dim, device):
    return torch.randn(batch_size, z_dim).to(device)

# Orthogonal loss function
def orthogonal_loss(feature1, feature2):
    inner_product = torch.sum(feature1 * feature2, dim=1)
    norm1 = torch.norm(feature1, dim=1)
    norm2 = torch.norm(feature2, dim=1)
    cosine_similarity = inner_product / (norm1 * norm2 + 1e-8)
    return torch.mean(cosine_similarity**2)  # Minimize the cosine similarity to make vectors orthogonal

# Compute gradient penalty for WGAN-GP
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
    batch_size = real_samples.size(0)
    epsilon = torch.rand(batch_size, 1, 1, device=device)
    epsilon = epsilon.expand_as(real_samples)
    interpolates = (epsilon * real_samples + (1 - epsilon) * fake_samples).requires_grad_(True)
    interpolates_output = discriminator(interpolates)
    gradients = torch.autograd.grad(
        outputs=interpolates_output,
        inputs=interpolates,
        grad_outputs=torch.ones_like(interpolates_output),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    return gradient_penalty


In [36]:
# Visualize and save generated waveforms
def visualize_and_save_generated_waveforms(generators, z_dim, features, num_waveforms, device, epoch, sample_rate=16000, output_dir='generated_audio'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for idx, gen in enumerate(generators):
        gen.eval()
        with torch.no_grad():
            noise = generate_noise(num_waveforms, z_dim, device)
            fake_waveforms = gen(features[:num_waveforms].to(device), noise).cpu()
            for i in range(num_waveforms):
                waveform = fake_waveforms[i]
                # Save each waveform to an audio file
                filename = f'epoch{epoch+1}_gen{idx+1}_sample{i+1}.wav'
                filepath = os.path.join(output_dir, filename)
                save_waveform_to_audio(waveform, sample_rate, filepath)
                print(f"Saved {filepath}")

In [37]:
# Generator architecture using Conv1d layers
class Generator(nn.Module):
    def __init__(self, in_channels=100, z_channels=128):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.z_channels = z_channels

        self.preprocess = nn.Conv1d(in_channels, 768, kernel_size=3, padding=1)
        self.gblocks = nn.ModuleList([
            GBlock(768, 768, z_channels, 1),
            GBlock(768, 768, z_channels, 1),
            GBlock(768, 384, z_channels, 2),
            GBlock(384, 384, z_channels, 2),
            GBlock(384, 384, z_channels, 2),
            GBlock(384, 192, z_channels, 3),
            GBlock(192, 96, z_channels, 5)
        ])
        self.postprocess = nn.Sequential(
            nn.Conv1d(96, 1, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, inputs, z):
        inputs = self.preprocess(inputs)
        outputs = inputs
        for layer in self.gblocks:
            outputs = layer(outputs, z)
        outputs = self.postprocess(outputs)
        return outputs

In [38]:
class GBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, z_channels, upsample_factor):
        super(GBlock, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.z_channels = z_channels
        self.upsample_factor = upsample_factor

        self.condition_batchnorm1 = ConditionalBatchNorm1d(in_channels, z_channels)
        self.first_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            UpsampleNet(in_channels, in_channels, upsample_factor),
            nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1)
        )
        self.condition_batchnorm2 = ConditionalBatchNorm1d(hidden_channels, z_channels)
        self.second_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=2, padding=2)
        )
        self.residual1 = nn.Sequential(
            UpsampleNet(in_channels, in_channels, upsample_factor),
            nn.Conv1d(in_channels, hidden_channels, kernel_size=1)
        )
        self.condition_batchnorm3 = ConditionalBatchNorm1d(hidden_channels, z_channels)
        self.third_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=4, padding=4)
        )
        self.condition_batchnorm4 = ConditionalBatchNorm1d(hidden_channels, z_channels)
        self.fourth_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=8, padding=8)
        )

    def forward(self, condition, z):
        inputs = condition
        outputs = self.condition_batchnorm1(inputs, z)
        outputs = self.first_stack(outputs)
        outputs = self.condition_batchnorm2(outputs, z)
        outputs = self.second_stack(outputs)
        residual_outputs = self.residual1(inputs) + outputs
        outputs = self.condition_batchnorm3(residual_outputs, z)
        outputs = self.third_stack(outputs)
        outputs = self.condition_batchnorm4(outputs, z)
        outputs = self.fourth_stack(outputs)
        outputs = outputs + residual_outputs
        return outputs

In [39]:
class UpsampleNet(nn.Module):
    def __init__(self, input_size, output_size, upsample_factor):
        super(UpsampleNet, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.upsample_factor = upsample_factor

        layer = nn.ConvTranspose1d(input_size, output_size, upsample_factor * 2,
                                   upsample_factor, padding=upsample_factor // 2)
        nn.init.orthogonal_(layer.weight)
        self.layer = spectral_norm(layer)

    def forward(self, inputs):
        outputs = self.layer(inputs)
        outputs = outputs[:, :, : inputs.size(-1) * self.upsample_factor]
        return outputs

In [40]:
class ConditionalBatchNorm1d(nn.Module):
    """Conditional Batch Normalization"""
    def __init__(self, num_features, z_channels=128):
        super().__init__()
        self.num_features = num_features
        self.z_channels = z_channels
        self.batch_norm = nn.BatchNorm1d(num_features, affine=False)
        self.layer = spectral_norm(nn.Linear(z_channels, num_features * 2))
        self.layer.weight.data.normal_(1, 0.02)
        self.layer.bias.data.zero_()

    def forward(self, inputs, noise):
        outputs = self.batch_norm(inputs)
        gamma, beta = self.layer(noise).chunk(2, 1)
        gamma = gamma.view(-1, self.num_features, 1)
        beta = beta.view(-1, self.num_features, 1)
        outputs = gamma * outputs + beta
        return outputs

In [41]:
# Discriminator architecture using Conv1d layers
class Multiple_Random_Window_Discriminators(nn.Module):
    def __init__(self, lc_channels, window_size=(2, 4, 8, 16, 30), upsample_factor=120):
        super(Multiple_Random_Window_Discriminators, self).__init__()
        self.lc_channels = lc_channels
        self.window_size = window_size
        self.upsample_factor = upsample_factor

        self.udiscriminators = nn.ModuleList([
            UnConditionalDBlocks(in_channels=1, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=2, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=4, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=8, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=15, factors=(2, 2), out_channels=(128, 256)),
        ])

        self.discriminators = nn.ModuleList([
            ConditionalDBlocks(in_channels=1, lc_channels=lc_channels,
                               factors=(5, 3, 2, 2, 2), out_channels=(128, 128, 256, 256)),
            ConditionalDBlocks(in_channels=2, lc_channels=lc_channels,
                               factors=(5, 3, 2, 2), out_channels=(128, 256, 256)),
            ConditionalDBlocks(in_channels=4, lc_channels=lc_channels,
                               factors=(5, 3, 2), out_channels=(128, 256)),
            ConditionalDBlocks(in_channels=8, lc_channels=lc_channels,
                               factors=(5, 3), out_channels=(256,)),
            ConditionalDBlocks(in_channels=15, lc_channels=lc_channels,
                               factors=(2, 2, 2), out_channels=(128, 256)),
        ])

    def forward(self, real_samples, fake_samples, conditions):
        real_outputs, fake_outputs = [], []
        # Unconditional discriminator
        for (size, layer) in zip(self.window_size, self.udiscriminators):
            size = size * self.upsample_factor
            index = np.random.randint(0, real_samples.size(-1) - size + 1)
            real_output = layer(real_samples[:, :, index: index + size])
            real_outputs.append(real_output)
            fake_output = layer(fake_samples[:, :, index: index + size])
            fake_outputs.append(fake_output)
        # Conditional discriminator
        for (size, layer) in zip(self.window_size, self.discriminators):
            lc_index = np.random.randint(0, conditions.size(-1) - size + 1)
            sample_index = lc_index * self.upsample_factor
            real_x = real_samples[:, :, sample_index: (lc_index + size) * self.upsample_factor]
            fake_x = fake_samples[:, :, sample_index: (lc_index + size) * self.upsample_factor]
            lc = conditions[:, :, lc_index: lc_index + size]
            real_output = layer(real_x, lc)
            real_outputs.append(real_output)
            fake_output = layer(fake_x, lc)
            fake_outputs.append(fake_output)
        return real_outputs, fake_outputs


In [42]:

class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample_factor):
        super(DBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample_factor = downsample_factor
        self.layers = nn.Sequential(
            nn.AvgPool1d(downsample_factor, stride=downsample_factor),
            nn.ReLU(),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2)
        )
        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool1d(downsample_factor, stride=downsample_factor)
        )

    def forward(self, inputs):
        outputs = self.layers(inputs) + self.residual(inputs)
        return outputs

In [43]:
class CondDBlock(nn.Module):
    def __init__(self, in_channels, lc_channels, downsample_factor):
        super(CondDBlock, self).__init__()
        self.in_channels = in_channels
        self.lc_channels = lc_channels
        self.downsample_factor = downsample_factor
        self.start = nn.Sequential(
            nn.AvgPool1d(downsample_factor, stride=downsample_factor),
            nn.ReLU(),
            nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1)
        )
        self.lc_conv1d = nn.Conv1d(lc_channels, in_channels * 2, kernel_size=1)
        self.end = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2)
        )
        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
            nn.AvgPool1d(downsample_factor, stride=downsample_factor)
        )

    def forward(self, inputs, conditions):
        outputs = self.start(inputs) + self.lc_conv1d(conditions)
        outputs = self.end(outputs)
        residual_outputs = self.residual(inputs)
        outputs = outputs + residual_outputs
        return outputs

In [44]:
class ConditionalDBlocks(nn.Module):
    def __init__(self, in_channels, lc_channels, factors=(2, 2, 2), out_channels=(128, 256)):
        super(ConditionalDBlocks, self).__init__()
        assert len(factors) == len(out_channels) + 1
        self.in_channels = in_channels
        self.lc_channels = lc_channels
        self.factors = factors
        self.out_channels = out_channels
        self.layers = nn.ModuleList()
        self.layers.append(DBlock(in_channels, 64, 1))
        in_channels = 64
        for (i, channel) in enumerate(out_channels):
            self.layers.append(DBlock(in_channels, channel, factors[i]))
            in_channels = channel
        self.cond_layer = CondDBlock(in_channels, lc_channels, factors[-1])
        self.post_process = nn.ModuleList([
            DBlock(in_channels * 2, in_channels * 2, 1),
            DBlock(in_channels * 2, in_channels * 2, 1)
        ])

    def forward(self, inputs, conditions):
        batch_size = inputs.size()[0]
        outputs = inputs.view(batch_size, self.in_channels, -1)
        for layer in self.layers:
            outputs = layer(outputs)
        outputs = self.cond_layer(outputs, conditions)
        for layer in self.post_process:
            outputs = layer(outputs)
        return outputs

In [45]:
class UnConditionalDBlocks(nn.Module):
    def __init__(self, in_channels, factors=(5, 3), out_channels=(128, 256)):
        super(UnConditionalDBlocks, self).__init__()
        self.in_channels = in_channels
        self.factors = factors
        self.out_channels = out_channels
        self.layers = nn.ModuleList()
        self.layers.append(DBlock(in_channels, 64, 1))
        in_channels = 64
        for (i, factor) in enumerate(factors):
            self.layers.append(DBlock(in_channels, out_channels[i], factor))
            in_channels = out_channels[i]
        self.layers.append(DBlock(in_channels, in_channels, 1))
        self.layers.append(DBlock(in_channels, in_channels, 1))

    def forward(self, inputs):
        batch_size = inputs.size()[0]
        outputs = inputs.view(batch_size, self.in_channels, -1)
        for layer in self.layers:
            outputs = layer(outputs)
        return outputs

In [46]:
def train_gan_with_pretrained_generators(
    pretrained_generator, num_epochs, z_dim, lr_gen, lr_disc, batch_size, train_dataset,
    num_generators, seed, audio_length, output_dir, lambda_gp=10, lambda_ortho=0.1, num_critic=5,
    checkpoint_dir='checkpoints', resume=True
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize multiple generators with the pretrained generator
    generators = []
    for _ in range(num_generators):
        gen = Generator(in_channels=train_dataset[0][1].shape[0], z_channels=z_dim).to(device)
        gen.load_state_dict(pretrained_generator.state_dict())
        generators.append(gen)

    # Initialize Discriminator
    discriminator = Multiple_Random_Window_Discriminators(lc_channels=train_dataset[0][1].shape[0]).to(device)

    optimizer_gens = [optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.5, 0.9)) for gen in generators]
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    # Load and preprocess the audio dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Training loop
    for epoch in range(num_epochs):
        torch.cuda.empty_cache()
        for batch_idx, (real_audio, features) in enumerate(train_loader):
            real_audio = real_audio.to(device)
            features = features.unsqueeze(1).to(device)  # Add channel dimension if necessary
            batch_size = real_audio.size(0)

            # Train Discriminator
            for _ in range(num_critic):
                optimizer_disc.zero_grad()
                disc_real_outputs, disc_fake_outputs = [], []
                noises = [generate_noise(batch_size, z_dim, device) for _ in range(num_generators)]
                fakes = [gen(features, noises[idx]) for idx, gen in enumerate(generators)]
                real_outputs, fake_outputs = discriminator(real_audio, torch.stack(fakes), features)
                loss_disc = sum([torch.mean(fake) - torch.mean(real) for real, fake in zip(real_outputs, fake_outputs)])
                gradient_penalty = compute_gradient_penalty(discriminator, real_audio, fakes[0], device)
                loss_disc += lambda_gp * gradient_penalty
                loss_disc.backward()
                optimizer_disc.step()

            # Train Generators
            for idx, gen in enumerate(generators):
                optimizer_gens[idx].zero_grad()
                noise = generate_noise(batch_size, z_dim, device)
                fake = gen(features, noise)
                fake_outputs = discriminator(fake, fake, features)[1]
                loss_gen = -torch.mean(torch.stack(fake_outputs))
                loss_gen.backward()
                optimizer_gens[idx].step()

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

        # Save generated samples and models
        visualize_and_save_generated_waveforms(
            generators, z_dim, features, num_waveforms=5, device=device, epoch=epoch, sample_rate=16000, output_dir=output_dir
        )
        for idx, gen in enumerate(generators):
            torch.save(gen.state_dict(), os.path.join(output_dir, f"generator_{idx}_epoch{epoch+1}.pth"))
        torch.save(discriminator.state_dict(), os.path.join(output_dir, f"discriminator_epoch{epoch+1}.pth"))

    print("Training complete.")


In [47]:
# Pretrain a single generator
def pretrain_single_generator(num_epochs, z_dim, lr_gen, lr_disc, batch_size, seed, audio_length, output_dir, train_dataset):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define the single generator and discriminator
    generator = Generator(in_channels=train_dataset[0][1].shape[0], z_channels=z_dim).to(device)
    discriminator = Multiple_Random_Window_Discriminators(lc_channels=train_dataset[0][1].shape[0]).to(device)

    optimizer_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.9))
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    # Load and preprocess the audio dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Training loop
    for epoch in range(num_epochs):
        torch.cuda.empty_cache()
        for batch_idx, (real_audio, features) in enumerate(train_loader):
            real_audio = real_audio.to(device)
            features = features.unsqueeze(1).to(device)
            batch_size = real_audio.size(0)

            # Train Discriminator
            for _ in range(5):
                optimizer_disc.zero_grad()
                noise = generate_noise(batch_size, z_dim, device)
                fake_audio = generator(features, noise).detach()
                real_outputs, fake_outputs = discriminator(real_audio, fake_audio, features)
                loss_disc = sum([torch.mean(fake) - torch.mean(real) for real, fake in zip(real_outputs, fake_outputs)])
                gradient_penalty = compute_gradient_penalty(discriminator, real_audio, fake_audio, device)
                loss_disc += 10 * gradient_penalty
                loss_disc.backward()
                optimizer_disc.step()

            # Train Generator
            optimizer_gen.zero_grad()
            noise = generate_noise(batch_size, z_dim, device)
            fake_audio = generator(features, noise)
            fake_outputs = discriminator(fake_audio, fake_audio, features)[1]
            loss_gen = -torch.mean(torch.stack(fake_outputs))
            loss_gen.backward()
            optimizer_gen.step()

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

        # Save generated samples and models
        visualize_and_save_generated_waveforms(
            [generator], z_dim, features, num_waveforms=5, device=device, epoch=epoch, sample_rate=16000, output_dir=output_dir
        )
        torch.save(generator.state_dict(), os.path.join(output_dir, f"pretrained_generator_epoch{epoch+1}.pth"))

    print("Pretraining complete.")
    return generator

In [48]:
set_seed(42)
audio_length = 64000
z_dim = 128
lr_gen = 0.0002
lr_disc = 0.0002
batch_size = 16
num_epochs = 50
root_dir = "./data"
sample_rate = 16000
num_generators = 5
output_dir = 'generated_audio'

In [49]:
train_dataset = preprocess_dataset(root_dir, target_sr=sample_rate, target_length=audio_length)


Found 2703 .flac files in ./data.
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0000.flac: module 'numpy' has no attribute 'round'
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0001.flac: module 'numpy' has no attribute 'round'
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0002.flac: module 'numpy' has no attribute 'round'
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0003.flac: module 'numpy' has no attribute 'round'
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0004.flac: module 'numpy' has no attribute 'round'
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0005.flac: module 'numpy' has no attribute 'round'
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0006.flac: module 'numpy' has no attribute 'round'
Error processing file ./data\LibriSpeech\dev-clean\1272\128104\1272-128104-0007.flac

KeyboardInterrupt: 

In [None]:
pretrained_generator = pretrain_single_generator(
        num_epochs=20,
        z_dim=z_dim,
        lr_gen=lr_gen,
        lr_disc=lr_disc,
        batch_size=batch_size,
        seed=42,
        audio_length=audio_length,
        output_dir='waveform_pre',
        train_dataset=train_dataset
    )

In [None]:
train_gan_with_pretrained_generators(
        pretrained_generator,
        num_epochs=num_epochs,
        z_dim=z_dim,
        lr_gen=lr_gen,
        lr_disc=lr_disc,
        batch_size=batch_size,
        train_dataset=train_dataset,
        num_generators=num_generators,
        seed=42,
        audio_length=audio_length,
        output_dir=output_dir,
        checkpoint_dir='my_checkpoints',
        resume=False
    )