In [1]:
import os
import tempfile

from gluonnlp.base import get_repo_url
from gluonnlp.models.t5 import _build_t5_tokenizer
from gluonnlp.utils.misc import download

In [2]:
with tempfile.TemporaryDirectory() as dir_path: 
    vocab_path = os.path.join(dir_path, 't5_spm.model')
    download(
        url=get_repo_url() + 'tokenizer_test_models/sentencepiece/case_t5/test_t5spm-5f05e7.model',
        path=vocab_path
    )
    tokenizer = _build_t5_tokenizer(vocab_path, False, 100)
    os.remove(vocab_path)

100%|██████████| 792k/792k [00:00<00:00, 17.0MiB/s]

Downloading /tmp/tmp2gil57ci/t5_spm.model from s3://gluonnlp-numpy-data/tokenizer_test_models/sentencepiece/case_t5/test_t5spm-5f05e7.model...





In [3]:
import random

import torch
import mxnet as mx
from mxnet import np, npx
import numpy as _np
from transformers import T5Model as HFT5
from transformers import T5Config as HFT5CFG
from gluonnlp.models.t5 import T5Model as GLT5
from gluonnlp.models.t5 import google_t5_small as GLT5CFG

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
npx.random.seed(0)
_np.random.seed(0)

In [4]:
hft5 = HFT5(HFT5CFG())
hft5.eval()

glt5 = GLT5.from_cfg(GLT5CFG())
# glt5.initialize()

In [5]:
# ft5.encoder.block[0].layer[0].SelfAttention.q.weight
# glt5.encoder.layers[0].self_attn_q.weight.data()

# d = {}
# for (k, v) in hft5.state_dict().items(): 
#     d[k] = v.shape
# d

# glt5.collect_params()

In [6]:
PARAM_MAP = [
    # 0.
    ('shared.weight', 'input_embedding_layer.weight'), 
    # 1. encoder / decoder
    ('{}.block.0.layer.0.SelfAttention.relative_attention_bias.weight', '{}.relative_position_encoder._rel_pos_embed.weight'), 
    # 2. encoder / decoder, block/layer #, layer_norm->self_attn_layer_norm / SelfAttention.o->self_attn_proj
    ('{}.block.{}.layer.0.{}.weight', '{}.layers.{}.{}.weight'), 
    # 3. encoder / decoder, block/layer #, 0.Self->self / 1.EncDec->cross, q/k/v
    ('{}.block.{}.layer.{}Attention.{}.weight', '{}.layers.{}.{}_attn_{}.weight'), 
    # 4. block/layer #, layer_norm->cross_attn_layer_norm / EncDecAttention.o->cross_attn_proj
    ('decoder.block.{}.layer.1.{}.weight', 'decoder.layers.{}.{}.weight'), 
    # 5. encoder / decoder, block/layer #, (encoder: 1 / decoder: 2), DenseReluDense.wi/wi_0/wi_1/wo / layer_norm
    ('{}.block.{}.layer.{}.{}.weight', '{}.layers.{}.ffn.{}.weight'), 
    # 6. encoder / decoder
    ('{}.final_layer_norm.weight', '{}.final_layer_norm.weight'), 
]

def convert_params(hf_t5_model, gluon_t5_model, ctx): 
    gluon_t5_model.initialize(ctx=ctx)
    hf_params = hf_t5_model.state_dict()
    gluon_params = gluon_t5_model.collect_params()
    # TODO(yongyi-wu): add sanity check, eg. param #, layer #, ffn activation, etc.
    num_layers = gluon_t5_model.num_layers

    def convert(hf_param, gluon_param): 
        gluon_params[gluon_param].set_data(
            hf_params[hf_param].cpu().numpy()
        )
        
    for idx, (hf_key, gluon_key) in enumerate(PARAM_MAP): 
        if idx == 0: 
            convert(hf_key, gluon_key)
        elif idx == 1: 
            for i in ['encoder', 'decoder']: 
                convert(hf_key.format(i), gluon_key.format(i))
        elif idx in [2, 3]: 
            for stack in ['encoder', 'decoder']: 
                for layer in range(num_layers): 
                    if 'Attention' not in hf_key: 
                        for i, j in [('layer_norm', 'self_attn_layer_norm'), ('SelfAttention.o', 'self_attn_proj')]: 
                            convert(hf_key.format(stack, layer, i), gluon_key.format(stack, layer, j))
                    else: 
                        for i in ['q', 'k', 'v']: 
                            convert(hf_key.format(stack, layer, '0.Self', i), gluon_key.format(stack, layer, 'self', i))
                            if stack == 'decoder': 
                                convert(hf_key.format(stack, layer, '1.EncDec', i), gluon_key.format(stack, layer, 'cross', i))
        elif idx == 4:  
            for layer in range(num_layers): 
                for i, j in [('layer_norm', 'cross_attn_layer_norm'), ('EncDecAttention.o', 'cross_attn_proj')]: 
                    convert(hf_key.format(layer, i), gluon_key.format(layer, j))
        elif idx == 5:
            for stack, i in [('encoder', 1), ('decoder', 2)]: 
                for layer in range(num_layers): 
                    if gluon_t5_model.activation == 'relu': 
                        denses = ['wi', 'wo']
                    elif gluon_t5_model.activation == 'gated-gelu': 
                        denses = ['wi_0', 'wi_1', 'wo']
                    else: 
                        raise ValueError('Unrecognized feed froward activation')
                    for j in denses + ['layer_norm']: 
                        convert(
                            hf_key.format(stack, layer, i, j if j == 'layer_norm' else 'DenseReluDense.{}'.format(j)), 
                            gluon_key.format(stack, layer, j)
                        )
        elif idx == 6: 
            for stack in ['encoder', 'decoder']: 
                convert(hf_key.format(stack), hf_key.format(stack))
    
    return gluon_t5_model

In [7]:
glt5 = convert_params(hft5, glt5, mx.cpu())

In [8]:
for src, tgt in [
    ('What time is it?', 'It is 2:00 PM.'), 
    ('Hello World?', 'World Hello!'), 
    ('These dead shall not have died in vain.', 'Government of the people, by the people, for the people shall not perish from the earth')
]: 
    src_data = tokenizer.encode(src, int)
    tgt_data = tokenizer.encode(tgt, int)

    hf_src_data = torch.LongTensor([src_data])
    hf_tgt_data = torch.LongTensor([tgt_data])
    gl_src_data = np.array([src_data], dtype=np.int64)
    gl_tgt_data = np.array([tgt_data], dtype=np.int64)

    hf_res = hft5(input_ids=hf_src_data, decoder_input_ids=hf_tgt_data)['last_hidden_state'].detach().numpy()
    gl_res = glt5(gl_src_data, np.array([len(gl_src_data[0])]), gl_tgt_data, np.array([len(gl_tgt_data[0])]))
    
    assert np.allclose(hf_res, gl_res, rtol=1e-05, atol=1e-05), \
        print('Transformer: {}\nGluon-nlp: {}'.format(hf_res, gl_res))

print('Done!')



Done!
