### Imports

In [1]:
import tensorflow as tf

from data_generator import DataGenerator
from train_test_utils import model_evaluate, model_fit
from model import ASRModel

In [2]:
print(tf.__version__)

2.0.0


In [3]:
tf.test.is_gpu_available()

True

### Paths

In [4]:
train_path = "./LibriSpeech100/train/train_all/"
dev_path = "./LibriSpeech100/dev/dev_all/"
test_path = "./LibriSpeech100/test/test_all/"

### Create DataGenerator objects

In [5]:
train_data = DataGenerator(train_path)
val_data = DataGenerator(dev_path)
test_data = DataGenerator(test_path)

### Build Model

In [6]:
model = ASRModel()
model.build(input_shape = (None, None, 20))
optimizer = tf.keras.optimizers.Adam()

In [7]:
model.summary()

Model: "asr_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  multiple                  3280      
_________________________________________________________________
batch_normalization (BatchNo multiple                  80        
_________________________________________________________________
time_distributed (TimeDistri multiple                  630       
_________________________________________________________________
activation (Activation)      multiple                  0         
Total params: 3,990
Trainable params: 3,950
Non-trainable params: 40
_________________________________________________________________


### Checkpoint

In [8]:
ckpt_dir = './training_checkpoints'
ckpt = tf.train.Checkpoint(optimizer=optimizer, model = model)
manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep = 2)    

In [9]:
losses, accuracies, val_losses, val_acc = model_fit(model, optimizer, train_data, manager, val_ds = val_data, epochs = 20)

Instructions for updating:
Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.
78
156
234
312
390
468
546
Epoch:  1  Loss: 796.766663  WER:  1.00022876686435
 Validation Loss: 445.175873  Validation WER:  1.008242400131115


### Restore Checkpoint and Test Model

In [10]:
save_file = open('outputs/predictions.txt', 'w')

In [11]:
# ckpt.restore(manager.latest_checkpoint)

_, acc = model_evaluate(model, test_data, test=True, save_file = save_file)

Test Loss: 476.435669  Test WER:  1.0050001220789566


In [12]:
print(acc)

1.0050001220789566
