Skip to content

Commit

Permalink
repeat for pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 29, 2020
1 parent 8ee381b commit b460bbe
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 117 deletions.
151 changes: 73 additions & 78 deletions scripts/pretraining/run_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from sklearn import metrics
from pretraining_utils import ElectraMasker, get_pretrain_data_npz, get_pretrain_data_text
from gluonnlp.utils.misc import grouper, set_seed, naming_convention, logging_config
from gluonnlp.utils.misc import grouper, repeat, set_seed, naming_convention, logging_config
from gluonnlp.initializer import TruncNorm
from gluonnlp.models.electra import ElectraModel, ElectraForPretrain, get_pretrained_electra
from gluonnlp.utils.parameter import clip_grad_global_norm
Expand Down Expand Up @@ -201,6 +201,7 @@ def init_comm(backend, gpus):

return store, num_workers, rank, local_rank, is_master_node, ctx_l


def final_save(model, save_dir, tokenizer):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
Expand Down Expand Up @@ -382,13 +383,11 @@ def train(args):
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
model.collect_params().zero_grad()
for update_count, batch_data in enumerate(grouper(repeat(train_dataloader), len(ctx_l) * args.num_accumulated)):
tic = time.time()
batch_id = 0
train_dataloader = grouper(data_train, len(ctx_l))
sample_l = next(train_dataloader)

for sample_l in grouper(batch_data, len(ctx_l))
# start training
for batch_data in grouper(repeat(data_train), len(ctx_l) * args.num_accumulated):
tic = time.time()
for sample_l in grouper(batch_data, len(ctx_l)):
loss_l = []
mlm_loss_l = []
rtd_loss_l = []
Expand Down Expand Up @@ -439,70 +438,65 @@ def train(args):
log_total_loss += sum([ele.as_in_ctx(ctx_l[0])
for ele in loss_l]).asnumpy() * loss_denom

# update
if (batch_id + 1) % args.num_accumulated == 0:
trainer.allreduce_grads()
# Here, the accumulated gradients are
# \sum_{n=1}^N g_n / loss_denom
# Thus, in order to clip the average gradient
# \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm
# We need to change the ratio to be
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
total_norm, ratio, is_finite = clip_grad_global_norm(
params, args.max_grad_norm * num_samples_per_update / loss_denom)
total_norm = total_norm / (num_samples_per_update / loss_denom)
trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True)
step_num += 1
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
model.collect_params().zero_grad()

# saving
if step_num % save_interval == 0 or step_num >= num_train_steps:
if is_master_node:
states_option(
step_num, trainer, args.output_dir, local_rank, 'Saving')
if local_rank == 0:
param_path = parameters_option(
step_num, model, args.output_dir, 'Saving')

# logging
if step_num % log_interval == 0 and local_rank == 0:
# Output the loss of per step
log_mlm_loss /= log_interval
log_rtd_loss /= log_interval
log_total_loss /= log_interval
toc = time.time()
logging.info(
'[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},'
' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},'
' Throughput={:.2f} samples/s, ETA={:.2f}h'.format(
step_num, log_mlm_loss, log_rtd_loss, log_total_loss,
trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic),
(num_train_steps - step_num) / (step_num / (toc - train_start_time)) / 3600))
tic = time.time()

if args.do_eval:
evaluation(writer, step_num, masked_input, output)
writer.add_scalars('loss',
{'total_loss': log_total_loss,
'mlm_loss': log_mlm_loss,
'rtd_loss': log_rtd_loss},
step_num)
log_mlm_loss = 0
log_rtd_loss = 0
log_total_loss = 0
log_sample_num = 0

num_samples_per_update = 0

if step_num >= num_train_steps:
logging.info('Finish training step: %d', step_num)
finish_flag = True
break

batch_id += 1

# update
trainer.allreduce_grads()
# Here, the accumulated gradients are
# \sum_{n=1}^N g_n / loss_denom
# Thus, in order to clip the average gradient
# \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm
# We need to change the ratio to be
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
total_norm, ratio, is_finite = clip_grad_global_norm(
params, args.max_grad_norm * num_samples_per_update / loss_denom)
total_norm = total_norm / (num_samples_per_update / loss_denom)
trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True)
step_num += 1
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
model.collect_params().zero_grad()

# saving
if step_num % save_interval == 0 or step_num >= num_train_steps:
if is_master_node:
states_option(
step_num, trainer, args.output_dir, local_rank, 'Saving')
if local_rank == 0:
param_path = parameters_option(
step_num, model, args.output_dir, 'Saving')

# logging
if step_num % log_interval == 0 and local_rank == 0:
# Output the loss of per step
log_mlm_loss /= log_interval
log_rtd_loss /= log_interval
log_total_loss /= log_interval
toc = time.time()
logging.info(
'[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},'
' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},'
' Throughput={:.2f} samples/s, ETA={:.2f}h'.format(
step_num, log_mlm_loss, log_rtd_loss, log_total_loss,
trainer.learning_rate, total_norm, toc - tic, log_sample_num / (toc - tic),
(num_train_steps - step_num) / (step_num / (toc - train_start_time)) / 3600))
tic = time.time()

if args.do_eval:
evaluation(writer, step_num, masked_input, output)
writer.add_scalars('loss',
{'total_loss': log_total_loss,
'mlm_loss': log_mlm_loss,
'rtd_loss': log_rtd_loss},
step_num)
log_mlm_loss = 0
log_rtd_loss = 0
log_total_loss = 0
log_sample_num = 0

num_samples_per_update = 0

if step_num >= num_train_steps:
logging.info('Finish training step: %d', step_num)
break

if is_master_node:
state_path = states_option(step_num, trainer, args.output_dir, local_rank, 'Saving')
Expand Down Expand Up @@ -565,13 +559,14 @@ def evaluation(writer, step_num, masked_input, eval_input):
rtd_recall = accuracy(rtd_labels, rtd_preds, rtd_labels * rtd_preds)
rtd_auc = auc(rtd_labels, rtd_probs, length_masks)
writer.add_scalars('results',
{'mlm_accuracy': mlm_accuracy.asnumpy().item(),
'corrupted_mlm_accuracy': corrupted_mlm_accuracy.asnumpy().item(),
'rtd_accuracy': rtd_accuracy.asnumpy().item(),
'rtd_precision': rtd_precision.asnumpy().item(),
'rtd_recall': rtd_recall.asnumpy().item(),
'rtd_auc':rtd_auc},
step_num)
{'mlm_accuracy': mlm_accuracy.asnumpy().item(),
'corrupted_mlm_accuracy': corrupted_mlm_accuracy.asnumpy().item(),
'rtd_accuracy': rtd_accuracy.asnumpy().item(),
'rtd_precision': rtd_precision.asnumpy().item(),
'rtd_recall': rtd_recall.asnumpy().item(),
'rtd_auc': rtd_auc},
step_num)


if __name__ == '__main__':
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
Expand Down
75 changes: 36 additions & 39 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,30 +152,29 @@ def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):

# TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality.
self.ChunkFeature = collections.namedtuple('ChunkFeature',
['qas_id',
'data',
'valid_length',
'segment_ids',
'masks',
'is_impossible',
'gt_start',
'gt_end',
'context_offset',
'chunk_start',
'chunk_length'])
['qas_id',
'data',
'valid_length',
'segment_ids',
'masks',
'is_impossible',
'gt_start',
'gt_end',
'context_offset',
'chunk_start',
'chunk_length'])
self.BatchifyFunction = bf.NamedTuple(self.ChunkFeature,
{'qas_id': bf.List(),
'data': bf.Pad(val=self.pad_id),
'valid_length': bf.Stack(),
'segment_ids': bf.Pad(),
'masks': bf.Pad(val=1),
'is_impossible': bf.Stack(),
'gt_start': bf.Stack(),
'gt_end': bf.Stack(),
'context_offset': bf.Stack(),
'chunk_start': bf.Stack(),
'chunk_length': bf.Stack()})

{'qas_id': bf.List(),
'data': bf.Pad(val=self.pad_id),
'valid_length': bf.Stack(),
'segment_ids': bf.Pad(),
'masks': bf.Pad(val=1),
'is_impossible': bf.Stack(),
'gt_start': bf.Stack(),
'gt_end': bf.Stack(),
'context_offset': bf.Stack(),
'chunk_start': bf.Stack(),
'chunk_length': bf.Stack()})

def process_sample(self, feature: SquadFeature):
"""Process the data to the following format.
Expand Down Expand Up @@ -529,8 +528,6 @@ def train(args):
trainer = mx.gluon.Trainer(qa_net.collect_params(),
args.optimizer, optimizer_params,
update_on_kvstore=False)
step_num = 0
finish_flag = False
num_samples_per_update = 0
loss_denom = float(len(ctx_l) * args.num_accumulated)

Expand All @@ -545,11 +542,12 @@ def train(args):
# start training
global_tic = time.time()
tic = time.time()
for update_count, batch_data in enumerate(grouper(repeat(train_dataloader), len(ctx_l) * args.num_accumulated)):
loss_l = []
span_loss_l = []
answerable_loss_l = []
for step_num, batch_data in enumerate(
grouper(repeat(train_dataloader), len(ctx_l) * args.num_accumulated)):
for sample_l in grouper(batch_data, len(ctx_l)):
loss_l = []
span_loss_l = []
answerable_loss_l = []
for sample, ctx in zip(sample_l, ctx_l):
if sample is None:
continue
Expand Down Expand Up @@ -599,17 +597,16 @@ def train(args):
total_norm = total_norm / (num_samples_per_update / loss_denom)

trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True)
step_num += 1
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
qa_net.collect_params().zero_grad()

# saving
if step_num % save_interval == 0 or step_num >= num_train_steps:
if (step_num + 1) % save_interval == 0 or (step_num + 1) >= num_train_steps:
version_prefix = 'squad' + args.version
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
version_prefix,
step_num)
(step_num + 1))
params_saved = os.path.join(args.output_dir, ckpt_name)
qa_net.save_parameters(params_saved)
ckpt_candidates = [
Expand All @@ -622,27 +619,27 @@ def train(args):
logging.info('Params saved in: {}'.format(params_saved))

# logging
if step_num % log_interval == 0:
if (step_num + 1) % log_interval == 0:
log_span_loss /= log_sample_num
log_answerable_loss /= log_sample_num
log_total_loss /= log_sample_num
toc = time.time()
logging.info(
'Batch: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
' ETA={:.2f}h'.format(update_count + 1, epoch_size, log_span_loss,
' ETA={:.2f}h'.format((step_num + 1), epoch_size, log_span_loss,
log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm,
toc - tic, log_sample_num / (toc - tic),
(num_train_steps - step_num) / (step_num / (toc - global_tic)) / 3600))
(num_train_steps - (step_num + 1)) / ((step_num + 1) / (toc - global_tic)) / 3600))
tic = time.time()
log_span_loss = 0
log_answerable_loss = 0
log_total_loss = 0
log_sample_num = 0
num_samples_per_update = 0

if step_num >= num_train_steps:
logging.info('Finish training step: %d', step_num)
if (step_num + 1) >= num_train_steps:
logging.info('Finish training step: %d', (step_num + 1))
break

return params_saved
Expand Down Expand Up @@ -852,7 +849,7 @@ def eval_validation(ckpt_name, best_eval):
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask
start_top_logits, start_top_index, end_top_logits, end_top_index, answerable_logits \
= qa_net.inference(tokens, segment_ids, valid_length, p_mask,
args.start_top_n, args.end_top_n)
args.start_top_n, args.end_top_n)
for i, qas_id in enumerate(sample.qas_id):
result = RawResultExtended(qas_id=qas_id,
start_top_logits=start_top_logits[i].asnumpy(),
Expand Down

0 comments on commit b460bbe

Please sign in to comment.