Skip to content

Commit

Permalink
horovod for squad
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 9, 2020
1 parent 1d374a2 commit 838be2a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 61 deletions.
47 changes: 6 additions & 41 deletions scripts/pretraining/run_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from sklearn import metrics
from pretraining_utils import ElectraMasker, get_pretrain_data_npz, get_pretrain_data_text
from gluonnlp.utils.misc import grouper, repeat, set_seed, naming_convention, logging_config
from gluonnlp.utils.misc import grouper, repeat, set_seed, naming_convention, logging_config, init_comm
from gluonnlp.initializer import TruncNorm
from gluonnlp.models.electra import ElectraModel, ElectraForPretrain, get_pretrained_electra
from gluonnlp.utils.parameter import clip_grad_global_norm
Expand Down Expand Up @@ -170,39 +170,6 @@ def get_pretraining_model(model_name, ctx_l,
'corrupted_tokens'])


def init_comm(backend, gpus):
"""Init communication backend"""
# backend specific implementation
if backend == 'horovod':
try:
import horovod.mxnet as hvd # pylint: disable=import-outside-toplevel
except ImportError:
logging.info('horovod must be installed.')
sys.exit(1)
hvd.init()
store = None
num_workers = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()
is_master_node = rank == local_rank
ctx_l = [mx.gpu(local_rank)]
logging.info('GPU communication supported by horovod')
else:
store = mx.kv.create(backend)
num_workers = store.num_workers
rank = store.rank
local_rank = 0
is_master_node = rank == local_rank
if gpus == '-1' or gpus == '':
ctx_l = [mx.cpu()]
logging.info('Runing on CPU')
else:
ctx_l = [mx.gpu(int(x)) for x in gpus.split(',')]
logging.info('GPU communication supported by KVStore')

return store, num_workers, rank, local_rank, is_master_node, ctx_l


def final_save(model, save_dir, tokenizer):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
Expand Down Expand Up @@ -261,6 +228,9 @@ def states_option(step_num, trainer, ckpt_dir, local_rank=0, option='Saving'):
def train(args):
store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
args.comm_backend, args.gpus)
logging.info('Training info: num_buckets: {}, '
'num_workers: {}, rank: {}'.format(
args.num_buckets, num_workers, rank))
cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l,
args.max_seq_length,
args.hidden_dropout_prob,
Expand All @@ -269,9 +239,6 @@ def train(args):
args.generator_layers_scale)
data_masker = ElectraMasker(
tokenizer, args.max_seq_length, args.mask_prob)
logging.info('Training info: num_buckets: {}, '
'num_workers: {}, rank: {}'.format(
args.num_buckets, num_workers, rank))
if args.from_raw_text:
if args.cached_file_path and not os.path.exists(args.cached_file_path):
os.mkdir(args.cached_file_path)
Expand Down Expand Up @@ -342,8 +309,6 @@ def train(args):
'epsilon': 1e-6,
'correct_bias': False,
})
# TODO(zheyuye), absentance of layer-wise decay, although the decay power
# is 1.0 in electra model
if args.comm_backend == 'horovod':
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
else:
Expand Down Expand Up @@ -448,9 +413,9 @@ def train(args):
# We need to change the ratio to be
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
total_norm, ratio, is_finite = clip_grad_global_norm(
params, args.max_grad_norm * num_samples_per_update / loss_denom)
params, args.max_grad_norm * num_samples_per_update / loss_denom)
total_norm = total_norm / (num_samples_per_update / loss_denom)
trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True)
trainer.update(num_samples_per_update / loss_denom)
step_num += 1
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
Expand Down
70 changes: 50 additions & 20 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@
from eval_utils import squad_eval
from squad_utils import SquadFeature, get_squad_examples, convert_squad_example_to_feature
from gluonnlp.models import get_backbone
from gluonnlp.utils.misc import grouper, repeat, set_seed, parse_ctx, logging_config, count_parameters
from gluonnlp.utils.misc import repeat, grouper, set_seed, init_comm, \
parse_ctx, logging_config, count_parameters
from gluonnlp.initializer import TruncNorm
from gluonnlp.utils.parameter import clip_grad_global_norm, grad_global_norm
from gluonnlp.data.sampler import SplitSampler
from gluonnlp.utils.parameter import grad_global_norm, clip_grad_global_norm

try:
import horovod.mxnet as hvd
except ImportError:
pass

mx.npx.set_np()

Expand All @@ -48,6 +55,10 @@ def parse_args():
parser.add_argument('--output_dir', type=str, default='squad_out',
help='The output directory where the model params will be written.'
' default is squad_out')
# Communication
parser.add_argument('--comm_backend', type=str, default='device',
choices=['horovod', 'dist_sync_device', 'device'],
help='Communication backend.')
parser.add_argument('--gpus', type=str, default='0',
help='list of gpus to run, e.g. 0 or 0,2,5. -1 means using cpu.')
# Training hyperparameters
Expand Down Expand Up @@ -384,8 +395,11 @@ def untune_params(model, untunable_depth, not_included=[]):
continue
value.grad_req = 'null'


def train(args):
ctx_l = parse_ctx(args.gpus)
store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
args.comm_backend, args.gpus)

cfg, tokenizer, qa_net, use_segmentation = \
get_network(args.model_name, ctx_l,
args.classifier_dropout,
Expand Down Expand Up @@ -439,12 +453,15 @@ def train(args):
sum([ele.is_impossible for ele in train_features])))
logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}'
.format(len(train_dataset), num_impossible))
sampler = SplitSampler(len(train_dataset), num_parts=num_workers,
part_index=rank, even_size=True)
train_dataloader = mx.gluon.data.DataLoader(
train_dataset,
batchify_fn=dataset_processor.BatchifyFunction,
batch_size=args.batch_size,
num_workers=0,
shuffle=True)
num_workers=4,
shuffle=True
sampler=sampler)
# Froze parameters
if 'electra' in args.model_name:
# does not work for albert model since parameters in all layers are shared
Expand All @@ -453,17 +470,24 @@ def train(args):
if args.layerwise_decay > 0:
qa_net.backbone.apply_layerwise_decay(args.layerwise_decay)

logging.info('Creating distributed trainer...')
# Collect differentiable parameters
param_dict = qa_net.collect_params()
# Do not apply weight decay to all the LayerNorm and bias
for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items():
v.wd_mult = 0.0
# Collect differentiable parameters
params = [p for p in qa_net.collect_params().values() if p.grad_req != 'null']
params = [p for p in param_dict.values() if p.grad_req != 'null']
# Set grad_req if gradient accumulation is required
if args.num_accumulated > 1:
logging.info('Using gradient accumulation. Effective global batch size = {}'
.format(args.num_accumulated * args.batch_size * len(ctx_l)))
.format(args.num_accumulated * args.batch_size * len(ctx_l) * num_workers))
for p in params:
p.grad_req = 'add'
# backend specific implementation
if args.comm_backend == 'horovod':
# Horovod: fetch and broadcast parameters
hvd.broadcast_parameters(param_dict, root_rank=0)

epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l)
if args.num_train_steps is not None:
num_train_steps = args.num_train_steps
Expand Down Expand Up @@ -504,9 +528,12 @@ def train(args):
'beta2': adam_betas[1],
'epsilon': args.adam_epsilon,
})
trainer = mx.gluon.Trainer(qa_net.collect_params(),
args.optimizer, optimizer_params,
update_on_kvstore=False)
if args.comm_backend == 'horovod':
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
else:
trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params,
update_on_kvstore=False)

num_samples_per_update = 0
loss_denom = float(len(ctx_l) * args.num_accumulated)

Expand All @@ -516,7 +543,7 @@ def train(args):
log_sample_num = 0
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
qa_net.collect_params().zero_grad()
param_dict.zero_grad()

# start training
global_tic = time.time()
Expand Down Expand Up @@ -575,17 +602,18 @@ def train(args):
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
total_norm, ratio, is_finite = clip_grad_global_norm(
params, args.max_grad_norm * num_samples_per_update / loss_denom)
total_norm = total_norm / (num_samples_per_update / loss_denom)
else:
total_norm = grad_global_norm(parameters)
total_norm = grad_global_norm(params)

total_norm = total_norm / (num_samples_per_update / loss_denom)
trainer.update(num_samples_per_update / loss_denom)
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
qa_net.collect_params().zero_grad()
param_dict.zero_grad()

# saving
if (step_num + 1) % save_interval == 0 or (step_num + 1) >= num_train_steps:
if local_rank == 0 and is_master_node and (
step_num + 1) % save_interval == 0 or (step_num + 1) >= num_train_steps:
version_prefix = 'squad' + args.version
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
version_prefix,
Expand All @@ -602,7 +630,7 @@ def train(args):
logging.info('Params saved in: {}'.format(params_saved))

# logging
if (step_num + 1) % log_interval == 0:
if local_rank == 0 and (step_num + 1) % log_interval == 0:
log_span_loss /= log_sample_num
log_answerable_loss /= log_sample_num
log_total_loss /= log_sample_num
Expand All @@ -611,8 +639,8 @@ def train(args):
'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
' ETA={:.2f}h'.format((step_num + 1), num_train_steps, log_span_loss,
log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm,
toc - tic, log_sample_num / (toc - tic),
log_answerable_loss, log_total_loss, trainer.learning_rate,
total_norm, toc - tic, log_sample_num / (toc - tic),
(num_train_steps - (step_num + 1)) / ((step_num + 1) / (toc - global_tic)) / 3600))
tic = time.time()
log_span_loss = 0
Expand All @@ -622,7 +650,9 @@ def train(args):
num_samples_per_update = 0

if (step_num + 1) >= num_train_steps:
logging.info('Finish training step: %d', (step_num + 1))
logging.info(
'Finish training step: {} within {} hours'.format(
step_num + 1, toc - global_tic))
break

return params_saved
Expand Down
33 changes: 33 additions & 0 deletions src/gluonnlp/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,36 @@ def check_version(min_version: str,
warnings.warn(msg)
else:
raise AssertionError(msg)

def init_comm(backend, gpus):
"""Init communication backend"""
# backend specific implementation
import mxnet as mx
if backend == 'horovod':
try:
import horovod.mxnet as hvd # pylint: disable=import-outside-toplevel
except ImportError:
logging.info('horovod must be installed.')
sys.exit(1)
hvd.init()
store = None
num_workers = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()
is_master_node = rank == local_rank
ctx_l = [mx.gpu(local_rank)]
logging.info('GPU communication supported by horovod')
else:
store = mx.kv.create(backend)
num_workers = store.num_workers
rank = store.rank
local_rank = 0
is_master_node = rank == local_rank
if gpus == '-1' or gpus == '':
ctx_l = [mx.cpu()]
logging.info('Runing on CPU')
else:
ctx_l = [mx.gpu(int(x)) for x in gpus.split(',')]
logging.info('GPU communication supported by KVStore')

return store, num_workers, rank, local_rank, is_master_node, ctx_l

0 comments on commit 838be2a

Please sign in to comment.