In [3]:
import glob
import numpy as np
import tensorflow as tf

In [4]:
Tx = 4797 # The number of time steps input to the model from the spectrogram
n_freq = 101 # Number of frequencies input to the model at each time step of the spectrogram
Ty = 1196 # The number of time steps in the output of our model
num_classes = 21

In [5]:
def _extract_feature(record, feature):
    example = tf.train.Example.FromString(record.numpy())
    return example.features.feature[feature].float_list.value

In [32]:
# Load tf record dataset
def parser(record):
    
    X = tf.reshape(
            tf.py_function(
            lambda r: _extract_feature(r, "X"),
            (record,),
            tf.float32
        ), [Tx, n_freq]
    )
    
    Y = tf.reshape(
        tf.py_function(
            lambda r: _extract_feature(r, "Y"),
            (record,),
            tf.float32
        ), [Ty, num_classes]
    )
    
    return X, Y
    
def dataset_input_fn(filenames, batch_size, num_epochs):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    #iterator = dataset.make_one_shot_iterator()
    #features, labels = iterator.get_next()

    return dataset

In [35]:
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import Model, load_model, Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout, Input, Masking, TimeDistributed, LSTM, Conv1D
from tensorflow.keras.layers import GRU, Bidirectional, BatchNormalization, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical

In [36]:
def seq_model(input_shape, n_classes):
    """
    Function creating the model's graph in Keras.
    
    Argument:
    input_shape -- shape of the model's input data (using Keras conventions)

    Returns:
    model -- Keras model instance
    """
    
    X_input = Input(shape = input_shape)
        
    # Step 1: CONV layer (≈4 lines)
    X = Conv1D(196, kernel_size=15, strides=4)(X_input)                                 # CONV1D
    X = BatchNormalization()(X)                                 # Batch normalization
    X = Activation('relu')(X)                                 # ReLu activation
    X = Dropout(0.8)(X)                                 # dropout (use 0.8)

    # Step 2: First GRU Layer (≈4 lines)
    X = GRU(units = 128, return_sequences = True)(X) # GRU (use 128 units and return the sequences)
    X = Dropout(0.8)(X)                                 # dropout (use 0.8)
    X = BatchNormalization()(X)                                 # Batch normalization
    
    # Step 3: Second GRU Layer (≈4 lines)
    X = GRU(units = 128, return_sequences = True)(X)   # GRU (use 128 units and return the sequences)
    X = Dropout(0.8)(X)                                 # dropout (use 0.8)
    X = BatchNormalization()(X)                                  # Batch normalization
    X = Dropout(0.8)(X)                                  # dropout (use 0.8)
    
    # Step 4: Time-distributed dense layer (≈1 line)
    X = TimeDistributed(Dense(n_classes, activation = "sigmoid"))(X) # time distributed  (sigmoid)

    model = Model(inputs = X_input, outputs = X)
    
    return model

In [37]:
Tx = 4797 # The number of time steps input to the model from the spectrogram
n_freq = 101 # Number of frequencies input to the model at each time step of the spectrogram
n_classes = 21
keras_model = seq_model((Tx, n_freq), n_classes)

In [38]:
keras_model.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 4797, 101)]       0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 1196, 196)         297136    
_________________________________________________________________
batch_normalization_6 (Batch (None, 1196, 196)         784       
_________________________________________________________________
activation_2 (Activation)    (None, 1196, 196)         0         
_________________________________________________________________
dropout_8 (Dropout)          (None, 1196, 196)         0         
_________________________________________________________________
gru_4 (GRU)                  (None, 1196, 128)         124800    
_________________________________________________________________
dropout_9 (Dropout)          (None, 1196, 128)         0   

In [39]:
opt = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, decay=0.01)
keras_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=["accuracy"])

In [41]:
training_set = dataset_input_fn(tfrecord_path, 64, None)

history = keras_model.fit(
    training_set.make_one_shot_iterator(),
    steps_per_epoch=10,
    epochs=5,
    verbose = 1
)

Epoch 1/5
 2/10 [=====>........................] - ETA: 1:03 - loss: 0.2611 - acc: 0.0538

KeyboardInterrupt: 

In [12]:
model = tf.keras.estimator.model_to_estimator(keras_model, model_dir="../models/test/")

In [13]:
tfrecord_path = glob.glob('../data/interim/*.tfrecord')
train_input = lambda: dataset_input_fn(tfrecord_path, 64, None)
model.train(input_fn=train_input, steps=7000)

W0722 17:40:27.945183 4525540800 deprecation.py:323] From /Users/az01640/Projets/multrigger-word/.venv/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
W0722 17:40:27.994925 4525540800 deprecation.py:506] From /Users/az01640/Projets/multrigger-word/.venv/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0722 17:40:27.996093 4525540800 deprecation.py:506] From /Users/az01640/Projets/multrigger-word/.venv/lib/python3.7/site-pack

KeyboardInterrupt: 