## Import Dependencies

In [None]:
# misc
import math
import time
from pathlib import Path

# scientific
import numpy as np
import beatbrain
from beatbrain import utils

# visualization
from IPython import display
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# Tensorflow
import tensorflow as tf

from tensorflow.keras import Model, Sequential, Input, optimizers
from tensorflow.keras import backend as K
from tensorflow.keras.utils import plot_model

from tensorflow.keras.layers import (
    Conv1D,
    Conv2D,
    Conv2DTranspose,
    MaxPool1D,
    MaxPool2D,
    Dense,
    Lambda,
    Reshape,
    Flatten,
    Layer,
    concatenate,
)
from tensorflow.keras.callbacks import (
    Callback,
    TensorBoard,
    ReduceLROnPlateau,
    EarlyStopping,
    ModelCheckpoint,
    TerminateOnNaN,
    CSVLogger,
    LambdaCallback,
)

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
sns.set()
sns.set_style("white")
%matplotlib inline

In [None]:
# Input
DATA_ROOT = Path("../data/fma/image")
IMAGE_DIMS = [512, 640, 1]
BATCH_SIZE = 4

In [None]:
train_dataset = utils.load_dataset(
    DATA_ROOT / "train", batch_size=BATCH_SIZE, parallel=False
)
val_dataset = utils.load_dataset(
    DATA_ROOT / "val", batch_size=BATCH_SIZE, parallel=False,
)
test_dataset = utils.load_dataset(
    DATA_ROOT / "test", batch_size=1, parallel=False, shuffle_buffer=0,
)

In [None]:
def build_cvae(
    latent_dim,
    input_shape,
    start_filters=512,
    num_conv=3,
    num_inception=3,
    num_deconv=3,
    batch_size=1,
    learning_rate=1e-4,
):
    def reparam(args):
        z_mean, z_log_var = args
        dim = tf.keras.backend.int_shape(z_mean)[1]
        eps = tf.keras.backend.random_normal(shape=(batch_size, dim))
        return eps * tf.exp(z_log_var * 0.5) + z_mean

    encoder_input = Input(shape=input_shape, batch_size=batch_size, name="encoder_input")
    e = Lambda(reshape_1d)(encoder_input)  # (batch_size, time, freq)
    for i in range(num_conv):
        e = Conv1D(
            filters=start_filters // (2 ** i),
            kernel_size=3,
            strides=1,
            padding="SAME",
            activation="relu",
        )(e)
    for i in range(num_inception):
        e = Inception1D(start_filters // 2 ** (num_conv + i))(e)
        e = MaxPool1D()(e)
    e = Flatten()(e)
    e = Dense(latent_dim * 2)(e)  # (batch_size, latent_dim*2)
    z_mean = Dense(latent_dim, name="z_mean")(e)
    z_log_var = Dense(latent_dim, name="z_log_var")(e)
    z = Lambda(reparam, output_shape=(latent_dim,), name="z")(
        [z_mean, z_log_var],
    )  # (batch_size, latent_dim)

    decoder_input_shape = [
        batch_size,
        input_shape[1] // 2 ** num_deconv,
        input_shape[0] // 2 ** num_deconv,
    ]  # shape: [?, time, freq]
    decoder_input = Input(shape=(latent_dim,), batch_size=batch_size, name="decoder_input")
    d = Dense(
        decoder_input_shape[1] * decoder_input_shape[2],
        activation="relu",
    )(decoder_input)
    d = Reshape(
        target_shape=(
            decoder_input_shape[1],
            decoder_input_shape[2],
        )
    )(d)
    for i in range(num_deconv):
        d = Conv1DTranspose(
            filters=input_shape[0] // 2 ** (num_deconv - i),
            kernel_size=3,
            strides=2,
            padding="SAME",
            activation="relu",
        )(d)
    decoder_output = Conv1DTranspose(
        filters=input_shape[0],
        kernel_size=3,
        strides=1,
        padding="SAME",
        activation="sigmoid",  # Changed this from RELU
        name="decoder_output",
    )(d)
    decoder_output = Lambda(reshape_2d)(decoder_output)

    encoder = Model(encoder_input, [z_mean, z_log_var, z], name="encoder")
    decoder = Model(decoder_input, decoder_output, name="decoder")
    model_output = decoder(encoder(encoder_input)[2])
    model = Model(encoder_input, model_output, name="vae")

    assert encoder_input.shape == model_output.shape
    reconstruction_loss = tf.losses.mse(encoder_input, model_output)
    reconstruction_loss = tf.reduce_sum(reconstruction_loss, axis=[1, 2])  # To exclude channels or not to exclude?
    logpz = log_normal_pdf(z, 0.0, 0.0)
    logqz_x = log_normal_pdf(z, z_mean, z_log_var)
    kl_loss = logqz_x - logpz
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)

    model.add_loss(vae_loss)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))
    # model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), loss=lambda yt, yp: vae_loss)
    return model, encoder, decoder


class Inception1D(Layer):
    def __init__(self, filters, *args, **kwargs):
        self._filters = filters
        self._args, self._kwargs = args, kwargs
        super().__init__()

    def build(self, input_shape):
        filters = self._filters
        inputs = Input(shape=input_shape[1:])
        conv1 = Conv1D(filters=filters, kernel_size=1, padding="SAME")(inputs)

        conv3 = Conv1D(filters=filters, kernel_size=1, padding="SAME")(inputs)
        conv3 = Conv1D(filters=filters, kernel_size=3, padding="SAME")(conv3)

        conv5 = Conv1D(filters=filters, kernel_size=1, padding="SAME")(inputs)
        conv5 = Conv1D(filters=filters, kernel_size=5, padding="SAME")(conv5)

        pool = MaxPool1D(pool_size=3, strides=1, padding="SAME")(inputs)
        pool = Conv1D(filters=filters, kernel_size=1, padding="SAME")(pool)

        concat = concatenate([conv1, conv3, conv5, pool], axis=-1)
        self._model = Model(inputs, concat)
        super().build(input_shape)

    def call(self, x):
        return self._model(x)

    def compute_output_shape(self, input_shape):
        return self._model.compute_output_shape(input_shape)
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            "filters": self._filters,
        })
        return config

        
class Conv1DTranspose(Layer):
    def __init__(self, filters, kernel_size, strides=1, *args, **kwargs):
        self._filters = filters
        self._kernel_size = (1, kernel_size)
        self._strides = (1, strides)
        self._args, self._kwargs = args, kwargs
        super(Conv1DTranspose, self).__init__()

    def build(self, input_shape):
        inp = Input(shape=input_shape[1:])
        reshaped = Lambda(lambda x: tf.expand_dims(x, 1), batch_input_shape=input_shape)(inp)
        conv = Conv2DTranspose(
            self._filters,
            kernel_size=self._kernel_size,
            strides=self._strides,
            *self._args,
            **self._kwargs
        )(reshaped)
        output = Lambda(lambda x: x[:, 0])(conv)
        self._model = Model(inputs=inp, outputs=output)
        super(Conv1DTranspose, self).build(input_shape)

    def call(self, x):
        return self._model(x)

    def compute_output_shape(self, input_shape):
        return self._model.compute_output_shape(input_shape)
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            "filters": self._filters,
            "kernel_size": self._kernel_size,
            "strides": self._strides,
        })
        return config


def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2.0 * np.pi)
    return tf.reduce_sum(
        -0.5 * ((sample - mean) ** 2.0 * tf.exp(-logvar) + logvar + log2pi), axis=raxis
    )


@tf.function
def sample(latent_dim, decoder, eps=None):
    if eps is None:
        eps = tf.random.normal(shape=(100, latent_dim))
    return decode(decoder, eps, apply_sigmoid=True)


def encode(encoder, x):
    inference = encoder(x)
    mean, logvar = tf.split(inference, num_or_size_splits=2, axis=1)
    return mean, logvar


def reparameterize(mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * 0.5) + mean


def decode(decoder, z, apply_sigmoid=False):
    logits = decoder(z)
    if apply_sigmoid:
        probs = tf.sigmoid(logits)
        return probs
    return logits


def reshape_1d(x, axis=-1):
    x = tf.squeeze(x, axis=-1)
    x = tf.transpose(x, [0, 2, 1])
    return x


def reshape_2d(x, axis=-1):
    x = tf.transpose(x, [0, 2, 1])
    x = tf.expand_dims(x, axis=-1)
    return x

## Hyperparameters, Model Output, and Logging

In [None]:
# Hyperparameters
LATENT_DIM = 256
EPOCHS = 50
NUM_CONV = 4
NUM_INCEPTION = 4
NUM_DECONV = 5
LEARNING_RATE = 1e-5

# Outputs
MODEL_NAME = "cvae_1d_inception"
MODEL_DIR = Path("../models")
MODEL_DIR.mkdir(exist_ok=True, parents=True)
LOG_DIR = Path("../logs") / MODEL_NAME
LOG_DIR.mkdir(exist_ok=True, parents=True)
LOG_FREQUENCY = 200

## Define Training Callbacks

In [None]:
class VisualizeCallback(Callback):
    def __init__(
        self,
        log_dir,
        latent_dim,
        validation_data,
        n_examples=4,
        random_vectors=None,
        heatmap=True,
        frequency="epoch",
        verbose=False,
    ):
        self.log_dir = Path(log_dir)
        self.latent_dim = latent_dim
        self.n_examples = n_examples
        self.cmap = "magma" if heatmap else "Greys"
        self.frequency = frequency
        self.verbose = verbose
        self.total_batch = 0
        self.random_vectors = random_vectors or tf.random.normal(
            shape=[n_examples, latent_dim]
        )
        self.fig = plt.figure()
        self.samples = list(validation_data.unbatch().take(self.n_examples))

        self.recon_raw = self.log_dir / "raw" / "reconstructed"
        self.recon_png = self.log_dir / "png" / "reconstructed"
        self.gen_raw = self.log_dir / "raw" / "generated"
        self.gen_png = self.log_dir / "png" / "generated"

    def on_train_begin(self, logs=None):
        self.recon_raw.mkdir(exist_ok=True, parents=True)
        self.recon_png.mkdir(exist_ok=True, parents=True)
        self.gen_raw.mkdir(exist_ok=True, parents=True)
        self.gen_png.mkdir(exist_ok=True, parents=True)

    def _visualize_reconstruction(self, batch=None, epoch=None):
        assert (batch is not None) or (epoch is not None)
        fig = plt.figure(self.fig.number)
        fig.set_size_inches(9, 4)
        for i, sample in enumerate(self.samples):
            fig.add_subplot(121)
            sample = sample[None, :]
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(sample[0, ..., 0].numpy()),
                title="Original",
                cmap=self.cmap,
            )
            fig.add_subplot(122)
            reconstructed = self.model(sample)
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(reconstructed[0, ..., 0].numpy()),
                title="Reconstructed",
                cmap=self.cmap,
            )
            fig.tight_layout()
            title = f"recon_{i + 1}@{'epoch' if epoch else 'batch'}_{epoch or batch}"
            fig.suptitle(title)
            fig.savefig(self.recon_png / f"{title}.png")
            utils.save_image(
                reconstructed[0, ..., 0], self.recon_raw / f"{title}.exr",
            )
            fig.clear()

    def _visualize_generation(self, batch=None, epoch=None):
        assert (batch is not None) or (epoch is not None)
        decoder = self.model.get_layer("decoder")
        generated = decoder(self.random_vectors)
        fig = plt.figure(self.fig.number)
        fig.set_size_inches(5, 4)
        for i, gen in enumerate(generated):
            gen = gen[None, :]
            title = f"gen_{i + 1}@{'epoch' if epoch else 'batch'}_{epoch or batch}"
            beatbrain.display.show_spec(
                utils.denormalize_spectrogram(gen[0, ..., 0].numpy()),
                title=title,
                cmap=self.cmap,
            )
            fig.tight_layout()
            fig.savefig(self.gen_png / f"{title}.png")
            utils.save_image(gen[0, ..., 0], self.gen_raw / f"{title}.exr")
            fig.clear()

    def on_epoch_begin(self, epoch, logs=None):
        if self.frequency == "epoch":
            self._visualize_reconstruction(epoch=epoch)
            self._visualize_generation(epoch=epoch)

    def on_train_batch_begin(self, batch, logs=None):
        if isinstance(self.frequency, int) and (self.total_batch % self.frequency == 0):
            self._visualize_reconstruction(batch=self.total_batch)
            self._visualize_generation(batch=self.total_batch)

    def on_train_batch_end(self, batch, logs=None):
        self.total_batch += 1

In [None]:
tensorboard = TensorBoard(log_dir=LOG_DIR, update_freq=LOG_FREQUENCY, profile_batch=0,)
reduce_lr = ReduceLROnPlateau(patience=2, factor=0.1, min_lr=1e-6, verbose=1,)
early_stop = EarlyStopping(patience=5, verbose=1,)
model_saver = ModelCheckpoint(
    str(MODEL_DIR / MODEL_NAME), save_best_only=True, verbose=1,
)
visualizer = VisualizeCallback(
    LOG_DIR, LATENT_DIM, val_dataset, frequency=LOG_FREQUENCY,
)

## Instantiate and Train Model

In [None]:
model, encoder, decoder = build_cvae(
    LATENT_DIM,
    IMAGE_DIMS,
    num_conv=NUM_CONV,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
)
model.summary()
encoder.summary()
decoder.summary()
plot_model(model, to_file=str(LOG_DIR / "model.png"), expand_nested=True, show_shapes=True)

In [None]:
# Train model
model.fit_generator(
    train_dataset,
    epochs=EPOCHS,
    callbacks=[tensorboard, model_saver, reduce_lr, visualizer,],
    validation_data=val_dataset,
)