# Using T5 for Masked Language Modeling (MLM) Task

## Introduction

As we have seen in T5 model's paper[1], the author leveraged a pre-training technique called "random spans". As the name suggests, this method corrupts the input sequence by spans of tokens rather than individual tokens. After that, every noise spans are mapped to unique sentinels starting from `<extra_id_0>` and the objective is to denoise these spans. By handling consequentive corrupted token spans altogether, this method yields significant speed-up as compared to BERT's objective while retain performance. 

For example, suppose we have original sentence as "Thank you for inviting me to your party last week". After applying random spans, we may get "Thank you for \<extra_id_0\> to \<extra_id_1\> week", where the target is "\<extra_id_0\> for inviting me \<extra_id_1\> your party last \<extra_id_2\>". 

In this tutorial, we are going to: 
1. Load a pretrained T5 model
2. Perform the MLM task on some examples

## Preliminaries

To begin, we load required packages. We also set seeds so that you could replicate our results on your machine. In addition, we introduce two helper functions for creating desired corrupted input and converting human-readable output. 

In [1]:
import itertools
import random
import warnings
warnings.filterwarnings('ignore')

import numpy as _np
from mxnet import np, npx
from gluonnlp.data.batchify import Pad
from gluonnlp.models import get_backbone
from gluonnlp.models.t5 import T5Seq2seq, mask_to_sentinel
from gluonnlp.sequence_sampler import BeamSearchSampler

In [2]:
npx.set_np()

random.seed(0)
np.random.seed(0)
npx.random.seed(0)
_np.random.seed(0)

In [3]:
def spans_to_masks(tokenizer, tokens, spans): 
    def _spans_to_masks(tokens, spans): 
        if isinstance(tokens[0], int): 
            masks = [0] * len(tokens)
            for i, span in enumerate(spans): 
                target = []
                for idx in span: 
                    assert 0 <= idx and idx < len(tokens), 'Span index out of range'
                    masks[idx] = 1
                    target.append(tokens[idx])
                print('{}: {}'.format(tokenizer.vocab.to_tokens(-1 - i), tokenizer.decode(target)), end='\t')
            print()
            return masks
        elif isinstance(tokens[0], list): 
            assert len(tokens) == len(spans), 'Every sample must have corresponding tokens and spans'
            res = []
            for i, (tok, s) in enumerate(zip(tokens, spans)): 
                print('[Sample {}]'.format(i), end='\t')
                res.append(_spans_to_masks(tok, s))
            return res
        else: 
            raise TypeError('Unsupported type of tokens: {}'.format(type(tokens)))
    return _spans_to_masks(tokens, spans)


def print_inference(tokenizer, output, spans): 
    for sample in range(output[0].shape[0]): 
        print('[SAMPLE {}]'.format(sample))
        n_spans = len(spans[sample])
        for beam in range(output[0].shape[1]): 
            print('(Beam {})'.format(beam), end='\t')
            ele_output = output[0][sample, beam, :enc_valid_length[sample].item()]
            for n in range(n_spans): 
                i = np.where(ele_output == len(tokenizer.vocab) - 1 - n)[0][0].item()
                j = i + 1 + len(spans[sample][n])
                ele_tokens = ele_output[i + 1:j].tolist()
                print('<extra_id_{}>: {}'.format(n, tokenizer.decode(ele_tokens)), end='\t')
            print()
        print()

## Model Loading

`get_backbone` is a handy way to load models and download pretrained weights (not limited to T5) from the GluonNLP repository. Here, we choose T5-large for illustration purpose. Alternatively, you can use `google_t5_small`, `google_t5_base`, `google_t5_3B`, `google_t5_11B` as well. 

`T5Seq2seq` is a inference model equipped with the incremental decoding feature. Notice that it must be initialized with a `T5Model` instance. 

In [4]:
T5Model, cfg, tokenizer, local_params_path, _ = get_backbone('google_t5_large')
backbone = T5Model.from_cfg(cfg)
backbone.load_parameters(local_params_path)
t5mlm = T5Seq2seq(backbone)
t5mlm.hybridize()

For this MLM task, we will also leverage `BeamSearchSampler`, a powerful and easy-to-use tools in many scenarios. 

In [5]:
beam_size = 4
t5mlm_seacher = BeamSearchSampler(beam_size, t5mlm, eos_id=1, vocab_size=32128)

Since the output of our tokenizer is a Python list (or nested Python list), each sample can be of different lengths. This allows more flexibilities in manipulating intermediate results, but requires an additional step before feeding into the model. `Pad` helps us group multiple samples an ndarray batch in a clean way. 

In [6]:
batcher = Pad(val=0, dtype=np.int32)

## A Toy Example

In this tutorial, we simply use a minibatch of two samples. We can inspect the tokenization result by passing `str` as the second argument. Notice that the tokenizer itself does not add EOS tokens, `</s>`, to the end of sequences. We leave the flexibility and responsibility to the user. 

In [7]:
text = [
    'Andrew Carnegie famously said , " My heart is in the work . " At CMU , we think about work a little differently .', 
    'Peace is a concept of societal friendship and harmony in the absence of hostility and violence . In a social sense , peace is commonly used to mean a lack of conflict and freedom from fear of violence between individuals or groups .' 
]
tokens = tokenizer.encode(text, int)
for ele_tokens in tokens: 
    ele_tokens.append(1) # append EOS token: </s>

In [8]:
for ele_tokens in tokenizer.encode(text, str): 
    print(ele_tokens)

['▁Andrew', '▁Carnegie', '▁famous', 'ly', '▁said', '▁', ',', '▁"', '▁My', '▁heart', '▁is', '▁in', '▁the', '▁work', '▁', '.', '▁"', '▁At', '▁C', 'MU', '▁', ',', '▁we', '▁think', '▁about', '▁work', '▁', 'a', '▁little', '▁differently', '▁', '.']
['▁Peace', '▁is', '▁', 'a', '▁concept', '▁of', '▁', 'societal', '▁friendship', '▁and', '▁harmony', '▁in', '▁the', '▁absence', '▁of', '▁host', 'ility', '▁and', '▁violence', '▁', '.', '▁In', '▁', 'a', '▁social', '▁sense', '▁', ',', '▁peace', '▁is', '▁commonly', '▁used', '▁to', '▁mean', '▁', 'a', '▁lack', '▁of', '▁conflict', '▁and', '▁freedom', '▁from', '▁fear', '▁of', '▁violence', '▁between', '▁individuals', '▁or', '▁groups', '▁', '.']


For illustration purpose, we manually define noise spans, although technically this should be a random process. Notice that every noise span corresponds to a tuple of token indices. 

In [9]:
noise_spans = [
    [(11, 12, 13)], # sequence 1
    [(4, 5, 6), (28, 29, 30), (46, 47, 48)] # sequence 2
]
masks = spans_to_masks(tokenizer, tokens, noise_spans)

[Sample 0]	<extra_id_0>: in the work	
[Sample 1]	<extra_id_0>: concept of 	<extra_id_1>: peace is commonly	<extra_id_2>: individuals or groups	


We have converted the `mask_to_sentinel()` from the `noise_span_to_unique_sentinel()` in the original [T5 repository](https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py), which helps map noise spans to `<extra_id>` sentinels and collapse a span's tokens into a single sentinel. For the input, `tokens` and `masks` are required to have the exact shape. 
> For curious readers, there are many more useful implementations in T5's original repository. 

The preparation step completes as we batch-ify the encoder input tokens (which is corrupted sequences) and record valid length for each sample. 

In [10]:
masked_tokens = mask_to_sentinel(tokens, masks, len(tokenizer.vocab))
enc_tokens = batcher(masked_tokens)
enc_valid_length = np.array([len(tok) for tok in masked_tokens], dtype=np.int32)

## Inference Step

We first get the initial states of decoder by calling `init_states()`. This method includes feeding the corrupted minibatch into the encoder, initializing "past" keys and values in every decoder's layer to zero, etc. The returned states is a 4-tuple of encoded results, valid lengths of corrupted sequences, our position (index) in incremental decoding, and past keys and values. 

In [11]:
states = t5mlm.init_states(enc_tokens, enc_valid_length)

Then, we simply initiate the beam search with a `<pad>` token (token id = 0) for each sample. The beam search will leverage the incremental decoding implemented in `T5Seq2seq` to speed up the inference, though it may still take a while. 

When the search is done, we can print the results nicely using our helper function, and compare them with the masked tokens (see above). Happily, the pretrianed T5-large gives reasonable (and some time perfect) guess in our toy MLM task. 

In [12]:
output = t5mlm_seacher(np.zeros_like(enc_valid_length), states, enc_valid_length)
print_inference(tokenizer, output, noise_spans)

[SAMPLE 0]
(Beam 0)	<extra_id_0>: in the work	
(Beam 1)	<extra_id_0>: in work	
(Beam 2)	<extra_id_0>: in work	
(Beam 3)	<extra_id_0>: in work	

[SAMPLE 1]
(Beam 0)	<extra_id_0>: state of 	<extra_id_1>: peace is	<extra_id_2>: people state	
(Beam 1)	<extra_id_0>: state of 	<extra_id_1>: peace is	<extra_id_2>: people state	
(Beam 2)	<extra_id_0>: state of 	<extra_id_1>: peace is	<extra_id_2>: people state	
(Beam 3)	<extra_id_0>: state of 	<extra_id_1>: peace is	<extra_id_2>: people state	



## References

[1] Raffel, C., et al. "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". JMLR 2020