# Demonstrate Seq2Seq Wrapper with CMUDict dataset

In [1]:
import tensorflow as tf
import numpy as np

# preprocessed data
from datasets.cmudict import data, data_utils

In [2]:
# load data from pickle and npy files
data_ctl, idx_words, idx_phonemes = data.load_data(PATH='datasets/cmudict/')
(trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_phonemes, idx_words)

In [3]:
# parameters 
xseq_len = trainX.shape[-1]
yseq_len = trainY.shape[-1]
batch_size = 128
xvocab_size = len(data_ctl['idx2pho'].keys())  
yvocab_size = len(data_ctl['idx2alpha'].keys())
emb_dim = 128

## Create an instance of the Wrapper

In [4]:
import seq2seq_wrapper

In [5]:
model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len,
                               yseq_len=yseq_len,
                               xvocab_size=xvocab_size,
                               yvocab_size=yvocab_size,
                               emb_dim=emb_dim,
                               num_layers=3
                               )

## Create data generators

Read *data_utils.py* for more information

In [6]:
val_batch_gen = data_utils.rand_batch_gen(validX, validY, 16)
train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, 128)

- Computational graph was built when the model was instantiated
- Now all we need to do is train the model using processed CMUdict dataset, via data generators
- Internally a loop is run for *epochs* times for training
- Evaluation is done periodically.

## Train

In [None]:
sess = model.train(train_batch_gen, val_batch_gen)

## Restore last saved session from disk

In [7]:
sess1 = model.restore_last_session()

## Predict

In [8]:
output = model.predict(sess1, val_batch_gen.__next__()[0])
print(output.shape)

(16, 16)


In [11]:
output

array([[ 2, 21,  6,  3,  9, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 4,  5, 16, 16, 18,  9, 20,  9, 14,  7,  0,  0,  0,  0,  0,  0],
       [18,  5, 16, 18,  1, 20,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [19,  8,  1, 16, 16,  1,  1, 12, 18,  0,  0,  0,  0,  0,  0,  0],
       [ 7,  1, 14,  9, 15, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [16, 15, 16, 12,  1,  3,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 6, 15, 18, 19, 12, 15, 14,  4,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 5, 14,  3, 15, 13,  2,  5, 18,  4,  0,  0,  0,  0,  0,  0,  0],
       [16, 18, 15,  6,  6,  9, 19,  5, 19,  0,  0,  0,  0,  0,  0,  0],
       [ 1, 12, 21, 19,  9, 15, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 7, 15, 12,  4, 19, 20, 15, 14,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 3, 15, 12, 12, 21, 19,  9, 14,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 3, 18,  5,  4,  5,  2, 12,  1,  0,  0,  0,  0,  0,  0,  0,  0],
       [18,  5,  3,  3, 14,  1, 20, 12,  5, 18,  0,

## Let us decode and see the words

In [14]:
for oi in output:
    print(data_utils.decode_word(oi, data_ctl['idx2alpha']))

bufcin
deppriting
reprate
shappaalr
ganion
poplaca
forslond
encomberd
proffises
alusion
goldston
collusin
credebla
reccnatler
sandefr
prerine
