# Tensorflow Dataset

This notebook demonstrates processing training data 'on the fly' using a Tensorflow Dataset. This is an efficient way to handle data: 

* Files are read and examples are prepared as they are needed by the GPU/TPU device, with buffers to avoid data starvation.
* Experimenting with different augmentation techniques just involves code changes, rather than reprocessing the dataset.
* Augmented datasets can be effectively infinite.

A couple caveats are in order, though:

* This implementation processes arbitrary audio using tf.python_function. This is fairly inefficient, due to the Python global interpreter lock. Reprocessing the data to wav files will allow greater efficiency.
* For the same reason, it's best to implement augmentations, etc, using Tensorflow and numpy. Librosa in particular may require using a tf.python_function, and thus a drop in efficiency.

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import tensorflow as tf
import tensorflow_datasets.public_api as tfds
from matplotlib import pyplot as plt
import soundfile as sf

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

import io
import os
import time

count = 0
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        count += 1
print('counted %d files.' % count)

class BirbsongDatasetBuilder(object):
    def __init__(self, batch_size, window_size_s=5):
        self.batch_size = batch_size
        self.window_size_s = window_size_s
        self.sample_rate = 32000
        self.features_rate = 100
        self.train_file_pattern = '/kaggle/input/birdclef-2021/train_short_audio/*/*.ogg'

    def select_window(self, audio_tensor):
        # In this example, we just select the first window_size_s.
        # In practice, you can select a random segment during training, or
        # use a heuristic to find an interesting segment.
        return audio_tensor[:self.window_size_s * self.sample_rate]

    def get_species_enum_table(self):
        # Create a static enum for the species set.
        species_list = sorted(os.listdir('/kaggle/input/birdclef-2021/train_short_audio'))
        species_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(tf.constant(species_list),
                                                tf.constant(range(len(species_list)))),
            default_value=-1)
        return species_table
    
    def augment_time_domain(self, audio, is_train):
        if not is_train:
            return audio
        # Do interesting things here.
        return audio
        
    def augment_features(self, features, is_train):
        if not is_train:
            return features
        # Do interesting things here.
        return features
        
    def process_examples(self, ex0, ex1=None, is_train=True):
        if ex1 is not None:
            # Here we combine two examples.
            labels = tf.stack([ex0['label'], ex1['label']],
                                   axis=1)
            labels_enum = tf.stack([self.species_enum_table.lookup(ex0['label']),
                                    self.species_enum_table.lookup(ex1['label'])],
                                   axis=1)
            
            # This applies a random gain to each member of the batch separately.
            gain0 = tf.random.uniform([self.batch_size, 1, 1], 0.2, 0.5)
            merged_audio = gain0 * ex0['audio'] + (1 - gain0) * ex1['audio']
        else:
            labels = ex0['label'][:, tf.newaxis]
            labels_enum = self.species_enum_table.lookup(ex0['label'])[:, tf.newaxis]
            merged_audio = ex0['audio']

        # This is a good place to apply any extra augmentations to the time
        # domain audio.
        merged_audio = self.augment_time_domain(merged_audio, is_train)

        # You can easily replace feature extraction with another representation,
        # like PCEN.
        features = self.extract_melspec(merged_audio)
        features = self.augment_features(features, is_train)
        combined = {
            'audio': merged_audio,
            'label': labels,
            'label_enum': labels_enum,
            'features': features,
            'source': ex0['source'],
        }
        return combined
    
    def extract_melspec(self, audio):
        # Create the features your model consumes.
        # Doing this work early saves time for the model.
        # There are many options; here's an example of creating a melspectrogram.
        frame_step = self.sample_rate // self.features_rate
        frame_length = int(0.1 * self.sample_rate)
        melspec_depth = 100
        lower_edge_hz = 100.0
        upper_edge_hz = 12000.0
        log_floor = 1e-2
        logmel_scalar = 0.1
        
        # Last dimension needs to be the time dimension.
        stfts = tf.signal.stft(
            audio[:, :, 0], frame_length=frame_length, frame_step=frame_step,
            pad_end=True)
        magnitude_spectrograms = tf.abs(stfts)
        num_spectrogram_bins = tf.shape(magnitude_spectrograms)[-1]
        linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
            melspec_depth, num_spectrogram_bins,
            self.sample_rate, lower_edge_hz, upper_edge_hz)
        mel_spectrograms = tf.tensordot(magnitude_spectrograms,
                                        linear_to_mel_weight_matrix, 1)
        mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
                                   linear_to_mel_weight_matrix.shape[-1:]))
        logmel = tf.math.log(tf.maximum(mel_spectrograms, log_floor))
        return logmel

    def _parse_and_trim_audio(self, filename):
        # In order for the map call to work properly, we need to wrap stateful
        # python operations with py_function. Otherwise, we get weird repeated
        # audio in the dataset which doesn't match the labels.
        # See also:
        # https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
        # https://www.tensorflow.org/api_docs/python/tf/py_function
        # https://www.tensorflow.org/guide/data#applying_arbitrary_python_logic
        #
        # Another alternative is to convert the dataset to wav files, and
        # use the tf.audio.decode_wav to parse the files.
        # This won't require using a tf.python_function, so can make better
        # use of multiple cores/threads.
        def _soundfile_read(filename):
            with open(filename.numpy(), 'br') as audio_file:
                tmp = io.BytesIO(audio_file.read())
                audio, rate = sf.read(tmp, dtype='float32')
            return audio
        [audio,] = tf.py_function(_soundfile_read, [filename], [tf.float32])
        audio.set_shape([None])
        audio = tf.reshape(audio, [-1, 1])

        audio = self.select_window(audio)
        label = tf.strings.split(filename, sep='/')[-2]
        # TODO: Lookup and include metadata.
        return {'audio': audio, 
                'label': label,
                'source': filename}

    def build(self, is_train):
        self.species_enum_table = self.get_species_enum_table()

        # Build a tensorflow Dataset from the training file pattern.
        ds = tf.data.Dataset.list_files(self.train_file_pattern)

        # Filter species here, to avoid reading files that you won't use.
        # ds = ds.filter(lambda x: tf.strings.split(x, '/')[-2] != 'batpig1')

        # Create a train/validation split.
        if is_train:
            ds = ds.repeat(-1)
            ds = ds.shuffle(64000)
            ds = ds.filter(lambda x: tf.strings.to_hash_bucket(x, 100) != 0)
        else:
            ds = ds.filter(lambda x: tf.strings.to_hash_bucket(x, 100) == 0)

        
        # First, we parse the audio from the ogg files and snip to the same length.
        ds = ds.map(self._parse_and_trim_audio,
                    num_parallel_calls=10,
                    deterministic=False)
        ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

        # We typically work on batches of examples. 
        # It's also much more efficient for pre-processing.
        ds = ds.batch(self.batch_size)
        # It's common to mix multiple examples to increase the difficulty for the model
        # during training. If you have additional noise files, you can zip those in, too.
        # Notice that process_examples has a signature which can handle one or two inputs;
        # this pattern can extend to handle an additional noise file, as well.
        if is_train:
            ds = ds.zip((ds, ds))
            process_fn = lambda x0, x1: self.process_examples(x0, x1, is_train=True)
        else:
            process_fn = lambda x0: self.process_examples(x0, is_train=False)

        # It's good to avoid making too many map calls.
        # So process_examples does all of the augmentation and example-merging.
        ds = ds.map(process_fn, 
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)
        # Prefetch creates a buffer for the model to read batches from.
        # This means the model (hopefully) is never waiting for preprocessing to finish.
        ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        return ds

tf.compat.v1.reset_default_graph()
ds_builder = BirbsongDatasetBuilder(batch_size=10)
ds = ds_builder.build(is_train=True)

# Now we have a training dataset.
# We can draw a batch to demonstrate that it works, but this is never how you want
# to use it in practice!
# Usually, you'll do model.fit(ds, ...)
# Then the model will draw automagically from the dataset as needed,
# and the dataset will keep its pretech cache filled automagically.
it = iter(ds)
starttime = time.time()
for i in range(10):
    btime = time.time()
    features = it.next()
    # print('.', end='')
    print('\t batch %02d : %5.3f s' % (i, time.time() - btime))
print('\nelasped per batch : ', (time.time() - starttime) / 10)
print('Feature labels : ', features['label'])
print('Feature label enums : ', features['label_enum'])
print('Feature melspec shape : ', features['features'].shape)

plt.figure(figsize=(15, 5))
melspecs = features['features'].numpy()
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(np.flipud(melspecs[i].T),
               cmap='Greys', aspect='auto')

In [None]:
import IPython.display as ipd

# Look inside one of the examples.
b = 9
audio = features['audio'][b, :, 0].numpy()
print(features['source'][b].numpy())
print(features['label'][b].numpy())
ipd.Audio(data=audio, rate=32000)


In [None]:
import tensorflow_hub as hub

N_SPECIES = 397
LEARNING_RATE = 1e-2

def build_model(input_shape):
    # Features shape is [B, T, D]
    # Designate input 'names' matching the keys in the dataset output dictionaries.
    features_input = tf.keras.Input(shape=input_shape, name='features')
    labels_enum = tf.keras.Input(shape=[None], name='label_enum', dtype=tf.int32)

    activations = tf.expand_dims(features_input, -1)
    inp_layer = tf.keras.layers.Conv2D(64, 5, (2, 2), activation='relu')
    activations = inp_layer(activations)
    activations = tf.keras.layers.Conv2D(64, 5, (2, 2), activation='relu')(activations)
    activations = tf.keras.layers.Conv2D(64, 3, (2, 2), activation='relu')(activations)
    activations = tf.keras.layers.Flatten()(activations)
    activations = tf.keras.layers.Dense(2 * N_SPECIES, activation='relu')(activations)
    logits = tf.keras.layers.Dense(N_SPECIES)(activations)
    model = tf.keras.Model(inputs=[features_input, labels_enum], outputs=logits)
    
    labels_n_hot = tf.one_hot(labels_enum, depth=N_SPECIES, axis=-1)
    labels_n_hot = tf.reduce_sum(labels_n_hot, axis=1)

    # Define the loss directly using the features in the dataset outputs.
    # There's not a great way to pass the labels from the dataset to the model.compile, 
    # so far as I can tell. 
    # Same goes for metrics; define them here, using the input features.
    bxe = tf.keras.losses.BinaryCrossentropy(from_logits=True)(labels_n_hot, logits)
    model.add_loss(bxe)
    return model

tf.compat.v1.reset_default_graph()
model = build_model(
    input_shape=features['features'].shape[1:])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE))

ds_builder = BirbsongDatasetBuilder(batch_size=32)
ds = ds_builder.build(is_train=True)
print('starting train loop...')
# This runs fine, but is slow on a single-CPU notebook.
# model.fit(ds, epochs=1, batch_size=32, verbose=1)