<a href="https://colab.research.google.com/github/tsokawing/spec2spec/blob/main/training/spec2spec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This notebook is an attempt to train a spec2spec model, which transforms audio from one style to another using their spectrogram images. Some of the codes here is based on the TensorFlow tutorial notebook on pix2pix.

# Setup

Import setups:

In [None]:
import os
import time
import datetime
import librosa
import librosa.display
import IPython

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt


from pathlib import Path
from google.colab import drive

Mount Google Drive:

In [None]:
# We will load training data from Google Drive.
drive.mount('/content/drive')

Utility functions:

In [None]:
def plot_spec(audio, samplerate, title='Mel-frequency Spectrogram', save=False):
  fig, ax = plt.subplots()
  audio_db = librosa.power_to_db(audio, ref=np.max)
  img = librosa.display.specshow(audio_db, sr=samplerate, fmax=8000, ax=ax,
                                 x_axis='time', y_axis='mel')

  fig.colorbar(img, ax=ax, format='%+2.0f dB')
  ax.set(title=title)

  if save:
    fig.savefig(f'/content/output/{title}.png')


def recover_audio(melspec):
  """Turn a mel-spectrogram back to audio using Griffim-Lim algorithm."""
  stft = librosa.feature.inverse.mel_to_stft(melspec, sr=SAMPLING_RATE)
  audio = librosa.griffinlim(stft)
  return audio

# Dataset Preparation

In this section, we will prepare a TensorFlow Dataset object which can be fitted to the model for training. To achieve this, we will load a pair of two lists of audio, one contains the inputs, while the other contains the corresponding groundtruths to be generated. The loaded WAV audios will then be converted to mel-spectrograms and normalized to [-1, 1].

In [None]:
# Paths to the training data.
INPUT_PATH = '/content/drive/MyDrive/cuhk/fyp/jsb/piano_wav'
GROUNDTRUTH_PATH = '/content/drive/MyDrive/cuhk/fyp/jsb/violin_wav'

In [None]:
# Audio I/O settings, values obtained from manual experiment.
SAMPLING_RATE = 22050
DURATION = 5.94

In [None]:
# Our training set consist of 305 images.
BUFFER_SIZE = 305
# Value of 1 gives better results for U-Net in original pix2pix experiment.
BATCH_SIZE = 1
# Each spectrogram image is 256x256x1 in size.
IMG_WIDTH = 256
IMG_HEIGHT = 256
OUTPUT_CHANNELS = 1

In [None]:
def load_all(dir):
  """Load all WAV files inside a directory into a list of training images."""
  data = []
  for i, p in enumerate(Path(dir).glob('*.wav')):
    data.append(p)
  return [load_audio(p) for p in data]


def load_audio(path, duration=DURATION, samplerate=SAMPLING_RATE):
  """Load a WAV file into a normalised training spectrogram image."""
  audio, samplerate = librosa.load(path, duration=duration, sr=samplerate)
  spec = librosa.feature.melspectrogram(y=audio, sr=samplerate, n_mels=256)
  normalized = normalize(spec)
  expanded = np.expand_dims(normalized, axis=2)  # (256, 256, 1)
  return expanded


def normalize(spec):
  """Normalize a spectrogram to [-1, 1]."""
  spec = (spec - 127.5) / 127.5
  return spec


def denormalize(spec):
  """Denormalize a spectrogram to [0, 255]."""
  spec = spec * 127.5 + 127.5
  return spec


def tensor_to_spec(tensor):
  """Turn model tensor (1, 256, 256, 1) to mel-spectrogram (256, 256)."""
  return denormalize(np.array(tensor[0, ...]).squeeze())


def prepare_dataset(test_ratio=10):
  """
  Prepare a pair of training and testing TensorFlow Dataset objects.
  
  Args:
    test_ratio (int): percentage split for test set, 10 stands for 10%.

  Returns:
    Two TensorFlow Dataset objects, one for training another for testing.
  """
  # Load files from input and groundtruth directories respectively.
  input_specs = load_all(INPUT_PATH)
  print(f'{len(input_specs)} input wav files loaded.')
  groundtruth_specs = load_all(GROUNDTRUTH_PATH)
  print(f'{len(groundtruth_specs)} groundtruth wav files loaded.')
  assert len(input_specs) == len(groundtruth_specs)
  dataset = tf.data.Dataset.from_tensor_slices((input_specs, groundtruth_specs))

  # Split train-test sets from original dataset.
  test_dataset = dataset.take(len(input_specs) // test_ratio)
  test_dataset = test_dataset.batch(BATCH_SIZE)
  train_dataset = dataset.skip(len(input_specs) // test_ratio)
  train_dataset = train_dataset.shuffle(BUFFER_SIZE)
  train_dataset = train_dataset.batch(BATCH_SIZE)
  return train_dataset, test_dataset

In [None]:
train_dataset, test_dataset = prepare_dataset()

# Training

## Generator

In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [None]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 1])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),                        # (batch_size, 64, 64, 128)
    downsample(256, 4),                        # (batch_size, 32, 32, 256)
    downsample(512, 4),                        # (batch_size, 16, 16, 512)
    downsample(512, 4),                        # (batch_size, 8, 8, 512)
    downsample(512, 4),                        # (batch_size, 4, 4, 512)
    downsample(512, 4),                        # (batch_size, 2, 2, 512)
    downsample(512, 4),                        # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),                      # (batch_size, 16, 16, 1024)
    upsample(256, 4),                      # (batch_size, 32, 32, 512)
    upsample(128, 4),                      # (batch_size, 64, 64, 256)
    upsample(64, 4),                       # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model.
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections.
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
# Test the generator.
for example_input, example_target in test_dataset.take(1):
  pred = generator(example_input, training=False)
  plot_spec(tensor_to_spec(pred), samplerate=SAMPLING_RATE)

In [None]:
LAMBDA = 100

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))  # MAE
  total_gen_loss = gan_loss + (LAMBDA * l1_loss)
  return total_gen_loss, gan_loss, l1_loss

## Discriminator

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 1], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 1], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])                               # (batch_size, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)                                       # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)                                         # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)                                         # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)                        # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)                  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)                   # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
  total_disc_loss = real_loss + generated_loss
  return total_disc_loss

## Training Loop

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
def generate_images(model, input, groundtruth, epoch):
  pred = model(input, training=True)

  input_spec = tensor_to_spec(input)
  truth_spec = tensor_to_spec(groundtruth)
  pred_spec = tensor_to_spec(pred)

  plot_spec(input_spec, title=f'Input', samplerate=SAMPLING_RATE)
  plot_spec(truth_spec, title=f'Ground Truth', samplerate=SAMPLING_RATE)
  plot_spec(pred_spec, title=f'Epoch {epoch} Prediction', samplerate=SAMPLING_RATE)
  
  return input_spec, truth_spec, pred_spec

In [None]:
for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target, 0)

In [None]:
@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
    tf.summary.scalar('disc_loss', disc_loss, step=step//1000)

In [None]:
def fit(train_ds, test_ds, steps):
  example_input, example_target = next(iter(test_ds.take(1)))
  start = time.time()

  for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
    # Generate an image every 1000 steps.
    if (step) % 1000 == 0:
      IPython.display.clear_output(wait=True)
      print(f"Step: {step//1000}k")
      if step != 0:
        print(f'Time taken for last 1000 steps: {time.time()-start:.2f} sec\n')

      start = time.time()
      generate_images(generator, example_input, example_target, step)

    train_step(input_image, target, step)
    
    # Print a dot every 10 steps.
    if (step+1) % 10 == 0:
      print('.', end='', flush=True)

    # Save (checkpoint) the model every 5k steps.
    if (step + 1) % 5000 == 0:
      checkpoint.save(file_prefix=checkpoint_prefix)

In [None]:
log_dir="logs/"
summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [None]:
fit(train_dataset, test_dataset, steps=40000)

# Results

In [None]:
# List all saved checkpoints.
!ls {checkpoint_dir}

In [None]:
# To restore the latest checkpoint, uncomment the line below.
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
# To restore a specific checkpoint, uncomment the line below.
# checkpoint.restore('/content/training_checkpoints/ckpt-7')

In [None]:
# Run the trained model on a few examples from the test set.
for inp, tar in test_dataset.take(1):
  _, _, pred = generate_images(generator, inp, tar, 39000)
  recovered = recover_audio(pred)

In [None]:
# Let's hear the recovered audio of the predicted spectrogram.
IPython.display.Audio(data=recovered, rate=SAMPLING_RATE)

# Exports

In [None]:
EXPORT_PATH = 'export'

In [None]:
tf.saved_model.save(generator, EXPORT_PATH)