## Import Dependencies

In [None]:
# misc
import math
import time
from pathlib import Path

# scientific
import numpy as np
import beatbrain
from beatbrain import utils

# visualization
from IPython import display
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# Tensorflow
import tensorflow as tf

from tensorflow.keras import Model, Sequential, Input, optimizers
from tensorflow.keras import backend as K
from tensorflow.keras.utils import plot_model

from tensorflow.keras.layers import (
    Conv2D,
    Conv2DTranspose,
    MaxPool2D,
    Dense,
    Lambda,
    Reshape,
    Flatten,
    Layer,
    concatenate,
)
from tensorflow.keras.callbacks import (
    Callback,
    TensorBoard,
    ReduceLROnPlateau,
    EarlyStopping,
    ModelCheckpoint,
    TerminateOnNaN,
    CSVLogger,
    LambdaCallback,
)

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
sns.set()
sns.set_style("white")
%matplotlib inline

## Load Datasets

In [None]:
# Input
DATA_ROOT = Path("../data/fma/image")
IMAGE_DIMS = [512, 640, 1]
BATCH_SIZE = 2

In [None]:
train_dataset = utils.load_dataset(
    DATA_ROOT / "train", batch_size=BATCH_SIZE, parallel=False
)
val_dataset = utils.load_dataset(
    DATA_ROOT / "val", batch_size=BATCH_SIZE, parallel=False,
)
test_dataset = utils.load_dataset(
    DATA_ROOT / "test", batch_size=1, parallel=False, shuffle_buffer=0,
)

## Define Model Architecture

In [None]:
def build_cvae(
    latent_dim,
    input_shape,
    start_filters=32,
    num_conv=3,
    num_inception=3,
    num_deconv=3,
    batch_size=1,
    learning_rate=1e-4,
):
    def reparam(args):
        z_mean, z_log_var = args
        dim = tf.keras.backend.int_shape(z_mean)[1]
        eps = tf.keras.backend.random_normal(shape=(batch_size, dim))
        return eps * tf.exp(z_log_var * 0.5) + z_mean

    encoder_input = Input(shape=input_shape, batch_size=batch_size, name="encoder_input")
    for i in range(num_conv):
        e = Conv2D(
            filters=start_filters * (2 ** i),
            kernel_size=3,
            strides=1,
            padding="SAME",
            activation="relu",
        )(e if i else encoder_input)  # First conv layer gets called with `encoder_input`
    e = MaxPool2D()(e)
    for i in range(num_inception):
        e = Inception2D(32)(e)
        e = MaxPool2D()(e)
    e = Flatten()(e)
    e = Dense(latent_dim * 2)(e)
    z_mean = Dense(latent_dim, name="z_mean")(e)
    z_log_var = Dense(latent_dim, name="z_log_var")(e)
    z = Lambda(reparam, output_shape=(latent_dim,), name="z")(
        [z_mean, z_log_var]
    )

    decoder_input_shape = [
        batch_size,
        input_shape[0] // 2 ** num_deconv,
        input_shape[1] // 2 ** num_deconv,
        input_shape[2],
    ]  # shape: [?, freq, time, channels]
    decoder_input = Input(shape=(latent_dim,), batch_size=batch_size, name="decoder_input")
    d = Dense(
        decoder_input_shape[1] * decoder_input_shape[2] * decoder_input_shape[3],
        activation="relu",
    )(decoder_input)
    d = Reshape(
        target_shape=(
            decoder_input_shape[1],
            decoder_input_shape[2],
            decoder_input_shape[3],
        )
    )(d)
    for i in range(num_deconv):
        d = Conv2DTranspose(
            filters=input_shape[0] // 2 ** i,
            kernel_size=3,
            strides=2,
            padding="SAME",
            activation="relu",
        )(d)
        print(d.shape)
    decoder_output = Conv2DTranspose(
        filters=1,
        kernel_size=3,
        strides=1,
        padding="SAME",
        activation="sigmoid",  # Changed this from RELU
        name="decoder_output",
    )(d)

    encoder = Model(encoder_input, [z_mean, z_log_var, z], name="encoder")
    decoder = Model(decoder_input, decoder_output, name="decoder")
    model_output = decoder(encoder(encoder_input)[2])
    model = Model(encoder_input, model_output, name="vae")

    assert encoder_input.shape == model_output.shape
    reconstruction_loss = tf.losses.mse(encoder_input, model_output)
    reconstruction_loss = tf.reduce_sum(reconstruction_loss, axis=[1, 2])
    logpz = log_normal_pdf(z, 0.0, 0.0)
    logqz_x = log_normal_pdf(z, z_mean, z_log_var)
    kl_loss = logqz_x - logpz
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)

    model.add_loss(vae_loss)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))
    # model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), loss=lambda yt, yp: vae_loss)
    return model, encoder, decoder


class Inception2D(Layer):
    def __init__(self, filters, *args, **kwargs):
        self._filters = filters
        self._args, self._kwargs = args, kwargs
        super().__init__()

    def build(self, input_shape):
        filters = self._filters // 4
        inputs = Input(shape=input_shape[1:])
        conv1 = Conv2D(filters=filters, kernel_size=1, padding="SAME")(inputs)

        conv3 = Conv2D(filters=filters, kernel_size=1, padding="SAME")(inputs)
        conv3 = Conv2D(filters=filters, kernel_size=3, padding="SAME")(conv3)

        conv5 = Conv2D(filters=filters, kernel_size=1, padding="SAME")(inputs)
        conv5 = Conv2D(filters=filters, kernel_size=5, padding="SAME")(conv5)

        pool = MaxPool2D(pool_size=3, strides=1, padding="SAME")(inputs)
        pool = Conv2D(filters=filters, kernel_size=1, padding="SAME")(pool)

        concat = concatenate([conv1, conv3, conv5, pool], axis=-1)
        self._model = Model(inputs, concat)
        super().build(input_shape)

    def call(self, x):
        return self._model(x)

    def compute_output_shape(self, input_shape):
        return self._model.compute_output_shape(input_shape)
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            "filters": self._filters,
        })
        return config


def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2.0 * np.pi)
    return tf.reduce_sum(
        -0.5 * ((sample - mean) ** 2.0 * tf.exp(-logvar) + logvar + log2pi), axis=raxis
    )


@tf.function
def sample(latent_dim, decoder, eps=None):
    if eps is None:
        eps = tf.random.normal(shape=(100, latent_dim))
    return decode(decoder, eps, apply_sigmoid=True)


def encode(encoder, x):
    inference = encoder(x)
    mean, logvar = tf.split(inference, num_or_size_splits=2, axis=1)
    return mean, logvar


def reparameterize(mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * 0.5) + mean


def decode(decoder, z, apply_sigmoid=False):
    logits = decoder(z)
    if apply_sigmoid:
        probs = tf.sigmoid(logits)
        return probs
    return logits

## Hyperparameters, Model Output, and Logging

In [None]:
# Hyperparameters
LATENT_DIM = 256
EPOCHS = 50
NUM_CONV = 2
NUM_INCEPTION = 3
NUM_DECONV = 5
LEARNING_RATE = 1e-5

# Outputs
MODEL_NAME = "cvae_2d_inception"
MODEL_DIR = Path("../models")
MODEL_DIR.mkdir(exist_ok=True, parents=True)
LOG_DIR = Path("../logs") / MODEL_NAME
LOG_DIR.mkdir(exist_ok=True, parents=True)
LOG_FREQUENCY = 200

## Define Training Callbacks

In [None]:
class VisualizeCallback(Callback):
    def __init__(
        self,
        log_dir,
        latent_dim,
        validation_data,
        n_examples=4,
        random_vectors=None,
        heatmap=True,
        frequency="epoch",
        verbose=False,
    ):
        self.log_dir = Path(log_dir)
        self.latent_dim = latent_dim
        self.n_examples = n_examples
        self.cmap = "magma" if heatmap else "Greys"
        self.frequency = frequency
        self.verbose = verbose
        self.total_batch = 0
        self.random_vectors = random_vectors or tf.random.normal(
            shape=[n_examples, latent_dim]
        )
        self.fig = plt.figure()
        self.samples = list(validation_data.unbatch().take(self.n_examples))

        self.recon_raw = self.log_dir / "raw" / "reconstructed"
        self.recon_png = self.log_dir / "png" / "reconstructed"
        self.gen_raw = self.log_dir / "raw" / "generated"
        self.gen_png = self.log_dir / "png" / "generated"

    def on_train_begin(self, logs=None):
        self.recon_raw.mkdir(exist_ok=True, parents=True)
        self.recon_png.mkdir(exist_ok=True, parents=True)
        self.gen_raw.mkdir(exist_ok=True, parents=True)
        self.gen_png.mkdir(exist_ok=True, parents=True)

    def _visualize_reconstruction(self, batch=None, epoch=None):
        assert (batch is not None) or (epoch is not None)
        fig = plt.figure(self.fig.number)
        fig.set_size_inches(9, 4)
        for i, sample in enumerate(self.samples):
            fig.add_subplot(121)
            sample = sample[None, :]
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(sample[0, ..., 0].numpy()),
                title="Original",
                cmap=self.cmap,
            )
            fig.add_subplot(122)
            reconstructed = self.model(sample)
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(reconstructed[0, ..., 0].numpy()),
                title="Reconstructed",
                cmap=self.cmap,
            )
            fig.tight_layout()
            title = f"recon_{i + 1}@{'epoch' if epoch else 'batch'}_{epoch or batch}"
            fig.suptitle(title)
            fig.savefig(self.recon_png / f"{title}.png")
            utils.save_image(
                reconstructed[0, ..., 0], self.recon_raw / f"{title}.exr",
            )
            fig.clear()

    def _visualize_generation(self, batch=None, epoch=None):
        assert (batch is not None) or (epoch is not None)
        decoder = self.model.get_layer("decoder")
        generated = decoder(self.random_vectors)
        fig = plt.figure(self.fig.number)
        fig.set_size_inches(5, 4)
        for i, gen in enumerate(generated):
            gen = gen[None, :]
            title = f"gen_{i + 1}@{'epoch' if epoch else 'batch'}_{epoch or batch}"
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(gen[0, ..., 0].numpy()),
                title=title,
                cmap=self.cmap,
            )
            fig.tight_layout()
            fig.savefig(self.gen_png / f"{title}.png")
            utils.save_image(gen[0, ..., 0], self.gen_raw / f"{title}.exr")
            fig.clear()

    def on_epoch_begin(self, epoch, logs=None):
        if self.frequency == "epoch":
            self._visualize_reconstruction(epoch=epoch)
            self._visualize_generation(epoch=epoch)

    def on_train_batch_begin(self, batch, logs=None):
        if isinstance(self.frequency, int) and (self.total_batch % self.frequency == 0):
            self._visualize_reconstruction(batch=self.total_batch)
            self._visualize_generation(batch=self.total_batch)

    def on_train_batch_end(self, batch, logs=None):
        self.total_batch += 1

In [None]:
tensorboard = TensorBoard(log_dir=LOG_DIR, update_freq=LOG_FREQUENCY, profile_batch=0,)
reduce_lr = ReduceLROnPlateau(patience=2, factor=0.1, min_lr=1e-6, verbose=1,)
early_stop = EarlyStopping(patience=5, verbose=1,)
model_saver = ModelCheckpoint(
    str(MODEL_DIR / MODEL_NAME), save_best_only=True, verbose=1,
)
visualizer = VisualizeCallback(
    LOG_DIR, LATENT_DIM, val_dataset, frequency=LOG_FREQUENCY,
)

## Instantiate and Train Model

### Train on FMA

In [None]:
model, encoder, decoder = build_cvae(
    LATENT_DIM,
    IMAGE_DIMS,
    num_conv=NUM_CONV,
    num_inception=NUM_INCEPTION,
    num_deconv=NUM_DECONV,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
)
model.summary()
encoder.summary()
decoder.summary()
plot_model(model, to_file=str(LOG_DIR / "model.png"), expand_nested=True, show_shapes=True)

In [None]:
# Train model
model.fit_generator(
    train_dataset,
    epochs=EPOCHS,
    callbacks=[tensorboard, model_saver, reduce_lr, visualizer,],
    validation_data=val_dataset,
)
# OR - Overfit model on a single sample
# single_sample = tf.data.Dataset.from_tensor_slices(list(train_dataset.take(1)))
# model.fit_generator(
#     single_sample,
#     epochs=400,
#     callbacks=[VisualizeCallback(LOG_DIR, LATENT_DIM, single_sample, frequency="epoch"),],
#     validation_data=single_sample,
# )

### Train on MNIST

In [None]:
# Load MNIST data
(train_images, _), (val_images, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype("float32")
val_images = val_images.reshape(val_images.shape[0], 28, 28, 1).astype("float32")
train_images /= 255.0
val_images /= 255.0
mnist_train = (
    tf.data.Dataset.from_tensor_slices(train_images).shuffle(50000).batch(BATCH_SIZE)
)
mnist_val = (
    tf.data.Dataset.from_tensor_slices(val_images).shuffle(50000).batch(BATCH_SIZE)
)

# Callbacks
visualizer = VisualizeCallback(LOG_DIR, LATENT_DIM, mnist_val, frequency=100)

# Define and train model
model, encoder, decoder = build_cvae(
    LATENT_DIM,
    (28, 28, 1),
    num_conv=1,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
)
model.fit_generator(
    mnist_train,
    epochs=200,
    callbacks=[tensorboard, visualizer,],
    validation_data=mnist_val,
)

## Visualize Reconstruction

In [None]:
# Optional - reload model from disk
model = tf.keras.models.load_model(f"{MODEL_DIR / MODEL_NAME}")
# model = tf.keras.models.load_model(f"../models/cvae.h5")

In [None]:
sample = next(iter(test_dataset.take(1)))
reconstructed = model.predict(sample)
beatbrain.display.show_spec(
    utils.denormalize_spectrogram(sample[0, ..., 0].numpy()), title="Original"
)
plt.show()
beatbrain.display.show_spec(
    utils.denormalize_spectrogram(model(sample)[0, ..., 0].numpy()),
    title="Reconstructed",
)
plt.show()

## Visualize Unconditioned Generation

In [None]:
EXAMPLES_TO_GENERATE = 16
INTERPOLATION_POINTS = 9
OUTPUT_DIR = Path("../data/output/images")

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[EXAMPLES_TO_GENERATE, LATENT_DIM]
)

In [None]:
sample = next(train_dataset.take(1))
beatbrain.display.show_spec(utils.denormalize_spectrogram(sample[0, ..., 0]), title="Orginal")
plt.show()
sns.distplot(train_dataset.flatten())
plt.show()

In [None]:
prediction = model(train_dataset)
beatbrain.display.show_spec(utils.denormalize_spectrogram(model(sample)[0, ..., 0].numpy()), title="Predicted")
plt.show()
sns.distplot(model(train_dataset).numpy().flatten())

In [None]:
N_FFT = 4096
HOP_LENGTH = 256
SAMPLE_RATE = 32768

In [None]:
sample_recon = utils.spectrogram_to_audio(sample[0, ..., 0], denormalize=True, n_fft=N_FFT, hop_length=HOP_LENGTH, sr=SAMPLE_RATE)

In [None]:
display.Audio(sample_recon, rate=SAMPLE_RATE)

In [None]:
prediction_recon = utils.spectrogram_to_audio(prediction[0, ..., 0], denormalize=True, n_fft=N_FFT, hop_length=HOP_LENGTH, sr=SAMPLE_RATE)

In [None]:
display.Audio(prediction_recon, rate=SAMPLE_RATE)

In [None]:
def generate_and_save_images(decoder, epoch, test_input):
    num_plots = math.ceil(math.sqrt(len(test_input)))
    predictions = sample(LATENT_DIM, decoder, eps=test_input)
    fig = plt.figure(figsize=(12, 12))
    fig.subplots_adjust(hspace=0, wspace=0)
    for i, pred in enumerate(predictions):
        plt.subplot(num_plots, num_plots, i + 1)
        plt.imshow(pred[:, :, 0], cmap='gray', vmin=0, vmax=1)
        plt.axis('off')
        output_dir = os.path.join(OUTPUT_DIR, 'progress', str(i))
        os.makedirs(output_dir, exist_ok=True)
        image = Image.fromarray(pred[:, :, 0].numpy(), mode='F')
#         image.save(os.path.join(output_dir, f"epoch_{epoch}.tiff"))
        image.save(os.path.join(output_dir, f"spec.tiff"))
    plt.show()

In [None]:
from time import time
import librosa

SR = 32768
N_FFT = 4096
N_MELS = 512
HOP_LENGTH = 256
DURATION = 5
A = "../data/fma/audio/000002.mp3"
B = "../data/fma/audio/000005.mp3"

interp_dir = Path(f"interpolation/{int(time())}")
interp_dir.mkdir(exist_ok=True, parents=True)

x, _ = librosa.load(A, sr=SR, duration=DURATION)
y, _ = librosa.load(B, sr=SR, duration=DURATION)
x = librosa.feature.melspectrogram(x, sr=SR, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS)
y = librosa.feature.melspectrogram(y, sr=SR, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS)
x, y = x[0, None], y[None, :, None]
print(x.shape)
x_mean, x_logvar = model.encode(x)
y_mean, y_logvar = model.encode(y)

# Reconstruction
fig = plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(x[0, ..., 0], cmap='gray', vmin=0, vmax=1)
plt.axis('off')
plt.subplot(122)
x_recon = model.sample(model.reparameterize(x_mean, x_logvar))
plt.imshow(x_recon[0, ..., 0], cmap='gray', vmin=0, vmax=1)
plt.axis('off')
plt.show()

# Interpolation
fractions = np.linspace(0, 1, num=INTERPOLATION_POINTS)[:, None]
means = (x_mean * (1 - fractions)) + (y_mean * fractions)  # Interpolated latent vectors
logvars = (x_logvar * (1 - fractions)) + (y_logvar * fractions)  # Interpolated latent vectors
points = model.reparameterize(means, logvars)
interpolated = model.sample(points)

num_plots = math.ceil(math.sqrt(INTERPOLATION_POINTS))
fig = plt.figure(figsize=(12, 12))
fig.subplots_adjust(hspace=0, wspace=0.03)
for i, pred in enumerate(interpolated):
    Image.fromarray(pred[:, :, 0].numpy(), mode='F').save(interp_dir.joinpath(f'{i}.tiff'))
    plt.subplot(num_plots, num_plots, i + 1)
    if i == 0:
        Image.fromarray(x[0, :, :, 0].numpy(), mode='F').save(interp_dir.joinpath(f'a.tiff'))
        plt.imshow(x[0, ..., 0], cmap='gray', vmin=0, vmax=1)
    elif i == INTERPOLATION_POINTS - 1:
        Image.fromarray(y[0, :, :, 0].numpy(), mode='F').save(interp_dir.joinpath(f'b.tiff'))
        plt.imshow(y[0, ..., 0], cmap='gray', vmin=0, vmax=1)
    else:
        plt.imshow(pred[:, :, 0], cmap='gray', vmin=0, vmax=1)
    plt.axis('off')

plt.savefig(interp_dir.joinpath('interpolation.png'), bbox_inches='tight')