In [2]:
!pip install trax 
import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

import sys
import os
import numpy as np
import textwrap

wrapper = textwrap.TextWrapper(width=70)
np.set_printoptions(threshold=sys.maxsize)

Collecting trax
[?25l  Downloading https://files.pythonhosted.org/packages/42/51/305b839f51d53abb393777f743e497d27bb341478f3fdec4d6ddaccc9fb5/trax-1.3.7-py2.py3-none-any.whl (521kB)
[K     |▋                               | 10kB 14.0MB/s eta 0:00:01[K     |█▎                              | 20kB 19.7MB/s eta 0:00:01[K     |█▉                              | 30kB 24.1MB/s eta 0:00:01[K     |██▌                             | 40kB 26.3MB/s eta 0:00:01[K     |███▏                            | 51kB 28.0MB/s eta 0:00:01[K     |███▊                            | 61kB 19.2MB/s eta 0:00:01[K     |████▍                           | 71kB 14.0MB/s eta 0:00:01[K     |█████                           | 81kB 15.0MB/s eta 0:00:01[K     |█████▋                          | 92kB 16.1MB/s eta 0:00:01[K     |██████▎                         | 102kB 17.2MB/s eta 0:00:01[K     |███████                         | 112kB 17.2MB/s eta 0:00:01[K     |███████▌                        | 122kB 17.2MB

# Dataset

In [3]:
train_stream_fn = trax.data.TFDS('scientific_papers',
                                 data_dir='content/',
                                 keys=('abstract', 'article'),
                                 train=True)

eval_stream_fn = trax.data.TFDS('scientific_papers',
                                data_dir='content/',
                                keys=('abstract', 'article'),
                                train=False)

[1mDownloading and preparing dataset 4.20 GiB (download: 4.20 GiB, generated: 7.07 GiB, total: 11.27 GiB) to content/scientific_papers/arxiv/1.1.1...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…






HBox(children=(FloatProgress(value=0.0, description='Generating splits...', max=3.0, style=ProgressStyle(descr…

HBox(children=(FloatProgress(value=0.0, description='Generating train examples...', max=203037.0, style=Progre…

HBox(children=(FloatProgress(value=0.0, description='Shuffling scientific_papers-train.tfrecord...', max=20303…

HBox(children=(FloatProgress(value=0.0, description='Generating validation examples...', max=6436.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Shuffling scientific_papers-validation.tfrecord...', max=…

HBox(children=(FloatProgress(value=0.0, description='Generating test examples...', max=6440.0, style=ProgressS…

HBox(children=(FloatProgress(value=0.0, description='Shuffling scientific_papers-test.tfrecord...', max=6440.0…

[1mDataset scientific_papers downloaded and prepared to content/scientific_papers/arxiv/1.1.1. Subsequent calls will reuse this data.[0m


# Tokenization

In [4]:
def tokenize(input_str, EOS=1):
    # trax.data.tokenize takes streams and returns streams
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab/',
                                      vocab_file='vocab.subwords'))
    
    return list(inputs) + [EOS]


def detokenize(integers):
    s = trax.data.detokenize(integers,
                             vocab_dir='vocab/',
                             vocab_file='vocab.subwords')
    
    return wrapper.fill(s)


SEP = 0 
EOS = 1

# Tokenized inputs + targets; 0 is sep
def data_preprocessing(stream):
    for (abstract, article) in stream:
        joint = np.array(list(article) + [EOS, SEP] + list(abstract) + [EOS])
        mask = [0] * (len(list(article)) + 2) + [1] * (len(list(abstract)) + 1) # Accounting for EOS and SEP
        yield joint, joint, np.array(mask)

input_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_dir='vocab/',
                       vocab_file='vocab.subwords'),
    data_preprocessing,
    trax.data.FilterByLength(2048)
)

# Apply
train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

assert sum((train_input - train_target)**2) == 0  # They are the same in LM

print(f'Mask sample:\n {train_mask} \n\n')
print('[Example][<EOS>][<pad>][Example Summary][<EOS>] \n')
print(f'Data sample:\n {detokenize(train_input)}')

Mask sample:
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 

# Bucketing

![alt text](https://sun9-62.userapi.com/impg/fuoqdikR_mdjXzKxawFUQl7mNWMiA2SHLsmvsA/UWaTaN0In-E.jpg?size=1444x976&quality=96&sign=775b3928fe8deb0c25f94926f5903efa&type=album)

In [15]:
# Batches of 16 sentences of length < 128 , 
# batches of 8 sentences of length < 256,
# batches of 4 sentences of length < 512, 
# batches of 2 sentences of length < 1024,
# batches of 1 sentence of length < 2048. 
boundaries =  [128, 256,  512, 1024]
batch_sizes = [16,    8,    4,    2, 1]


train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(eval_stream)

# Random sampling
input_batch, _, mask_batch = next(train_batch_stream)

print(input_batch.shape, '\n')
print('[article word values, ... article word values, 1 <EOS>, 0 <pad>, \
\n abstract word values, ... abstract word values, 1 <EOS>, \n 0s for padding \
(if there\'re no 0s in the end, we reached the maximum length)] \n')
print(input_batch[0])
print('\n Article:\n', detokenize(input_batch[0]))

(1, 1711) 

[article word values, ... article word values, 1 <EOS>, 0 <pad>, 
 abstract word values, ... abstract word values, 1 <EOS>, 
 0s for padding (if there're no 0s in the end, we reached the maximum length)] 

[  697    70   230 25719  5409  3601 19622     5   402   106   412   213
 23533 25901  2902  5147   697 24224     4  2652 26054     4 23662  3387
   669  4884 14272 14272 14272 27439  6774  7583  4884   669  6435  4884
 14272 14272 14272 27439  6774  7583  4884  2652  5243 13929 10018  1045
   402 12846  5065   556 23662  2627     2    39  1151  4712   320 25719
  5409  8033   132   213 10432 24255   364   598 27439  6774  7583     3
   213 17104 16071   524   527  1896   763 15986     4  2652 15661    21
  3610 17015  1791    39  1151   213  4469  1175   132   824  4590   598
   379  1248    87   286   446   783   742   320  1151  4237  8051  1248
 24529   497   121    10    32  1564  2703   364 27439  6774  7583     2
    97 17015  1791    39  1771    28  9819    70  10

# Transformer LM

$$
\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\
$$

In [28]:
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)
loop.run(10)




Step      1: Total number of trainable weights: 316336
Step      1: Ran 1 train steps in 10.74 secs
Step      1: train CrossEntropyLoss |  10.41117954
Step      1: eval  CrossEntropyLoss |  10.41211128
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 99.00 secs
Step     10: train CrossEntropyLoss |  10.41329575
Step     10: eval  CrossEntropyLoss |  10.41292191
Step     10: eval          Accuracy |  0.00000000


In [30]:
model = TransformerLM(mode='eval')