[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timsainb/tensorflow2-generative-models/blob/master/8.0-NSYNTH-iterator.ipynb)

## Iterator for NSynth 

The NSYNTH dataset is a set of thousands of musical notes saved as waveforms. To input these into a Seq2Seq model as spectrograms, I wrote a small dataset class that converts to spectrogram in tensorflow (using the code from the spectrogramming notebook). 

![a dataset iterator for tensorflow 2.0](https://github.com/timsainb/tensorflow2-generative-models/blob/master/imgs/nsynth-dataset.png?raw=1)

### Install packages if in colab

In [14]:
### install necessary packages if in colab
def run_subprocess_command(cmd):
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE)
    for line in process.stdout:
        print(line.decode().strip())


import sys, subprocess

IN_COLAB = "google.colab" in sys.modules
colab_requirements = [
    "pip install tf-nightly-gpu-2.0-preview==2.0.0.dev20190513",
    "pip install librosa",
    "pip uninstall librosa -y",
    "pip install 'librosa==0.7.2'",
    "pip install pathlib2",
    "pip install numba"
]
if IN_COLAB:
    for i in colab_requirements:
        run_subprocess_command(i)

Uninstalling librosa-0.7.2:
Successfully uninstalled librosa-0.7.2


### load packages

In [21]:
!apt-get install python-pip python3-pip


Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  libpython-all-dev python-all python-all-dev python-asn1crypto
  python-cffi-backend python-crypto python-cryptography python-dbus
  python-enum34 python-gi python-idna python-ipaddress python-keyring
  python-keyrings.alt python-pip-whl python-pkg-resources python-secretstorage
  python-setuptools python-six python-wheel python-xdg python3-asn1crypto
  python3-cffi-backend python3-crypto python3-cryptography python3-idna
  python3-keyring python3-keyrings.alt python3-pkg-resources
  python3-secretstorage python3-setuptools python3-six python3-wheel
  python3-xdg
Suggested packages:
  python-crypto-doc python-cryptography-doc python-cryptography-vectors
  python-dbus-dbg python-dbus-doc python-enum34-doc python-gi-cairo
  gnome-keyring libkf5wallet-bin gir1.2-gnomekeyring-1.0 python-fs
  python-gdata python-keyczar python-secretstorage-do

In [22]:
!pip uninstall librosa -y
!pip install "librosa==0.7.2" 
!pip3 install pathlib2
!pip install numba==0.48

Uninstalling librosa-0.7.2:
  Successfully uninstalled librosa-0.7.2
Processing /root/.cache/pip/wheels/4c/6e/d7/bb93911540d2d1e44d690a1561871e5b6af82b69e80938abef/librosa-0.7.2-cp37-none-any.whl
Installing collected packages: librosa
Successfully installed librosa-0.7.2
Collecting numba==0.48
[?25l  Downloading https://files.pythonhosted.org/packages/39/dc/5ce4a94d98e8a31cab21b150e23ca2f09a7dd354c06a69f71801ecd890db/numba-0.48.0-1-cp37-cp37m-manylinux2014_x86_64.whl (3.5MB)
[K     |████████████████████████████████| 3.6MB 4.3MB/s 
Collecting llvmlite<0.32.0,>=0.31.0dev0
[?25l  Downloading https://files.pythonhosted.org/packages/a0/10/d02c0ac683fc47ecda3426249509cf771d748b6a2c0e9d5ebbee76a7b80a/llvmlite-0.31.0-cp37-cp37m-manylinux1_x86_64.whl (20.2MB)
[K     |████████████████████████████████| 20.2MB 1.3MB/s 
[?25hInstalling collected packages: llvmlite, numba
  Found existing installation: llvmlite 0.34.0
    Uninstalling llvmlite-0.34.0:
      Successfully uninstalled llvmlite-0.3

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.io import FixedLenFeature, parse_single_example
from librosa.core.time_frequency import mel_frequencies
from pathlib2 import Path
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
print(tf.__version__)

2.5.0


### Download or load dataset
Tensorflow datasets will automatically download or load the dataset for you at this location

In [5]:
DATA_DIR = Path("data").resolve()
DATA_DIR

PosixPath('/content/data')

In [None]:
ds_train, ds_test = tfds.load(
    name="nsynth", split=["train", "test"], data_dir=DATA_DIR
)

[1mDownloading and preparing dataset nsynth/full/2.3.3 (download: 73.07 GiB, generated: 73.09 GiB, total: 146.16 GiB) to /content/data/nsynth/full/2.3.3...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=1069.0, style=ProgressStyle(descrip…

In [None]:
ds_train

### Prepare spectrogramming and parameters

In [None]:
def _normalize_tensorflow(S, hparams):
    return tf.clip_by_value((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)

def _tf_log10(x):
    numerator = tf.math.log(x)
    denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
    return numerator / denominator


def _amp_to_db_tensorflow(x):
    return 20 * _tf_log10(tf.clip_by_value(tf.abs(x), 1e-5, 1e100))


def _stft_tensorflow(signals, hparams):
    return tf.signal.stft(
        signals,
        hparams.win_length,
        hparams.hop_length,
        hparams.n_fft,
        pad_end=True,
        window_fn=tf.signal.hann_window,
    )


def spectrogram_tensorflow(y, hparams):
    D = _stft_tensorflow(y, hparams)
    S = _amp_to_db_tensorflow(tf.abs(D)) - hparams.ref_level_db
    return _normalize_tensorflow(S, hparams)

In [None]:
class HParams(object):
    """ Hparams was removed from tf 2.0alpha so this is a placeholder
    """
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [None]:
hparams = HParams( 
    # network
    batch_size = 32,
    # spectrogramming
    sample_rate = 16000,
    create_spectrogram = True,
    win_length = 1024,
    n_fft = 1024,
    hop_length= 400,
    ref_level_db = 50,
    min_level_db = -100,
    # mel scaling
    num_mel_bins = 128,
    mel_lower_edge_hertz = 0,
    mel_upper_edge_hertz = 8000,
    # inversion
    power = 1.5, # for spectral inversion
    griffin_lim_iters = 50,
    pad=True,
    #
)

### Create the dataset class

In [None]:
class NSynthDataset(object):
    def __init__(
        self,
        tf_records,
        hparams,
        is_training=True,
        test_size = 1000,
        prefetch=1000, # how many spectrograms to prefetch
        num_parallel_calls=10, # how many parallel threads should be preparing data
        n_samples = 305979, # how many items are in the dataset
        shuffle_buffer = 1000,
    ):
        self.is_training = is_training
        self.nsamples = n_samples
        self.test_size = test_size
        self.hparams = hparams
        self.prefetch = prefetch
        self.shuffle_buffer = shuffle_buffer
        # prepare for mel scaling
        if self.hparams.create_spectrogram:
            self.mel_matrix = self._make_mel_matrix()
        # create dataset of tfrecords
        self.raw_dataset = tf.data.TFRecordDataset(tf_records)
        # prepare dataset iterations
        self.dataset = self.raw_dataset.map(
            lambda x: self._parse_function(x), num_parallel_calls=num_parallel_calls
        )
        # make and split train and test datasets
        self.prepare_datasets()
        
    def prepare_datasets(self):
        # Note: there are better ways to do batching and shuffling here:
        #    https://www.tensorflow.org/guide/performance/datasets
        
        self.dataset_train = self.dataset.skip(self.test_size)
        self.dataset_train = self.dataset_train.shuffle(self.shuffle_buffer)
        self.dataset_train = self.dataset_train.prefetch(self.prefetch)
        self.dataset_train = self.dataset_train.batch(hparams.batch_size)
        
        
        self.dataset_test = self.dataset.take(self.test_size)
        self.dataset_test = self.dataset_test.shuffle(self.shuffle_buffer)
        self.dataset_test = self.dataset_test.prefetch(self.prefetch)
        self.dataset_test = self.dataset_test.batch(hparams.batch_size)
        
        
        
    def _make_mel_matrix(self):
        # create mel matrix
        mel_matrix = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins=self.hparams.num_mel_bins,
            num_spectrogram_bins=int(self.hparams.n_fft / 2) + 1,
            sample_rate=self.hparams.sample_rate,
            lower_edge_hertz=self.hparams.mel_lower_edge_hertz,
            upper_edge_hertz=self.hparams.mel_upper_edge_hertz,
            dtype=tf.dtypes.float32,
            name=None,
        )
        # gets the center frequencies of mel bands
        mel_f = mel_frequencies(
            n_mels=hparams.num_mel_bins + 2,
            fmin=hparams.mel_lower_edge_hertz,
            fmax=hparams.mel_upper_edge_hertz,
        )
        # Slaney-style mel is scaled to be approx constant energy per channel (from librosa)
        enorm = tf.dtypes.cast(
            tf.expand_dims(tf.constant(2.0 / (mel_f[2 : hparams.num_mel_bins + 2] - mel_f[:hparams.num_mel_bins])), 0),
            tf.float32,
        )
        # normalize matrix
        mel_matrix = tf.multiply(mel_matrix, enorm)
        mel_matrix = tf.divide(mel_matrix, tf.reduce_sum(mel_matrix, axis=0))
        
        return mel_matrix

    def print_feature_list(self):
        # get the features
        element = list(self.raw_dataset.take(count=1))[0]
        # parse the element in to the example message
        example = tf.train.Example()
        example.ParseFromString(element.numpy())
        print(list(example.features.feature))

    def _parse_function(self, example_proto):
        """ There are a number of features here (listed above). 
        I'm only grabbing a few here
        """
        features = {
            "id": FixedLenFeature([], dtype=tf.string),
            "pitch": FixedLenFeature([1], dtype=tf.int64),
            "velocity": FixedLenFeature([1], dtype=tf.int64),
            "audio": FixedLenFeature([64000], dtype=tf.float32),
            "instrument/source": FixedLenFeature([1], dtype=tf.int64),
            "instrument/family": FixedLenFeature([1], dtype=tf.int64),
            "instrument/label": FixedLenFeature([1], dtype=tf.int64),
        }
        example = parse_single_example(example_proto, features)

        if self.hparams.create_spectrogram:
            # create spectrogram
            example["spectrogram"] = spectrogram_tensorflow(
                example["audio"], self.hparams
            )
            # create melspectrogram
            example["spectrogram"] = tf.expand_dims(
                tf.transpose(tf.tensordot(
                    example["spectrogram"], self.mel_matrix, 1
            )), axis=2
            )
            
        return example

### Produce the dataset from tfrecords

In [None]:
training_tfrecords = [str(i) for i in list((DATA_DIR / "nsynth").glob('**/*train.tfrecord*'))]

In [None]:
hparams.batch_size = 32

In [None]:
dset = NSynthDataset(training_tfrecords, hparams)

In [None]:
dset

### Test plot an example from the dataset

In [None]:
ex = next(iter(dset.dataset_test))
np.shape(ex["spectrogram"].numpy())

In [None]:
fig, ax = plt.subplots(ncols=1, figsize=(15,4))
cax = ax.matshow(np.squeeze(ex["spectrogram"].numpy()[10]), aspect='auto', origin='lower')
fig.colorbar(cax)

### test how fast we can iterate over the dataset
This value is around 15 iterations/second on my computer locally with 10 parallel threads. It might be slower on colab or other free resources. 

In [None]:
from tqdm.autonotebook import tqdm

In [None]:
for batch, train_x in tqdm(
        zip(range(1000), dset.dataset_train), total=1000):
    continue