In [1]:
# Imports

import tensorflow as tf
from tensorflow.keras import datasets, models, layers
from sklearn.model_selection import train_test_split
import wandb
from wandb.keras import WandbCallback
from matplotlib import pyplot
import os
import numpy as np

In [2]:
# Setting up input flows

dataset_path = "../../dakshina_dataset_v1.0/ta/lexicons"
train_path = dataset_path + "/ta.translit.sampled.train.tsv"
test_path = dataset_path + "/ta.translit.sampled.test.tsv"

train_data = ''.join([each.decode('utf-8') for each in open(train_path, 'rb')]).split()
y_train, X_train, z_train = train_data[::3], train_data[1::3], [int(each) for each in train_data[2::3]]

test_data = ''.join([each.decode('utf-8') for each in open(test_path, 'rb')]).split()
y_test, X_test, z_test = test_data[::3], test_data[1::3], [int(each) for each in test_data[2::3]]

In [3]:
# Setting up the vocabulary

input_vocab, output_vocab = set(), set()

for each in X_train:
    for every in each:
        input_vocab.add(every)
for each in X_test:
    for every in each:
        input_vocab.add(every)

for each in y_train:
    for every in each:
        output_vocab.add(every)
for each in y_test:
    for every in each:
        output_vocab.add(every)
        
input_vocab.add("")
output_vocab.add("")

input_vocab = sorted(list(input_vocab))
output_vocab = sorted(list(output_vocab))
input_v_len = len(input_vocab)
output_v_len = len(output_vocab)

input_inv = dict([(char, i) for i, char in enumerate(input_vocab)])
output_inv = dict([(char, i) for i, char in enumerate(output_vocab)])

In [4]:
def onehot(X, y):
    max_input_len, max_output_len = 0, 0
    for each in X:
        max_input_len = max(max_input_len, len(each))
    for each in y:
        max_output_len = max(max_output_len, len(each))
    
    encoder_input_data = np.zeros((len(X), max_input_len, input_v_len), dtype="float32")
    decoder_input_data = np.zeros((len(X), max_output_len, output_v_len), dtype="float32")
    decoder_target_data = np.zeros((len(X), max_output_len, output_v_len), dtype="float32")

    for i, (a, b) in enumerate(zip(X, y)):
        for t, char in enumerate(a):
            encoder_input_data[i, t, input_inv[char]] = 1.0
        encoder_input_data[i, t + 1 :, output_inv[""]] = 1.0
        for t, char in enumerate(b):
            decoder_input_data[i, t, output_inv[char]] = 1.0
            if t > 0:
                decoder_target_data[i, t - 1, output_inv[char]] = 1.0
        decoder_input_data[i, t + 1 :, output_inv[""]] = 1.0
        decoder_target_data[i, t:, output_inv[""]] = 1.0
    
    return encoder_input_data, decoder_input_data, decoder_target_data

In [22]:
def myRNN(latent_dim):
    encoder_inputs = tf.keras.Input(shape=(None, input_v_len))
    encoder = layers.SimpleRNN(latent_dim, return_state=True)
    _, state_h = encoder(encoder_inputs)
    encoder_states = [state_h]
    
    decoder_inputs = tf.keras.Input(shape=(None, output_v_len))
    decoder_RNN = layers.SimpleRNN(latent_dim, return_sequences=True, return_state=True)
    decoder_outputs, _ = decoder_RNN(decoder_inputs, initial_state=encoder_states)
    decoder_dense = layers.Dense(output_v_len, activation="softmax")
    decoder_outputs = decoder_dense(decoder_outputs)

    model = tf.keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
    
    return model

In [23]:
model = myRNN(128)

In [24]:
model.summary()

Model: "functional_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, None, 27)]   0                                            
__________________________________________________________________________________________________
input_12 (InputLayer)           [(None, None, 47)]   0                                            
__________________________________________________________________________________________________
simple_rnn_10 (SimpleRNN)       [(None, 128), (None, 19968       input_11[0][0]                   
__________________________________________________________________________________________________
simple_rnn_11 (SimpleRNN)       [(None, None, 128),  22528       input_12[0][0]                   
                                                                 simple_rnn_10[0][1]   

In [25]:
def train(batch_size=64, epochs=2):
    model.compile(optimizer="nadam", loss="categorical_crossentropy", metrics=["accuracy"])
    
    aa, bb, cc = onehot(X_train, y_train)
    model.fit([aa, bb], cc, batch_size=batch_size, epochs=epochs, validation_split=0.1,)