#  orthogonal loss 
This model work well

In [1]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import librosa
import soundfile as sf
from glob import glob
from phonemizer import phonemize
from torch.nn.utils import spectral_norm
import torch.nn.functional as F


# DATA prepare

In [2]:
# Set random seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
# Load a FLAC audio file and resample it to the target sample rate
def load_flac(file_path, target_sr=16000):
    audio, sr = librosa.load(file_path, sr=target_sr)
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = audio / max_val  # Normalize to [-1, 1]
    return audio

# Pad or trim audio to the target length
def pad_or_trim(audio, target_length=64000):
    if len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)))
    else:
        audio = audio[:target_length]
    return audio

In [4]:
# Retrieve the transcription for a given audio file from the corresponding .trans.txt file
def get_transcription(file_path):
    dir_path = os.path.dirname(file_path)
    base_name = os.path.basename(file_path)
    file_id = os.path.splitext(base_name)[0]

    # Locate the transcription file
    transcription_file = None
    for file in os.listdir(dir_path):
        if file.endswith(".trans.txt"):
            transcription_file = os.path.join(dir_path, file)
            break

    if not transcription_file:
        raise FileNotFoundError(f"No transcription file found in {dir_path}")

    # Read the transcription file and find the transcription for the current audio file
    with open(transcription_file, "r") as f:
        for line in f:
            parts = line.strip().split(" ", 1)
            if parts[0] == file_id:
                transcription = parts[1]
                return transcription

    raise ValueError(f"No transcription found for file {file_id}")

# Extract phonetic features from transcription
def get_phonetic_features(transcription, max_length=100):
    phonemes = phonemize(transcription, backend="espeak", language="en-us")
    phoneme_to_id = {char: idx for idx, char in enumerate(sorted(set(phonemes)))}
    phonetic_features = [phoneme_to_id[p] for p in phonemes]

    # Convert to tensor and pad/truncate
    phonetic_features = torch.tensor(phonetic_features, dtype=torch.float32)
    if len(phonetic_features) < max_length:
        phonetic_features = nn.functional.pad(phonetic_features, (0, max_length - len(phonetic_features)))
    else:
        phonetic_features = phonetic_features[:max_length]
    return phonetic_features


In [5]:
# Preprocess the dataset and create a TensorDataset
def preprocess_dataset(root_dir, target_sr=16000, target_length=64000, feature_length=100):
    flac_files = glob(os.path.join(root_dir, '**', '*.flac'), recursive=True)
    print(f"Found {len(flac_files)} .flac files in {root_dir}.")
    if len(flac_files) == 0:
        print("No .flac files found. Please check the root_dir path.")
    audio_dataset = []
    feature_dataset = []
    for file in flac_files:
        try:
            audio = load_flac(file, target_sr)
            audio = pad_or_trim(audio, target_length)
            transcription = get_transcription(file)
            phonetic_features = get_phonetic_features(transcription, max_length=feature_length)
            audio_dataset.append(audio)
            feature_dataset.append(phonetic_features)
        except Exception as e:
            print(f"Error processing file {file}: {e}")
    audio_dataset = torch.tensor(audio_dataset, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
    feature_dataset = torch.stack(feature_dataset)  # Stack tensors
    return TensorDataset(audio_dataset, feature_dataset)

In [6]:
# Save a waveform to an audio file
def save_waveform_to_audio(waveform, sample_rate, filename):
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.detach().cpu().numpy()
    waveform = np.squeeze(waveform)
    max_val = np.max(np.abs(waveform))
    if max_val > 0:
        waveform = waveform / max_val
    sf.write(filename, waveform, sample_rate)

# Verify waveform-to-audio conversion using preprocessed dataset
def verify_waveform_to_audio(root_dir, sample_rate=16000, target_length=64000, output_dir="verified_audio"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    dataset = preprocess_dataset(root_dir, sample_rate, target_length)
    num_samples_to_verify = min(5, len(dataset))
    for idx in range(num_samples_to_verify):
        waveform = dataset[idx][0]  # Access audio data
        filename = os.path.join(output_dir, f"example_waveform_{idx+1}.wav")
        save_waveform_to_audio(waveform, sample_rate, filename)
        print(f"Waveform saved to {filename}")
        # Plot the waveform
        plt.figure(figsize=(12, 4))
        plt.plot(waveform.numpy().squeeze())
        plt.title(f"Waveform {idx+1}")
        plt.xlabel("Sample Index")
        plt.ylabel("Amplitude")
        plt.show()

In [7]:
# Generate noise for the generator
def generate_noise(batch_size, z_dim, device):
    return torch.randn(batch_size, z_dim).to(device)

# Orthogonal loss function
def orthogonal_loss(feature1, feature2):
    inner_product = torch.sum(feature1 * feature2, dim=1)
    norm1 = torch.norm(feature1, dim=1)
    norm2 = torch.norm(feature2, dim=1)
    cosine_similarity = inner_product / (norm1 * norm2 + 1e-8)
    return torch.mean(cosine_similarity**2)  # Minimize the cosine similarity to make vectors orthogonal

# Compute gradient penalty for WGAN-GP
def compute_gradient_penalty(discriminator, real_samples, fake_samples, conditions, device):
    batch_size = real_samples.size(0)
    epsilon = torch.rand(batch_size, 1, 1, device=device)  # Shape: [batch_size, 1, 1]
    epsilon = epsilon.expand_as(real_samples)  # Expand to match `real_samples` shape

    # Interpolate between real and fake samples
    interpolates = (epsilon * real_samples + (1 - epsilon) * fake_samples).requires_grad_(True)

    # Pass interpolates, fake_samples, and conditions to the discriminator
    real_outputs, fake_outputs = discriminator(interpolates, fake_samples, conditions)

    # Compute gradient penalty for each output
    gradient_penalties = []
    for real_output in real_outputs:
        grad_outputs = torch.ones_like(real_output, device=device)
        gradients = torch.autograd.grad(
            outputs=real_output,
            inputs=interpolates,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]

        # Reshape gradients and compute L2 norm
        gradients = gradients.view(batch_size, -1)  # Flatten
        gradient_norm = gradients.norm(2, dim=1)  # L2 norm
        gradient_penalty = ((gradient_norm - 1) ** 2).mean()
        gradient_penalties.append(gradient_penalty)

    # Average all gradient penalties
    return sum(gradient_penalties) / len(gradient_penalties)




In [8]:
def visualize_and_save_generated_waveforms(generators, z_dim, features, num_waveforms, device, epoch, sample_rate=16000, output_dir='generated_audio'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for idx, gen in enumerate(generators):
        gen.eval()
        with torch.no_grad():
            noise = generate_noise(num_waveforms, z_dim, device)
            fake_waveforms = gen(features[:num_waveforms].to(device), noise).cpu()
            
            # Adjust num_waveforms to the actual size of fake_waveforms
            num_available_waveforms = fake_waveforms.size(0)
            if num_available_waveforms < num_waveforms:
                print(f"Warning: Requested {num_waveforms} waveforms, but generator produced {num_available_waveforms}")
                num_waveforms = num_available_waveforms

            for i in range(num_waveforms):
                waveform = fake_waveforms[i]
                filename = f'epoch{epoch+1}_gen{idx+1}_sample{i+1}.wav'
                filepath = os.path.join(output_dir, filename)
                save_waveform_to_audio(waveform, sample_rate, filepath)
                print(f"Saved {filepath}")


# model definition

In [9]:
class Generator(nn.Module):
    
    def __init__(self, in_channels=1, z_channels=128):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.z_channels = z_channels

        self.preprocess = nn.Conv1d(1, 768, kernel_size=3, padding=1)

        self.gblocks = nn.ModuleList([
            GBlock(768, 768, z_channels, 5),  # Upsample by 4
            GBlock(768, 768, z_channels, 4),  # Upsample by 4
            GBlock(768, 384, z_channels, 4),  # Upsample by 4
            GBlock(384, 384, z_channels, 4),  # Upsample by 2
            GBlock(384, 192, z_channels, 2),  # Upsample by 2
        ])
        self.postprocess = nn.Sequential(
            nn.Conv1d(192, 1, kernel_size=3, padding=1),
            nn.Tanh()
        )


    def forward(self, inputs, z):
        # print(f"Input shape: {inputs.shape}")  # Debug input shape
        inputs = self.preprocess(inputs)
        outputs = inputs
        for i, layer in enumerate(self.gblocks):
            outputs = layer(outputs, z)
            # print(f"After GBlock {i}: {outputs.shape}")  # Debug output shape
        outputs = self.postprocess(outputs)
        return outputs


In [10]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class GBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, z_channels, upsample_factor):
        super(GBlock, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.z_channels = z_channels
        self.upsample_factor = upsample_factor

        # GroupNorm for memory efficiency
        self.condition_norm1 = nn.GroupNorm(32, in_channels)
        self.first_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            UpsampleNet(in_channels, in_channels, upsample_factor),
            nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1)
        )
        self.condition_norm2 = nn.GroupNorm(32, hidden_channels)
        self.second_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=1, padding=1)
        )
        self.residual1 = nn.Sequential(
            UpsampleNet(in_channels, in_channels, upsample_factor),
            nn.Conv1d(in_channels, hidden_channels, kernel_size=1)
        )
        self.condition_norm3 = nn.GroupNorm(32, hidden_channels)
        self.third_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=2, padding=2)
        )
        self.condition_norm4 = nn.GroupNorm(32, hidden_channels)
        self.fourth_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=2, padding=2)
        )

    def forward(self, condition, z):
        def run_forward(inputs):
            outputs = self.condition_norm1(inputs)
            outputs = self.first_stack(outputs)
            outputs = self.condition_norm2(outputs)
            outputs = self.second_stack(outputs)
            residual_outputs = self.residual1(inputs) + outputs
            outputs = self.condition_norm3(residual_outputs)
            outputs = self.third_stack(outputs)
            outputs = self.condition_norm4(outputs)
            outputs = self.fourth_stack(outputs)
            outputs = outputs + residual_outputs
            return outputs

        # Gradient checkpointing to save memory
        outputs = checkpoint.checkpoint(run_forward, condition)
        return outputs


In [11]:
class UpsampleNet(nn.Module):
    def __init__(self, input_size, output_size, upsample_factor):
        super(UpsampleNet, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.upsample_factor = upsample_factor

        layer = nn.ConvTranspose1d(input_size, output_size, upsample_factor * 2,
                                   upsample_factor, padding=upsample_factor // 2)
        nn.init.orthogonal_(layer.weight)
        self.layer = spectral_norm(layer)

    def forward(self, inputs):
        outputs = self.layer(inputs)
        outputs = outputs[:, :, : inputs.size(-1) * self.upsample_factor]
        return outputs

In [12]:
class ConditionalBatchNorm1d(nn.Module):
    """Conditional Batch Normalization"""
    def __init__(self, num_features, z_channels=128):
        super().__init__()
        self.num_features = num_features
        self.z_channels = z_channels
        self.batch_norm = nn.BatchNorm1d(num_features, affine=False)
        self.layer = spectral_norm(nn.Linear(z_channels, num_features * 2))
        self.layer.weight.data.normal_(1, 0.02)
        self.layer.bias.data.zero_()

    def forward(self, inputs, noise):
        outputs = self.batch_norm(inputs)
        gamma, beta = self.layer(noise).chunk(2, 1)
        gamma = gamma.view(-1, self.num_features, 1)
        beta = beta.view(-1, self.num_features, 1)
        outputs = gamma * outputs + beta
        return outputs

In [13]:
# Discriminator architecture using Conv1d layers
class Multiple_Random_Window_Discriminators(nn.Module):
    def __init__(self, lc_channels, window_size=(2, 4, 8, 16, 30), upsample_factor=120):
        super(Multiple_Random_Window_Discriminators, self).__init__()
        self.lc_channels = lc_channels
        self.window_size = window_size
        self.upsample_factor = upsample_factor

        self.udiscriminators = nn.ModuleList([
            UnConditionalDBlocks(in_channels=1, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=2, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=4, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=8, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=15, factors=(2, 2), out_channels=(128, 256)),
        ])

        self.discriminators = nn.ModuleList([
            ConditionalDBlocks(in_channels=1, lc_channels=lc_channels,
                               factors=(5, 3, 2, 2, 2), out_channels=(128, 128, 256, 256)),
            ConditionalDBlocks(in_channels=2, lc_channels=lc_channels,
                               factors=(5, 3, 2, 2), out_channels=(128, 256, 256)),
            ConditionalDBlocks(in_channels=4, lc_channels=lc_channels,
                               factors=(5, 3, 2), out_channels=(128, 256)),
            ConditionalDBlocks(in_channels=8, lc_channels=lc_channels,
                               factors=(5, 3), out_channels=(256,)),
            ConditionalDBlocks(in_channels=15, lc_channels=lc_channels,
                               factors=(2, 2, 2), out_channels=(128, 256)),
        ])

    def forward(self, real_samples, fake_samples, conditions):
        real_outputs, fake_outputs = [], []

        # Unconditional discriminator
        for (size, layer) in zip(self.window_size, self.udiscriminators):
            size = size * self.upsample_factor
            if real_samples.size(-1) < size:
                raise ValueError(f"Window size {size} is too large for input with length {real_samples.size(-1)}")

            index = np.random.randint(0, real_samples.size(-1) - size + 1)
            #print(f"Unconditional index: {index}, size: {size}, slice: {index}:{index + size}")

            real_slice = real_samples[:, :, index: index + size]
            fake_slice = fake_samples[:, :, index: index + size]

            if real_slice.size(-1) == 0 or fake_slice.size(-1) == 0:
                raise ValueError(f"Generated slice has zero length: real_slice.shape={real_slice.shape}, fake_slice.shape={fake_slice.shape}")

            real_output = layer(real_slice)
            fake_output = layer(fake_slice)

            real_outputs.append(real_output)
            fake_outputs.append(fake_output)

        # Conditional discriminator
        for (size, layer) in zip(self.window_size, self.discriminators):
            lc_index = np.random.randint(0, conditions.size(-1) - size + 1)
            sample_index = lc_index * self.upsample_factor

            if real_samples.size(-1) < (lc_index + size) * self.upsample_factor:
                raise ValueError(f"Window size exceeds input size for conditional discriminator.")

            real_x = real_samples[:, :, sample_index: (lc_index + size) * self.upsample_factor]
            fake_x = fake_samples[:, :, sample_index: (lc_index + size) * self.upsample_factor]
            lc = conditions[:, :, lc_index: lc_index + size]

            if real_x.size(-1) == 0 or fake_x.size(-1) == 0:
                raise ValueError(f"Generated slice has zero length: real_x.shape={real_x.shape}, fake_x.shape={fake_x.shape}")

            real_output = layer(real_x, lc)
            fake_output = layer(fake_x, lc)

            real_outputs.append(real_output)
            fake_outputs.append(fake_output)

        return real_outputs, fake_outputs


In [14]:
class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample_factor):
        super(DBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample_factor = downsample_factor
        self.layers = nn.Sequential(
            nn.AvgPool1d(downsample_factor, stride=downsample_factor),
            nn.ReLU(),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2)
        )
        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool1d(downsample_factor, stride=downsample_factor)
        )

    def forward(self, inputs):
        outputs = self.layers(inputs) + self.residual(inputs)
        return outputs

In [15]:
class CondDBlock(nn.Module):
    def __init__(self, in_channels, lc_channels, upsample_factor):
        super(CondDBlock, self).__init__()
        self.lc_conv1d = nn.Conv1d(lc_channels, in_channels, kernel_size=1)
        self.start = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
        self.end = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
        self.residual = nn.Conv1d(in_channels, in_channels, kernel_size=1)

    def forward(self, inputs, conditions):
        # Ensure conditions are processed to match inputs
        conditions = self.lc_conv1d(conditions)
        # print(f"Conditions shape after lc_conv1d: {conditions.shape}")  # Debug
        outputs = self.start(inputs) + conditions
        outputs = self.end(outputs)
        residual_outputs = self.residual(inputs)
        return outputs + residual_outputs


In [16]:
class ConditionalDBlocks(nn.Module):
    def __init__(self, in_channels, lc_channels, factors=(2, 2, 2), out_channels=(128, 256)):
        super(ConditionalDBlocks, self).__init__()
        assert len(factors) == len(out_channels) + 1
        self.in_channels = in_channels
        self.lc_channels = lc_channels
        self.factors = factors
        self.out_channels = out_channels
        self.layers = nn.ModuleList()
        self.layers.append(DBlock(in_channels, 64, 1))
        in_channels = 64
        for i, channel in enumerate(out_channels):
            self.layers.append(DBlock(in_channels, channel, factors[i]))
            in_channels = channel
        self.cond_layer = CondDBlock(in_channels, lc_channels, factors[-1])
        
        # New adjustment layer to match post_process input
        self.adjust_channels = nn.Conv1d(in_channels, 512, kernel_size=1)

        self.post_process = nn.ModuleList([
            DBlock(512, 512, 1),
            DBlock(512, 512, 1)
        ])

    def forward(self, inputs, conditions):
        batch_size = inputs.size(0)
        # print(f"Initial inputs shape: {inputs.shape}, conditions shape: {conditions.shape}")  # Debug

        # Reshape inputs
        outputs = inputs.view(batch_size, self.in_channels, -1)
        # print(f"After view: {outputs.shape}")  # Debug

        # Process through layers
        for i, layer in enumerate(self.layers):
            outputs = layer(outputs)
            # print(f"After layer {i}: {outputs.shape}")  # Debug

        # Adjust conditions to match temporal dimension
        conditions = F.adaptive_avg_pool1d(conditions, output_size=1)  # Reduce temporal dimension to 1
        # print(f"Conditions shape after pooling: {conditions.shape}")  # Debug

        conditions = conditions.expand(-1, self.lc_channels, outputs.size(-1))  # Expand to match outputs
        # print(f"Conditions shape after expand: {conditions.shape}")  # Debug

        # Apply cond_layer
        outputs = self.cond_layer(outputs, conditions)
        #print(f"After cond_layer: {outputs.shape}")  # Debug

        # Adjust channels
        outputs = self.adjust_channels(outputs)
        #print(f"After adjust_channels: {outputs.shape}")  # Debug

        # Post-process
        for i, layer in enumerate(self.post_process):
            outputs = layer(outputs)
            #print(f"After post_process layer {i}: {outputs.shape}")  # Debug

        return outputs


In [17]:
class UnConditionalDBlocks(nn.Module):
    def __init__(self, in_channels, factors=(5, 3), out_channels=(128, 256)):
        super(UnConditionalDBlocks, self).__init__()
        self.in_channels = in_channels
        self.factors = factors
        self.out_channels = out_channels
        self.layers = nn.ModuleList()
        self.layers.append(DBlock(in_channels, 64, 1))
        in_channels = 64
        for (i, factor) in enumerate(factors):
            self.layers.append(DBlock(in_channels, out_channels[i], factor))
            in_channels = out_channels[i]
        self.layers.append(DBlock(in_channels, in_channels, 1))
        self.layers.append(DBlock(in_channels, in_channels, 1))

    def forward(self, inputs):
        batch_size = inputs.size()[0]
        outputs = inputs.view(batch_size, self.in_channels, -1)
        for layer in self.layers:
            outputs = layer(outputs)
        return outputs

In [18]:
import os
import torch
import torch.optim as optim
from torch.amp import GradScaler, autocast
from torch.utils.data import DataLoader

class Encoder(nn.Module):
    def __init__(self, audio_length):
        super(Encoder, self).__init__()
        self.audio_length = audio_length
        self.model = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=5, stride=2, padding=2),  # Downsample
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2),  # Downsample
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),  # Downsample
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool1d(1)  # Reduce to (batch_size, 128, 1)
        )

    def forward(self, x):
        return self.model(x).squeeze(-1)  # Shape: (batch_size, 128)


def orthogonal_loss(features1, features2):
    """Compute orthogonal loss between two feature sets."""
    return torch.mean(torch.sum(features1 * features2, dim=1)**2)

# train loop

In [None]:
def train_gan_with_pretrained_generators(
    pretrained_generator, num_epochs, z_dim, lr_gen, lr_disc, batch_size, train_dataset,
    num_generators, seed, audio_length, output_dir, lambda_gp=10, lambda_ortho=0.1, num_critic=5,
    checkpoint_path="checkpoint.pth", resume=False
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize multiple generators with pretrained weights
    generators = []
    for _ in range(num_generators):
        gen = Generator(z_dim, audio_length).to(device)
        gen.load_state_dict(pretrained_generator.state_dict())
        generators.append(gen)

    # Initialize Discriminator and Encoder
    discriminator = Multiple_Random_Window_Discriminators(lc_channels=1).to(device)
    encoder = Encoder(audio_length).to(device)

    # Optimizers
    optimizer_gens = [optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.5, 0.9)) for gen in generators]
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))
    optimizer_encoder = optim.Adam(encoder.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    # Mixed precision scalers
    scaler_gens = [GradScaler() for _ in range(num_generators)]
    scaler_disc = GradScaler()
    scaler_encoder = GradScaler()

    # Load checkpoint if resuming
    start_epoch = 0
    if resume and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        start_epoch = checkpoint['epoch']
        for idx, gen in enumerate(generators):
            gen.load_state_dict(checkpoint['generator_state_dicts'][idx])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        for idx, opt_gen in enumerate(optimizer_gens):
            opt_gen.load_state_dict(checkpoint['optimizer_gen_state_dicts'][idx])
        optimizer_disc.load_state_dict(checkpoint['optimizer_disc_state_dict'])
        optimizer_encoder.load_state_dict(checkpoint['optimizer_encoder_state_dict'])
        print(f"Resumed training from epoch {start_epoch}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for epoch in range(start_epoch, num_epochs):
        torch.cuda.empty_cache()
        for batch_idx, (real_audio, features) in enumerate(train_loader):
            real_audio = real_audio.to(device)
            features = features.unsqueeze(1).to(device)
            batch_size = real_audio.size(0)

            # Train Discriminator
            for _ in range(num_critic):
                optimizer_disc.zero_grad()
                noises = [generate_noise(batch_size, z_dim, device) for _ in range(num_generators)]

                with autocast(device_type='cuda'):
                    fakes = [gen(features, noises[idx]).detach() for idx, gen in enumerate(generators)]
                    real_outputs, fake_outputs = discriminator(real_audio, torch.stack(fakes), features)
                    loss_disc = sum(torch.mean(fake) - torch.mean(real) for real, fake in zip(real_outputs, fake_outputs))
                    gradient_penalty = compute_gradient_penalty(discriminator, real_audio, fakes[0], features, device)
                    loss_disc += lambda_gp * gradient_penalty

                scaler_disc.scale(loss_disc).backward()
                scaler_disc.step(optimizer_disc)
                scaler_disc.update()

            # Train Generators and Encoder
            for idx, gen in enumerate(generators):
                optimizer_gens[idx].zero_grad()
                optimizer_encoder.zero_grad()

                noise = generate_noise(batch_size, z_dim, device)
                with autocast(device_type='cuda'):
                    fake = gen(features, noise)
                    fake_outputs = discriminator(fake, fake, features)[1]
                    loss_gen = -torch.mean(fake_outputs)

                    # Orthogonal loss
                    gen_feature = encoder(fake)
                    ortho_loss = 0
                    for other_idx, other_gen in enumerate(generators):
                        if idx != other_idx:
                            other_noise = generate_noise(batch_size, z_dim, device)
                            other_fake = other_gen(features, other_noise)
                            other_feature = encoder(other_fake)
                            ortho_loss += orthogonal_loss(gen_feature, other_feature)
                    ortho_loss /= (num_generators - 1)
                    total_loss_gen = loss_gen + lambda_ortho * ortho_loss

                scaler_gens[idx].scale(total_loss_gen).backward()
                scaler_gens[idx].step(optimizer_gens[idx])
                scaler_gens[idx].update()
                scaler_encoder.step(optimizer_encoder)
                scaler_encoder.update()

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'generator_state_dicts': [gen.state_dict() for gen in generators],
            'discriminator_state_dict': discriminator.state_dict(),
            'encoder_state_dict': encoder.state_dict(),
            'optimizer_gen_state_dicts': [opt.state_dict() for opt in optimizer_gens],
            'optimizer_disc_state_dict': optimizer_disc.state_dict(),
            'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
        }
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

    print("Training complete.")
    return generators

In [20]:
def pad_tensors_to_match(tensor_list):
    # Find the maximum size in each dimension
    max_shape = list(max(tensor.size(dim) for tensor in tensor_list) for dim in range(tensor_list[0].dim()))

    padded_tensors = []
    for tensor in tensor_list:
        # Calculate the padding needed for each dimension
        pad = []
        for dim, max_dim in enumerate(max_shape[::-1]):
            pad.extend([0, max_dim - tensor.size(dim)])
        pad = pad[::-1]  # Reverse padding list
        # Apply padding
        padded_tensor = torch.nn.functional.pad(tensor, pad)
        padded_tensors.append(padded_tensor)
    
    return torch.stack(padded_tensors, dim=0)


In [32]:
from torch.amp import GradScaler, autocast
def pretrain_single_generator(
    num_epochs, z_dim, lr_gen, lr_disc, batch_size, seed,
    audio_length, output_dir, train_dataset, checkpoint_path="checkpoint.pth",
    resume=False
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define models and move to device
    generator = Generator(in_channels=train_dataset[0][1].shape[0], z_channels=z_dim).to(device)
    discriminator = Multiple_Random_Window_Discriminators(lc_channels=train_dataset[0][1].shape[0]).to(device)

    optimizer_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.9))
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    # Mixed precision scaler
    scaler_gen = GradScaler()
    scaler_disc = GradScaler()

    start_epoch = 0
    if resume and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
        optimizer_disc.load_state_dict(checkpoint['optimizer_disc_state_dict'])
        scaler_gen.load_state_dict(checkpoint['scaler_gen'])
        scaler_disc.load_state_dict(checkpoint['scaler_disc'])
        start_epoch = checkpoint['epoch']
        print(f"Resuming training from epoch {start_epoch}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for epoch in range(start_epoch, num_epochs):
        torch.cuda.empty_cache()
        generator.train()
        discriminator.train()

        for batch_idx, (real_audio, features) in enumerate(train_loader):
            real_audio = real_audio.to(device)
            features = features.unsqueeze(1).to(device)
            batch_size = real_audio.size(0)

            # Train Discriminator
            for _ in range(5):
                optimizer_disc.zero_grad()
                noise = generate_noise(batch_size, z_dim, device)

                with autocast(device_type='cuda'):
                    # Detach fake_audio to avoid grad on generator when training disc
                    fake_audio = generator(features, noise).detach()
                    real_outputs, fake_outputs = discriminator(real_audio, fake_audio, features)
                    loss_disc = sum(torch.mean(fake) - torch.mean(real) for real, fake in zip(real_outputs, fake_outputs))
                    gradient_penalty = compute_gradient_penalty(discriminator, real_audio, fake_audio, features, device)
                    loss_disc += 10 * gradient_penalty

                scaler_disc.scale(loss_disc).backward()
                scaler_disc.step(optimizer_disc)
                scaler_disc.update()

            # Train Generator
            optimizer_gen.zero_grad()
            noise = generate_noise(batch_size, z_dim, device)

            with autocast(device_type='cuda'):
                fake_audio = generator(features, noise)
                fake_outputs = discriminator(fake_audio, fake_audio, features)[1]
                # Adjust if needed, e.g. if fake_outputs is a list of tensors
                # Ensure pad_tensors_to_match is defined and imports are correct
                fake_outputs = pad_tensors_to_match(fake_outputs)
                loss_gen = -torch.mean(fake_outputs)

            scaler_gen.scale(loss_gen).backward()
            scaler_gen.step(optimizer_gen)
            scaler_gen.update()

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

        # Save model output
        visualize_and_save_generated_waveforms(
            [generator], z_dim, features, num_waveforms=5, device=device, epoch=epoch, sample_rate=16000, output_dir=output_dir
        )

        # Save checkpoint
        checkpoint = {
            'epoch': epoch+1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_gen_state_dict': optimizer_gen.state_dict(),
            'optimizer_disc_state_dict': optimizer_disc.state_dict(),
            'scaler_gen': scaler_gen.state_dict(),
            'scaler_disc': scaler_disc.state_dict(),
            'z_dim': z_dim,
            'lr_gen': lr_gen,
            'lr_disc': lr_disc,
            'batch_size': batch_size,
            'seed': seed,
            'audio_length': audio_length,
        }
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

        torch.save(generator.state_dict(), os.path.join(output_dir, f"pretrained_generator_epoch{epoch+1}.pth"))

    print("Pretraining complete.")
    return generator


In [22]:
# def pretrain_single_generator(num_epochs, z_dim, lr_gen, lr_disc, batch_size, seed, audio_length, output_dir, train_dataset):
#     set_seed(seed)
#     # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     # # Define the single generator and discriminator
#     # generator = Generator(in_channels=1, z_channels=z_dim).to(device)

#     # discriminator = Multiple_Random_Window_Discriminators(lc_channels=1)

# # In CondDBlock (and related blocks), ensure lc_conv1d is defined as:
    


    

#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
#     # Move models to device
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     generator = Generator(in_channels=train_dataset[0][1].shape[0], z_channels=z_dim).to(device)
#     discriminator = Multiple_Random_Window_Discriminators(lc_channels=train_dataset[0][1].shape[0]).to(device)
#     optimizer_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.9))
#     optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))

# # Training loop

#     for epoch in range(num_epochs):
#         torch.cuda.empty_cache()
#         for batch_idx, (real_audio, features) in enumerate(train_loader):
#             real_audio = real_audio.to(device)
#             features = features.unsqueeze(1).to(device)
#             print(f"real_audio.shape: {real_audio.shape}")
#             print(f"features.shape: {features.shape}")
            


#             batch_size = real_audio.size(0)

#             # Train Discriminator
#             for _ in range(5):
#                 optimizer_disc.zero_grad()
#                 noise = generate_noise(batch_size, z_dim, device)
#                 fake_audio = generator(features, noise).detach()
#                 real_outputs, fake_outputs = discriminator(real_audio, fake_audio, features)
#                 loss_disc = sum([torch.mean(fake) - torch.mean(real) for real, fake in zip(real_outputs, fake_outputs)])
#                 gradient_penalty = compute_gradient_penalty(discriminator, real_audio, fake_audio, features, device)
#                 loss_disc += 10 * gradient_penalty
#                 loss_disc.backward()
#                 optimizer_disc.step()

#             # Train Generator
#             optimizer_gen.zero_grad()
#             noise = generate_noise(batch_size, z_dim, device)
#             fake_audio = generator(features, noise)
#             print(f"fake_audio.shape: {fake_audio.shape}")
#             fake_outputs = discriminator(fake_audio, fake_audio, features)[1]
#             fake_outputs = pad_tensors_to_match(fake_outputs)  # Ensure uniform tensor size
#             loss_gen = -torch.mean(fake_outputs)
#             loss_gen.backward()
#             optimizer_gen.step()

#         print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

#         visualize_and_save_generated_waveforms(
#             [generator], z_dim, features, num_waveforms=5, device=device, epoch=epoch, sample_rate=16000, output_dir=output_dir
#         )
#         torch.save(generator.state_dict(), os.path.join(output_dir, f"pretrained_generator_epoch{epoch+1}.pth"))

#     print("Pretraining complete.")
#     return generator


In [23]:
set_seed(42)
audio_length = 64000
z_dim = 128
lr_gen = 0.0002
lr_disc = 0.0002
batch_size = 4
num_epochs = 50
root_dir = "./data"
sample_rate = 16000
num_generators = 5
output_dir = 'generated_audio'

In [24]:
train_dataset = preprocess_dataset(root_dir, target_sr=sample_rate, target_length=audio_length)

Found 2703 .flac files in ./data.


  audio_dataset = torch.tensor(audio_dataset, dtype=torch.float32).unsqueeze(1)  # Add channel dimension


In [28]:
import torch
import gc

def clear_all_memory():
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # Clear the cache
        torch.cuda.synchronize()  # Ensure all CUDA operations are finished
        print("Cleared GPU memory cache.")

    # Delete all tensors
    gc.collect()  # Run garbage collection to free up Python memory
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
                del obj
        except Exception as e:
            pass  # Skip over objects that can't be deleted

    # Re-run garbage collection and clear the cache again
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("Final GPU cleanup complete.")

# Call the function to clear all memory
clear_all_memory()


Cleared GPU memory cache.


  return isinstance(obj, torch.Tensor)
  if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):


Final GPU cleanup complete.


In [33]:
pretrained_generator = pretrain_single_generator(
        num_epochs=20,
        z_dim=z_dim,
        lr_gen=lr_gen,
        lr_disc=lr_disc,
        batch_size=batch_size,
        seed=42,
        audio_length=audio_length,
        output_dir='waveform_pre',
        train_dataset=train_dataset,
        checkpoint_path="gan_single_check.pth"
    )

  return fn(*args, **kwargs)


Epoch [1/20] Loss D: 0.0220, Loss G: 0.0042




Saved waveform_pre/epoch1_gen1_sample1.wav
Saved waveform_pre/epoch1_gen1_sample2.wav
Saved waveform_pre/epoch1_gen1_sample3.wav
Checkpoint saved at gan_single_check.pth
Epoch [2/20] Loss D: 0.0197, Loss G: -0.0023
Saved waveform_pre/epoch2_gen1_sample1.wav
Saved waveform_pre/epoch2_gen1_sample2.wav
Saved waveform_pre/epoch2_gen1_sample3.wav
Checkpoint saved at gan_single_check.pth
Epoch [3/20] Loss D: 0.0056, Loss G: -0.0001
Saved waveform_pre/epoch3_gen1_sample1.wav
Saved waveform_pre/epoch3_gen1_sample2.wav
Saved waveform_pre/epoch3_gen1_sample3.wav
Checkpoint saved at gan_single_check.pth
Epoch [4/20] Loss D: 0.0110, Loss G: -0.0026


OSError: [Errno 5] Input/output error: 'waveform_pre'

In [None]:
train_gan_with_pretrained_generators(
        pretrained_generator,
        num_epochs=num_epochs,
        z_dim=z_dim,
        lr_gen=lr_gen,
        lr_disc=lr_disc,
        batch_size=batch_size,
        train_dataset=train_dataset,
        num_generators=num_generators,
        seed=42,
        audio_length=audio_length,
        output_dir=output_dir,
        checkpoint_dir='my_checkpoints',
        resume=False
    )