#### Meaningful description 

In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Import and setup the tensorflow session
import tensorflow as tf

# Sanity check to ensure we are on correct tensorflow version
print("Using Tensorflow v. %s"%tf.__version__); assert tf.__version__ == "2.4.1";

# Setup a strategy to distribute resources and training
# across all available GPUs on the system.
strategy = tf.distribute.MirroredStrategy()

# Check number of GPUs used by the distribution scope.
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

In [None]:
# Import Keras modules
from tensorflow.keras import optimizers

# Import from custom modules
from data_utils.mg_sg_generator import TFE_FN_TEMPLATE
from data_utils.mg_sg_generator import prepare_dataset_for_training
from data_utils.mg_sg_generator import get_dataset_for, extract_dlen_from_tfr

from data_utils.sg_preprocessor import spectrogram2audio
from data_utils.sg_preprocessor import inverse_normalize_db_0_1, FNCOLS 

# Import custom tensorflow utilities
from tf_extensions.tf_custom.models import GaussianBetaVAE
from tf_extensions.tf_custom.models import make_cnn_vae_encoder, make_dense_vae_decoder

# Common utilities
import numpy as np

# Display utilities
import matplotlib.pyplot as plt
import IPython.display as ipd

### Prepare and Setup Data
Construct and/or load the datasets we need for the various experiments:
- Linear Spectrogram <b>(r128xc128)</b>, latent dim: 256
- Linear Spectrogram <b>(r512xc128)</b>, latent dim: 256
- Linear Spectrogram <b>(r512xc128)</b>, latent dim: 256x4 (2048)

Where rNUM indicated the FFT bins (i.e. the FFT-size) used to compute the spectrograms we are attempting to predict. cNUM is simply the STFT columns, which are a direct result of the <i>hop_size</i> we use.  


#### Globals for all experiments: 
Some variables are held constant throughout the different models under experimentation.
We use the block below to define these, such that we can re-use them throughout the notebook.

In [None]:
# Convenience lambda(s) to retrieve information
# about a relevant .tfrecords file (nexamples, filepath)
TFR_TRAIN      = lambda tfr_dict: tfr_dict["train"]
TFR_VALIDATION = lambda tfr_dict: tfr_dict["validation"]
TFR_TEST       = lambda tfr_dict: tfr_dict["test"]
TFR_PATH       = lambda tfr_tuple: tfr_tuple[-1]

# Define a small lambda to compute the columns
# to use for each nn input. TODO: FINISH
NCOLS = lambda hl: FNCOLS(hl)

# Define a small lambda to compute the rows
# to use for each nn input.
NROWS = lambda nfft: nfft//2 + 1

# Define the train-size percentage.
# This value indicates how much of the dataset is reserved for training. 
# The remaining is split evenly into validation/testing.
# (i.e. a split could be 70-15-15)
TRAIN_SIZE = 0.7

# Define a flag to indicate whether we should
# recompute all the datasets.
DS_OVERWRITE = False

# Define the beta-value to use for tuning the VAE(s).
VAE_BETA = 1.

# Define the learning rate to apply across
# all training conditions.
BASE_LR = 4e-04 

# Define the optimizer to use for training across
# all conditions.
OPTIMIZER = optimizers.Adam(lr=BASE_LR)

# Define batch size to use for training.
BATCH_SIZE = 256

# Define a lambda to retrieve the step size(s).
STEPS_PER_EPOCH = lambda tfr_tuple: tfr_tuple[0] // BATCH_SIZE 

#### Approach 1 (AP1): Linspect r128xc128, latent dim 256

In [None]:
# Define the FFT-size.
ap1_nfft = 256

# Define the hop_length (overlap)
# in samples.
ap1_overlap = ap1_nfft // 2

# Define the latent dimension size.
ap1_latent_dim = 256

# Fetch a dataset that fullfils the 
# requirements for this condition.
ap1_fout = get_dataset_for(
    nfft               = ap1_nfft,
    overlap            = ap1_overlap,
    train_size         = TRAIN_SIZE, 
    overwrite_existing = DS_OVERWRITE
)

# Prepare the input/output shape(s).
ap1_inout_shape_train = (NROWS(ap1_nfft), )
ap1_inout_shape_eval  = (NROWS(ap1_nfft), NCOLS(ap1_overlap))  

# Fetch a deep learning model (VAE) 
# for this condition.
with strategy.scope():
    ap1_vae = GaussianBetaVAE(
        N                   = TFR_TRAIN(ap1_fout)[0],
        M                   = BATCH_SIZE,
        beta                = VAE_BETA,
        input_dim           = ap1_inout_shape_train,
        latent_dim          = ap1_latent_dim,
        create_encoder_func = make_cnn_vae_encoder,
        create_decoder_func = make_dense_vae_decoder
    )

    # Compile the model and make ready for training.
    ap1_vae.custom_compile(optimizer=OPTIMIZER)

In [None]:
# Prepare dataset(s) for training.
ap1_ds_train = prepare_dataset_for_training(
    filename     = TFR_PATH(TFR_TRAIN(ap1_fout)),
    batch_size   = BATCH_SIZE,
    cast_to_type = tf.float64
)

ap1_ds_validation = prepare_dataset_for_training(
    filename     = TFR_PATH(TFR_VALIDATION(ap1_fout)),
    batch_size   = BATCH_SIZE,
    cast_to_type = tf.float64
)

# Train the vae model.
ap1_history = ap1_vae.fit(
    ap1_ds_train,
    steps_per_epoch  = STEPS_PER_EPOCH(TFR_TRAIN(ap1_fout)),
    validation_data  = ap1_ds_validation,
    validation_steps = STEPS_PER_EPOCH(TFR_VALIDATION(ap1_fout)),
    epochs           = 10
)

#### Evaluate AP1 model 

In [None]:
import tensorflow_probability as tfp
from data_utils.mg_sg_generator import get_dataset_small

def sample_mvn(x_mu, x_logvar):
    x_sig = np.exp(0.5*x_logvar)
    batch = x_sig.shape[0]
    mvn = tfp.distributions.MultivariateNormalDiag(
        loc=x_mu, scale_diag=x_sig
    )
    return mvn.sample(shape=[batch]).numpy()

def strip_file(filename):
    return os.path.splitext(filename[filename.rindex('/') + 1:])[0]

def prepare_dataset_for_evaluation(filename, cast_to_type=None):
    # Fetch the necessary batching we have to do
    # with the current memory-reduction setup
    # TODO: explain memory reduction setup
    batch_size = int(strip_file(filename).split('_')[3].split('x')[-1])
    ds = get_dataset_small(filename, dtype=cast_to_type)
    ds = ds.batch(batch_size)
    return ds

In [None]:
ap1_ds_test = prepare_dataset_for_evaluation(
    filename     = TFR_PATH(TFR_TRAIN(ap1_fout)),
    cast_to_type = tf.float64
)

In [None]:
ds_iter = iter(ap1_ds_test)
mg_, sg_ = ds_iter.get_next()
#mu_x, logvar_x, _, _ = ap1_vae.predict(mg_)
#print(mu_x.shape, logvar_x.shape)

In [None]:
specgram_inv = sg_.numpy()#sample_mvn(mu_x, logvar_x)
print(specgram_inv.shape)

In [None]:
#ap1_iter = iter(ap1_ds_test)
#specgram_inv = ap1_iter.get_next()[1].numpy()
#print(specgram_inv.shape)

fig, axs = plt.subplots(1, 1, figsize=(10, 10))
axs.set_xticks([])
axs.set_yticks([])
axs.imshow(np.flipud(specgram_inv.T), aspect="auto", cmap="Spectral_r", interpolation="bicubic")

In [None]:
y = spectrogram2audio(specgram_inv.T, ap1_overlap, True)
print(y.shape)

In [None]:
ipd.display(ipd.Audio(y, rate=22050))