<a href="https://colab.research.google.com/github/pbrandl/aNN_Audio/blob/master/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Requirements

*   **[JAX](https://jax.readthedocs.io/en/latest/index.html)** is a numpy-similar module for high-performance numerical computation.
*   **[Flax](https://github.com/google/flax)** is a neural network library and ecosystem on top of JAX.
*   **[IPython](https://ipython.readthedocs.io/en/stable/api/generated/IPython.display.html)** is used for displaying and playing audio samples in this notebook.
*   **google.colab** is used to get access to Google Drive for importing the dataset file.
*   **pickle** is used to convert the dataset file to its original object.

In [None]:
!pip install flax

import jax
import flax
from absl import app
from absl import flags
from flax import linen as nn
from flax.training import train_state
import numpy as np
import jax.numpy as jnp
import pickle
from google.colab import drive
from IPython.display import Audio
from IPython.core.display import display

# For Type Annotations
from typing import Generator, Mapping, Tuple, NamedTuple, Sequence
PRNGKey = jnp.ndarray



In [None]:
# @title Set Working Directories
# Mount Google Drive
drive.mount('/content/drive') 
project_path = '/content/drive/My Drive/nn_drum' # @param
dataset_path = '/content/drive/My Drive/nn_drum/snares_tensor.db' # @param



Mounted at /content/drive


In [None]:
# @title Load the Dataset
# @markdown The dataset is loaded as a numpy array, then shaped accoriding to the size of the batch
batch_size =  5 # @param
# @markdown The sample rate of the samples in the database
sample_rate = 44100 # @param

def load_dataset(path, to_numpy=False):
    with open(path, 'rb') as file:
        tensor = pickle.load(file)
        return tensor.numpy() if to_numpy else tensor

dataset = load_dataset(dataset_path, to_numpy=True)
num_data = dataset.shape[0]
dataset = dataset.reshape(num_data // batch_size, batch_size, -1)
print(f"Dataset shape {dataset.shape} with (num_batches, batch_size, sample_length).")

Dataset shape (86, 5, 14700) with (num_batches, batch_size, sample_length).


In [None]:
#@markdown Listen to a sample
Audio(dataset[0, 0, :], rate=sample_rate)

In [None]:
class Encoder(nn.Module):
    latents: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(44100 // 3, name='fc1')(x)
        x = nn.relu(x)
        mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
        logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
        return mean_x, logvar_x


class Decoder(nn.Module):

    @nn.compact
    def __call__(self, z):
        z = nn.Dense(500, name='fc1')(z)
        z = nn.relu(z)
        z = nn.Dense(44100 // 3, name='fc2')(z)
        return z


class VAE(nn.Module):
    latents: int = 20

    def setup(self):
        self.encoder = Encoder(self.latents)
        self.decoder = Decoder()

    def __call__(self, x, z_rng):
        mean, logvar = self.encoder(x)
        z = reparameterize(z_rng, mean, logvar)
        recon_x = self.decoder(z)
        return recon_x, mean, logvar

    def generate(self, z):
        return nn.sigmoid(self.decoder(z))


def reparameterize(rng, mean, logvar):
    std = jnp.exp(0.5 * logvar)
    eps = random.normal(rng, logvar.shape)
    return mean + eps * std

@jax.vmap
def kl_divergence(mean, logvar):
    return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))



In [None]:
@jax.jit
def train_step(state, batch, z_rng):
    def loss_fn(params):
        recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng)

        bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
        kld_loss = kl_divergence(mean, logvar).mean()
        loss = bce_loss + kld_loss
        return loss
        
    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)

In [None]:
model = VAE(300)

rng = jax.random.PRNGKey(0)
rng, key = jax.random.split(rng)

init_data = jnp.ones((5, 44100 // 3), jnp.float32)
 
state = train_state.TrainState.create(
    apply_fn=model().apply,
    params=model().init(key, init_data, rng)['params'],
    tx=optax.adam(FLAGS.learning_rate),
)

epochs = 1
for epoch in range(FLAGS.epochs):
    for batch in dataset:
        rng, key = random.split(rng)
        state = train_step(state, batch, key)


TypeError: ignored