Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 15, 2020
1 parent dc55fc9 commit f5c94a6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
31 changes: 18 additions & 13 deletions scripts/conversion_toolkits/convert_fairseq_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def test_model(fairseq_model, gluon_model, gpu):
ctx = mx.gpu(gpu) if gpu is not None else mx.cpu()
batch_size = 3
seq_length = 32
num_mask = 5
vocab_size = len(fairseq_model.task.dictionary)
padding_id = fairseq_model.model.decoder.sentence_encoder.padding_idx
input_ids = np.random.randint( # skip padding_id
Expand All @@ -281,34 +280,30 @@ def test_model(fairseq_model, gluon_model, gpu):
seq_length,
(batch_size,)
)
mlm_positions = np.random.randint(
0,
seq_length // 2,
(batch_size, num_mask)
)

for i in range(batch_size): # add padding, for fairseq padding mask
input_ids[i,valid_length[i]:] = padding_id

gl_input_ids = mx.np.array(input_ids, dtype=np.int32, ctx=ctx)
gl_valid_length = mx.np.array(valid_length, dtype=np.int32, ctx=ctx)
gl_masked_positions = mx.np.array(mlm_positions, dtype=np.int32, ctx=ctx)
# project the all tokens that is taking whole positions
gl_masked_positions = mx.npx.arange_like(gl_input_ids, axis=1)
gl_masked_positions = gl_masked_positions + mx.np.zeros_like(gl_input_ids)

fs_input_ids = torch.from_numpy(input_ids).cuda(gpu)
fs_masked_positions = torch.from_numpy(mlm_positions).cuda(gpu)
if gpu is not None:
fs_input_ids = fs_input_ids.cuda(gpu)

fairseq_model.model.eval()

gl_all_hiddens, gl_pooled, gl_mlm_scores = \
gluon_model(gl_input_ids, gl_valid_length, gl_masked_positions)

fs_mlm_scores, fs_extra = \
fairseq_model.model.cuda(gpu)(
fs_input_ids,
return_all_hiddens=True,
masked_tokens=fs_masked_positions)
return_all_hiddens=True)
fs_all_hiddens = fs_extra['inner_states']

# checking all_encodings_outputs
num_layers = fairseq_model.args.encoder_layers
for i in range(num_layers + 1):
gl_hidden = gl_all_hiddens[i].asnumpy()
Expand All @@ -322,7 +317,17 @@ def test_model(fairseq_model, gluon_model, gpu):
1E-3,
1E-3
)
#TODO(zheyuye), checking the masking scores
# checking masked_language_scores
gl_mlm_scores = gl_mlm_scores.asnumpy()
fs_mlm_scores = fs_mlm_scores.transpose(0, 1)
fs_mlm_scores = fs_mlm_scores.detach().cpu().numpy()
for j in range(batch_size):
assert_allclose(
gl_mlm_scores[j, :valid_length[j], :],
fs_mlm_scores[j, :valid_length[j], :],
1E-3,
1E-3
)

def rename(save_dir):
"""Rename converted files with hash"""
Expand Down
15 changes: 9 additions & 6 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def parse_args():
help='list of gpus to run, e.g. 0 or 0,2,5. -1 means using cpu.')
# Training hyperparameters
parser.add_argument('--seed', type=int, default=100, help='Random seed')
parser.add_argument('--log_interval', type=int, default=100, help='The logging interval.')
parser.add_argument('--log_interval', type=int, default=50,
help='The logging interval for training')
parser.add_argument('--eval_log_interval', type=int, default=10,
help='The logging interval for evaluation')
parser.add_argument('--save_interval', type=int, default=None,
help='the number of steps to save model parameters.'
'default is every epoch')
Expand Down Expand Up @@ -135,7 +138,7 @@ def parse_args():
parser.add_argument('--all_evaluate', action='store_true',
help='Whether to evaluate all intermediate checkpoints '
'instead of only last one')
parser.add_argument('--max_saved_ckpt', type=int, default=10,
parser.add_argument('--max_saved_ckpt', type=int, default=5,
help='The maximum number of saved checkpoints')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -323,8 +326,8 @@ def get_squad_features(args, tokenizer, segment):
The list of processed data features
"""
data_cache_path = os.path.join(CACHE_PATH,
'dev_{}_squad_{}.ndjson'.format(args.model_name,
args.version))
'{}_{}_squad_{}.ndjson'.format(
segment, args.model_name, args.version))
is_training = (segment == 'train')
if os.path.exists(data_cache_path) and not args.overwrite_cache:
data_features = []
Expand Down Expand Up @@ -637,7 +640,7 @@ def train(args):
ckpt_candidates = [
f for f in os.listdir(
args.output_dir) if f.endswith('.params')]
# keep last 10 checkpoints
# keep last `max_saved_ckpt` checkpoints
if len(ckpt_candidates) > args.max_saved_ckpt:
ckpt_candidates.sort(key=lambda ele: (len(ele), ele))
os.remove(os.path.join(args.output_dir, ckpt_candidates[0]))
Expand Down Expand Up @@ -841,7 +844,7 @@ def eval_validation(ckpt_name, best_eval):
num_workers=0,
shuffle=False)

log_interval = args.log_interval
log_interval = args.eval_log_interval
all_results = []
epoch_tic = time.time()
tic = time.time()
Expand Down

0 comments on commit f5c94a6

Please sign in to comment.