In [1]:
import warnings
warnings.filterwarnings('ignore')
import random

import numpy as _np
import tensorflow as tf
from mxnet import np, npx
from gluonnlp.data.batchify import Pad
from gluonnlp.models import get_backbone
from gluonnlp.models.t5 import T5Seq2seq, mask_to_sentinel
from gluonnlp.sequence_sampler import BeamSearchSampler

npx.set_np()
np.set_printoptions(threshold=np.inf)

In [2]:
T5Model, cfg, tokenizer, local_params_path, _ = get_backbone('google_t5_base')

In [3]:
backbone = T5Model.from_cfg(cfg)
backbone.load_parameters(local_params_path)
t5_nmt = T5Seq2seq(backbone)
t5_nmt.hybridize()
t5_searcher = BeamSearchSampler(4, t5_nmt, eos_id=1, vocab_size=32128)

In [4]:
batcher = Pad(val=0, dtype=np.int32)

In [5]:
text = [
    'Let us realize the arc of the moral universe is long, but it bends toward justice . ', 
    'The road to success and the road to failure are almost exactly the same . ', 
    'Four score and seven years ago our fathers brought forth , upon this continent , a new nation , conceived in Liberty , and dedicated to the proposition that all men are created equal .' 
]
tokens = tokenizer.encode(text, int)
print([len(tok) for tok in tokens])

[21, 16, 43]


In [6]:
mask1 = [1 if i in [6, 7, 8] else 0 for i in range(21)]
mask2 = [1 if i in [9, 10] else 0 for i in range(16)]
mask3 = [1 if i in [7, 8, 9, 23, 24, 35, 36, 37] else 0 for i in range(43)]

In [7]:
masked_tokens, _ = mask_to_sentinel(tokens, [mask1, mask2, mask3], len(tokenizer.vocab))

In [8]:
enc_valid_length = np.array([len(tok) for tok in masked_tokens], dtype=np.int32)
enc_tokens = batcher(masked_tokens)

In [9]:
states = t5_nmt.init_states(enc_tokens, enc_valid_length)

In [10]:
res = t5_searcher(np.zeros_like(enc_valid_length), states, enc_valid_length)
res

(array([[[    0,     3,     5,    37,     3,  4667,    13,     8,     3,
           4667,    13,     8,     3,  4667,    13,     8,     3,  4667,
             13,     8,     3,  4667,    13,     8,     3,  4667,    13,
              8,     3,  4667,    13,     8,     3,  4667,    13,     8,
              3,  4667,    13,     8,     3,  4667,    13,     8,     3,
           4667,    13,     8,     3,  4667,    13,     8,     3,  4667,
             13,     8,     3,  4667,    13,     8,     3,  4667,    13,
              8,     3,  4667,    13,     8,     3,  4667,    13,     8,
              3,  4667,    13,     8,     3,  4667,    13,     8,     3,
           4667,    13,     8,  6923,     3,     5,    37,  6923,     3,
              5,    37,  6923,     3,     5,  1563,   178,  3384,     8,
              3,  4667,    13,     8,  6923,     3,     5,    37,  6923,
              8,  6923,     8,  6923,     8,  6923,     1],
         [    0,     3,     5,    37,     3,  4667,    13,     8

In [11]:
for batch in range(enc_tokens.shape[0]): 
    print('#### ANOTHER SAMPLE ###')
    for i in range(4): 
        l = res[0][batch, i, :enc_valid_length[batch].item()]
        l = np.clip(l, 0, 31099).tolist() # TODO(yongyi-wu): handle out-of-range extra tokens
        print(tokenizer.decode(l), end='\n\n')

#### ANOTHER SAMPLE ###
. The arc of the arc of the arc of the arc of

. The arc of the arc of the arc of the arc of

. The arc of the arc of the arc of the arc of

. The arc of the arc of the arc of the arc of

#### ANOTHER SAMPLE ###
tufted aretufted aretufted the same . . . 

. The road to success are the same . . 

. The road to failure are the same . . 

tufted aretufted aretufted the same . . . 

#### ANOTHER SAMPLE ###
tufted, tufted that alltufted nation ,tufted, tufted that alltufteds , andtufteds , andtufted, tufted, tufted that alltufted,

tufted, tufted that alltufted nation ,tufted, tufted that alltufteds , andtufteds , andtufted, tufted, tufted that alltufted,

tufted, tufted that alltufted nation ,tufted, tufted that alltufteds , andtufteds , andtufted, tufted, tufted that alltufted,

tufted, tufted that alltufted nation ,tufted, tufted that alltufteds , andtufteds ,tufted, tufted, tufted that alltufted, 

