#  orthogonal loss 
This model work well

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random

In [2]:
# Set up logging
import logging
# Set up logging
logging.basicConfig(
    filename="training_log_500.txt",
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# Dataloader

In [3]:
import os
import librosa
import numpy as np
from glob import glob
from torch.utils.data import TensorDataset
import torch
import soundfile as sf
import matplotlib.pyplot as plt

import os
import librosa
import numpy as np
from glob import glob
import torch
import soundfile as sf
from torch.utils.data import Dataset, DataLoader


def load_flac(file_path, target_sr=16000):
    """Load a FLAC audio file and resample it to the target sample rate."""
    audio, sr = librosa.load(file_path, sr=target_sr)
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = audio / max_val  # Normalize to [-1, 1]
    return audio


def pad_or_trim(audio, target_length=64000):
    """Pad or trim audio to the target length."""
    if len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)))
    else:
        audio = audio[:target_length]
    return audio


def preprocess_and_save_dataset(root_dir, output_dir, target_sr=16000, target_length=64000):
    flac_files = glob(os.path.join(root_dir, '**', '*.flac'), recursive=True)
    if len(flac_files) == 0:
        raise ValueError("No .flac files found. Please check the root_dir path.")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    dataset = []
    fragmented_count = 0
    sink = []

    counter = 0
    for file in flac_files:
        try:
            audio = load_flac(file, target_sr)
            audio_length = len(audio)

            if audio_length == target_length:
                sf.write(os.path.join(output_dir, f"audio_{counter}.wav"), audio, target_sr)
                counter += 1
            elif audio_length < target_length:
                fragmented_count += 1
                sink.append(audio)
            else:
                num_full_chunks = audio_length // target_length
                for i in range(num_full_chunks):
                    chunk = audio[i * target_length: (i + 1) * target_length]
                    sf.write(os.path.join(output_dir, f"audio_{counter}.wav"), chunk, target_sr)
                    counter += 1
                
                remainder = audio[num_full_chunks * target_length:]
                if len(remainder) > 0:
                    sink.append(remainder)
        except Exception as e:
            print(f"Error processing file {file}: {e}")

    current_audio = np.array([], dtype=np.float32)
    num_additional_samples = 0

    for fragment in sink:
        current_audio = np.concatenate((current_audio, fragment))
        while len(current_audio) >= target_length:
            chunk = current_audio[:target_length]
            sf.write(os.path.join(output_dir, f"audio_{counter}.wav"), chunk, target_sr)
            counter += 1
            num_additional_samples += 1
            current_audio = current_audio[target_length:]

    if len(current_audio) > 0:
        padded_audio = pad_or_trim(current_audio, target_length)
        sf.write(os.path.join(output_dir, f"audio_{counter}.wav"), padded_audio, target_sr)
        counter += 1
        num_additional_samples += 1

    print(f"Number of additional samples created from sink fragments: {num_additional_samples}")
    print(f"Preprocessed data saved to {output_dir}. Total samples: {counter}")


class AudioDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.audio_files = sorted(glob(os.path.join(data_dir, '*.wav')))
        self.transform = transform

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        file_path = self.audio_files[idx]
        audio, sr = librosa.load(file_path, sr=None)
        if self.transform:
            audio = self.transform(audio)
        return torch.tensor(audio, dtype=torch.float32)




def save_waveform_to_audio(waveform, sample_rate, filename):
    """
    Save a waveform to an audio file.
    """
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.detach().cpu().numpy()
    
    waveform = np.squeeze(waveform)
    max_val = np.max(np.abs(waveform))
    if max_val > 0:
        waveform = waveform / max_val
    
    sf.write(filename, waveform, sample_rate)




In [4]:
# Usage example:
root_dir = "./data"
processed_data = './preprocessed_data'
audio_length = 64000  # Set this to your desired length

target_sr = 16000
preprocess_and_save_dataset(root_dir, processed_data, target_sr=target_sr, target_length=audio_length)


Number of additional samples created from sink fragments: 1422
Preprocessed data saved to ./preprocessed_data. Total samples: 4850


In [None]:
train_dataset = AudioDataset(processed_data)

# model Definition

In [6]:
import math
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior
    torch.backends.cudnn.benchmark = False  # Turn off the cuDNN auto-tuner to avoid nondeterministic behavior


# img_size = 28
# img_channels = 1

def get_dim_for_each_layer(z_dim, total_l, l, output_dim):
    """
    Calculate the dimension for the l-th layer in the generator.

    Parameters:
    - z_dim: int, the input dimension (e.g., latent vector size).
    - total_l: int, the total number of layers in the generator.
    - l: int, the current layer index (1-indexed).
    - output_dim: int, the final output dimension (e.g., audio length).

    Returns:
    - int: the calculated dimension for the l-th layer.
    """
    if l < 1 or l > total_l:
        raise ValueError("Layer index 'l' must be in the range [1, total_l].")
    if l == total_l:
        return output_dim
    if l == 1:
        return z_dim
    # Calculate the dimension change per layer
    step = (output_dim - z_dim) / (total_l - 1)
    
    # Compute the dimension for the l-th layer
    dim = z_dim + (l - 1) * step
    return math.ceil(dim)  # Use math.ceil to round up to an integer
    

In [7]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim, audio_length):
        super(Generator, self).__init__()

        self.initial_length = audio_length // 256  # 64000 / 256 = 250
        self.model = nn.Sequential(
            nn.Linear(z_dim, 512 * self.initial_length),  # Output: (batch_size, 512 * 250)
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, (512, self.initial_length)),  # Shape: (batch_size, 512, 250)
            nn.ConvTranspose1d(512, 256, kernel_size=25, stride=4, padding=11, output_padding=1),  # Output length: 1000
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(256, 128, kernel_size=25, stride=4, padding=11, output_padding=1),  # Output length: 4000
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(128, 64, kernel_size=25, stride=4, padding=11, output_padding=1),   # Output length: 16000
            nn.LeakyReLU(0.2),
            nn.ConvTranspose1d(64, 1, kernel_size=25, stride=4, padding=11, output_padding=1),     # Output length: 64000
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)









class Discriminator(nn.Module):
    def __init__(self, audio_length):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=25, stride=4, padding=11),  # Output: (batch_size, 64, L1)
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=25, stride=4, padding=11),  # Output: (batch_size, 128, L2)
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 256, kernel_size=25, stride=4, padding=11),  # Output: (batch_size, 256, L3)
            nn.LeakyReLU(0.2),
            nn.Conv1d(256, 512, kernel_size=25, stride=4, padding=11),  # Output: (batch_size, 512, L4)
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool1d(1),  # Output: (batch_size, 512, 1)
            nn.Flatten(),             # Output: (batch_size, 512)
            nn.Linear(512, 1)
        )

    def forward(self, x):
        # Remove the unnecessary unsqueeze
        # x = x.unsqueeze(1)  # This line is removed
        return self.model(x)





class Encoder(nn.Module):
    def __init__(self, audio_length):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool1d(1),  # Output: (batch_size, 128, 1)
            nn.Flatten(),             # Output: (batch_size, 128)
            nn.Linear(128, 64)        # Output: (batch_size, 64)
        )

    def forward(self, x):
        # x = x.unsqueeze(1)  # Remove if input x already has channel dimension
        return self.encoder(x)





# Loss function
criterion = nn.BCELoss()

In [8]:
def generate_noise(batch_size, z_dim, device):
    return torch.randn(batch_size, z_dim).to(device)

# Orthogonal loss function
def orthogonal_loss(feature1, feature2):
    inner_product = torch.sum(feature1 * feature2, dim=1)
    norm1 = torch.norm(feature1, dim=1)
    norm2 = torch.norm(feature2, dim=1)
    cosine_similarity = inner_product / (norm1 * norm2 + 1e-8)
    return torch.mean(cosine_similarity**2)  # Minimize the cosine similarity to make vectors orthogonal


In [9]:
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
    batch_size = real_samples.size(0)

    # Sample epsilon uniformly in [0,1]
    epsilon = torch.rand(batch_size, 1, 1, device=device)
    epsilon = epsilon.expand_as(real_samples)

    # Interpolate between real and fake samples
    interpolates = (epsilon * real_samples + (1 - epsilon) * fake_samples).requires_grad_(True)
    interpolates_output = critic(interpolates)

    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=interpolates_output,
        inputs=interpolates,
        grad_outputs=torch.ones_like(interpolates_output),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Reshape gradients
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    return gradient_penalty


In [10]:
import matplotlib.pyplot as plt

def visualize_and_save_generated_waveforms(generators, z_dim, num_waveforms, device,epoch, sample_rate=16000, output_dir='generated_audio'):
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    noise = generate_noise(num_waveforms, z_dim, device)
    for idx, gen in enumerate(generators):
        fake_waveforms = gen(noise).detach().cpu().numpy()
        for i in range(num_waveforms):
            waveform = fake_waveforms[i]
            # Save each waveform to an audio file
            filename = f'epoch{epoch+1}_gen{idx+1}_sample{i+1}.wav'
            filepath = os.path.join(output_dir, filename)
            save_waveform_to_audio(waveform, sample_rate, filepath)
            logging.info(f"Saved {filepath}")



In [11]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

def pretrain_single_generator(num_epochs, z_dim, lr_gen, lr_disc, batch_size, seed, audio_length, output_dir, train_dataset):
    # Set the seed for reproducibility
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # Check for device (GPU or CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define the single generator and discriminator
    generator = Generator(z_dim, audio_length).to(device)
    discriminator = Discriminator(audio_length).to(device)

    optimizer_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.9))
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

    lambda_gp = 10  # Gradient penalty coefficient
    num_critic = 5  # Number of discriminator updates per generator update

    # To track the losses
    loss_disc_history = []
    loss_gen_history = []

    # Resume training if checkpoints exist
    checkpoint_path = os.path.join(output_dir, "checkpoint.pth")
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(checkpoint_path)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
        optimizer_disc.load_state_dict(checkpoint['optimizer_disc_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        loss_disc_history = checkpoint['loss_disc_history']
        loss_gen_history = checkpoint['loss_gen_history']
        logging.info(f"Resumed from epoch {start_epoch}")

    # Training loop
    for epoch in range(num_epochs):
        torch.cuda.empty_cache()  # Clear unused memory
        loss_disc_epoch = 0
        loss_gen_epoch = 0

        for batch_idx, real in enumerate(train_loader):
            real = real.to(device)
            batch_size = real.size(0)
            if real.dim() == 2:
                real = real.unsqueeze(1)  # Add channel dimension (batch_size, 1, audio_length)

            #logging.info(f"real shape: {real.shape}")

            # Train Discriminator multiple times
            for _ in range(num_critic):
                optimizer_disc.zero_grad()

                # Generate fake data
                noise = generate_noise(batch_size, z_dim, device)
                fake = generator(noise).detach()

                # Compute discriminator outputs
                disc_real = discriminator(real)
                disc_fake = discriminator(fake)

                # Compute Wasserstein loss
                loss_disc_real = -torch.mean(disc_real)
                loss_disc_fake = torch.mean(disc_fake)
                loss_disc = loss_disc_real + loss_disc_fake

                # Compute gradient penalty
                gradient_penalty = compute_gradient_penalty(discriminator, real.data, fake.data, device)
                loss_disc += lambda_gp * gradient_penalty

                loss_disc.backward()
                optimizer_disc.step()

            loss_disc_epoch += loss_disc.item()

            # Train Generator
            optimizer_gen.zero_grad()
            noise = generate_noise(batch_size, z_dim, device)
            fake = generator(noise)
            disc_fake = discriminator(fake)
            loss_gen = -torch.mean(disc_fake)
            loss_gen.backward()
            optimizer_gen.step()

            loss_gen_epoch += loss_gen.item()

        avg_loss_disc = loss_disc_epoch / len(train_loader)
        avg_loss_gen = loss_gen_epoch / len(train_loader)

        # Record the losses
        loss_disc_history.append(avg_loss_disc)
        loss_gen_history.append(avg_loss_gen)

        logging.info(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {avg_loss_disc:.4f}, Loss G: {avg_loss_gen:.4f}")

        # Visualize generated waveforms
        visualize_and_save_generated_waveforms(
            [generator], z_dim, num_waveforms=5, device=device, epoch=epoch, sample_rate=16000, output_dir=output_dir
        )

        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_gen_state_dict': optimizer_gen.state_dict(),
            'optimizer_disc_state_dict': optimizer_disc.state_dict(),
            'loss_disc_history': loss_disc_history,
            'loss_gen_history': loss_gen_history,
        }
        torch.save(checkpoint, checkpoint_path)

    # Save the generator model
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    torch.save(generator.state_dict(), os.path.join(output_dir, "pretrained_generator.pth"))
    logging.info(f"Pretrained generator model saved to {os.path.join(output_dir, 'pretrained_generator.pth')}")

    # Plot the learning curves
    plt.figure(figsize=(10, 5))
    plt.plot(loss_disc_history, label="Discriminator Loss")
    plt.plot(loss_gen_history, label="Generator Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Learning Curves")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, "learning_curves.png"))
    plt.show()

    return generator


In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import librosa
import librosa.display
import soundfile as sf
import os
from torch.nn.utils import spectral_norm
import torch.autograd as autograd
# Wasserstein loss
def wasserstein_loss(y_pred, y_true):
    return torch.mean(y_pred * y_true)

def compute_gradient_penalty(discriminator, real_data, fake_data, device):
    batch_size = real_data.size(0)
    # Corrected epsilon shape to match real_data dimensions
    epsilon = torch.rand(batch_size, 1, 1, device=device)  # Shape: [batch_size, 1, 1]
    epsilon = epsilon.expand_as(real_data)  # Now expands to [batch_size, 1, audio_length]
    
    # Interpolate between real and fake data
    interpolates = epsilon * real_data + (1 - epsilon) * fake_data
    interpolates = interpolates.requires_grad_(True)
    
    # Compute discriminator output
    disc_interpolates = discriminator(interpolates)
    
    # Compute gradients
    grad_outputs = torch.ones_like(disc_interpolates)
    gradients = torch.autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Reshape gradients
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    return gradient_penalty


In [13]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import os

def train_gan_with_pretrained_generators(
    pretrained_generator, num_epochs, z_dim, lr_gen, lr_disc, batch_size, train_dataset,
    num_generators, seed, audio_length, output_dir, lambda_gp=10, lambda_ortho=0.1, num_critic=5,
    checkpoint_dir='checkpoints', resume=True
):
    import os

    # Set the seed for reproducibility
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize multiple generators with the pretrained generator
    generators = []
    for _ in range(num_generators):
        gen = Generator(z_dim, audio_length).to(device)
        gen.load_state_dict(pretrained_generator.state_dict())
        generators.append(gen)

    # Initialize Discriminator
    discriminator = Discriminator(audio_length).to(device)

    optimizer_gens = [optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.5, 0.9)) for gen in generators]
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.9))

    # Initialize Encoder only if lambda_ortho > 0
    if lambda_ortho > 0:
        encoder = Encoder(audio_length).to(device)
        optimizer_encoder = optim.Adam(encoder.parameters(), lr=lr_disc, betas=(0.5, 0.9))
    else:
        encoder = None
        optimizer_encoder = None

    # Load and preprocess
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

    # Check for existing checkpoint
    start_epoch = 0
    if resume and os.path.exists(os.path.join(checkpoint_dir, 'checkpoint.pth')):
        logging.info("Resuming from checkpoint...")
        checkpoint = torch.load(os.path.join(checkpoint_dir, 'checkpoint.pth'), map_location=device)
        start_epoch = checkpoint['epoch'] + 1  # Next epoch to start from

        # Load models
        for idx, gen in enumerate(generators):
            gen.load_state_dict(checkpoint['generator_state_dicts'][idx])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        if lambda_ortho > 0:
            encoder.load_state_dict(checkpoint['encoder_state_dict'])

        # Load optimizers
        for idx, optimizer_gen in enumerate(optimizer_gens):
            optimizer_gen.load_state_dict(checkpoint['optimizer_gens_state_dicts'][idx])
        optimizer_disc.load_state_dict(checkpoint['optimizer_disc_state_dict'])
        if lambda_ortho > 0:
            optimizer_encoder.load_state_dict(checkpoint['optimizer_encoder_state_dict'])

        # Load loss histories
        loss_disc_history = checkpoint['loss_disc_history']
        loss_gens_history = checkpoint['loss_gens_history']

    else:
        logging.info("Starting training from scratch.")
        start_epoch = 0
        # Initialize loss histories
        loss_disc_history = []
        loss_gens_history = [[] for _ in range(num_generators)]

    # Training loop
    try:
        for epoch in range(start_epoch, num_epochs):
            torch.cuda.empty_cache()
            loss_disc_epoch = 0
            loss_gens_epoch = [0] * num_generators

            for batch_idx, (real,) in enumerate(train_loader):
                real = real.to(device)
                if real.dim() == 2:
                    real = real.unsqueeze(1)  # Add channel dimension (batch_size, 1, audio_length)
                batch_size = real.size(0)
                real_label = -torch.ones(batch_size, 1, device=device)
                fake_label = torch.ones(batch_size, 1, device=device)

                real = real + 0.001 * torch.randn_like(real)

                # Train Discriminator multiple times
                for _ in range(num_critic):
                    optimizer_disc.zero_grad()

                    disc_real = discriminator(real)

                    noises = [generate_noise(batch_size, z_dim, device) for _ in range(num_generators)]
                    fakes = [gen(noises[idx]).detach() for idx, gen in enumerate(generators)]

                    for idx in range(num_generators):
                        fakes[idx] = fakes[idx] + 0.001 * torch.randn_like(fakes[idx])

                    disc_fakes = [discriminator(fake) for fake in fakes]

                    # Average the fake losses
                    loss_disc_fake = sum(wasserstein_loss(disc_fake, fake_label) for disc_fake in disc_fakes) / num_generators
                    loss_disc_real = wasserstein_loss(disc_real, real_label)
                    loss_disc = loss_disc_real + loss_disc_fake

                    # Average gradient penalty
                    gradient_penalty = sum(
                        compute_gradient_penalty(discriminator, real, fake, device) for fake in fakes
                    ) / num_generators
                    loss_disc += lambda_gp * gradient_penalty

                    loss_disc.backward()
                    optimizer_disc.step()

                loss_disc_epoch += loss_disc.item()

                # Train Generators
                for idx, gen in enumerate(generators):
                    optimizer_gens[idx].zero_grad()
                    if lambda_ortho > 0:
                        optimizer_encoder.zero_grad()

                    noise = generate_noise(batch_size, z_dim, device)
                    fake = gen(noise)
                    disc_fake = discriminator(fake)

                    loss_gen = wasserstein_loss(disc_fake, real_label)

                    if lambda_ortho > 0:
                        # Compute orthogonal loss
                        gen_feature = encoder(fake)
                        ortho_loss_total = 0
                        for other_idx, other_gen in enumerate(generators):
                            if idx != other_idx:
                                other_noise = generate_noise(batch_size, z_dim, device)
                                other_fake = other_gen(other_noise)
                                other_feature = encoder(other_fake)
                                ortho_loss = orthogonal_loss(gen_feature, other_feature)
                                ortho_loss_total += ortho_loss

                        ortho_loss_total /= (num_generators - 1)
                        total_loss_gen = loss_gen + lambda_ortho * ortho_loss_total
                    else:
                        total_loss_gen = loss_gen

                    total_loss_gen.backward()
                    optimizer_gens[idx].step()

                    if lambda_ortho > 0:
                        optimizer_encoder.step()

                    loss_gens_epoch[idx] += total_loss_gen.item()

            avg_loss_disc = loss_disc_epoch / len(train_loader)
            avg_loss_gens = [loss / len(train_loader) for loss in loss_gens_epoch]

            logging.info(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {avg_loss_disc:.4f}")
            for idx in range(num_generators):
                logging.info(f"Loss G{idx+1}: {avg_loss_gens[idx]:.4f}")
            logging.info('-' * 50)

            # Record the losses
            loss_disc_history.append(avg_loss_disc)
            for idx in range(num_generators):
                loss_gens_history[idx].append(avg_loss_gens[idx])

            # Visualize generated waveforms
            visualize_and_save_generated_waveforms(
                generators, z_dim, num_waveforms=5, device=device, epoch=epoch, sample_rate=16000, output_dir=output_dir
            )

            # Save checkpoint after each epoch
            checkpoint = {
                'epo  ch': epoch,
                'generator_state_dicts': [gen.state_dict() for gen in generators],
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_gens_state_dicts': [opt.state_dict() for opt in optimizer_gens],
                'optimizer_disc_state_dict': optimizer_disc.state_dict(),
                'loss_disc_history': loss_disc_history,
                'loss_gens_history': loss_gens_history
            }
            if lambda_ortho > 0:
                checkpoint['encoder_state_dict'] = encoder.state_dict()
                checkpoint['optimizer_encoder_state_dict'] = optimizer_encoder.state_dict()

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            torch.save(checkpoint, os.path.join(checkpoint_dir, 'checkpoint.pth'))
            logging.info(f"Checkpoint saved at epoch {epoch}")

    except KeyboardInterrupt:
        logging.info(f"Training interrupted at epoch {epoch}. Saving checkpoint...")
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'generator_state_dicts': [gen.state_dict() for gen in generators],
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_gens_state_dicts': [opt.state_dict() for opt in optimizer_gens],
            'optimizer_disc_state_dict': optimizer_disc.state_dict(),
            'loss_disc_history': loss_disc_history,
            'loss_gens_history': loss_gens_history
        }
        if lambda_ortho > 0:
            checkpoint['encoder_state_dict'] = encoder.state_dict()
            checkpoint['optimizer_encoder_state_dict'] = optimizer_encoder.state_dict()

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'checkpoint.pth'))
        logging.info(f"Checkpoint saved at epoch {epoch}")
        logging.info("Exiting training early.")
        return generators

    return generators


In [14]:

def initialize_multiple_generators(pretrained_generator, num_generators, z_dim):
    # Initialize multiple generators from the pretrained generator's weights
    generators = []
    for _ in range(num_generators):
        new_generator = Generator(z_dim).to(pretrained_generator.gen[0].weight.device)  # Ensure same device
        new_generator.load_state_dict(pretrained_generator.state_dict())  # Copy weights
        generators.append(new_generator)
    return generators

In [None]:
pretrained_generator = pretrain_single_generator(num_epochs=20, 
                                                 z_dim=100, lr_disc=0.0002,lr_gen=0.0002, batch_size=64, seed=42, audio_length=64000, 
                                                 output_dir='waveform_pre_500', train_dataset=train_dataset)


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


KeyboardInterrupt: 

In [None]:
def load_pretrained_generator(z_dim, weight_path, device, audio_length):
    """
    Load the pretrained generator model weights.

    Args:
        z_dim (int): Latent space dimension.
        weight_path (str): Path to the saved model weights.
        device (torch.device): Device to load the model onto.

    Returns:
        Generator: Loaded generator model.
    """
    generator = Generator(z_dim, audio_length=audio_length).to(device)
    generator.load_state_dict(torch.load(weight_path, map_location=device))
    generator.eval()  # Set the model to evaluation mode
    logging.info(f"Pretrained generator model loaded from {weight_path}")
    return generator


In [None]:
import torch

# Clear unused CUDA memory
torch.cuda.empty_cache()


# Accual train code (resumable)

In [None]:
train_gan_with_pretrained_generators(
    pretrained_generator,
    num_epochs=200,
    z_dim=100,
    lr_gen=0.0002,
    lr_disc=0.0002,
    batch_size=24,
    train_dataset=dataset,
    num_generators=5,
    seed=42,
    audio_length=64000,
    output_dir='waveform_mal_2',
    checkpoint_dir='my_checkpoints',
    resume=True
)
