
# Libraries

In [None]:
!pip install trax
import random
import numpy as np

import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

!pip list | grep trax

# Data

In [None]:
train_stream_fn = trax.data.TFDS('para_crawl/ende',
                                 data_dir='./data/',
                                 keys=('en', 'de'),
                                 eval_holdout_size=0.01, # 1% for eval
                                 train=True)

eval_stream_fn = trax.data.TFDS('para_crawl/ende',
                                data_dir='./data/',
                                keys=('en', 'de'),
                                eval_holdout_size=0.01, # 1% for eval
                                train=False)

[1mDownloading and preparing dataset 1.22 GiB (download: 1.22 GiB, generated: 4.04 GiB, total: 5.26 GiB) to ./data/para_crawl/ende/1.2.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…






HBox(children=(FloatProgress(value=0.0, description='Generating splits...', max=1.0, style=ProgressStyle(descr…

HBox(children=(FloatProgress(value=0.0, description='Generating train examples...', max=16264448.0, style=Prog…

HBox(children=(FloatProgress(value=0.0, description='Shuffling para_crawl-train.tfrecord...', max=16264448.0, …

[1mDataset para_crawl downloaded and prepared to ./data/para_crawl/ende/1.2.0. Subsequent calls will reuse this data.[0m


# Tokenization

In [None]:
# global variables
VOCAB_FILE = 'vocab.subword'
VOCAB_DIR = '/content/'

tokenized_train_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)
tokenized_eval_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)

EOS = 1

def append_eos(stream):
    for (inputs, targets) in stream:
        inputs_with_eos = list(inputs) + [EOS]
        targets_with_eos = list(targets) + [EOS]
        yield np.array(inputs_with_eos), np.array(targets_with_eos)

tokenized_train_stream = append_eos(tokenized_train_stream)
tokenized_eval_stream = append_eos(tokenized_eval_stream)

[34mSingle tokenized example input:[0m [ 1053 29982     2 24373     5     6 14989     2  7201  9466     2 24373
 11983 11314 28837     2 11314 28837     2  2511 27810     2 27403]
[34mSingle tokenized example target:[0m [ 2407 16718  5769     2 14732  1740    47 19765     5     2 14649    28
     2  9708 27889  2803  7131  5461     2  7131  5461     2  2511 27810
     2 27403]


# Functions

* word2Ind: word to index
* ind2Word: index to word
* word2Count: word to number of times it appears
* num_words: total num of words appeared
* tokenize(): text sentence to token list (indices), words to subwords
* detokenize(): token list to sentence (string)

In [None]:
def tokenize(input_str, vocab_file=None, vocab_dir=None):
    EOS = 1
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_file=vocab_file, vocab_dir=vocab_dir))
    inputs = list(inputs) + [EOS]
    batch_inputs = np.reshape(np.array(inputs), [1, -1])
    
    return batch_inputs


def detokenize(integers, vocab_file=None, vocab_dir=None):
    integers = list(np.squeeze(integers))
    EOS = 1
    if EOS in integers:
        integers = integers[:integers.index(EOS)] 
    
    return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)

# Bucketing

![alt text](https://sun9-62.userapi.com/impg/fuoqdikR_mdjXzKxawFUQl7mNWMiA2SHLsmvsA/UWaTaN0In-E.jpg?size=1444x976&quality=96&sign=775b3928fe8deb0c25f94926f5903efa&type=album)

In [None]:
boundaries =  [8,   16,  32, 64, 128, 256, 512]
batch_sizes = [256, 128, 64, 32, 16,    8,   4,  2]

train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]
)(filtered_train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]  
)(filtered_eval_stream)

train_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)
eval_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(eval_batch_stream)

# Model Arch

In [None]:
model = NMTAttn()
print(model)

Serial_in2_out2[
  Select[0,1,0,1]_in2_out4
  Parallel_in2_out2[
    Serial[
      Embedding_33300_1024
      LSTM_1024
      LSTM_1024
    ]
    Serial[
      Serial[
        ShiftRight(1)
      ]
      Embedding_33300_1024
      LSTM_1024
    ]
  ]
  PrepareAttentionInput_in3_out4
  Serial_in4_out2[
    Branch_in4_out3[
      None
      Serial_in4_out2[
        _in4_out4
        Serial_in4_out2[
          Parallel_in3_out3[
            Dense_1024
            Dense_1024
            Dense_1024
          ]
          PureAttention_in4_out2
          Dense_1024
        ]
        _in2_out2
      ]
    ]
    Add_in2
  ]
  Select[0,2]_in3_out2
  LSTM_1024
  LSTM_1024
  Dense_33300
  LogSoftmax
]


# Training

In [None]:
training_loop.run(10)


Step      1: Total number of trainable weights: 148492820
Step      1: Ran 1 train steps in 124.57 secs
Step      1: train CrossEntropyLoss |  10.45017815
Step      1: eval  CrossEntropyLoss |  10.44099426
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 470.58 secs
Step     10: train CrossEntropyLoss |  10.31279373
Step     10: eval  CrossEntropyLoss |  10.08163643
Step     10: eval          Accuracy |  0.05136187


# Decoding

In [None]:
def greedy_decode_test(sentence, NMTAttn=None, vocab_file=None, vocab_dir=None):
    _,_, translated_sentence = sampling_decode(sentence, NMTAttn, vocab_file=vocab_file, vocab_dir=vocab_dir)
    
    return translated_sentence

your_sentence = 'Your sentence.'
greedy_decode_test(your_sentence, model, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR);

# Eval

In [2]:
def jaccard_similarity(candidate, reference):
    can_unigram_set, ref_unigram_set = set(candidate), set(reference)  
    joint_elems = can_unigram_set.intersection(ref_unigram_set)
    all_elems = can_unigram_set.union(ref_unigram_set)
    overlap = len(joint_elems) / len(all_elems)
    
    return overlap

jaccard_similarity([1, 2, 3], [1, 2, 3, 4])

0.75