In [1]:
import sys
import os

import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


In [4]:
train_stream_fn = trax.data.TFDS('cnn_dailymail',
                                 keys=('article', 'highlights'),
                                 train=True)

eval_stream_fn = trax.data.TFDS('cnn_dailymail',
                                keys=('article', 'highlights'),
                                train=False)

[1mDownloading and preparing dataset cnn_dailymail/plain_text/3.0.0 (download: 558.32 MiB, generated: 1.27 GiB, total: 1.82 GiB) to /Users/takshshilarawat/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /Users/takshshilarawat/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0.incompleteNE8IVC/cnn_dailymail-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=287113.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /Users/takshshilarawat/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0.incompleteNE8IVC/cnn_dailymail-validation.tfrecord


HBox(children=(FloatProgress(value=0.0, max=13368.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /Users/takshshilarawat/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0.incompleteNE8IVC/cnn_dailymail-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=11490.0), HTML(value='')))

[1mDataset cnn_dailymail downloaded and prepared to /Users/takshshilarawat/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0. Subsequent calls will reuse this data.[0m


In [7]:
def tokenize(input_str, EOS=1):
 
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab_dir/',
                                      vocab_file='summarize32k.subword.subwords'))

    return list(inputs) + [EOS]

def detokenize(integers):
    s = trax.data.detokenize(integers,
                             vocab_dir='vocab_dir/',
                             vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(s)

In [8]:
SEP = 0 # Padding or separator token
EOS = 1 # End of sentence token


def preprocess(stream):
    for (article, summary) in stream:
        joint = np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
        mask = [0] * (len(list(article)) + 2) + [1] * (len(list(summary)) + 1) # Accounting for EOS and SEP
        yield joint, joint, np.array(mask)


input_pipeline = trax.data.Serial(
    
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                       vocab_file='summarize32k.subword.subwords'),
    preprocess,
    trax.data.FilterByLength(2048)
)


train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

assert sum((train_input - train_target)**2) == 0  # They are the same in Language Model (LM).

In [9]:
print(f'Single example:\n\n {detokenize(train_input)}')

Single example:

 By . Daily Mail Reporter . UPDATED: . 09:53 EST, 12 January 2012 . A
motorist was stunned when he discovered double yellow lines had been
painted underneath his car while it was parked - and then given a
ticket. Flecks of yellow paint were even sprayed on the bumper of
Patrick McCrystal's car as the lines were painted under the front of
it in Kedleston Street, Derby. The 49-year-old had parked his Ford
Fiesta in the street near to a Co-operative store and a petrol
station, where he works, for three years. Stunned: Patrick McCrystal
with his Ford Fiesta, which was given a ticket after council workmen
had sprayed yellow lines under the car while it was parked legally .
When he parked for his 2pm shift, he noticed new yellow lines had been
painted across a housing block entrance. But there was a gap between
those lines and existing ones in the street, so Mr McCrystal parked
there, in his usual spot. Hours later, a colleague on his dinner break
saw that extra lines had be

In [10]:
boundaries =  [128, 256,  512, 1024]
batch_sizes = [16,    8,    4,    2, 1]

# Create the streams.
train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(eval_stream)

In [11]:
input_batch, _, mask_batch = next(train_batch_stream)

input_batch.shape

(1, 1727)

In [12]:
print(input_batch[0])

[  380   527   213   296 29725     4     5  2448  3620 15124   902    39
  1151   669 27439  6050 13459  1628  3528  1879 29725   391   592   166
   527    95     6  1260  2942   578     3   198    39  1151    92  6164
   132   181    64   527   805 29725     4     5  2840   132   401   509
   320 20669   320  3636 15052   578   770   527   213   947     3     9
 24046  1041    78    36   527   213 13965  1229  1536   390   527   213
   104     2   412  2659   527   101  1124   320   955   278   102  4956
   228  1019  4078     3  9175  6051     4   246  1019   846   379  2165
   132   186    64   527   401  9133  2840   947    18    46 19715   592
     2   103    23    46  1595   379 11423 17805   232    11 13049  7844
   809   213   401   947     2    36   527   213 13965  1229   132   213
   296     2  4872   447  1435 19715   379     9 26588   400   809   213
   947     2  1480   229  2232   691  3636 15052     2    39  2897   161
  1427   320  1536    78   644  2836     2   758 17

In [13]:
print('Article:\n\n', detokenize(input_batch[0]))

Article:

 One of the country’s biggest rail terminals will be ‘effectively
closed’ today because of over-running engineering works. There will be
no trains in or out of King’s Cross in London due to delays to Network
Rail works north of the station. The disruption comes on one of the
busiest travel days of the year, as thousands of people try to return
home after visiting family for Christmas. Scroll down for video .
Services in and out of London Kings Cross station have been cancelled
today, it has been announced . Frustration: Travellers at the London
station, one of the busiest in the country, where services are
cancelled . The disruptions at the station, which is managed by
Network Rail, will affect those planning to travel on East Coast,
First Hull Trains, Grand Central and Great Northern services. East
Coast Trains made the announcement on its website yesterday evening,
where it advised passengers to delay their travel if possible. It also
said that a revised timetable is curren

In [14]:
def create_tensor(t):
    """Create tensor from list of lists"""
    return jnp.array(t)


def display_tensor(t, name):
    """Display shape and tensor"""
    print(f'{name} shape: {t.shape}\n')
    print(f'{t}\n')

In [15]:
q = create_tensor([[1, 0, 0], [0, 1, 0]])
display_tensor(q, 'query')
k = create_tensor([[1, 2, 3], [4, 5, 6]])
display_tensor(k, 'key')
v = create_tensor([[0, 1, 0], [1, 0, 1]])
display_tensor(v, 'value')
m = create_tensor([[0, 0], [-1e9, 0]])
display_tensor(m, 'mask')

query shape: (2, 3)

[[1 0 0]
 [0 1 0]]

key shape: (2, 3)

[[1 2 3]
 [4 5 6]]

value shape: (2, 3)

[[0 1 0]
 [1 0 1]]

mask shape: (2, 2)

[[ 0.e+00  0.e+00]
 [-1.e+09  0.e+00]]





In [16]:
def DotProductAttention(query, key, value, mask):

    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "Embedding dimensions of q, k, v aren't all the same"

    depth = query.shape[-1]
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
    
    if mask is not None: 
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)

    attention = jnp.matmul(dots, value)

    return attention

In [17]:
def compute_attention_heads_closure(n_heads, d_head):
    def compute_attention_heads(x):

        batch_size = x.shape[0]
        seqlen = x.shape[1]
        x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
        x = jnp.transpose(x, (0, 2, 1, 3))
        x = jnp.reshape(x,(batch_size*n_heads,seqlen, d_head))
        
        return x
    
    return compute_attention_heads

In [18]:
def dot_product_self_attention(q, k, v):
    mask_size = q.shape[-2]
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)

    return DotProductAttention(q, k, v, mask)

In [19]:
def compute_attention_output_closure(n_heads, d_head):

    def compute_attention_output(x):
        seqlen = x.shape[1]
        x = jnp.reshape(x,(-1, n_heads,seqlen,d_head))
        x = jnp.transpose(x, (0,2,1,3))
        return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
    
    return compute_attention_output

In [20]:
def CausalAttention(d_feature, 
                    n_heads, 
                    compute_attention_heads_closure=compute_attention_heads_closure,
                    dot_product_self_attention=dot_product_self_attention,
                    compute_attention_output_closure=compute_attention_output_closure,
                    mode='train'):
    assert d_feature % n_heads == 0
    d_head = d_feature // n_heads

    ComputeAttentionHeads = tl.Fn('AttnHeads',  compute_attention_heads_closure(n_heads, d_head), n_out=1)
        

    return tl.Serial(
        tl.Branch( # creates three towers for one input, takes activations and creates queries keys and values
            [tl.Dense(d_feature), ComputeAttentionHeads], # queries
            [tl.Dense(d_feature), ComputeAttentionHeads], # keys
            [tl.Dense(d_feature), ComputeAttentionHeads], # values
        ),
        
        tl.Fn('DotProductAttn', dot_product_self_attention, n_out=1), # takes QKV
        tl.Fn('AttnOutput', compute_attention_output_closure(n_heads, d_head), n_out=1), # to allow for parallel
        tl.Dense(d_feature) # Final dense layer
    )



In [21]:
def DecoderBlock(d_model, d_ff, n_heads,
                 dropout, mode, ff_activation):
 
    causal_attention = CausalAttention( 
                        d_model,
                        n_heads=n_heads,
                        mode=mode
                        )

    feed_forward = [ 
        tl.LayerNorm(),
        tl.Dense(d_ff),
        ff_activation(), # Generally ReLU
        tl.Dropout(rate=dropout, mode=mode),
        tl.Dense(d_model),
        tl.Dropout(rate=dropout, mode=mode)
    ]

    return [
      tl.Residual(
           tl.LayerNorm(),
          causal_attention,
          tl.Dropout(rate=dropout, mode=mode)
        ),
      tl.Residual(
          feed_forward
        ),
      ]


In [22]:
def TransformerLM(vocab_size=33300,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  dropout=0.1,
                  max_len=4096,
                  mode='train',
                  ff_activation=tl.Relu):

    positional_encoder = [ 
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode)]

    decoder_blocks = [ 
        DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers)]

    return tl.Serial(
    
        tl.ShiftRight(mode=mode),
        positional_encoder,
        decoder_blocks,
        tl.LayerNorm(),

        tl.Dense(vocab_size),
        tl.LogSoftmax()
    )


In [23]:
from trax.supervised import training

def training_loop(TransformerLM, train_gen, eval_gen, output_dir = "~/model"):

    output_dir = os.path.expanduser(output_dir)  # trainer is an object
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01)

    train_task = training.TrainTask( 
      labeled_data=train_gen, # The training generator
      loss_layer=tl.CrossEntropyLoss(), # Loss function 
      optimizer=trax.optimizers.Adam(learning_rate=0.01), 
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=10
    )

    eval_task = training.EvalTask( 
      labeled_data=eval_gen, # The evaluation generator
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] # CrossEntropyLoss and Accuracy
    )

    loop = training.Loop(TransformerLM(d_model=4,
                                       d_ff=16,
                                       n_layers=1,
                                       n_heads=2,
                                       mode='train'),
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    
    return loop

In [None]:
!rm -f ~/model/model.pkl.gz
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)
loop.run(10)

In [None]:
# Get the model architecture
model = TransformerLM(mode='eval')

# Load the pre-trained weights
model.init_from_file('model.pkl.gz', weights_only=True)

In [None]:
def next_symbol(cur_output_tokens, model):

    token_length = len(cur_output_tokens)
    
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] # Don't replace this 'None'! This is a way of setting the batch dim

    output, _ = model((padded_with_batch, padded_with_batch)) 
    log_probs = output[0, token_length, :]
    
    return int(np.argmax(log_probs))

In [None]:
def greedy_decode(input_sentence, model):

    cur_output_tokens = tokenize(input_sentence) + [0]
    generated_output = [] 
    cur_output = 0 
    EOS = 1 
    
    while cur_output != EOS:
        cur_output = next_symbol(cur_output_tokens, model)
        cur_output_tokens.append(cur_output)
        generated_output.append(cur_output)
        print(detokenize(generated_output))
    
    return detokenize(generated_output)

In [None]:
test_sentence = "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))