In [None]:
# !pip install --quiet tensorflow_io
# !pip install --quiet tensorflow_addons
!pip install --quiet tensorflow_probability

In [None]:
import tensorflow as tf
# import tensorflow_io as tfio
# import tensorflow_addons as tfa
import tensorflow_probability as tfp
from kaggle_datasets import KaggleDatasets
import pandas as pd
import numpy as np
from sklearn import model_selection
import os
import glob
import tqdm
import math
import matplotlib.pyplot as plt


### TPU stuff

In [None]:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
print("All devices: ", tf.config.list_logical_devices('TPU'))

In [None]:
strategy = tf.distribute.experimental.TPUStrategy(tpu)
AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_DS_PATH = KaggleDatasets().get_gcs_path('rfcx-species-audio-detection')
TRAIN_TFREC = GCS_DS_PATH + "/tfrecords/train"

## 1. Helper functions for TF dataset

In [None]:
def _int32(x):
    return tf.cast(x, 'int32')

def random_truncated_normal(mean, stddev, min_val, max_val):
    i = tf.constant( 1.0)
    x = tf.constant(-1.0)
    cond = lambda i, x: x<min_val or x>max_val
    body = lambda i, x: (tf.add(i, 1.0), tf.random.normal([], mean, stddev*i))
    _, value = tf.while_loop(cond, body, [i, x])
    return value

def obtain_onehot_label(labels, t_starts, t_ends, window_start, window_end):
    onehot_label = tf.zeros(24, dtype='float32')
    for i in range(len(labels)):
        species = tf.cast(labels[i], 'int32')
        t_start = t_starts[i]
        t_end = t_ends[i]
        if tf.logical_and(tf.greater(window_end, t_start),
                          tf.less(window_start, t_end)):
            onehot_label += tf.one_hot(species, 24)
    return tf.clip_by_value(onehot_label, 0, 1)
    
def slice_record(features, window_size=5.0):
    
    items = tf.strings.split(features['label_info'], sep=',').to_tensor()

    spid = tf.strings.to_number(tf.gather(items, 0, axis=1))
    tmin = tf.strings.to_number(tf.gather(items, 2, axis=1))
    tmax = tf.strings.to_number(tf.gather(items, 4, axis=1))
    
    idx = tf.random.uniform(
        shape=(), minval=0, maxval=tf.shape(spid)[0], dtype='int32')
    mean = (tmax[idx]*features['rate'] + tmin[idx]*features['rate']) / 2
    
    window_size = window_size*features['rate']
    duration = tmax[idx]*features['rate'] - tmin[idx]*features['rate']
    
    stddev = tf.math.maximum(duration, window_size) / 2.0 / 3.0

    value = random_truncated_normal(
        mean, stddev, min_val=window_size/2, max_val=60*features['rate']-window_size)
    
    idx_center = _int32(value)
    idx_start = idx_center - (_int32(window_size) // 2)
    idx_end  = idx_start + _int32(window_size)

    features['label_info'] = obtain_onehot_label(
        spid, _int32(tmin*features['rate']), _int32(tmax*features['rate']), idx_start, idx_end)
    features['audio_wav'] = features['audio_wav'][idx_start: idx_end]
    
    return {
        'signal': features['audio_wav'],
        'rate': features['rate'],
        'label_info': features['label_info']
    }

def parse_function(example_proto, mode='training'):
    
    feature_description = {
        'recording_id': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'audio_wav': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'label_info': tf.io.FixedLenFeature([], tf.string, default_value=''),
    }
    example = tf.io.parse_single_example(example_proto, feature_description)
    wav, rate = tf.audio.decode_wav(example['audio_wav']) # mono
    example['audio_wav'] = tf.squeeze(tf.cast(wav, 'float32'), axis=-1)
    example['rate'] = tf.cast(rate, 'float32')

    label_info = tf.strings.split(example['label_info'], sep='"')[1]
    label_info = tf.strings.split(label_info, sep=';')
    example['label_info'] = label_info
    example = slice_record(example)
    return example['signal'] # return only signal for the sake of demonstration
   

## 2. (LogMel)SpecLayer and random augmentation

In [None]:
def random_white_noise(x, p, snr_min=15.0, snr_max=30.0):
    def _random_snr(x, size, minval, maxval):
        snr = tf.random.uniform((size, 1), minval, maxval)
        amplitude = tf.math.reduce_max(tf.math.abs(x))
        return amplitude / (10**(snr/20))
    shape = tf.shape(x)
    snr = _random_snr(x, shape[0], snr_min, snr_max)
    mul = tf.cast(tfp.distributions.Bernoulli(probs=p).sample((shape[0], 1)), 'float32')
    noise = tf.random.normal(shape=shape)
    ampl_noise = tf.math.reduce_max(tf.math.abs(noise))
    noise = (noise - tf.reduce_mean(noise)) * 1 / ampl_noise * snr * mul
    return x + noise


def random_volume_control(x, p, db_limit=10, mode='uniform'):
    shape = tf.cast(tf.shape(x), 'float32')
    db = tf.random.uniform((tf.cast(shape[0], 'int32'), 1), -db_limit, +db_limit)
    r = tf.range(shape[1], dtype='float32')
    r = tf.reshape(tf.tile(r, [shape[0]]), (shape[0], -1))
    if mode == 'uniform':
        noise_type = tf.ones(tf.cast(shape, 'int32'), dtype='float32')
    elif mode == 'fade':
        noise_type = r[::-1]/(shape[1] - 1)
    elif mode == 'cosine':
        noise_type = tf.math.cos(r/shape[1] * math.pi * 2)
    elif mode == 'sine':
        noise_type = tf.math.sin(r/shape[1] * math.pi * 2)
    else:
        raise ValueError(
            "mode has to be either 'uniform', 'fade', 'cosine' or 'sine'")
    mul = tf.cast(tfp.distributions.Bernoulli(probs=p).sample((tf.cast(shape[0], 'int32'), 1)), 'float32')
    noise = 10 ** (db * noise_type / 20 * mul)
    return x * noise

def random_time_mask(input, p, max_mask=50):
    batch_size = tf.shape(input)[0]
    time_max = tf.cast(tf.repeat(tf.shape(input)[1], [batch_size]), 'float32')
    time_min = tf.cast(tf.repeat(tf.constant(0), [batch_size]), 'float32')
    mask_max = tf.cast(tf.repeat(tf.constant(max_mask), [batch_size]), 'float32')

    t = tfp.distributions.Uniform(low=time_min, high=mask_max).sample()
    t0 = tfp.distributions.Uniform(low=time_min, high=time_max-t).sample()

    mul = tf.cast(tfp.distributions.Bernoulli(probs=p).sample((batch_size)), 'float32')
    t = t * mul
    indices = tf.reshape(tf.range(time_max[0]), (1, -1, 1))
    condition = tf.math.logical_and(
        tf.math.greater_equal(indices, t0), tf.math.less(indices, t0 + t)
    )
    condition = tf.transpose(condition, (2, 1, 0))
    zero = tf.constant(0, dtype=input.dtype)
    return tf.where(condition, zero, input)

def random_freq_mask(input, p, max_mask=16):
    batch_size = tf.shape(input)[0]
    freq_max = tf.cast(tf.repeat(tf.shape(input)[2], [batch_size]), 'float32')
    freq_min = tf.cast(tf.repeat(tf.constant(0), [batch_size]), 'float32')
    mask_max = tf.cast(tf.repeat(tf.constant(max_mask), [batch_size]), 'float32')

    f = tfp.distributions.Uniform(low=freq_min, high=mask_max).sample()
    f0 = tfp.distributions.Uniform(low=freq_min, high=freq_max-f).sample()

    mul = tf.cast(tfp.distributions.Bernoulli(probs=p).sample((batch_size)), 'float32')
    f = f * mul
    indices = tf.reshape(tf.range(freq_max[0]), (1, -1, 1))
    condition = tf.math.logical_and(
        tf.math.greater_equal(indices, f0), tf.math.less(indices, f0 + f)
    )
    condition = tf.transpose(condition, (2, 0, 1))
    zero = tf.constant(0, dtype=input.dtype)
    return tf.where(condition, zero, input)

def random_brightness(input, p, max_delta=0.2):
    batch_size = tf.shape(input)[0]
    delta = tfp.distributions.Uniform(low=-max_delta, high=max_delta).sample((batch_size, 1, 1))
    mul = tf.cast(tfp.distributions.Bernoulli(probs=p).sample((batch_size, 1, 1)), 'float32')
    delta = delta * mul
    return tf.math.add(input, delta)

def random_gaussian_noise(input, p, scale=0.2):
    batch_size = tf.shape(input)[0]
    shape = tf.shape(input)
    noise = tfp.distributions.Normal(loc=0, scale=scale).sample(shape)
    mul = tf.cast(tfp.distributions.Bernoulli(probs=p).sample((batch_size, 1, 1)), 'float32')
    noise = noise * mul
    return input + noise


class SpecLayer(tf.keras.layers.Layer):

    def __init__(self,
                 fft_length=2048,
                 frame_length=2048,
                 frame_step=512,
                 mel_power=2.0,
                 num_mel_bins=384,
                 sample_rate=48_000,
                 lower_edge_hertz=40,
                 upper_edge_hertz=20_000,
                 name='spec_layer',
                 dtype='float32',
                 **kwargs):

        super(SpecLayer, self).__init__(name=name, dtype=dtype, **kwargs)
  
        self.fft_length = fft_length
        self.frame_length = frame_length
        self.frame_step = frame_step
        self.mel_power = mel_power
        self.mel_matrix = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins=num_mel_bins,
            num_spectrogram_bins=fft_length//2+1,
            sample_rate=sample_rate,
            lower_edge_hertz=lower_edge_hertz,
            upper_edge_hertz=upper_edge_hertz)


    def build(self, input_shape):
        self.non_trainable_weights.append(self.mel_matrix)
        super(SpecLayer, self).build(input_shape)

    def call(self, inputs, training=False, tta=False):

        x = inputs
        
        if training:
            x = random_volume_control(x, p=0.5, mode='sine')
            x = random_white_noise(x, p=0.3)

        x = tf.math.abs(
            tf.signal.stft(
                x,
                frame_length=self.frame_length,
                frame_step=self.frame_step,
                fft_length=self.fft_length,
                window_fn=tf.signal.hann_window,
                pad_end=True
            )
        )
        x = tf.matmul(tf.math.pow(x, self.mel_power), self.mel_matrix)
        # log10
        x = tf.math.log(x + 1e-6) / tf.math.log(10.0)
        x = self.standardize(x)

        if training:
            x = random_time_mask(x, p=0.5)
            x = random_freq_mask(x, p=0.5)
            x = random_brightness(x, p=0.3)
            x = random_gaussian_noise(x, p=0.2)

        return self.preprocess(x)


    @staticmethod
    def standardize(x):
        x -= tf.math.reduce_mean(x, [1,2], True)
        x /= tf.math.reduce_std( x, [1,2], True) + 1e-6
        return x

    @staticmethod
    def preprocess(x):
        x -= tf.math.reduce_min(x, [1,2], True)
        x /= tf.math.reduce_max(x, [1,2], True)
        x *= tf.constant(255.)
        x = tf.transpose(x, (0, 2, 1))
        return x

## 3. Creating dataset

tfrecords --> waveform signal

In [None]:
dataset = tf.data.TFRecordDataset(
    filenames=tf.io.gfile.glob(TRAIN_TFREC + '/*.tfrec'),
    num_parallel_reads=AUTOTUNE)
dataset = dataset.map(parse_function, AUTOTUNE)
dataset = dataset.batch(32*strategy.num_replicas_in_sync)
dataset = dataset.prefetch(AUTOTUNE)

## 4. Create (and fit) model

waveform signal --> log-mel spectrogram<br>
iterating over 4.727 training examples in ~20 sec


In [None]:
with strategy.scope():
    model = tf.keras.Sequential([SpecLayer()])
    model.compile()
    model.fit(dataset, epochs=5)

## 5. Some output

In [None]:
fig, axes = plt.subplots(2,1, figsize=(10, 10))

wave = next(iter(dataset))
axes[0].plot(wave.numpy()[0])
logmelspec = SpecLayer()(wave)
axes[1].imshow(logmelspec.numpy()[0].astype(np.uint8))