In [14]:
import tensorflow as tf
import librosa

In [15]:
def build_artificial_dataset(num_samples: int):
    data = []
    sampling_rates = []

    for i in range(num_samples):
        y, sr = librosa.load(librosa.ex('nutcracker'))
        data.append(y)
        sampling_rates.append(sr)
    features_dataset = tf.data.Dataset.from_tensor_slices(data)
    labels_dataset = tf.data.Dataset.from_tensor_slices(sampling_rates)
    dataset = tf.data.Dataset.zip((features_dataset, labels_dataset))

    return dataset

In [16]:
ds = build_artificial_dataset(10)

In [17]:
for k in ds.take(1):
    print(k)

(<tf.Tensor: shape=(2643264,), dtype=float32, numpy=
array([ 2.2716861e-06,  5.3327208e-06, -7.2473290e-06, ...,
        1.1170751e-05,  1.2871889e-06,  5.4120628e-06], dtype=float32)>, <tf.Tensor: shape=(), dtype=int32, numpy=22050>)


### 1: Directly augment the audio data


In [18]:
from audiomentations import Compose, AddGaussianNoise, PitchShift, Shift

augmentations_pipeline = Compose(
    [
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
        PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
        Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5),
    ]
)


def apply_pipeline(y, sr):
    shifted = augmentations_pipeline(y, sr)
    return shifted


@tf.function
def tf_apply_pipeline(feature, sr, ):
    """
    Applies the augmentation pipeline to audio files
    @param y: audio data
    @param sr: sampling rate
    @return: augmented audio data
    """
    augmented_feature = tf.numpy_function(
        apply_pipeline, inp=[feature, sr], Tout=tf.float32, name="apply_pipeline"
    )

    return augmented_feature, sr


def augment_audio_dataset(dataset: tf.data.Dataset):
    dataset = dataset.map(tf_apply_pipeline)

    return dataset

In [19]:
ds = augment_audio_dataset(ds)
ds = ds.map(lambda y, sr: (tf.expand_dims(y, axis=-1), sr))

### 2: Augment the audio data during the forward pass

In [20]:
for s in ds.take(1):
    input_shape = s[0].shape
    print(input_shape)

(2643264, 1)


In [21]:
import kapre
from spec_augment import SpecAugment

def get_model(input_shape, num_classes: int = 10):
    input_layer = tf.keras.layers.Input(shape=input_shape, dtype=tf.float32)

    melspectrogram = kapre.composed.get_melspectrogram_layer(
        n_fft=1024,
        return_decibel=True,
        n_mels=256,
        input_data_format='channels_last',
        output_data_format='channels_last')(input_layer)
    spec_augment = SpecAugment(freq_mask_param=27,  # F in paper
                               time_mask_param=100,  # T in paper
                               n_freq_mask=1,  # mF in paper
                               n_time_mask=2,  # mT in paper
                               mask_value=-1, )


    resnet_input_tensor = spec_augment(melspectrogram)

    core = tf.keras.applications.resnet_v2.ResNet152V2(
        input_tensor=resnet_input_tensor,
        include_top=False,
        pooling="avg",
        weights=None,
    )
    core = core.output

    output = tf.keras.layers.Dense(units=num_classes)(core)

    resnet_model = tf.keras.Model(inputs=[input_layer], outputs=[output], name="audio_model")

    return resnet_model


In [22]:
model = get_model(input_shape=input_shape)

In [23]:
model.summary()

Model: "audio_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 2643264, 1)] 0                                            
__________________________________________________________________________________________________
melspectrogram (Sequential)     (None, 10322, 256, 1 0           input_2[0][0]                    
__________________________________________________________________________________________________
SpecAugment (SpecAugment)       (None, 10322, 256, 1 0           melspectrogram[0][0]             
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 10328, 262, 1 0           SpecAugment[0][0]                
________________________________________________________________________________________