In [1]:
import sys
import os

import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


In [2]:
#dataset
train_stream_fn = trax.data.TFDS('cnn_dailymail',
                                 data_dir='data/',
                                 keys=('article', 'highlights'),
                                 train=True)

eval_stream_fn = trax.data.TFDS('cnn_dailymail',
                                data_dir='data/',
                                keys=('article', 'highlights'),
                                train=False)

In [3]:
def tokenize(input_str, EOS=1):
    """Input str to features dict, ready for inference"""
  
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab_dir/',
                                      vocab_file='summarize32k.subword.subwords'))
    
    return list(inputs) + [EOS]

def detokenize(integers):
    """List of ints to str"""
  
    s = trax.data.detokenize(integers,
                             vocab_dir='vocab_dir/',
                             vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(s)

In [4]:
# Special tokens
SEP = 0 # Padding / separator
EOS = 1 # End of sentence 

#concatenate inputs
def preprocess(stream):
    for (article, summary) in stream:
        joint = np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
        mask = [0] * (len(list(article)) + 2) + [1] * (len(list(summary)) + 1) 
        yield joint, joint, np.array(mask)


input_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                       vocab_file='summarize32k.subword.subwords'),
    preprocess,
    trax.data.FilterByLength(2048)
)


train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

assert sum((train_input - train_target)**2) == 0  # They are the same in Language Model (LM).

In [5]:
# prints mask, 0s on article, 1s on summary
print(f'Single example mask:\n\n {train_mask}')

Single example mask:

 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0

In [6]:
print(f'Single example:\n\n {detokenize(train_input)}')

Single example:

 Fears are growing that Britain's jails are becoming a hotbed of
extremism after it was revealed today that nearly half the inmates of
one top security prison are Muslim. Some 42 per cent of those housed
at Category A Whitemoor jail - and more than a quarter of those in
London prisons - consider themselves to be of Islamic faith. Experts
now fear large numbers are being radicalised on the inside, where they
say the spread of Jihadist ideas is rife. Figures show more than a
quarter of inmates in London jails are Muslim, with one Category A
jail revealing 42 per cent of its convicts follow the Islamic faith .
Whitemoor inmate Zia Al Haq, left, was jailed for 18 for planning bomb
attacks in London while Nezar Hindawi, right, was handed a 45-year
sentence for plotting to blow up a jet . A source at Cambridgeshire
jail Whitemoor told the Sunday People: 'Whitemoor is now effectively
run by Muslims, many of whom are Jihadis.' A 2012 probe into the jail
branded it a 'Taliban r

In [7]:
# Bucketing to create batched generators.

boundaries =  [128, 256,  512, 1024]
batch_sizes = [16,    8,    4,    2, 1]

train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes)(eval_stream)

In [9]:
input_batch, _, mask_batch = next(train_batch_stream)

input_batch.shape

(1, 1924)

In [10]:
# print corresponding integer values
print(input_batch[0])

[  202 14607  5318   861 15883  1353   452  1292 11146    76   130  8016
    19    90   196     3   200   148   130  1188  3768   213  1483   527
   607   678     2    90  7511  1171 19941     2   130  1188  1475    61
    38  5799  1019   156     3    13  2492    61  7084   444   110   285
   130   296  1793     7    26  1325   320  3273   132   566  1292   355
   213  8866 25727     4   527  2528  1480    40    46  1471    78   350
  1068   132    28  6294   214 22820 11214    21    76    28   238   527
  4131 14777 24302   691   213  5071   763  4945     3   577  7511   213
   350  1034  9423   248 24792   119  2063   132  2045     2   130  1188
   133  1084   285    13  1353  7087  7033    47   809  2946   180   286
   320  2040    38   150  4394  3771     3    13   540   320 14657  5849
 15037   111   130  1188    76   186    28    60    76    41   547   213
  1282   132    31 11261     3    52  1353   213   669 27634     4 20096
  7086 27634   391  4394    76  4872    28   666  2

In [11]:
# print article and its summary
print('Article:\n\n', detokenize(input_batch[0]))

Article:

 (CNN) -- My mom was always sports mad -- my dad not so much. But both
my parents understood the importance of significant events, so when
opportunity knocked, my parents opened up all doors for me. I grew up
knowing full well that my country wasn't allowed to participate in
international sports following the sporting boycott of 1977 which had
been placed on South Africa in a protest against apartheid -- a system
of racial segregation enforced by the ruling white minority. So when
the South African rugby team toured New Zealand in 1981, my parents
made sure that I was awoken at 04:30 to watch all three Test matches.
I got to snuggle between my parents -- and a first -- they put the
television in their bedroom. It was the "flour bomb" Test -- where a
light aircraft flew over Auckland's Eden Park before and during the
match, dropping flour bombs, smoke bombs and anti-apartheid leaflets
-- where I began to realize that other countries didn't like us. I
didn't understand why at t

In [12]:
def create_tensor(t):
    return jnp.array(t)


def display_tensor(t, name):
    print(f'{name} shape: {t.shape}\n')
    print(f'{t}\n')

In [18]:
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention
    """

    assert query.shape[-1] == key.shape[-1] == value.shape[-1], 
    depth = query.shape[-1]
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
   
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
   
    logsumexp = trax.fastmath.logsumexp(dots, axis = -1, keepdims = True)
    dots = jnp.exp(dots - logsumexp)
    attention = jnp.matmul(dots, value)
    
    return attention

In [21]:
def compute_attention_heads_closure(n_heads, d_head):

    def compute_attention_heads(x):
        """ Compute the attention heads.
        """
        
        batch_size = x.shape[0]
        seqlen = x.shape[1]
        x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
        x = jnp.transpose(x, (0, 2, 1, 3))
        x = jnp.reshape(x, (-1, seqlen, d_head))
        
        return x
    
    return compute_attention_heads

In [23]:
def dot_product_self_attention(q, k, v):
    """ Masked dot product self attention.
    """
    mask_size = q.shape[1]
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    
    return DotProductAttention(q, k, v, mask)

In [25]:
def compute_attention_output_closure(n_heads, d_head):
    
    def compute_attention_output(x):
        """ Compute the attention output.
        """
        seqlen = x.shape[1]
        x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
        x = jnp.transpose(x, ( 0, 2, 1 , 3))

        return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
    
    return compute_attention_output

In [29]:
def CausalAttention(d_feature, 
                    n_heads, 
                    compute_attention_heads_closure=compute_attention_heads_closure,
                    dot_product_self_attention=dot_product_self_attention,
                    compute_attention_output_closure=compute_attention_output_closure,
                    mode='train'):
    """Transformer multi-headed causal attention.
    """
    
    assert d_feature % n_heads == 0
    d_head = d_feature // n_heads

    ComputeAttentionHeads = tl.Fn('AttnHeads', compute_attention_heads_closure(n_heads, d_head), n_out=1)
        

    return tl.Serial(
        tl.Branch( 
            [tl.Dense(d_feature), ComputeAttentionHeads], # queries
            [tl.Dense(d_feature), ComputeAttentionHeads], # keys
            [tl.Dense(d_feature), ComputeAttentionHeads], # values
        ),
        
        tl.Fn('DotProductAttn', dot_product_self_attention, n_out=1), 
        tl.Fn('AttnOutput', compute_attention_output_closure(n_heads, d_head), n_out=1), 
        tl.Dense(d_feature) 
    )

In [30]:
print(CausalAttention(d_feature=512, n_heads=8))

Serial[
  Branch_out3[
    [Dense_512, AttnHeads]
    [Dense_512, AttnHeads]
    [Dense_512, AttnHeads]
  ]
  DotProductAttn_in3
  AttnOutput
  Dense_512
]


In [31]:
def DecoderBlock(d_model, d_ff, n_heads,
                 dropout, mode, ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.
    """
    
    
    causal_attention = CausalAttention( 
                        d_model,
                        n_heads=n_heads,
                        mode=mode
                        )

    feed_forward = [ 
        tl.LayerNorm(),
        tl.Dense(d_ff),
        ff_activation(),
        tl.Dropout(rate = dropout, mode = mode),
        tl.Dense(d_model),
        tl.Dropout(rate = dropout, mode = mode)
    ]

    return [
      tl.Residual(
          tl.LayerNorm(),
          causal_attention,
          tl.Dropout(rate = dropout, mode = mode)
        ),
      tl.Residual(
          feed_forward
        ),
      ]

In [32]:
print(DecoderBlock(d_model=512, d_ff=2048, n_heads=8, dropout=0.1, mode='train', ff_activation=tl.Relu))

[Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Serial[
        Branch_out3[
          [Dense_512, AttnHeads]
          [Dense_512, AttnHeads]
          [Dense_512, AttnHeads]
        ]
        DotProductAttn_in3
        AttnOutput
        Dense_512
      ]
      Dropout
    ]
  ]
  Add_in2
], Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Dense_2048
      Relu
      Dropout
      Dense_512
      Dropout
    ]
  ]
  Add_in2
]]


In [37]:
def TransformerLM(vocab_size=33300,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  dropout=0.1,
                  max_len=4096,
                  mode='train',
                  ff_activation=tl.Relu):
    """Returns a Transformer language model.
    """
    
    positional_encoder = [ 
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate = dropout, mode = mode),
        tl.PositionalEncoding(max_len = max_len, mode = mode)]

    decoder_blocks = [ 
        DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers)]

    return tl.Serial(
        tl.ShiftRight(mode = mode),
        positional_encoder,
        decoder_blocks,
        tl.LayerNorm(),

        tl.Dense(vocab_size),
        tl.LogSoftmax()
    )


In [38]:
print(TransformerLM(n_layers=1))

Serial[
  ShiftRight(1)
  Embedding_33300_512
  Dropout
  PositionalEncoding
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Serial[
          Branch_out3[
            [Dense_512, AttnHeads]
            [Dense_512, AttnHeads]
            [Dense_512, AttnHeads]
          ]
          DotProductAttn_in3
          AttnOutput
          Dense_512
        ]
        Dropout
      ]
    ]
    Add_in2
  ]
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Dense_2048
        Relu
        Dropout
        Dense_512
        Dropout
      ]
    ]
    Add_in2
  ]
  LayerNorm
  Dense_33300
  LogSoftmax
]


In [39]:
from trax.supervised import training

def training_loop(TransformerLM, train_gen, eval_gen, output_dir = "~/model"):

    output_dir = os.path.expanduser(output_dir)  
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01)


    train_task = training.TrainTask( 
      labeled_data=train_gen, 
      loss_layer=tl.CrossEntropyLoss(), 
      optimizer=trax.optimizers.Adam(0.01), 
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=10
    )

    eval_task = training.EvalTask( 
      labeled_data=eval_gen,
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] 
    )


    loop = training.Loop(TransformerLM(d_model=4,
                                       d_ff=16,
                                       n_layers=1,
                                       n_heads=2,
                                       mode='train'),
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    
    return loop

In [40]:
!rm -f ~/model/model.pkl.gz
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream)
loop.run(10)


Step      1: Ran 1 train steps in 9.16 secs
Step      1: train CrossEntropyLoss |  10.41526318
Step      1: eval  CrossEntropyLoss |  10.41306114
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 60.16 secs
Step     10: train CrossEntropyLoss |  10.41435432
Step     10: eval  CrossEntropyLoss |  10.41310978
Step     10: eval          Accuracy |  0.00000000


In [41]:
model = TransformerLM(mode='eval')
model.init_from_file('model.pkl.gz', weights_only=True)

In [44]:
def next_symbol(cur_output_tokens, model):
    """Returns the next symbol for a given sentence.
    """
    token_length = len(cur_output_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length + 1)))

    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] 
    output, _ = model((padded_with_batch, padded_with_batch)) 
    log_probs = output[0, token_length, :]
    
    
    return int(np.argmax(log_probs))

In [46]:
def greedy_decode(input_sentence, model):

    cur_output_tokens = tokenize(input_sentence) + [0]
    generated_output = [] 
    cur_output = 0 
    EOS = 1 
    
    while cur_output != EOS:
        cur_output = next_symbol(cur_output_tokens, model)
        cur_output_tokens.append(cur_output)
        generated_output.append(cur_output)
        print(detokenize(generated_output))
    
    
    return detokenize(generated_output)

In [47]:
test_sentence = "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))

It was a sunny day when I went to the market to buy some flowers. But
I only found roses, not tulips. 

:
: I
: I just
: I just found
: I just found ros
: I just found roses
: I just found roses,
: I just found roses, not
: I just found roses, not tu
: I just found roses, not tulips
: I just found roses, not tulips
: I just found roses, not tulips.
: I just found roses, not tulips.<EOS>
: I just found roses, not tulips.<EOS>


In [None]:
article = "It’s the posing craze sweeping the U.S. after being brought to fame by skier Lindsey Vonn, soccer star Omar Cummings, baseball player Albert Pujols - and even Republican politician Rick Perry. But now four students at Riverhead High School on Long Island, New York, have been suspended for dropping to a knee and taking up a prayer pose to mimic Denver Broncos quarterback Tim Tebow. Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were all suspended for one day because the ‘Tebowing’ craze was blocking the hallway and presenting a safety hazard to students. Scroll down for video. Banned: Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll (all pictured left) were all suspended for one day by Riverhead High School on Long Island, New York, for their tribute to Broncos quarterback Tim Tebow. Issue: Four of the pupils were suspended for one day because they allegedly did not heed to warnings that the 'Tebowing' craze at the school was blocking the hallway and presenting a safety hazard to students."
print(wrapper.fill(article), '\n')
print(greedy_decode(article, model))

**Expected Output:**
```CPP
Jordan
Jordan Ful
Jordan Fulcol
Jordan Fulcoly
Jordan Fulcoly,
Jordan Fulcoly, Wayne
Jordan Fulcoly, Wayne Dre
Jordan Fulcoly, Wayne Drexe
Jordan Fulcoly, Wayne Drexel
Jordan Fulcoly, Wayne Drexel,
.
.
.

Final summary:

Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard to
students.<EOS>
```