In [1]:
import os

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Conv2D, Bidirectional, LSTM, GRU, Dense
from tensorflow.keras.layers import Dropout, BatchNormalization, LeakyReLU, PReLU
from tensorflow.keras.layers import Input, MaxPooling2D, Reshape, MaxPool2D, Lambda, AveragePooling2D
from tensorflow.keras.optimizers import Adam

import pandas as pd

In [2]:
def create_character_mapping():

    character_map = {' ': 0}

    for i in range(97, 123):
        character_map[chr(i)] = len(character_map)

    return character_map

character_map = create_character_mapping()

In [7]:
print(character_map)

{' ': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}


In [13]:
print(char_map)
print(index_map)

{"'": 0, '<SPACE>': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, 'i': 10, 'j': 11, 'k': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27}
{1: "'", 2: ' ', 3: 'a', 4: 'b', 5: 'c', 6: 'd', 7: 'e', 8: 'f', 9: 'g', 10: 'h', 11: 'i', 12: 'j', 13: 'k', 14: 'l', 15: 'm', 16: 'n', 17: 'o', 18: 'p', 19: 'q', 20: 'r', 21: 's', 22: 't', 23: 'u', 24: 'v', 25: 'w', 26: 'x', 27: 'y', 28: 'z'}


In [17]:
import codecs
vocab_file = '../vocab.txt'
lines = []
with codecs.open(vocab_file, "r", "utf-8") as fin:
      lines.extend(fin.readlines())
        
print(lines)

["# List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which\n", '# will be ignored by the parser.\n', '# begin of vocabulary\n', ' \n', 'a\n', 'b\n', 'c\n', 'd\n', 'e\n', 'f\n', 'g\n', 'h\n', 'i\n', 'j\n', 'k\n', 'l\n', 'm\n', 'n\n', 'o\n', 'p\n', 'q\n', 'r\n', 's\n', 't\n', 'u\n', 'v\n', 'w\n', 'x\n', 'y\n', 'z\n', "'\n", '-\n', '# end of vocabulary']


In [27]:
token_to_index = {}
index_to_token = {}
index = 0
for line in lines:
    line = line[:-1]  # Strip the '\n' char.
    if line.startswith("#"):
        # Skip from reading comment line.
        continue
    token_to_index[line] = index
    index_to_token[index] = line
    index += 1
print(token_to_index)
print(index_to_token)

{' ': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, "'": 27, '-': 28}
{0: ' ', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 27: "'", 28: '-'}


ModuleNotFoundError: No module named 'utils'

In [2]:
def get_data_detail(meta):
    result = {}
    result['max_input_length'] = meta['spec_length'].max()
    result['max_label_length'] = meta['label_length'].max()
    result['num_samples'] = meta.shape[0]
    return result

In [15]:
#Each tf.train.Example record contains one or more "features", and the input pipeline typically converts these features into tensors.
def _parse_batch(record_batch, config):

    # Create a description of the features
    feature_description = {
        'feature': tf.io.FixedLenFeature([config['max_input_length'],40,1], tf.float32),
        'label': tf.io.FixedLenFeature([config['max_label_length']], tf.int64),
    }

    # Parse the input `tf.Example` proto using the dictionary above
    example = tf.io.parse_example(record_batch, feature_description)
    
    return example['feature'], example['label']

In [16]:
def get_dataset_from_tfrecords(config, tfrecords_dir='tfrecords' , split='train', batch_size=64, n_epochs=10):
    
    if split not in ('train', 'test', 'validate'):
        raise ValueError("split must be either 'train', 'test' or 'validate'")
    
    # List all *.tfrecord files for the selected split
    pattern = os.path.join(tfrecords_dir, '{}*.tfrecord'.format(split))
    files_ds = tf.data.Dataset.list_files(pattern)

    # Disregard data order in favor of reading speed
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    files_ds = files_ds.with_options(ignore_order)

    # Read TFRecord files in an interleaved order
    ds = tf.data.TFRecordDataset(files_ds,
                                 compression_type='ZLIB')
    # Prepare batches
    ds = ds.batch(batch_size)

    # Parse a batch into a dataset of [audio, label] pairs
    ds = ds.map(lambda x: _parse_batch(x, config))

    # Repeat the training data for n_epochs. Don't repeat test/validate splits.
    if split == 'train':
        ds = ds.repeat(n_epochs)

    return ds.prefetch(buffer_size=AUTOTUNE)

In [26]:
def ctc_loss_lambda_func(y_true, y_pred):
    """Function for computing the CTC loss"""

    if len(y_true.shape) > 2:
        y_true = tf.squeeze(y_true)

    input_length = tf.math.reduce_sum(y_pred, axis=-1, keepdims=False)
    input_length = tf.math.reduce_sum(input_length, axis=-1, keepdims=True)
    label_length = tf.math.count_nonzero(y_true, axis=-1, keepdims=True, dtype="int64")

    loss = K.ctc_batch_cost(y_true, y_pred, input_length, label_length)
    loss = tf.reduce_mean(loss)

    return loss

def build_baseline_model(input_size, d_model, learning_rate=3e-4):

    input_data = Input(name="input", shape=input_size)

    conv_1 = Conv2D(32, (3,3), activation = 'relu', padding='same')(input_data)
    pool_1 = MaxPool2D(pool_size=(3, 2), strides=2)(conv_1)
    
    conv_2 = Conv2D(64, (3,3), activation = 'relu', padding='same')(pool_1)
    batch_norm_2 = BatchNormalization()(conv_2)
    
    conv_3 = Conv2D(64, (3,3), activation = 'relu', padding='same')(batch_norm_2)
    batch_norm_3 = BatchNormalization()(conv_3)
    pool_3 = MaxPool2D(pool_size=(1, 2))(batch_norm_3)
    
    shape = pool_3.get_shape()
    blstm = Reshape((shape[1], shape[2] * shape[3]))(pool_3)
    
    blstm = Bidirectional(LSTM(64, return_sequences=True, dropout = 0.5))(blstm)
    blstm = Dropout(rate=0.5)(blstm)
    output_data = Dense(d_model, activation = 'softmax')(blstm)

#     optimizer = RMSprop(learning_rate=learning_rate)
    optimizer = Adam(learning_rate=learning_rate)
    
    model = Model(inputs=input_data, outputs=output_data)
    model.compile(optimizer=optimizer, loss=ctc_loss_lambda_func)
    model.summary()
    return model

In [27]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

ROOT = '../../data/dev_clean_final/'
meta = pd.read_csv(os.path.join(ROOT,'metadata.csv'), index_col = 'index')

config = get_data_detail(meta)

num_rows = 40
num_columns = config['max_input_length']
num_label = 28

train_ds = get_dataset_from_tfrecords(config, tfrecords_dir=os.path.join(ROOT, 'TFrecords'), split='train')

model = build_baseline_model(input_size = (num_columns, num_rows, 1), d_model = num_label)



Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           [(None, 1406, 40, 1)]     0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 1406, 40, 32)      320       
_________________________________________________________________
max_pooling2d_12 (MaxPooling (None, 702, 20, 32)       0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 702, 20, 64)       18496     
_________________________________________________________________
batch_normalization_12 (Batc (None, 702, 20, 64)       256       
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 702, 20, 64)       36928     
_________________________________________________________________
batch_normalization_13 (Batc (None, 702, 20, 64)       256 

In [28]:
model.fit(train_ds, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
 14/260 [>.............................] - ETA: 2:08 - loss: 242.6212

KeyboardInterrupt: 