## Import Statements

```markdown
# Import necessary libraries and modules for the project
#
# This section imports various Python libraries and modules required for:
# - File and directory operations (os)
# - Random number generation (random, numpy)
# - Mathematical operations (math, numpy)
# - PyTorch for deep learning (torch and its submodules)
# - Audio processing (soundfile, librosa)
# - File globbing (glob)
# - Text-to-phoneme conversion (phonemizer)
# - Data loading and processing (torch.utils.data)
# - Neural network normalization (torch.nn.utils)
# - Gradient checkpointing (torch.utils.checkpoint)
# - Plotting (matplotlib)
# - Garbage collection (gc)
# - Logging (logging)
# - Automatic mixed precision training (torch.amp)


In [1]:


import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import soundfile as sf
import librosa
from glob import glob
from phonemizer import phonemize
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils import spectral_norm
import torch.utils.checkpoint as checkpoint
import matplotlib.pyplot as plt
import gc
import logging
from torch.amp import GradScaler, autocast

###################################
# Utility and Setup Functions
###################################

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def load_flac(file_path, target_sr=16000):
    audio, sr = librosa.load(file_path, sr=target_sr)
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = audio / max_val  # Normalize to [-1, 1]
    return audio

def pad_or_trim(audio, target_length=64000):
    if len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)))
    else:
        audio = audio[:target_length]
    return audio

def get_transcription(file_path):
    dir_path = os.path.dirname(file_path)
    base_name = os.path.basename(file_path)
    file_id = os.path.splitext(base_name)[0]
    transcription_file = None
    for file in os.listdir(dir_path):
        if file.endswith(".trans.txt"):
            transcription_file = os.path.join(dir_path, file)
            break
    if not transcription_file:
        raise FileNotFoundError(f"No transcription file found in {dir_path}")

    with open(transcription_file, "r") as f:
        for line in f:
            parts = line.strip().split(" ", 1)
            if parts[0] == file_id:
                transcription = parts[1]
                return transcription
    raise ValueError(f"No transcription found for file {file_id}")

def save_waveform_to_audio(waveform, sample_rate, filename):
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.detach().cpu().numpy()
    waveform = np.squeeze(waveform)
    max_val = np.max(np.abs(waveform))
    if max_val > 0:
        waveform = waveform / max_val
    sf.write(filename, waveform, sample_rate)

def generate_noise(batch_size, z_dim, device):
    return torch.randn(batch_size, z_dim, device=device)

def orthogonal_loss(feature1, feature2):
    inner_product = torch.sum(feature1 * feature2, dim=1)
    norm1 = torch.norm(feature1, dim=1)
    norm2 = torch.norm(feature2, dim=1)
    cosine_similarity = inner_product / (norm1 * norm2 + 1e-8)
    return torch.mean(cosine_similarity**2)  # Minimize similarity to enforce orthogonality

def compute_gradient_penalty(discriminator, real_samples, fake_samples, conditions, device):
    batch_size = real_samples.size(0)
    epsilon = torch.rand(batch_size, 1, 1, device=device)
    epsilon = epsilon.expand_as(real_samples)
    interpolates = (epsilon * real_samples + (1 - epsilon) * fake_samples).requires_grad_(True)

    real_outputs, fake_outputs = discriminator(interpolates, fake_samples, conditions)
    # Assuming just one set of outputs for simplicity:
    real_output = real_outputs[0]

    grad_outputs = torch.ones_like(real_output, device=device)
    gradients = torch.autograd.grad(
        outputs=real_output,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,    # Make sure this is True
        retain_graph=True      # Consider adding this if needed
    )[0]

    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    return gradient_penalty

def visualize_and_save_generated_waveforms(generators, z_dim, features_emb, num_waveforms, device, epoch, sample_rate=16000, output_dir='generated_audio'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for idx, gen in enumerate(generators):
        gen.eval()
        with torch.no_grad():
            noise = generate_noise(num_waveforms, z_dim, device)
            fake_waveforms = gen(features_emb[:num_waveforms], noise).cpu()
            num_available_waveforms = fake_waveforms.size(0)
            if num_available_waveforms < num_waveforms:
                logger.info(f"Warning: Requested {num_waveforms} waveforms, got {num_available_waveforms}")

            for i in range(num_available_waveforms):
                waveform = fake_waveforms[i]
                filename = f'epoch{epoch+1}_gen{idx+1}_sample{i+1}.wav'
                filepath = os.path.join(output_dir, filename)
                save_waveform_to_audio(waveform, sample_rate, filepath)
                logger.info(f"Saved {filepath}")

def pad_tensors_to_match(tensor_list):
    # Pad a list of tensors along the last dimension to match their sizes
    max_len = max(t.size(-1) for t in tensor_list)
    max_channels = max(t.size(1) for t in tensor_list)
    padded_tensors = []
    for t in tensor_list:
        diff_len = max_len - t.size(-1)
        diff_channels = max_channels - t.size(1)
        if diff_len > 0 or diff_channels > 0:
            t = F.pad(t, (0, diff_len, 0, diff_channels))
        padded_tensors.append(t)
    return torch.stack(padded_tensors, dim=0)

def clear_all_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        logger.info("Cleared GPU memory cache.")
    gc.collect()
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
                del obj
        except:
            pass
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        logger.info("Final GPU cleanup complete.")


###################################
# Custom Dataset
###################################

In [3]:
class MyAudioDataset(Dataset):
    def __init__(self, root_dir, target_sr=16000, target_length=64000, feature_length=100):
        self.root_dir = root_dir
        self.target_sr = target_sr
        self.target_length = target_length
        self.feature_length = feature_length
        self.flac_files = glob(os.path.join(root_dir, '**', '*.flac'), recursive=True)
        logger.info(f"Found {len(self.flac_files)} .flac files in {root_dir}.")
        if len(self.flac_files) == 0:
            raise ValueError("No .flac files found.")
        
        # Build a global phoneme-to-id dictionary by scanning all files
        all_phonemes = set()
        for f in self.flac_files:
            try:
                transcription = get_transcription(f)
                phonemes = phonemize(transcription, backend="espeak", language="en-us")
                all_phonemes.update(list(phonemes))
            except Exception as e:
                logger.info(f"Error processing file {f}: {e}")
        
        self.phoneme_to_id = {p: i for i, p in enumerate(sorted(all_phonemes))}
        logger.info(f"Phoneme vocabulary size: {len(self.phoneme_to_id)}")

    def __len__(self):
        return len(self.flac_files)

    def __getitem__(self, idx):
        file = self.flac_files[idx]
        try:
            audio = load_flac(file, self.target_sr)
            audio = pad_or_trim(audio, self.target_length)
            transcription = get_transcription(file)
            if not transcription or not transcription.strip():
                raise ValueError(f"Empty or invalid transcription for file: {file}")

            phonemes = phonemize(transcription, backend="espeak", language="en-us")
            phonetic_features = [self.phoneme_to_id[p] for p in phonemes]

            # Convert to tensor and pad/truncate
            phonetic_features = torch.tensor(phonetic_features, dtype=torch.long)
            if len(phonetic_features) < self.feature_length:
                phonetic_features = F.pad(phonetic_features, (0, self.feature_length - len(phonetic_features)))
            else:
                phonetic_features = phonetic_features[:self.feature_length]

            audio = torch.tensor(audio, dtype=torch.float32).unsqueeze(0)  # [B,1,T]
            # features is [feature_length], long tensor of phoneme IDs
            return audio, phonetic_features
        except Exception as e:
            # If any error occurs, return a dummy sample or raise
            logger.info(f"Error reading {file}: {e}")
            # Return dummy data (should rarely happen)
            dummy_audio = torch.zeros((1, self.target_length), dtype=torch.float32)
            dummy_features = torch.zeros(self.feature_length, dtype=torch.long)
            return dummy_audio, dummy_features

###################################
# Models
###################################

In [4]:
class UpsampleNet(nn.Module):
    def __init__(self, input_size, output_size, upsample_factor):
        super(UpsampleNet, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.upsample_factor = upsample_factor  # Add this line

        layer = nn.ConvTranspose1d(input_size, output_size, upsample_factor * 2,
                                   upsample_factor, padding=upsample_factor // 2)
        nn.init.orthogonal_(layer.weight)
        self.layer = spectral_norm(layer)

    def forward(self, inputs):
        outputs = self.layer(inputs)
        outputs = outputs[:, :, : inputs.size(-1) * self.upsample_factor]
        return outputs


class GBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, z_channels, upsample_factor):
        super(GBlock, self).__init__()
        self.condition_norm1 = nn.GroupNorm(32, in_channels)
        self.first_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            UpsampleNet(in_channels, in_channels, upsample_factor),
            nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1)
        )
        self.condition_norm2 = nn.GroupNorm(32, hidden_channels)
        self.second_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=1, padding=1)
        )
        self.residual1 = nn.Sequential(
            UpsampleNet(in_channels, in_channels, upsample_factor),
            nn.Conv1d(in_channels, hidden_channels, kernel_size=1)
        )
        self.condition_norm3 = nn.GroupNorm(32, hidden_channels)
        self.third_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=2, padding=2)
        )
        self.condition_norm4 = nn.GroupNorm(32, hidden_channels)
        self.fourth_stack = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=2, padding=2)
        )

    def forward(self, condition, z):
        def run_forward(inputs):
            outputs = self.condition_norm1(inputs)
            outputs = self.first_stack(outputs)
            outputs = self.condition_norm2(outputs)
            outputs = self.second_stack(outputs)
            residual_outputs = self.residual1(inputs) + outputs
            outputs = self.condition_norm3(residual_outputs)
            outputs = self.third_stack(outputs)
            outputs = self.condition_norm4(outputs)
            outputs = self.fourth_stack(outputs)
            outputs = outputs + residual_outputs
            return outputs
        outputs = checkpoint.checkpoint(run_forward, condition)
        return outputs

class Generator(nn.Module):
    def __init__(self, vocab_size, embedding_dim=64, z_channels=128):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        nn.init.normal_(self.embedding.weight, 0.0, 0.1)

        self.preprocess = nn.Conv1d(embedding_dim, 768, kernel_size=3, padding=1)
        self.gblocks = nn.ModuleList([
            GBlock(768, 768, z_channels, 5),
            GBlock(768, 768, z_channels, 4),
            GBlock(768, 384, z_channels, 4),
            GBlock(384, 384, z_channels, 4),
            GBlock(384, 192, z_channels, 2),
        ])
        self.postprocess = nn.Sequential(
            nn.Conv1d(192, 1, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, conditions_emb, z):
        # conditions_emb: [B, embedding_dim, T]
        outputs = self.preprocess(conditions_emb)
        for layer in self.gblocks:
            outputs = layer(outputs, z)
        outputs = self.postprocess(outputs)
        return outputs

class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample_factor):
        super(DBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.AvgPool1d(downsample_factor, stride=downsample_factor),
            nn.ReLU(),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2)
        )
        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool1d(downsample_factor, stride=downsample_factor)
        )

    def forward(self, inputs):
        outputs = self.layers(inputs) + self.residual(inputs)
        return outputs

class CondDBlock(nn.Module):
    def __init__(self, in_channels, lc_channels, upsample_factor):
        super(CondDBlock, self).__init__()
        self.lc_conv1d = nn.Conv1d(lc_channels, in_channels, kernel_size=1)
        self.start = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
        self.end = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
        self.residual = nn.Conv1d(in_channels, in_channels, kernel_size=1)

    def forward(self, inputs, conditions):
        conditions = self.lc_conv1d(conditions)
        outputs = self.start(inputs) + conditions
        outputs = self.end(outputs)
        residual_outputs = self.residual(inputs)
        return outputs + residual_outputs

class ConditionalDBlocks(nn.Module):
    def __init__(self, in_channels, lc_channels, factors=(2,2,2), out_channels=(128,256)):
        super(ConditionalDBlocks, self).__init__()
        self.in_channels = in_channels
        self.lc_channels = lc_channels
        self.layers = nn.ModuleList()
        self.layers.append(DBlock(in_channels, 64, 1))
        in_channels = 64
        for i, channel in enumerate(out_channels):
            self.layers.append(DBlock(in_channels, channel, factors[i]))
            in_channels = channel
        self.cond_layer = CondDBlock(in_channels, lc_channels, factors[-1])
        self.adjust_channels = nn.Conv1d(in_channels, 512, kernel_size=1)
        self.post_process = nn.ModuleList([DBlock(512, 512, 1), DBlock(512, 512, 1)])

    def forward(self, inputs, conditions):
        batch_size = inputs.size(0)
        outputs = inputs.view(batch_size, self.in_channels, -1)
        for layer in self.layers:
            outputs = layer(outputs)
        # conditions: [B, lc_channels, T'] -> pool to 1 then expand
        conditions_pooled = F.adaptive_avg_pool1d(conditions, 1)
        conditions_pooled = conditions_pooled.expand(-1, self.lc_channels, outputs.size(-1))
        outputs = self.cond_layer(outputs, conditions_pooled)
        outputs = self.adjust_channels(outputs)
        for layer in self.post_process:
            outputs = layer(outputs)
        return outputs

class UnConditionalDBlocks(nn.Module):
    def __init__(self, in_channels, factors=(5, 3), out_channels=(128, 256)):
        super(UnConditionalDBlocks, self).__init__()
        self.in_channels = in_channels
        self.layers = nn.ModuleList()
        self.layers.append(DBlock(in_channels, 64, 1))
        in_channels = 64
        for (i, factor) in enumerate(factors):
            self.layers.append(DBlock(in_channels, out_channels[i], factor))
            in_channels = out_channels[i]
        self.layers.append(DBlock(in_channels, in_channels, 1))
        self.layers.append(DBlock(in_channels, in_channels, 1))

    def forward(self, inputs):
        batch_size = inputs.size(0)
        outputs = inputs.view(batch_size, self.in_channels, -1)
        for layer in self.layers:
            outputs = layer(outputs)
        return outputs

class Multiple_Random_Window_Discriminators(nn.Module):
    def __init__(self, lc_channels, window_size=(2,4,8,16,30), upsample_factor=120):
        super(Multiple_Random_Window_Discriminators, self).__init__()
        self.lc_channels = lc_channels
        self.window_size = window_size
        self.upsample_factor = upsample_factor

        self.udiscriminators = nn.ModuleList([
            UnConditionalDBlocks(in_channels=1, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=2, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=4, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=8, factors=(5, 3), out_channels=(128, 256)),
            UnConditionalDBlocks(in_channels=15, factors=(2, 2), out_channels=(128, 256)),
        ])

        self.discriminators = nn.ModuleList([
            ConditionalDBlocks(in_channels=1, lc_channels=lc_channels,
                               factors=(5, 3, 2, 2, 2), out_channels=(128, 128, 256, 256)),
            ConditionalDBlocks(in_channels=2, lc_channels=lc_channels,
                               factors=(5, 3, 2, 2), out_channels=(128, 256, 256)),
            ConditionalDBlocks(in_channels=4, lc_channels=lc_channels,
                               factors=(5, 3, 2), out_channels=(128, 256)),
            ConditionalDBlocks(in_channels=8, lc_channels=lc_channels,
                               factors=(5, 3), out_channels=(256,)),
            ConditionalDBlocks(in_channels=15, lc_channels=lc_channels,
                               factors=(2, 2, 2), out_channels=(128, 256)),
        ])

    def forward(self, real_samples, fake_samples, conditions):
        real_outputs, fake_outputs = [], []
        # Unconditional
        for (size, layer) in zip(self.window_size, self.udiscriminators):
            size = size * self.upsample_factor
            index = np.random.randint(0, real_samples.size(-1) - size + 1)
            real_slice = real_samples[:, :, index: index + size]
            fake_slice = fake_samples[:, :, index: index + size]
            real_output = layer(real_slice)
            fake_output = layer(fake_slice)
            real_outputs.append(real_output)
            fake_outputs.append(fake_output)

        # Conditional
        for (size, layer) in zip(self.window_size, self.discriminators):
            lc_index = np.random.randint(0, conditions.size(-1) - size + 1)
            sample_index = lc_index * self.upsample_factor
            real_x = real_samples[:, :, sample_index: (lc_index + size)*self.upsample_factor]
            fake_x = fake_samples[:, :, sample_index: (lc_index + size)*self.upsample_factor]
            lc = conditions[:, :, lc_index: lc_index + size]
            real_output = layer(real_x, lc)
            fake_output = layer(fake_x, lc)
            real_outputs.append(real_output)
            fake_outputs.append(fake_output)

        return real_outputs, fake_outputs

class Encoder(nn.Module):
    def __init__(self, audio_length):
        super(Encoder, self).__init__()
        self.audio_length = audio_length
        self.model = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool1d(1)
        )

    def forward(self, x):
        return self.model(x).squeeze(-1)



###################################
# Training Functions
###################################

In [5]:

def pretrain_single_generator(
    num_epochs, z_dim, lr_gen, lr_disc, batch_size, seed,
    audio_length, output_dir, train_dataset, checkpoint_path="checkpoint.pth",
    resume=False, vocab_size=100, embedding_dim=64
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    phoneme_embedding = nn.Embedding(vocab_size, embedding_dim).to(device)
    nn.init.normal_(phoneme_embedding.weight, 0.0, 0.1)

    generator = Generator(vocab_size=vocab_size, embedding_dim=embedding_dim, z_channels=z_dim).to(device)
    discriminator = Multiple_Random_Window_Discriminators(lc_channels=embedding_dim).to(device)

    optimizer_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.9))
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    scaler_gen = GradScaler()
    scaler_disc = GradScaler()

    start_epoch = 0
    if resume and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
        optimizer_disc.load_state_dict(checkpoint['optimizer_disc_state_dict'])
        scaler_gen.load_state_dict(checkpoint['scaler_gen'])
        scaler_disc.load_state_dict(checkpoint['scaler_disc'])
        start_epoch = checkpoint['epoch']
        logger.info(f"Resuming training from epoch {start_epoch}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for epoch in range(start_epoch, num_epochs):
        torch.cuda.empty_cache()
        generator.train()
        discriminator.train()

        for batch_idx, (real_audio, features_ids) in enumerate(train_loader):
            real_audio = real_audio.to(device)  # [B,1,T]
            features_ids = features_ids.to(device) # [B, feature_length]
            # Embed features
            features_emb = phoneme_embedding(features_ids) # [B, feature_length, emb_dim]
            features_emb = features_emb.transpose(1, 2)     # [B, emb_dim, feature_length]

            # Train Discriminator
            for _ in range(5):  # Critic steps
                optimizer_disc.zero_grad()
                noise = generate_noise(real_audio.size(0), z_dim, device)

                with autocast(device_type='cuda'):
                    fake_audio = generator(features_emb, noise).detach()
                    real_outputs, fake_outputs = discriminator(real_audio, fake_audio, features_emb)
                    
                    # Compute WGAN loss for discriminator
                    loss_disc = sum(torch.mean(f) - torch.mean(r) for r, f in zip(real_outputs, fake_outputs))
                    
                    # Gradient Penalty
                    gradient_penalty = compute_gradient_penalty(discriminator, real_audio, fake_audio, features_emb, device)
                    loss_disc += 10 * gradient_penalty

                # Backward and step
                try:
                    scaler_disc.scale(loss_disc).backward(retain_graph=True)
                except RuntimeError as e:
                    logger.info("Backward failed:", e)
                    exit(1)

                scaler_disc.step(optimizer_disc)
                scaler_disc.update()

            # Train Generator
            optimizer_gen.zero_grad()
            # In Generator Training
            noise = generate_noise(real_audio.size(0), z_dim, device)
            with autocast(device_type='cuda'):
                fake_audio = generator(features_emb, noise)  # Do not detach here
                fake_outputs = discriminator(fake_audio, fake_audio, features_emb)[1]
                fake_outputs = pad_tensors_to_match(fake_outputs)
                loss_gen = -torch.mean(fake_outputs)

            # Backward and step
            scaler_gen.scale(loss_gen).backward(retain_graph=True)  # Ensure no re-use of graph
            scaler_gen.step(optimizer_gen)
            scaler_gen.update()

        logger.info(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

        # Save samples
        visualize_and_save_generated_waveforms([generator], z_dim, features_emb, num_waveforms=5, device=device, epoch=epoch, sample_rate=16000, output_dir=output_dir)

        # Save checkpoint
        checkpoint = {
            'epoch': epoch+1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_gen_state_dict': optimizer_gen.state_dict(),
            'optimizer_disc_state_dict': optimizer_disc.state_dict(),
            'scaler_gen': scaler_gen.state_dict(),
            'scaler_disc': scaler_disc.state_dict(),
            'z_dim': z_dim,
            'lr_gen': lr_gen,
            'lr_disc': lr_disc,
            'batch_size': batch_size,
            'seed': seed,
            'audio_length': audio_length,
        }
        torch.save(checkpoint, checkpoint_path)
        logger.info(f"Checkpoint saved at {checkpoint_path}")

        torch.save(generator.state_dict(), os.path.join(output_dir, f"pretrained_generator_epoch{epoch+1}.pth"))

    logger.info("Pretraining complete.")
    return generator, phoneme_embedding

def train_gan_with_pretrained_generators(
    pretrained_generator, phoneme_embedding,
    num_epochs, z_dim, lr_gen, lr_disc, batch_size, train_dataset,
    num_generators, seed, audio_length, output_dir, checkpoint_path="multi_gan_checkpoint.pth",
    resume=False
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize multiple generators with pretrained weights
    generators = []
    for _ in range(num_generators):
        gen = Generator(vocab_size=phoneme_embedding.num_embeddings, embedding_dim=phoneme_embedding.embedding_dim, z_channels=z_dim).to(device)
        gen.load_state_dict(pretrained_generator.state_dict())
        generators.append(gen)

    discriminator = Multiple_Random_Window_Discriminators(lc_channels=phoneme_embedding.embedding_dim).to(device)
    encoder = Encoder(audio_length).to(device)

    optimizer_gens = [optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.5, 0.9)) for gen in generators]
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))
    optimizer_encoder = optim.Adam(encoder.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    scaler_gens = [GradScaler() for _ in range(num_generators)]
    scaler_disc = GradScaler()
    scaler_encoder = GradScaler()

    start_epoch = 0
    if resume and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        start_epoch = checkpoint['epoch']
        for idx, gen in enumerate(generators):
            gen.load_state_dict(checkpoint['generator_state_dicts'][idx])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        for idx, opt_gen in enumerate(optimizer_gens):
            opt_gen.load_state_dict(checkpoint['optimizer_gen_state_dicts'][idx])
        optimizer_disc.load_state_dict(checkpoint['optimizer_disc_state_dict'])
        optimizer_encoder.load_state_dict(checkpoint['optimizer_encoder_state_dict'])
        logger.info(f"Resumed training from epoch {start_epoch}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    lambda_gp = 10
    lambda_ortho = 0.1
    num_critic = 5

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for epoch in range(start_epoch, num_epochs):
        torch.cuda.empty_cache()
        for batch_idx, (real_audio, features_ids) in enumerate(train_loader):
            real_audio = real_audio.to(device)
            features_ids = features_ids.to(device)
            features_emb = phoneme_embedding(features_ids).transpose(1,2) # [B, emb_dim, length]

            # Train Discriminator
            for _ in range(num_critic):
                optimizer_disc.zero_grad()
                noises = [generate_noise(real_audio.size(0), z_dim, device) for _ in range(num_generators)]

                with autocast(device_type='cuda'):
                    fakes = [gen(features_emb, noises[i]).detach() for i, gen in enumerate(generators)]
                    # We must stack fakes in a way to pass them to the discriminator.
                    # Discriminator expects two sets: real and fake.
                    # It processes each window independently. Just call discriminator on real_audio and "first" fake:
                    real_outputs, fake_outputs = discriminator(real_audio, fakes[0], features_emb)
                    # If we intended multiple sets, we could combine them, but code is from original snippet.
                    # We'll just consider the single fakes[0] for gradient penalty and WGAN loss.
                    loss_disc = sum(torch.mean(fake) - torch.mean(real) for real, fake in zip(real_outputs, fake_outputs))
                    gradient_penalty = compute_gradient_penalty(discriminator, real_audio, fakes[0], features_emb, device)
                    loss_disc += lambda_gp * gradient_penalty

                scaler_disc.scale(loss_disc).backward()
                scaler_disc.step(optimizer_disc)
                scaler_disc.update()

            # Train Generators and Encoder
            for idx, gen in enumerate(generators):
                optimizer_gens[idx].zero_grad()
                optimizer_encoder.zero_grad()

                noise = generate_noise(real_audio.size(0), z_dim, device)
                with autocast(device_type='cuda'):
                    fake = gen(features_emb, noise)
                    fake_outputs = discriminator(fake, fake, features_emb)[1]
                    fake_outputs = pad_tensors_to_match(fake_outputs)
                    loss_gen = -torch.mean(fake_outputs)

                    # Orthogonal loss
                    gen_feature = encoder(fake)
                    ortho_loss_val = 0
                    for other_idx, other_gen in enumerate(generators):
                        if idx != other_idx:
                            other_noise = generate_noise(real_audio.size(0), z_dim, device)
                            other_fake = other_gen(features_emb, other_noise)
                            other_feature = encoder(other_fake)
                            ortho_loss_val += orthogonal_loss(gen_feature, other_feature)
                    ortho_loss_val /= (num_generators - 1)
                    total_loss_gen = loss_gen + lambda_ortho * ortho_loss_val

                scaler_gens[idx].scale(total_loss_gen).backward()
                scaler_gens[idx].step(optimizer_gens[idx])
                scaler_gens[idx].update()
                scaler_encoder.step(optimizer_encoder)
                scaler_encoder.update()

        logger.info(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'generator_state_dicts': [gen.state_dict() for gen in generators],
            'discriminator_state_dict': discriminator.state_dict(),
            'encoder_state_dict': encoder.state_dict(),
            'optimizer_gen_state_dicts': [opt.state_dict() for opt in optimizer_gens],
            'optimizer_disc_state_dict': optimizer_disc.state_dict(),
            'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
        }
        torch.save(checkpoint, checkpoint_path)
        logger.info(f"Checkpoint saved at {checkpoint_path}")

    logger.info("Training complete.")
    return generators


###################################
# Main Execution
###################################

In [None]:
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    filename="training_500.log",  # Specify a file to write logs to
    filemode="a"  # Append mode
)

# Create a console handler to also logger.info logs to the console
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)

# Add the console handler to the root logger
logging.getLogger('').addHandler(console_handler)

# Now you can use logging throughout your code
logger = logging.getLogger(__name__)

# Example usage:
logger.info("Logging system initialized")

set_seed(42)
audio_length = 64000
z_dim = 128
lr_gen = 0.0002
lr_disc = 0.0002
batch_size = 8
num_epochs = 50
root_dir = "./data"
sample_rate = 16000
num_generators = 5
output_dir = 'generated_audio'

# Create dataset (no longer loading all into RAM)
train_dataset = MyAudioDataset(root_dir, target_sr=sample_rate, target_length=audio_length)
vocab_size = len(train_dataset.phoneme_to_id)

2024-12-12 21:27:26,658 - INFO - Logging system initialized
2024-12-12 21:27:26,825 - INFO - Found 148688 .flac files in ./data.


In [None]:
if __name__ == "__main__":
    

    clear_all_memory()
    pretrained_generator, phoneme_embedding = pretrain_single_generator(
        num_epochs=20,
        z_dim=z_dim,
        lr_gen=lr_gen,
        lr_disc=lr_disc,
        batch_size=batch_size,
        seed=42,
        audio_length=audio_length,
        output_dir='output/waveform_pre_linguistics_o1',
        train_dataset=train_dataset,
        checkpoint_path="output/gan_single_check.pth",
        embedding_dim=64,
        vocab_size=vocab_size
    )



In [None]:
    train_gan_with_pretrained_generators(
        pretrained_generator, phoneme_embedding,
        num_epochs=num_epochs,
        z_dim=z_dim,
        lr_gen=lr_gen,
        lr_disc=lr_disc,
        batch_size=batch_size,
        train_dataset=train_dataset,
        num_generators=num_generators,
        seed=42,
        audio_length=audio_length,
        output_dir=output_dir,
        checkpoint_path='output/my_checkpoints/multi_gan_checkpoint.pth',
        resume=True
    )