In [1]:
import warnings
warnings.filterwarnings('ignore')
import random

import numpy as _np
import tensorflow as tf
from mxnet import np, npx
from gluonnlp.models import get_backbone
from gluonnlp.models.t5 import T5NMTInference
from gluonnlp.sequence_sampler import BeamSearchSampler

npx.set_np()

In [2]:
# set
random.seed(0)
_np.random.seed(0)
np.random.seed(0)
npx.random.seed(0)

# def noise_span_to_unique_sentinel(tokens, noise_mask, vocab_size):
#     prev_token_is_noise = tf.pad(noise_mask[:-1], [[1, 0]])
#     first_noise_tokens = tf.logical_and(
#         noise_mask, tf.logical_not(prev_token_is_noise))
#     subsequent_noise_tokens = tf.logical_and(noise_mask, prev_token_is_noise)
#     sentinel = vocab_size - tf.cumsum(tf.cast(first_noise_tokens, tokens.dtype))
#     tokens = tf.where(first_noise_tokens, sentinel, tokens)
#     return tf.boolean_mask(tokens, tf.logical_not(subsequent_noise_tokens))

In [3]:
T5Model, cfg, tokenizer, local_params_path, _ = get_backbone('google_t5_small')

In [4]:
backbone = T5Model.from_cfg(cfg)
backbone.load_parameters(local_params_path)
t5_nmt = T5NMTInference(backbone)
t5_searcher = BeamSearchSampler(4, t5_nmt, eos_id=1, vocab_size=32128)

In [5]:
text = 'Thank you for inviting me to your party last week .'
tokens = tokenizer.encode(text, int)
print(tokens)
print(tokenizer.encode(text, str))

[1562, 25, 21, 14256, 140, 12, 39, 1088, 336, 471, 3, 5]
['▁Thank', '▁you', '▁for', '▁inviting', '▁me', '▁to', '▁your', '▁party', '▁last', '▁week', '▁', '.']


In [6]:
# tokens = tf.constant(tokens)
# masks = tf.constant(masks)
# assert(tokens.shape == masks.shape)
# tokens = noise_span_to_unique_sentinel(tokens, masks, 32100)

src_tokens = np.array([[1562, 25, 32099, 12, 32098, 471, 3, 5]])
src_valid_length = np.array([src_tokens.shape[1]])
tgt_token = np.array([32099])

In [7]:
states = t5_nmt.init_states(src_tokens, src_valid_length)

In [8]:
res = t5_searcher(tgt_token, states, src_valid_length)

In [9]:
res

(array([[[32099, 32098,     5,     3,     5,     3,     5,     3,     5,
              3,     5,     3,     5,     3,     5,     3,     5,     3,
              5,     3,     5,     3,     5,     3,     5,     3,     5,
              3,     5,     3,     5,     3,     5,     3,     5,     3,
              5,     3,     5,     3,     5,     3,     5,     3,     5,
              3,     5,     3,     5,     3,     5,     3,     5,     3,
              5,     3,     5,     3,     5,     3,     5,     3,     5,
              3,     5,     3,     5,     3,     5,     3,     5,     3,
              5,     3,     5,     3,     5,     3,     5,     3,     5,
              3,     5,     3,     5,     3,     5,     3,     5,     3,
              5,     3,     5,     3,     5,     3,     5,     3,     5,
              3,     5,     3,     5,     3,     5,     3,     5,     3,
              5,     3,     5,     3,     5,     3,  1562,  1562,  1562,
           1562,  1562,  1562,  1562,  1562,  1562,