<a href="https://colab.research.google.com/github/pbrandl/nn_drsynth/blob/main/nn_drsyn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q dm-haiku optax

[?25l[K     |█▏                              | 10kB 19.8MB/s eta 0:00:01[K     |██▎                             | 20kB 23.1MB/s eta 0:00:01[K     |███▌                            | 30kB 17.0MB/s eta 0:00:01[K     |████▋                           | 40kB 15.0MB/s eta 0:00:01[K     |█████▊                          | 51kB 7.9MB/s eta 0:00:01[K     |███████                         | 61kB 7.4MB/s eta 0:00:01[K     |████████                        | 71kB 8.4MB/s eta 0:00:01[K     |█████████▏                      | 81kB 9.3MB/s eta 0:00:01[K     |██████████▍                     | 92kB 8.7MB/s eta 0:00:01[K     |███████████▌                    | 102kB 7.4MB/s eta 0:00:01[K     |████████████▊                   | 112kB 7.4MB/s eta 0:00:01[K     |█████████████▉                  | 122kB 7.4MB/s eta 0:00:01[K     |███████████████                 | 133kB 7.4MB/s eta 0:00:01[K     |████████████████▏               | 143kB 7.4MB/s eta 0:00:01[K     |█████████████████▎     

In [3]:
import jax
import haiku as hk
import jax.numpy as jnp
import pickle
from google.colab import drive

from typing import Generator, Mapping, Tuple, NamedTuple, Sequence

random_key = jax.random.PRNGKey(0)

# Set Working Directories
drive.mount('/content/drive')
project_path = '/content/drive/My Drive/nn_drum'
dataset_path = '/content/drive/My Drive/nn_drum/snares_tensor.db'
print("Working in {}.".format(project_path))




Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Working in /content/drive/My Drive/nn_drum.


In [4]:
def load_dataset(path):
    with open(path, 'rb') as file:
        return pickle.load(file)

dataset = load_dataset(dataset_path)


In [5]:
output_shape = 44100//3

class Encoder(hk.Module):
  """Encoder model."""

  def __init__(self, hidden_size: int = 2**12, latent_size: int = 512):
    super().__init__()
    self._hidden_size = hidden_size
    self._latent_size = latent_size

  def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    x = hk.Flatten()(x)
    x = hk.Linear(self._hidden_size)(x)
    x = jax.nn.relu(x)

    mean = hk.Linear(self._latent_size)(x)
    log_stddev = hk.Linear(self._latent_size)(x)
    stddev = jnp.exp(log_stddev)

    return mean, stddev


class Decoder(hk.Module):
  """Decoder model."""

  def __init__(self, hidden_size: int = 2**12, output_shape: Sequence[int] = output_shape):
    super().__init__()
    self._hidden_size = hidden_size
    self._output_shape = output_shape

  def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
    z = hk.Linear(self._hidden_size)(z)
    z = jax.nn.relu(z)

    logits = hk.Linear(np.prod(self._output_shape))(z)
    logits = jnp.reshape(logits, (-1, *self._output_shape))

    return logits

In [8]:
class VariationalAutoEncoder(hk.Module):
    def __init__(encoder: hk.Module, decoder: hk.Module):
        self.encoder = encoder
        self.decoder = decoder
    
    def __call__(self, x: jnp.ndarray):
        x = x.astype(jnp.float32)
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape)
        logits = Decoder(self._hidden_size, self._output_shape)(z)

        p = jax.nn.sigmoid(logits)

        return mean, stddev, logits