Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 13, 2020
1 parent 4074a26 commit 44a09a3
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 132 deletions.
100 changes: 59 additions & 41 deletions scripts/conversion_toolkits/convert_fairseq_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch
from gluonnlp.utils.misc import sha1sum, logging_config
from gluonnlp.models.roberta import RobertaModel as gluon_RobertaModel
from gluonnlp.models.roberta import RobertaModel, RobertaForMLM
from gluonnlp.data.tokenizers import HuggingFaceByteBPETokenizer
from gluonnlp.data.vocab import Vocab as gluon_Vocab
from fairseq.models.roberta import RobertaModel as fairseq_RobertaModel
Expand Down Expand Up @@ -163,19 +163,22 @@ def convert_config(fairseq_cfg, vocab_size, cfg):

def convert_params(fairseq_model,
gluon_cfg,
gluon_model_cls,
ctx,
is_mlm=True
gluon_prefix='robert_'):
print('converting params')
fairseq_params = fairseq_model.state_dict()
fairseq_prefix = 'model.decoder.'
gluon_model = gluon_model_cls.from_cfg(
gluon_cfg,
use_mlm=True,
use_pooler=False,
output_all_encodings=True,
prefix=gluon_prefix
)
if is_mlm:
gluon_model = RobertaForMLM(backbone_cfg=gluon_cfg, prefix=gluon_prefix)
gluon_model.backbone_model._output_all_encodings = True
else:
gluon_model = RobertaForMLM.from_cfg(
gluon_cfg,
use_pooler=True,
output_all_encodings=True,
prefix=gluon_prefix
)
gluon_model.initialize(ctx=ctx)
gluon_model.hybridize()
gluon_params = gluon_model.collect_params()
Expand Down Expand Up @@ -223,11 +226,6 @@ def convert_params(fairseq_model,
('sentence_encoder.embed_tokens.weight', 'tokens_embed_weight'),
('sentence_encoder.emb_layer_norm.weight', 'embed_ln_gamma'),
('sentence_encoder.emb_layer_norm.bias', 'embed_ln_beta'),
('lm_head.dense.weight', 'lm_dense1_weight'),
('lm_head.dense.bias', 'lm_dense1_bias'),
('lm_head.layer_norm.weight', 'lm_ln_gamma'),
('lm_head.layer_norm.bias', 'lm_ln_beta'),
('lm_head.bias', 'tokens_embed_bias')
]:
fs_name = fairseq_prefix + k
gl_name = gluon_prefix + v
Expand All @@ -241,14 +239,26 @@ def convert_params(fairseq_model,
gluon_params[gl_pos_embed_name].set_data(
fairseq_params[fs_pos_embed_name].cpu().numpy()[padding_idx + 1:,:])

# assert untie=False
assert np.array_equal(
fairseq_params[fairseq_prefix + 'sentence_encoder.embed_tokens.weight'].cpu().numpy(),
fairseq_params[fairseq_prefix + 'lm_head.weight'].cpu().numpy()
)

if is_mlm:
for k, v in [
('lm_head.dense.weight', 'mlm_proj_weight'),
('lm_head.dense.bias', 'mlm_proj_bias'),
('lm_head.layer_norm.weight', 'mlm_ln_gamma'),
('lm_head.layer_norm.bias', 'mlm_ln_beta'),
('lm_head.bias', 'tokens_embed_bias')
]:
fs_name = fairseq_prefix + k
gl_name = gluon_prefix + v
gluon_params[gl_name].set_data(
fairseq_params[fs_name].cpu().numpy())
# assert untie=False
assert np.array_equal(
fairseq_params[fairseq_prefix + 'sentence_encoder.embed_tokens.weight'].cpu().numpy(),
fairseq_params[fairseq_prefix + 'lm_head.weight'].cpu().numpy()
)
return gluon_model


def test_model(fairseq_model, gluon_model, gpu):
print('testing model')
ctx = mx.gpu(gpu) if gpu is not None else mx.cpu()
Expand Down Expand Up @@ -278,16 +288,16 @@ def test_model(fairseq_model, gluon_model, gpu):

fairseq_model.model.eval()

gl_all_hiddens, gl_x = \
gluon_all_hiddens, gluon_pooled, gluon_mlm_scores = \
gluon_model(gl_input_ids, gl_valid_length)

fs_x, fs_extra = \
fairseq_mlm_scores, fs_extra = \
fairseq_model.model.cuda(gpu)(fs_input_ids, return_all_hiddens=True)
fs_all_hiddens = fs_extra['inner_states']

num_layers = fairseq_model.args.encoder_layers
for i in range(num_layers + 1):
gl_hidden = gl_all_hiddens[i].asnumpy()
gl_hidden = gluon_all_hiddens[i].asnumpy()
fs_hidden = fs_all_hiddens[i]
fs_hidden = fs_hidden.transpose(0, 1)
fs_hidden = fs_hidden.detach().cpu().numpy()
Expand All @@ -299,13 +309,13 @@ def test_model(fairseq_model, gluon_model, gpu):
1E-3
)

gl_x = gl_x.asnumpy()
fs_x = fs_x.transpose(0, 1)
fs_x = fs_x.detach().cpu().numpy()
gluon_mlm_scores = gluon_mlm_scores.asnumpy()
fairseq_mlm_scores = fairseq_mlm_scores.transpose(0, 1)
fairseq_mlm_scores = fairseq_mlm_scores.detach().cpu().numpy()
for j in range(batch_size):
assert_allclose(
gl_x[j, :valid_length[j], :],
fs_x[j, :valid_length[j], :],
gluon_mlm_scores[j, :valid_length[j], :],
fairseq_mlm_scores[j, :valid_length[j], :],
1E-3,
1E-3
)
Expand Down Expand Up @@ -337,24 +347,32 @@ def convert_fairseq_model(args):
vocab_size = convert_vocab(args, fairseq_roberta)

gluon_cfg = convert_config(fairseq_roberta.args, vocab_size,
gluon_RobertaModel.get_cfg().clone())
RobertaModel.get_cfg().clone())
with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of:
of.write(gluon_cfg.dump())

ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu()
gluon_roberta = convert_params(fairseq_roberta,
gluon_cfg,
gluon_RobertaModel,
ctx,
gluon_prefix='roberta_')

if args.test:
test_model(fairseq_roberta, gluon_roberta, args.gpu)
for is_mlm in [False, True]:
gluon_roberta = convert_params(fairseq_roberta,
gluon_cfg,
ctx,
is_mlm=is_mlm,
gluon_prefix='roberta_')

if is_mlm:
if args.test:
test_model(fairseq_roberta, gluon_roberta, args.gpu)

gluon_roberta.save_parameters(os.path.join(args.save_dir, 'model_mlm.params'), deduplicate=True)
logging.info('Convert the RoBERTa MLM model in {} to {}'.
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
os.path.join(args.save_dir, 'model_mlm.params')))
else:
gluon_roberta.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True)
logging.info('Convert the RoBERTa backbone model in {} to {}'.
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
os.path.join(args.save_dir, 'model.params')))

gluon_roberta.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True)
logging.info('Convert the RoBERTa model in {} to {}'.
format(os.path.join(args.fairseq_model_path, 'model.pt'), \
os.path.join(args.save_dir, 'model.params')))
logging.info('Conversion finished!')
logging.info('Statistics:')
rename(args.save_dir)
Expand Down
4 changes: 1 addition & 3 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ def get_network(model_name,
backbone.load_parameters(
backbone_params_path,
ignore_extra=True,
allow_missing=True,
ctx=ctx_l)
num_params, num_fixed_params = count_parameters(backbone.collect_params())
logging.info(
Expand Down Expand Up @@ -478,9 +477,8 @@ def train(args):
batch_size=args.batch_size,
num_workers=0,
sampler=sampler)
# Froze parameters
if 'electra' in args.model_name:
# does not work for albert model since parameters in all layers are shared
# Froze parameters, does not work for albert model since parameters in all layers are shared
if args.untunable_depth > 0:
untune_params(qa_net, args.untunable_depth)
if args.layerwise_decay > 0:
Expand Down
Loading

0 comments on commit 44a09a3

Please sign in to comment.