In [1]:
!pip3 install tensorflow_datasets dataclasses typing_extensions importlib_resources zipp tensorflow_metadata dill promise --no-deps

Collecting tensorflow_datasets
  Using cached https://files.pythonhosted.org/packages/93/83/85f14bcf27df5ae23502803502f8506eefec18a285fea909aa67dc9b736e/tensorflow_datasets-4.4.0-py3-none-any.whl
Collecting dataclasses
  Using cached https://files.pythonhosted.org/packages/fe/ca/75fac5856ab5cfa51bbbcefa250182e50441074fdc3f803f6e76451fab43/dataclasses-0.8-py3-none-any.whl
Collecting typing_extensions
  Using cached https://files.pythonhosted.org/packages/05/e4/baf0031e39cf545f0c9edd5b1a2ea12609b7fcba2d58e118b11753d68cf0/typing_extensions-4.0.1-py3-none-any.whl
Collecting importlib_resources
  Using cached https://files.pythonhosted.org/packages/24/1b/33e489669a94da3ef4562938cd306e8fa915e13939d7b8277cb5569cb405/importlib_resources-5.4.0-py3-none-any.whl
Collecting zipp
  Using cached https://files.pythonhosted.org/packages/bd/df/d4a4974a3e3957fd1c1fa3082366d7fff6e428ddb55f074bf64876f8e8ad/zipp-3.6.0-py3-none-any.whl
Collecting tensorflow_metadata
  Using cached https://files.pythonhosted

In [9]:
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import tensorflow as tf
import os
import json
import numpy as np
from sklearn.model_selection import train_test_split
import time

In [2]:
datasets_dir = '/home/andysilv/yandexsdc/seminars/attention/datasets'
if not os.path.exists(datasets_dir):
    os.makedirs(datasets_dir)

In [3]:
EN_TOKENIZER_PATH = 'en_tokenizer.json'
RU_TOKENIZER_PATH = 'ru_tokenizer.json'
NUM_WORDS = 30000
NO_CACHED_TOKENIZER = False
MAX_LENGTH = 50 + 2


def init_tokenizer(tokenizer_path, texts):
    if not os.path.exists(tokenizer_path) or NO_CACHED_TOKENIZER:
        print('initializing tokenizer and storing it to', tokenizer_path)
        tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=NUM_WORDS,
            filters='!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n',
            lower=True, split=' ', char_level=False, oov_token='<UNK>',
            document_count=0
        )

        tokenizer.fit_on_texts(texts)
        with open(tokenizer_path, 'w') as f:
            f.write(tokenizer.to_json())
    else:
        print('loading tokenizer from', tokenizer_path)
        with open(tokenizer_path, 'r') as f:
            tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(f.read())
    return tokenizer


def load_datasets():
    config = tfds.translate.wmt.WmtConfig(
        version="0.0.1",
        language_pair=("ru", "en"),
        subsets={
            tfds.Split.TRAIN: ["yandexcorpus"],
            tfds.Split.VALIDATION: ["newstest2012", 'newstest2013', 'newstest2014', 'newstest2015', 'newstest2016', 'newstest2017'],
        },
    )
    builder = tfds.builder("wmt_translate", config=config)

    download_config = tfds.download.DownloadConfig(manual_dir=datasets_dir, extract_dir=datasets_dir)
    builder.download_and_prepare(download_config=download_config, download_dir=datasets_dir)
    return builder.as_dataset(as_supervised=True)


def preprocess_sentence(s):
    return '<start> ' + s.decode('utf-8') + ' <end>'


def build_tokenizers(datasets):
    ru_texts, en_texts = zip(*[
        (preprocess_sentence(ru), preprocess_sentence(en))
        for ru, en in datasets['train'].as_numpy_iterator()])
    en_tokenizer = init_tokenizer(EN_TOKENIZER_PATH, en_texts)
    ru_tokenizer = init_tokenizer(RU_TOKENIZER_PATH, ru_texts)
    
    return ru_tokenizer, en_tokenizer, ru_texts, en_texts



def max_length(tensor):
    return max(len(t) for t in tensor)

In [4]:
datasets = load_datasets()



In [5]:
ru_tokenizer, en_tokenizer, ru_texts, en_texts = build_tokenizers(datasets)

loading tokenizer from en_tokenizer.json
loading tokenizer from ru_tokenizer.json


In [6]:
input_tensor = ru_tokenizer.texts_to_sequences(ru_texts)
target_tensor = en_tokenizer.texts_to_sequences(en_texts)

In [7]:
input_tensor, target_tensor = zip(*[(ru, en) for ru, en in zip(input_tensor, target_tensor)
                                    if len(ru) <= MAX_LENGTH and len(en) <= MAX_LENGTH])
max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)

In [8]:
max_length_inp, max_length_tar

(52, 52)

In [9]:
input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, maxlen=max_length_inp, padding='post', value=0)
target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, maxlen=max_length_tar, padding='post', value=0)

In [10]:
# Creating training and validation sets using an 80-20 split
input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)

# Show length
len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)

(764070, 764070, 191018, 191018)

In [11]:
BUFFER_SIZE = len(input_tensor_train)
BATCH_SIZE = 64
N_BATCH = BUFFER_SIZE // BATCH_SIZE
WORDS_EMBEDDING_SIZE = 256 # 256
HIDDEN_STATES = 1000 # 1024
vocab_inp_size = ru_tokenizer.num_words + 1  # pad
vocab_tar_size = en_tokenizer.num_words + 1  # pad

dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

In [12]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
            enc_units, return_sequences=True, 
            return_state=True, 
            recurrent_activation='sigmoid', 
            recurrent_initializer='glorot_uniform')
        
    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state=hidden)        
        return output, state
    
    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.enc_units))

In [13]:
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(
            dec_units, return_sequences=True, 
            return_state=True, 
            recurrent_activation='sigmoid', 
            recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size)
        
        # used for attention
        self.W1 = tf.keras.layers.Dense(self.dec_units)
        self.W2 = tf.keras.layers.Dense(self.dec_units)
        self.V = tf.keras.layers.Dense(1)
        
    def call(self, x, hidden, enc_output):
        # enc_output shape == (batch_size, max_length, hidden_size)
        
        # hidden shape == (batch_size, hidden size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden size)
        # we are doing this to perform addition to calculate the score
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        
        # score shape == (batch_size, max_length, hidden_size)
        score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))
        
        # attention_weights shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying score to self.V
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        
        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * enc_output
        context_vector = tf.reduce_sum(context_vector, axis=1)
        
        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)
        
        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        
        # passing the concatenated vector to the GRU
        output, state = self.gru(x)
        
        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[2]))
        
        # output shape == (batch_size * 1, vocab)
        x = self.fc(output)
        
        return x, state, attention_weights
        
    def initialize_hidden_state(self, batch_sz):
        return tf.zeros((self.batch_sz, self.dec_units))

In [14]:
encoder = Encoder(vocab_inp_size, WORDS_EMBEDDING_SIZE, HIDDEN_STATES, BATCH_SIZE)
decoder = Decoder(vocab_tar_size, WORDS_EMBEDDING_SIZE, HIDDEN_STATES, BATCH_SIZE)

In [15]:
optimizer = tf.keras.optimizers.Adam()


def loss_function(real, pred):
    # ignore paddings in loss, which have 0 index
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask
    return tf.reduce_mean(loss_)



In [16]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [None]:
encoder.load_weights(checkpoint_path)


In [None]:
EPOCHS = 10

encoder.batch_sz = BATCH_SIZE
decoder.batch_sz = BATCH_SIZE

for epoch in range(EPOCHS):
    start = time.time()
    
    hidden = encoder.initialize_hidden_state()
    total_loss = 0
    
    for (batch, (inp, targ)) in enumerate(dataset):
        loss = 0
        
        with tf.GradientTape() as tape:
            enc_output, enc_hidden = encoder(inp, hidden)
            
            dec_hidden = enc_hidden
            
            # dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * BATCH_SIZE, 1)       
            dec_input = tf.expand_dims(targ[:, 0], 1)
            # Teacher forcing - feeding the target as the next input
            for t in range(1, targ.shape[1]):
                # passing enc_output to the decoder
                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
                
                loss += loss_function(targ[:, t], predictions)
                
                # using teacher forcing
                dec_input = tf.expand_dims(targ[:, t], 1)
        
        batch_loss = (loss / int(targ.shape[1]))
        
        total_loss += batch_loss
        
        variables = encoder.variables + decoder.variables
        
        gradients = tape.gradient(loss, variables)
        
        optimizer.apply_gradients(zip(gradients, variables))
        
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                         batch,
                                                         batch_loss.numpy()))
    # saving (checkpoint) the model every 2 epochs
    if (epoch + 1) % 2 == 0:
        checkpoint.save(file_prefix = checkpoint_prefix)
    
    print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                        total_loss / N_BATCH))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))



Epoch 1 Batch 0 Loss 2.4536
Epoch 1 Batch 100 Loss 2.6749
Epoch 1 Batch 200 Loss 2.3028
Epoch 1 Batch 300 Loss 2.4789
Epoch 1 Batch 400 Loss 2.4676
Epoch 1 Batch 500 Loss 2.2050
Epoch 1 Batch 600 Loss 2.5052
Epoch 1 Batch 700 Loss 2.4129
Epoch 1 Batch 800 Loss 2.2379
Epoch 1 Batch 900 Loss 2.4893
Epoch 1 Batch 1000 Loss 2.3665
Epoch 1 Batch 1100 Loss 2.3412
Epoch 1 Batch 1200 Loss 2.0012
Epoch 1 Batch 1300 Loss 2.2255
Epoch 1 Batch 1400 Loss 2.5297
Epoch 1 Batch 1500 Loss 2.0364
Epoch 1 Batch 1600 Loss 2.1950
Epoch 1 Batch 1700 Loss 2.1993
Epoch 1 Batch 1800 Loss 2.2462
Epoch 1 Batch 1900 Loss 2.0792
Epoch 1 Batch 2000 Loss 1.8850
Epoch 1 Batch 2100 Loss 1.9330
Epoch 1 Batch 2200 Loss 2.2040
Epoch 1 Batch 2300 Loss 1.9729
Epoch 1 Batch 2400 Loss 2.0645
Epoch 1 Batch 2500 Loss 2.3787
Epoch 1 Batch 2600 Loss 1.8867
Epoch 1 Batch 2700 Loss 1.8932
Epoch 1 Batch 2800 Loss 2.3298
Epoch 1 Batch 2900 Loss 1.8589
Epoch 1 Batch 3000 Loss 1.9070
Epoch 1 Batch 3100 Loss 1.8551
Epoch 1 Batch 3200 L

In [81]:
def translate(text):
    text = preprocess_sentence(text)
    
    encoder.batch_sz = 1
    decoder.batch_sz = 1
    hidden = encoder.initialize_hidden_state()
    input_tensor = np.array(ru_tokenizer.texts_to_sequences([text]))
    enc_output, enc_hidden = encoder(input_tensor, hidden)
    
    dec_hidden = enc_hidden
    dec_input = np.array([[en_tokenizer.word_index['<start>']]])
    
    beam_width = 3
    max_len = 50
    initial_result = {
        "result": [en_tokenizer.word_index['<start>']],
        "log_prob": 0
    }
    
    results = [initial_result]
    current_states = [dict(
        dec_hidden=enc_hidden,
        dec_input=dec_input,
        **initial_result
    )]
    
    for t in range(1, max_len):
        new_current_states = []
        for current_state in current_states:
            # passing enc_output to the decoder
            predictions, dec_hidden, _ = decoder(
                current_state['dec_input'], current_state['dec_hidden'], enc_output)
            max_prob_inds = np.argpartition(predictions, -beam_width)[0, -beam_width:]
            for ind in max_prob_inds:
                probs = np.exp(predictions[0]) / np.sum(np.exp(predictions[0]))
                res = {
                    'log_prob': current_state['log_prob'] + np.log(probs[ind]),
                    'result': current_state['result'] + [ind]
                }
                results.append(res)
                
                if ind == en_tokenizer.word_index['<end>']:
                    continue
                new_current_states.append(dict(
                    dec_hidden=dec_hidden,
                    dec_input=np.array([[ind]]),
                    **res
                ))
        new_current_states.sort(key=lambda x: x['log_prob'])    
        new_current_states = new_current_states[-beam_width:]
        current_states = new_current_states
    
    for r in results:
        r['normalized_log_prob'] = r['log_prob'] / len(r['result'])
    
    results.sort(key=lambda k: -r['normalized_log_prob'])
    return results

In [91]:
results = translate('Выпить чай'.encode('utf-8'))
for r in results:
    if r['result'][-1] != en_tokenizer.word_index['<end>']:
        continue
    print(en_tokenizer.sequences_to_texts([r['result']]))

['<start> drink tea <end>']
['<start> a free tea <end>']
['<start> a free tea cup <end>']
['<start> a free tea with <end>']
['<start> a free tea cup of <end>']
['<start> we can a free tea <end>']
['<start> we can a free tea cup <end>']
['<start> we can a free tea with <end>']
['<start> a free tea cup of tea <end>']
['<start> we can a free tea cup of <end>']
['<start> a free tea cup of tea party <end>']
['<start> we can a free tea with tea party <end>']
['<start> we can a free tea cup of tea <end>']
['<start> a free tea cup of tea party to <UNK> <end>']
['<start> we can a free tea cup of tea party <end>']
['<start> a free tea cup of tea party to our tea is a cup <end>']
['<start> a free tea cup of tea party to our tea with a cup <end>']
['<start> a free tea cup of tea party to our tea with a cup to <end>']
['<start> a free tea cup of tea party to our tea with a cup of <end>']
['<start> a free tea cup of tea party to our tea is a cup of tea <end>']
['<start> a free tea cup of tea party t

In [None]:
class RNNSearchTrainer(object):
    EN_TOKENIZER_PATH = 'en_tokenizer.json'
    RU_TOKENIZER_PATH = 'ru_tokenizer.json'
    NO_CACHED_TOKENIZER = True
    
    
    def __init__(self):
        config = tfds.translate.wmt.WmtConfig(
            version="0.0.1",
            language_pair=("ru", "en"),
            subsets={
                tfds.Split.TRAIN: ["yandexcorpus"],
                tfds.Split.VALIDATION: ["newstest2012", 'newstest2013', 'newstest2014', 'newstest2015', 'newstest2016', 'newstest2017'],
            },
        )
        builder = tfds.builder("wmt_translate", config=config)
        
        download_config = tfds.download.DownloadConfig(manual_dir=datasets_dir, extract_dir=datasets_dir)
        builder.download_and_prepare(download_config=download_config, download_dir=datasets_dir)
        
        self._datasets = builder.as_dataset(as_supervised=True)
        train_examples = self._datasets['train']
        val_examples = self._datasets['validation']
        
        
        ru_texts, en_texts = zip(*[
                (self._preprocess_sentence(ru), self._preprocess_sentence(en))
                for ru, en in train_examples.as_numpy_iterator()])
        
        self._en_tokenizer = self._init_tokenizer(self.EN_TOKENIZER_PATH, en_texts)
        self._ru_tokenizer = self._init_tokenizer(self.RU_TOKENIZER_PATH, ru_texts)
        
    def preprocess_sentence(self, s):
        return '<start> ' + s.decode('utf-8') + ' <end>'
        
    def _init_tokenizer(self, tokenizer_path, texts):
        if not os.path.exists(tokenizer_path) or self.NO_CACHED_TOKENIZER:
            tokenizer = tf.keras.preprocessing.text.Tokenizer(
                num_words=RnnSearchModel.NUM_WORDS,
                filters='!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n',
                lower=True, split=' ', char_level=False, oov_token='<UNK>',
                document_count=0
            )

            tokenizer.fit_on_texts(texts)
            with open(tokenizer_path, 'w') as f:
                f.write(tokenizer.to_json())
        else:
            with open(tokenizer_path, 'r') as f:
                tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(f.read())
        return tokenizer
    
    def get_dataset(self):
        input_tensor, output_tensor = zip(*[
            (self.preprocess_sentence(ru), self.preprocess_sentence(en))
            for ru, en in dataset.as_numpy_iterator()])
        
    
    def train(self):
        rnn_search = RnnSearchModel()
        
        dataset = self._datasets['train'].shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
        dataset = dataset.batch(batch_size=80)     # batch_size=1 if you want to get only one element per step
        for e in range(epoch):
            print('Epoch', e)
            for num_batch, (ru, en) in enumerate(dataset.as_numpy_iterator()):
                tokenized_input_sentences = self._en_tokenizer.texts_to_sequences(
                    [self._preprocess_sentence(x) for x in en])
                tokenized_output_sentences = self._ru_tokenizer.texts_to_sequences(
                    [self._preprocess_sentence(x) for x in ru])
                
                rnn_search((tokenized_input_sentences, tokenized_output_sentences), True)

In [None]:
trainer = RNNSearchTrainer()

In [79]:
dataset = trainer._datasets['train']

In [66]:
ru_texts, en_texts = zip(*[
                (trainer.preprocess_sentence(ru), trainer.preprocess_sentence(en))
                for ru, en in dataset.as_numpy_iterator()])

In [70]:
for ru, en in dataset.as_numpy_iterator():
    print(len(trainer._ru_tokenizer.texts_to_sequences([trainer._preprocess_sentence(ru)])[0]))
    break

32


In [47]:
trainer._en_tokenizer.texts_to_sequences(['Some strange'])

[[74, 3075]]

In [80]:
ru_lengths, en_lengths = zip(*[(len(trainer._ru_tokenizer.texts_to_sequences([trainer._preprocess_sentence(ru)])[0]), 
                                len(trainer._en_tokenizer.texts_to_sequences([trainer._preprocess_sentence(en)])[0])) 
                               for ru, en in dataset.as_numpy_iterator()])

In [81]:
max(ru_lengths)

117

In [21]:

for (ru, en) in trainer._datasets['train'].as_numpy_iterator():
    print(trainer._preprocess_sentence(en))
    break

<start> The author of the catalogue also draws the attention of the readers to differences in the presentation of the material received, depending on whether they came from an archaeological museum or from an institution incidentally in the possession of such a type of collection. <end>


In [None]:
@tf.function
def f(x, y):
    

In [107]:
epoch = 10
dataset = train_examples.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=80)     # batch_size=1 if you want to get only one element per step

num_batch = 0


In [164]:
for (ru, en) in train_examples.as_numpy_iterator():
    x = en_tokenizer.texts_to_sequences([en.decode()])
    x[0] = [0] * 4 + x[0]
    break

In [165]:
print(x)

[[0, 0, 0, 0, 1, 1490, 2, 1, 7170, 46, 6862, 1, 655, 2, 1, 3634, 4, 2118, 5, 1, 2051, 2, 1, 610, 615, 1715, 10, 425, 35, 478, 20, 27, 8486, 1234, 19, 20, 27, 2278, 8785, 5, 1, 4713, 2, 59, 6, 355, 2, 1210]]


In [108]:
lengths = [len(s) for s in x]
max_len = max(lengths)

In [109]:
encoding.shape

(80, 14, 30000)

In [132]:
for i in range(len(x)):
    x[i] = [0] * (max_len - len(x[i])) + x[i]
encoding = tf.one_hot(x, 30000).numpy()
for i in range(len(x)):
    encoding[i][:max_len - lengths[i]] *= 0

In [135]:
next(i for i, z in enumerate(encoding[0,61]) if z)

6962

In [122]:
max_len

62

In [131]:
tf.one_hot?

In [104]:
en_tokenizer.word_index

{'the': 1,
 'of': 2,
 'and': 3,
 'to': 4,
 'in': 5,
 'a': 6,
 'is': 7,
 'for': 8,
 'that': 9,
 'on': 10,
 'with': 11,
 'as': 12,
 'it': 13,
 'are': 14,
 'be': 15,
 'by': 16,
 'you': 17,
 'this': 18,
 'or': 19,
 'from': 20,
 'was': 21,
 'not': 22,
 'at': 23,
 'have': 24,
 'will': 25,
 'i': 26,
 'an': 27,
 'we': 28,
 'all': 29,
 'which': 30,
 'can': 31,
 'has': 32,
 'but': 33,
 'their': 34,
 'they': 35,
 'he': 36,
 'your': 37,
 'one': 38,
 'if': 39,
 'his': 40,
 'its': 41,
 'more': 42,
 'other': 43,
 'there': 44,
 'were': 45,
 'also': 46,
 'our': 47,
 'when': 48,
 'new': 49,
 'who': 50,
 'been': 51,
 'about': 52,
 'time': 53,
 'only': 54,
 'had': 55,
 'so': 56,
 'these': 57,
 'up': 58,
 'such': 59,
 '1': 60,
 'any': 61,
 'no': 62,
 'people': 63,
 'into': 64,
 'them': 65,
 'would': 66,
 'may': 67,
 'than': 68,
 'do': 69,
 'out': 70,
 'some': 71,
 'what': 72,
 'first': 73,
 'use': 74,
 '2': 75,
 'should': 76,
 'most': 77,
 'information': 78,
 'world': 79,
 'after': 80,
 'well': 81,
 'inter

In [172]:
encoding.reshape()

(80, 62, 30000)

In [175]:
tf.transpose(encoding, [1, 0, 2]).shape

TensorShape([62, 80, 30000])

In [5]:
class RnnSearchModel(tf.keras.Model):
    NUM_WORDS = 30000
    HIDDEN_STATES = 1000
    WORDS_EMBEDDING_SIZE = 620
    MAXOUT_HIDDEN_LAYER_SIZE = 500
    ALIGNMENT_MODEL_HIDDEN_UNITS = 1000
    MAX_OUTPUT_SENTENCE_LEN = 50
    

    def __init__(self, **kwargs):
        super(RnnSearchModel, self).__init__(**kwargs)
        self._forward_enc_weights = self._init_encoder_weights()
        self._backward_enc_weights = self._init_encoder_weights()
        self._encoder_embedding_layer = tf.keras.layers.Embedding(self.NUM_WORDS, self.WORDS_EMBEDDING_SIZE)
        self._decoder_embedding_layer = tf.keras.layers.Embedding(self.NUM_WORDS, self.WORDS_EMBEDDING_SIZE)
        self._dec_weights = self._init_decoder_weights()
        
        
    def _init_encoder_weights():
        ortho_initializer = tf.keras.initializers.Orthogonal()
        normal_initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=0.01)
        
        use_bias = True
        return {
            'W': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer),
            'Wr': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer),
            'Wz': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer),
            'U':  tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'Ur': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'Uz': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
        }
    
    def _init_decoder_weights(self):
        ortho_initializer = tf.keras.initializers.Orthogonal()
        normal_initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=0.01)
        normal_initializer_001 = tf.keras.initializers.RandomNormal(mean=0., stddev=0.001)
        use_bias = True
        return {
            'va': tf.keras.layers.Dense(units=1, use_bias=use_bias, kernel_initializer=tf.keras.initializers.Zeros()),
            'Wa': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer_001),
            'Ua': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer_001),
            'Ws': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer),
            'W': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer),
            'Wr': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer),
            'Wz': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=normal_initializer),
            'U':  tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'Ur': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'Uz': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'C': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'Cz': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'Cr': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'U0': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'V0': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'C0': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
            'W0': tf.keras.layers.Dense(units=self.HIDDEN_STATES, use_bias=use_bias, kernel_initializer=ortho_initializer),
        }
        
    
    def _pad_and_one_hot_encode(self, x):
        batch_size = len(x)
        lengths = [len(s) for s in x]
        max_len = max(lengths)
        # pad all input sequences with leading zeros to make their lengths equal to max_len
        for i in range(batch_size):
            x[i] = x[i] + [0] * (max_len - lengths[i])

        x = tf.one_hot(x, self.NUM_WORDS)
        return x
    
    def _get_encoder_states(self, x, weights):
        prev_h = tf.zeros(self.HIDDEN_STATES)
        hidden_states = []
        for x_i in x:
            ex_i = self._encoder_embedding_layer(x_i)
            r = tf.sigmoid(weights['Wr'](ex_i) + weights['Ur'](prev_h))
            z = tf.sigmoid(weights['Wz'](ex_i) + weights['Uz'](prev_h))
            cur_h_tmp = tf.keras.activations.tanh(weights['W'](ex_i) + weights['U'](r * prev_h))
            cur_h = (1 - z) * prev_h + z * cur_h_tmp
            hidden_states.append(cur_h)
            prev_h = cur_h
        return hidden_states
    
    def _get_predictions(self, weights, h):
        prev_s = tf.tanh(self._dec_weights['Ws'](h[0, self.HIDDEN_STATES:]))
        probs = []
        for i in range(self.MAX_OUTPUT_SENTENCE_LEN):
            e_i = self._dec_weights['va'](tf.tanh(self._dec_weights['Wa'](prev_s) + self._dec_weights['Ua'](h)))
            alpha_i = tf.keras.layers.Softmax()(e_i)
            c = tf.tensordot(alpha_i, h, 1)
            ey_i = self._decoder_embedding_layer(y_i)
            r = tf.sigmoid(weights['Wr'](ey_i) + weights['Ur'](prev_s) + weights['Cr'](c))
            z = tf.sigmoid(weights['Wz'](ey_i) + weights['Uz'](prev_s) + weights['Cz'](c))
            cur_s_tmp = tf.keras.activations.tanh(weights['W'](ey_i) + weights['U'](r * prev_h) + weights['Cr'](c))
            cur_s = (1 - z) * prev_s + z * cur_s_tmp
            t = weights['U0'](prev_s) + weights['V0'](ey_i) + weights['C0'](c)
            t = tfa.layers.Maxout(num_units=self.MAXOUT_HIDDEN_LAYER_SIZE)(t)
            probs_i = weights['W0'](t)
            prev_s = cur_s
            probs.append(probs[i])
        return tf.stack(probs)

        
        
    def call(self, inputs, training=False):
        x = self._pad_and_one_hot_encode(inputs)
        
        # reshape x to be    max_len x VocabSize x batch_size
        x = tf.transpose(encoding, [1, 0, 2])
        h_forward = self._get_encoder_states(x, self._forward_enc_weights)
        h_backward = self._get_encoder_states(x, self._backward_enc_weights)
        h = tf.concat([h_forward, h_backward], axis=1, name='h')  # shape must become max_len x batch_size x 2 * HIDDEN_STATES
        y_predicted = self._get_predictions(self._dec_weights, h)
        return y_predicted
        
        