In [None]:
%load_ext nb_black

## Import Dependencies

In [None]:
# misc
import math
import time
from pathlib import Path

# scientific
import numpy as np
import beatbrain
from beatbrain import utils

# visualization
from IPython import display
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# Tensorflow
import tensorflow as tf

from tensorflow.keras import Model, Input
from tensorflow.keras import backend
from tensorflow.keras import optimizers

from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Lambda, Reshape
from tensorflow.keras.callbacks import (
    Callback,
    TensorBoard,
    ReduceLROnPlateau,
    EarlyStopping,
    ModelCheckpoint,
    TerminateOnNaN,
    CSVLogger,
    LambdaCallback,
)

In [None]:
sns.set()
sns.set_style("white")
%matplotlib inline

## Load Datasets

In [None]:
# Input
DATA_ROOT = Path("../data/fma/image")
IMAGE_DIMS = [512, 640, 1]
BATCH_SIZE = 4

In [None]:
train_dataset = utils.load_dataset(
    DATA_ROOT / "train", batch_size=BATCH_SIZE, parallel=False
)
val_dataset = utils.load_dataset(
    DATA_ROOT / "val", batch_size=BATCH_SIZE, parallel=False,
)
test_dataset = utils.load_dataset(
    DATA_ROOT / "test", batch_size=1, parallel=False, shuffle_buffer=0,
)

## Define Model Architecture

In [None]:
def build_cvae(latent_dim, input_shape, num_conv=3, batch_size=1, learning_rate=1e-3):
    def reparam(args):
        z_mean, z_log_var = args
        dim = tf.keras.backend.int_shape(z_mean)[1]
        eps = tf.keras.backend.random_normal(shape=(batch_size, dim))
        return eps * tf.exp(z_log_var * 0.5) + z_mean

    encoder_input = tf.keras.Input(shape=input_shape, batch_size=batch_size)
    e = tf.keras.layers.Conv2D(
        filters=32, kernel_size=3, strides=(2, 2), activation="relu"
    )(encoder_input)
    for i in range(num_conv):
        e = tf.keras.layers.Conv2D(
            filters=64, kernel_size=3, strides=(2, 2), activation="relu"
        )(e)
    decoder_input_shape = tf.keras.backend.int_shape(e)
    # decoder_input_shape: (1, 31, 39, 64)
    e = tf.keras.layers.Flatten()(e)
    e = tf.keras.layers.Dense(16)(e)
    z_mean = tf.keras.layers.Dense(latent_dim)(e)
    z_log_var = tf.keras.layers.Dense(latent_dim)(e)
    z = tf.keras.layers.Lambda(reparam, output_shape=(latent_dim,))([z_mean, z_log_var])
    encoder = tf.keras.Model(encoder_input, [z_mean, z_log_var, z], name="encoder")

    decoder_input_shape = (
        1,
        input_shape[0] // (2 ** (num_conv + 1)),
        input_shape[1] // (2 ** (num_conv + 1)),
        64,
    )
    # (decoder_input_shape): (32, 40, 64)

    decoder_input = tf.keras.Input(shape=(latent_dim,))
    d = tf.keras.layers.Dense(
        decoder_input_shape[1] * decoder_input_shape[2] * decoder_input_shape[3],
        activation=tf.nn.relu,
    )(decoder_input)
    d = tf.keras.layers.Reshape(
        target_shape=(
            decoder_input_shape[1],
            decoder_input_shape[2],
            decoder_input_shape[3],
        )
    )(d)
    # d.shape: (None, 31, 39, 64)
    for i in range(num_conv):
        d = tf.keras.layers.Conv2DTranspose(
            filters=64,
            kernel_size=3,
            strides=(2, 2),
            padding="SAME",
            activation="relu",
        )(d)
    d = tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu",
    )(d)
    decoder_output = tf.keras.layers.Conv2DTranspose(
        filters=1, kernel_size=3, strides=(1, 1), padding="SAME",
    )(d)

    decoder = tf.keras.Model(decoder_input, decoder_output, name="decoder")
    outputs = decoder(encoder(encoder_input)[2])
    model = tf.keras.Model(encoder_input, outputs, name="vae")

    assert encoder_input.shape == outputs.shape
    reconstruction_loss = tf.losses.mse(encoder_input, outputs)
    reconstruction_loss = tf.reduce_sum(reconstruction_loss, axis=[1, 2])
    log2pi = tf.math.log(2.0 * np.pi)
    logpz = log_normal_pdf(z, 0.0, 0.0)
    logqz_x = log_normal_pdf(z, z_mean, z_log_var)
    kl_loss = logqz_x - logpz
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)

    model.add_loss(vae_loss)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))
    return model, encoder, decoder


def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2.0 * np.pi)
    return tf.reduce_sum(
        -0.5 * ((sample - mean) ** 2.0 * tf.exp(-logvar) + logvar + log2pi), axis=raxis
    )


@tf.function
def sample(latent_dim, decoder, eps=None):
    if eps is None:
        eps = tf.random.normal(shape=(100, latent_dim))
    return decode(decoder, eps, apply_sigmoid=True)


def encode(encoder, x):
    inference = encoder(x)
    mean, logvar = tf.split(inference, num_or_size_splits=2, axis=1)
    return mean, logvar


def reparameterize(mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * 0.5) + mean


def decode(decoder, z, apply_sigmoid=False):
    logits = decoder(z)
    if apply_sigmoid:
        probs = tf.sigmoid(logits)
        return probs
    return logits

## Hyperparameters, Model Output, and Logging

In [None]:
# Hyperparameters
LATENT_DIM = 256
EPOCHS = 50
NUM_CONV = 3
LEARNING_RATE = 1e-3

# Callback options
EARLY_STOP_PATIENCE = 3
REDUCE_LR_PATIENCE = 1

# Outputs
MODEL_NAME = "2d-CVAE"
MODEL_DIR = Path("../models")
MODEL_DIR.mkdir(exist_ok=True, parents=True)
LOG_DIR = Path("../logs")

## Define Training Callbacks

In [None]:
class VisualizeCallback(Callback):
    def __init__(
        self, log_dir, latent_dim, validation_data, n_examples=4, random_vectors=None,
    ):
        self.log_dir = Path(log_dir)
        self.latent_dim = latent_dim
        self.n_examples = n_examples

        self.random_vectors = random_vectors or tf.random.normal(
            shape=[n_examples, latent_dim]
        )
        self.data = list(validation_data.unbatch().take(self.n_examples))

        self.recon_raw = self.log_dir / "raw" / "reconstructed"
        self.recon_png = self.log_dir / "png" / "reconstructed"
        self.gen_raw = self.log_dir / "raw" / "generated"
        self.gen_png = self.log_dir / "png" / "generated"

    def on_train_begin(self, logs=None):
        self.recon_raw.mkdir(exist_ok=True, parents=True)
        self.recon_png.mkdir(exist_ok=True, parents=True)
        self.gen_raw.mkdir(exist_ok=True, parents=True)
        self.gen_png.mkdir(exist_ok=True, parents=True)

    def _visualize_reconstruction(self, epoch):
        #         data = self.validation_data.unbatch().take(self.n_examples)
        for i, sample in enumerate(self.data):
            sample = sample[None, :]
            fig, axes = plt.subplots(1, 2, figsize=(14, 7))
            #             plt.subplots_adjust(
            #                 left=None, bottom=None, right=None, top=None, wspace=None, hspace=None
            #             )
            plt.sca(axes[0])
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(sample[0, ..., 0].numpy()),
                title="Original",
            )
            plt.sca(axes[1])
            reconstructed = self.model(sample)
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(reconstructed[0, ..., 0].numpy()),
                title="Reconstructed",
            )
            fig.tight_layout()
            title = f"recon_{i + 1}@epoch_{epoch}"
            fig.suptitle(title)
            plt.savefig(self.recon_png / f"{title}.png")
            utils.save_image(
                reconstructed[0, ..., 0], self.recon_raw / f"{title}.exr",
            )
            plt.close()

    def _visualize_generation(self, epoch):
        decoder = self.model.get_layer("decoder")
        generated = decoder(self.random_vectors)
        for i, gen in enumerate(generated):
            gen = gen[None, :]
            fig = plt.figure()
            title = f"gen_{i + 1}@epoch_{epoch}"
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(gen[0, ..., 0].numpy()), title=title,
            )
            fig.tight_layout()
            fig.savefig(self.gen_png / f"{title}.png")
            utils.save_image(gen[0, ..., 0], self.gen_raw / f"{title}.exr")
            plt.close()

    def on_epoch_begin(self, epoch, logs=None):
        self._visualize_reconstruction(epoch)
        self._visualize_generation(epoch)
        # if epoch >= 2:
        #     raise StopIteration()

In [None]:
tensorboard = TensorBoard(
    log_dir=LOG_DIR / MODEL_NAME, update_freq=256, profile_batch=0,
)
reduce_lr = ReduceLROnPlateau(patience=1, factor=0.1, min_lr=1e-5, verbose=1,)
early_stop = EarlyStopping(patience=3, verbose=1,)
model_saver = ModelCheckpoint(
    str(MODEL_DIR / MODEL_NAME), save_best_only=True, verbose=1,
)
csv_logger = CSVLogger(f"{LOG_DIR / MODEL_NAME / 'log'}.csv")
visualizer = VisualizeCallback(LOG_DIR / MODEL_NAME, LATENT_DIM, val_dataset,)

## Instantiate and Train Model

In [None]:
model, encoder, decoder = build_cvae(
    LATENT_DIM,
    IMAGE_DIMS,
    num_conv=NUM_CONV,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
)
model.summary()

In [None]:
# Train model
model.fit_generator(
    train_dataset.take(5),
    epochs=EPOCHS,
    callbacks=[tensorboard, reduce_lr, early_stop, model_saver, visualizer,],
    validation_data=val_dataset.take(5),
)

## Visualize Reconstruction

In [None]:
# Optional - reload model from disk
# model = tf.keras.models.load_model(f"{MODEL_DIR / MODEL_NAME}")
model = tf.keras.models.load_model(f"../models/cvae.h5")

In [None]:
sample = next(iter(test_dataset.take(1)))
reconstructed = model.predict(sample)
beatbrain.display.show_spec(
    utils.denormalize_spectrogram(sample[0, ..., 0].numpy()), title="Original"
)
plt.show()
beatbrain.display.show_spec(
    utils.denormalize_spectrogram(model(sample)[0, ..., 0].numpy()),
    title="Reconstructed",
)
plt.show()

## Visualize Unconditioned Generation

In [None]:
EXAMPLES_TO_GENERATE = 16
INTERPOLATION_POINTS = 9
OUTPUT_DIR = Path("../data/output/images")

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[EXAMPLES_TO_GENERATE, LATENT_DIM]
)

In [None]:
sample = next(train_dataset.take(1))
beatbrain.display.show_spec(utils.denormalize_spectrogram(sample[0, ..., 0]), title="Orginal")
plt.show()
sns.distplot(train_dataset.flatten())
plt.show()

In [None]:
prediction = model(train_dataset)
beatbrain.display.show_spec(utils.denormalize_spectrogram(model(sample)[0, ..., 0].numpy()), title="Predicted")
plt.show()
sns.distplot(model(train_dataset).numpy().flatten())

In [None]:
N_FFT = 4096
HOP_LENGTH = 256
SAMPLE_RATE = 32768

In [None]:
sample_recon = utils.spectrogram_to_audio(sample[0, ..., 0], denormalize=True, n_fft=N_FFT, hop_length=HOP_LENGTH, sr=SAMPLE_RATE)

In [None]:
display.Audio(sample_recon, rate=SAMPLE_RATE)

In [None]:
prediction_recon = utils.spectrogram_to_audio(prediction[0, ..., 0], denormalize=True, n_fft=N_FFT, hop_length=HOP_LENGTH, sr=SAMPLE_RATE)

In [None]:
display.Audio(prediction_recon, rate=SAMPLE_RATE)

In [None]:
def generate_and_save_images(decoder, epoch, test_input):
    num_plots = math.ceil(math.sqrt(len(test_input)))
    predictions = sample(LATENT_DIM, decoder, eps=test_input)
    fig = plt.figure(figsize=(12, 12))
    fig.subplots_adjust(hspace=0, wspace=0)
    for i, pred in enumerate(predictions):
        plt.subplot(num_plots, num_plots, i + 1)
        plt.imshow(pred[:, :, 0], cmap='gray', vmin=0, vmax=1)
        plt.axis('off')
        output_dir = os.path.join(OUTPUT_DIR, 'progress', str(i))
        os.makedirs(output_dir, exist_ok=True)
        image = Image.fromarray(pred[:, :, 0].numpy(), mode='F')
#         image.save(os.path.join(output_dir, f"epoch_{epoch}.tiff"))
        image.save(os.path.join(output_dir, f"spec.tiff"))
    plt.show()

In [None]:
from time import time
import librosa

SR = 32768
N_FFT = 4096
N_MELS = 512
HOP_LENGTH = 256
DURATION = 5
A = "../data/fma/audio/000002.mp3"
B = "../data/fma/audio/000005.mp3"

interp_dir = Path(f"interpolation/{int(time())}")
interp_dir.mkdir(exist_ok=True, parents=True)

x, _ = librosa.load(A, sr=SR, duration=DURATION)
y, _ = librosa.load(B, sr=SR, duration=DURATION)
x = librosa.feature.melspectrogram(x, sr=SR, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS)
y = librosa.feature.melspectrogram(y, sr=SR, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS)
x, y = x[0, None], y[None, :, None]
print(x.shape)
x_mean, x_logvar = model.encode(x)
y_mean, y_logvar = model.encode(y)

# Reconstruction
fig = plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(x[0, ..., 0], cmap='gray', vmin=0, vmax=1)
plt.axis('off')
plt.subplot(122)
x_recon = model.sample(model.reparameterize(x_mean, x_logvar))
plt.imshow(x_recon[0, ..., 0], cmap='gray', vmin=0, vmax=1)
plt.axis('off')
plt.show()

# Interpolation
fractions = np.linspace(0, 1, num=INTERPOLATION_POINTS)[:, None]
means = (x_mean * (1 - fractions)) + (y_mean * fractions)  # Interpolated latent vectors
logvars = (x_logvar * (1 - fractions)) + (y_logvar * fractions)  # Interpolated latent vectors
points = model.reparameterize(means, logvars)
interpolated = model.sample(points)

num_plots = math.ceil(math.sqrt(INTERPOLATION_POINTS))
fig = plt.figure(figsize=(12, 12))
fig.subplots_adjust(hspace=0, wspace=0.03)
for i, pred in enumerate(interpolated):
    Image.fromarray(pred[:, :, 0].numpy(), mode='F').save(interp_dir.joinpath(f'{i}.tiff'))
    plt.subplot(num_plots, num_plots, i + 1)
    if i == 0:
        Image.fromarray(x[0, :, :, 0].numpy(), mode='F').save(interp_dir.joinpath(f'a.tiff'))
        plt.imshow(x[0, ..., 0], cmap='gray', vmin=0, vmax=1)
    elif i == INTERPOLATION_POINTS - 1:
        Image.fromarray(y[0, :, :, 0].numpy(), mode='F').save(interp_dir.joinpath(f'b.tiff'))
        plt.imshow(y[0, ..., 0], cmap='gray', vmin=0, vmax=1)
    else:
        plt.imshow(pred[:, :, 0], cmap='gray', vmin=0, vmax=1)
    plt.axis('off')

plt.savefig(interp_dir.joinpath('interpolation.png'), bbox_inches='tight')