In [1]:
import numpy as np 
import pandas as pd 
import os

In [2]:
male = pd.read_csv("/kaggle/input/audiogen/line_index_male.tsv", sep='\t', names=['tag', 'content'])
female = pd.read_csv("/kaggle/input/audiogen/line_index_female.tsv", sep='\t', names=['tag', 'content'])

In [3]:
!git clone https://github.com/mostafaelaraby/wavegan-pytorch.git

Cloning into 'wavegan-pytorch'...
remote: Enumerating objects: 78, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 78 (delta 2), reused 6 (delta 2), pack-reused 68[K
Unpacking objects: 100% (78/78), 2.14 MiB | 10.03 MiB/s, done.


In [4]:

import os
os.chdir('wavegan-pytorch')

In [5]:
!pip3  install -r requirements.txt

Collecting pescador>=2.0.1
  Downloading pescador-2.1.0.tar.gz (20 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: pescador
  Building wheel for pescador (setup.py) ... [?25ldone
[?25h  Created wheel for pescador: filename=pescador-2.1.0-py3-none-any.whl size=21105 sha256=d87f2b49fa6b07e21c7a61c2227312007abf3163aa9d7c5a872b152c6722f1e6
  Stored in directory: /root/.cache/pip/wheels/f0/e3/c6/32d30d5eb5292dac352d2fca4ebf393aa94e09b9b8b4b0f341
Successfully built pescador
Installing collected packages: pescador
Successfully installed pescador-2.1.0
[0m

In [6]:
%%writefile params.py
import torch
import random
import numpy as np
import logging
import os

#############################
# DataSet Path
#############################s
target_signals_dir = '/kaggle/input/audiogen/ta_in_male'
#############################
# Model Params
#############################
model_prefix = 'male' # name of the model to be saved
n_iterations = 100000
lr_g = 1e-4
lr_d = 3e-4 # you can use with discriminator having a larger learning rate than generator instead of using n_critic updates ttur https://arxiv.org/abs/1706.08500
beta1 = 0.5
beta2 = 0.9
use_batchnorm=False
validate=False
decay_lr = False # used to linearly deay learning rate untill reaching 0 at iteration 100,000
generator_batch_size_factor = 1 # in some cases we might try to update the generator with double batch size used in the discriminator https://arxiv.org/abs/1706.08500
n_critic = 1 # update generator every n_critic steps if lr_g = lr_d the n_critic's default value is 5 
# gradient penalty regularization factor.
p_coeff = 10
batch_size = 10
noise_latent_dim = 100  # size of the sampling noise
model_capacity_size = 32    # model capacity during training can be reduced to 32 for larger window length of 2 seconds and 4 seconds
# rate of storing validation and costs params
store_cost_every = 300
progress_bar_step_iter_size = 400
#############################
# Backup Params
#############################
take_backup = True
backup_every_n_iters = 1000
save_samples_every = 1000
output_dir = 'output'
if not(os.path.isdir(output_dir)):
    os.makedirs(output_dir)

window_length = 65536 #[16384, 32768, 65536] in case of a longer window change model_capacity_size to 32
sampling_rate = 16000
normalize_audio = True 
num_channels = 1

#############################
# Logger init
#############################
LOGGER = logging.getLogger('wavegan')
LOGGER.setLevel(logging.DEBUG)
#############################
# Torch Init and seed setting
#############################
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
# update the seed
manual_seed = 2019 
random.seed(manual_seed)
torch.manual_seed(manual_seed)
np.random.seed(manual_seed)
if cuda:
    torch.cuda.manual_seed(manual_seed)
    torch.cuda.empty_cache()

Overwriting params.py


In [7]:
%%writefile utils.py
import os
import time
import math
import torch
import random
import librosa
import librosa.display
import numpy as np
from torch.utils import data
import matplotlib
import matplotlib.pyplot as plt
import glob
import pescador
import torch.nn as nn
from torch.autograd import Variable
from params import *
import soundfile as sf
#############################
# File Utils
#############################
def get_recursive_files(folderPath, ext):
    results = os.listdir(folderPath)
    outFiles = []
    for file in results:
        if os.path.isdir(os.path.join(folderPath, file)):
            outFiles += get_recursive_files(os.path.join(folderPath, file), ext)
        elif file.endswith(ext):
            outFiles.append(os.path.join(folderPath, file))

    return outFiles


def make_path(output_path):
    if not os.path.isdir(output_path):
        os.makedirs(output_path)
    return output_path


#############################
# Plotting utils
#############################
def visualize_audio(audio_tensor, is_monphonic=False):
    # takes a batch ,n channels , window length and plots the spectogram
    input_audios = audio_tensor.detach().cpu().numpy()
    plt.figure(figsize=(18, 50))
    for i, audio in enumerate(input_audios):
        plt.subplot(10, 2, i + 1)
        if is_monphonic:
            plt.title("Monophonic %i" % (i + 1))
            librosa.display.waveshow(audio[0], sr=sampling_rate)
        else:
            D = librosa.amplitude_to_db(np.abs(librosa.stft(audio[0])), ref=np.max)
            librosa.display.specshow(D, y_axis="linear")
            plt.colorbar(format="%+2.0f dB")
            plt.title("Linear-frequency power spectrogram %i" % (i + 1))
    if not (os.path.isdir("visualization")):
        os.makedirs("visualization")
    plt.savefig("visualization/interpolation.png")


def visualize_loss(loss_1, loss_2, first_legend, second_legend, y_label):
    plt.figure(figsize=(10, 5))
    plt.title("{} and {} Loss During Training".format(first_legend, second_legend))
    plt.plot(loss_1, label=first_legend)
    plt.plot(loss_2, label=second_legend)
    plt.xlabel("iterations")
    plt.ylabel(y_label)
    plt.grid(True)
    plt.tight_layout()
    plt.legend()
    plt.show()
    if not (os.path.isdir("visualization")):
        os.makedirs("visualization")
    plt.savefig("visualization/loss.png")


def latent_space_interpolation(model, n_samples=10):
    z_test = sample_noise(2)
    with torch.no_grad():
        interpolates = []
        for alpha in np.linspace(0, 1, n_samples):
            interpolate_vec = alpha * z_test[0] + ((1 - alpha) * z_test[1])
            interpolates.append(interpolate_vec)

        interpolates = torch.stack(interpolates)
        generated_audio = model(interpolates)
    visualize_audio(generated_audio, True)


#############################
# Wav files utils
#############################
# Fast loading used with wav files only of 8 bits
def load_wav(wav_file_path):
    try:
        audio_data, _ = librosa.load(wav_file_path, sr=sampling_rate)

        if normalize_audio:
            # Clip magnitude
            max_mag = np.max(np.abs(audio_data))
            if max_mag > 1:
                audio_data /= max_mag
    except Exception as e:
        LOGGER.error("Could not load {}: {}".format(wav_file_path, str(e)))
        raise e
    audio_len = len(audio_data)
    if audio_len < window_length:
        pad_length = window_length - audio_len
        left_pad = pad_length // 2
        right_pad = pad_length - left_pad
        audio_data = np.pad(audio_data, (left_pad, right_pad), mode="constant")

    return audio_data.astype("float32")


def sample_audio(audio_data, start_idx=None, end_idx=None):
    audio_len = len(audio_data)
    if audio_len == window_length:
        # If we only have a single 1*window_length audio, just yield.
        sample = audio_data
    else:
        # Sample a random window from the audio
        if start_idx is None or end_idx is None:
            start_idx = np.random.randint(0, (audio_len - window_length) // 2)
            end_idx = start_idx + window_length
        sample = audio_data[start_idx:end_idx]
    sample = sample.astype("float32")
    assert not np.any(np.isnan(sample))
    return sample, start_idx, end_idx


def sample_buffer(buffer_data, start_idx=None, end_idx=None):
    audio_len = len(buffer_data) // 4
    if audio_len == window_length:
        # If we only have a single 1*window_length audio, just yield.
        sample = buffer_data
    else:
        # Sample a random window from the audio
        if start_idx is None or end_idx is None:
            start_idx = np.random.randint(0, (audio_len - window_length) // 2)
            end_idx = start_idx + window_length
        sample = buffer_data[start_idx * 4 : end_idx * 4]
    return sample, start_idx, end_idx


def wav_generator(file_path):
    audio_data = load_wav(file_path)
    while True:
        sample, _, _ = sample_audio(audio_data)
        yield {"single": sample}


def create_stream_reader(single_signal_file_list):
    data_streams = []
    for audio_path in single_signal_file_list:
        stream = pescador.Streamer(wav_generator, audio_path)
        data_streams.append(stream)
    mux = pescador.ShuffledMux(data_streams)
    batch_gen = pescador.buffer_stream(mux, batch_size)
    return batch_gen


def save_samples(epoch_samples, epoch):
    """
    Save output samples.
    """
    sample_dir = make_path(os.path.join(output_dir, str(epoch)))

    for idx, sample in enumerate(epoch_samples):
        output_path = os.path.join(sample_dir, "{}.wav".format(idx + 1))
        sample = sample[0]
#         librosa.output.write_wav(output_path, sample, sampling_rate)
        
        sf.write(output_path, sample, sampling_rate)

#############################
# Sampling from model
#############################
def sample_noise(size):
    z = torch.FloatTensor(size, noise_latent_dim).to(device)
    z.data.normal_()  # generating latent space based on normal distribution
    return z


#############################
# Model Utils
#############################


def update_optimizer_lr(optimizer, lr, decay):
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr * decay


def gradients_status(model, flag):
    for p in model.parameters():
        p.requires_grad = flag


def weights_init(m):
    if isinstance(m, nn.Conv1d):
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
        m.bias.data.fill_(0)
    elif isinstance(m, nn.Linear):
        m.bias.data.fill_(0)


#############################
# Creating Data Loader and Sampler
#############################
class WavDataLoader:
    def __init__(self, folder_path, audio_extension="wav"):
        self.signal_paths = get_recursive_files(folder_path, audio_extension)
        self.data_iter = None
        self.initialize_iterator()

    def initialize_iterator(self):
        data_iter = create_stream_reader(self.signal_paths)
        self.data_iter = iter(data_iter)

    def __len__(self):
        return len(self.signal_paths)

    def numpy_to_tensor(self, numpy_array):
        numpy_array = numpy_array[:, np.newaxis, :]
        return torch.Tensor(numpy_array).to(device)

    def __iter__(self):
        return self

    def __next__(self):
        x = next(self.data_iter)
        return self.numpy_to_tensor(x["single"])


if __name__ == "__main__":
    # For debugging purposes
    import time 
    start = time.time()
    print(time.time() - start)
    train_loader = WavDataLoader('/kaggle/input/audiogen/ta_in_male', "wav")
    start = time.time()
    for i in range(7):
        x = next(train_loader)
    print(time.time() - start)

Overwriting utils.py


In [8]:
%%writefile train.py
from params import *
from utils import *
from models import *
import torch.optim as optim
import torch
from torch.autograd import grad, Variable
from tqdm import tqdm


class WaveGan_GP(object):
    def __init__(self, train_loader, val_loader):
        super(WaveGan_GP, self).__init__()
        self.g_cost = []
        self.train_d_cost = []
        self.train_w_distance = []
        self.valid_g_cost = [-1]
        self.valid_reconstruction = []

        self.discriminator = WaveGANDiscriminator(
            slice_len=window_length,
            model_size=model_capacity_size,
            use_batch_norm=use_batchnorm,
            num_channels=num_channels,
        ).to(device)
        self.discriminator.apply(weights_init)

        self.generator = WaveGANGenerator(
            slice_len=window_length,
            model_size=model_capacity_size,
            use_batch_norm=use_batchnorm,
            num_channels=num_channels,
        ).to(device)
        self.generator.apply(weights_init)

        self.optimizer_g = optim.Adam(
            self.generator.parameters(), lr=lr_g, betas=(beta1, beta2)
        )  # Setup Adam optimizers for both G and D
        self.optimizer_d = optim.Adam(
            self.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2)
        )

        self.train_loader = train_loader
        self.val_loader = val_loader

        self.validate = validate
        self.n_samples_per_batch = len(train_loader)

    def calculate_discriminator_loss(self, real, generated):
        disc_out_gen = self.discriminator(generated)
        disc_out_real = self.discriminator(real)

        alpha = torch.FloatTensor(batch_size, 1, 1).uniform_(0, 1).to(device)
        alpha = alpha.expand(batch_size, real.size(1), real.size(2))

        interpolated = (1 - alpha) * real.data + (alpha) * generated.data[:batch_size]
        interpolated = Variable(interpolated, requires_grad=True)

        # calculate probability of interpolated examples
        prob_interpolated = self.discriminator(interpolated)
        grad_inputs = interpolated
        ones = torch.ones(prob_interpolated.size()).to(device)
        gradients = grad(
            outputs=prob_interpolated,
            inputs=grad_inputs,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        # calculate gradient penalty
        grad_penalty = (
            p_coeff
            * ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
        )
        assert not (torch.isnan(grad_penalty))
        assert not (torch.isnan(disc_out_gen.mean()))
        assert not (torch.isnan(disc_out_real.mean()))
        cost_wd = disc_out_gen.mean() - disc_out_real.mean()
        cost = cost_wd + grad_penalty
        return cost, cost_wd

    def apply_zero_grad(self):
        self.generator.zero_grad()
        self.optimizer_g.zero_grad()

        self.discriminator.zero_grad()
        self.optimizer_d.zero_grad()

    def enable_disc_disable_gen(self):
        gradients_status(self.discriminator, True)
        gradients_status(self.generator, False)

    def enable_gen_disable_disc(self):
        gradients_status(self.discriminator, False)
        gradients_status(self.generator, True)

    def disable_all(self):
        gradients_status(self.discriminator, False)
        gradients_status(self.generator, False)

    def train(self):
        progress_bar = tqdm(total=n_iterations // progress_bar_step_iter_size)
        fixed_noise = sample_noise(batch_size).to(
            device
        )  # used to save samples every few epochs

        gan_model_name = "gan_{}.tar".format(model_prefix)

        first_iter = 0
        if take_backup and os.path.isfile(gan_model_name):
            if cuda:
                checkpoint = torch.load(gan_model_name)
            else:
                checkpoint = torch.load(gan_model_name, map_location="cpu")
            self.generator.load_state_dict(checkpoint["generator"])
            self.discriminator.load_state_dict(checkpoint["discriminator"])
            self.optimizer_d.load_state_dict(checkpoint["optimizer_d"])
            self.optimizer_g.load_state_dict(checkpoint["optimizer_g"])
            self.train_d_cost = checkpoint["train_d_cost"]
            self.train_w_distance = checkpoint["train_w_distance"]
            self.valid_g_cost = checkpoint["valid_g_cost"]
            self.g_cost = checkpoint["g_cost"]

            first_iter = checkpoint["n_iterations"]
            for i in range(0, first_iter, progress_bar_step_iter_size):
                progress_bar.update()
            self.generator.eval()
            with torch.no_grad():
                fake = self.generator(fixed_noise).detach().cpu().numpy()
            save_samples(fake, first_iter)
        self.generator.train()
        self.discriminator.train()
        for iter_indx in range(first_iter, n_iterations):
            self.enable_disc_disable_gen()
            for _ in range(n_critic):
                real_signal = next(self.train_loader)

                # need to add mixed signal and flag
                noise = sample_noise(batch_size * generator_batch_size_factor)
                generated = self.generator(noise)
                #############################
                # Calculating discriminator loss and updating discriminator
                #############################
                self.apply_zero_grad()
                disc_cost, disc_wd = self.calculate_discriminator_loss(
                    real_signal.data, generated.data
                )
                assert not (torch.isnan(disc_cost))
                disc_cost.backward()
                self.optimizer_d.step()

            if self.validate and iter_indx % store_cost_every == 0:
                self.disable_all()
                val_data = next(self.val_loader)
                val_real = val_data
                with torch.no_grad():
                    val_discriminator_output = self.discriminator(val_real)
                    val_generator_cost = val_discriminator_output.mean()
                    self.valid_g_cost.append(val_generator_cost.item())

            #############################
            # (2) Update G network every n_critic steps
            #############################
            self.apply_zero_grad()
            self.enable_gen_disable_disc()
            noise = sample_noise(batch_size * generator_batch_size_factor)
            generated = self.generator(noise)
            discriminator_output_fake = self.discriminator(generated)
            generator_cost = -discriminator_output_fake.mean()
            generator_cost.backward()
            self.optimizer_g.step()
            self.disable_all()

            if iter_indx % store_cost_every == 0:
                self.g_cost.append(generator_cost.item() * -1)
                self.train_d_cost.append(disc_cost.item())
                self.train_w_distance.append(disc_wd.item() * -1)

                progress_updates = {
                    "Loss_D WD": str(self.train_w_distance[-1]),
                    "Loss_G": str(self.g_cost[-1]),
                    "Val_G": str(self.valid_g_cost[-1]),
                }
                progress_bar.set_postfix(progress_updates)

            if iter_indx % progress_bar_step_iter_size == 0:
                progress_bar.update()
            # lr decay
            if decay_lr:
                decay = max(0.0, 1.0 - (iter_indx * 1.0 / n_iterations))
                # update the learning rate
                update_optimizer_lr(self.optimizer_d, lr_d, decay)
                update_optimizer_lr(self.optimizer_g, lr_g, decay)

            if iter_indx % save_samples_every == 0:
                with torch.no_grad():
                    latent_space_interpolation(self.generator, n_samples=2)
                    fake = self.generator(fixed_noise).detach().cpu().numpy()
                save_samples(fake, iter_indx)

            if take_backup and iter_indx % backup_every_n_iters == 0:
                saving_dict = {
                    "generator": self.generator.state_dict(),
                    "discriminator": self.discriminator.state_dict(),
                    "n_iterations": iter_indx,
                    "optimizer_d": self.optimizer_d.state_dict(),
                    "optimizer_g": self.optimizer_g.state_dict(),
                    "train_d_cost": self.train_d_cost,
                    "train_w_distance": self.train_w_distance,
                    "valid_g_cost": self.valid_g_cost,
                    "g_cost": self.g_cost,
                }
                torch.save(saving_dict, gan_model_name)

        self.generator.eval()


if __name__ == "__main__":
    train_loader = WavDataLoader("/kaggle/input/audiogen/ta_in_male")
    val_loader = WavDataLoader("/kaggle/input/audiogen/ta_in_female")

    wave_gan = WaveGan_GP(train_loader, val_loader)
    wave_gan.train()
    visualize_loss(
        wave_gan.g_cost, wave_gan.valid_g_cost, "Train", "Val", "Negative Critic Loss"
    )
    latent_space_interpolation(wave_gan.generator, n_samples=5)

Overwriting train.py


In [None]:

!python3  train.py