diff --git a/examples/run_glue.py b/examples/run_glue.py index 1558a812c3e3..b2baa6afadbb 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -20,6 +20,7 @@ import argparse import glob import logging +import math import os import random @@ -29,6 +30,15 @@ TensorDataset) from torch.utils.data.distributed import DistributedSampler +_TORCH_XLA_INSTALLED = True +try: + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.distributed.parallel_loader as pl + import torch_xla.distributed.xla_multiprocessing as xmp +except: + _TORCH_XLA_INSTALLED = False + try: from torch.utils.tensorboard import SummaryWriter except: @@ -78,27 +88,44 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) +def get_sampler(dataset, args): + if args.local_rank == -1 and not args.use_tpu: + return RandomSampler(dataset) + num_replicas, rank = None, None + if args.use_tpu: + num_replicas = xm.xrt_world_size() + rank = xm.get_ordinal() + return DistributedSampler(dataset, num_replicas=num_replicas, rank=rank) + + def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) - train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) - train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + + train_sampler = get_sampler(train_dataset, args) + if args.use_tpu: + dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) + num_batches = len(dataloader) + else: + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + num_batches = len(train_dataloader) if args.max_steps > 0: t_total = args.max_steps - args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + args.num_train_epochs = args.max_steps // (num_batches // args.gradient_accumulation_steps) + 1 else: - t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + t_total = num_batches // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, - {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} - ] + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, + ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) if args.fp16: @@ -120,21 +147,22 @@ def train(args, train_dataset, model, tokenizer): # Train! logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num examples = %d", num_batches * args.train_batch_size) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", - args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + (args.train_batch_size * args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 - tr_loss, logging_loss = 0.0, 0.0 + loss = None model.zero_grad() - train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] or args.use_tpu) set_seed(args) # Added here for reproductibility (even between python 2 and 3) for _ in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0] or args.use_tpu) for step, batch in enumerate(epoch_iterator): model.train() batch = tuple(t.to(args.device) for t in batch) @@ -157,14 +185,17 @@ def train(args, train_dataset, model, tokenizer): else: loss.backward() - tr_loss += loss.item() - if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu: + if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) - optimizer.step() + if args.use_tpu: + xm.optimizer_step(optimizer, barrier=True) + else: + optimizer.step() + scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 @@ -176,8 +207,7 @@ def train(args, train_dataset, model, tokenizer): for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) - tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) - logging_loss = tr_loss + tb_writer.add_scalar('loss', loss.item(), global_step) if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint @@ -185,18 +215,15 @@ def train(args, train_dataset, model, tokenizer): if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training - model_to_save.save_pretrained(output_dir) + model_to_save.save_pretrained(output_dir, xla_device=args.use_tpu) torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) - if args.tpu: - args.xla_model.optimizer_step(optimizer, barrier=True) - model.zero_grad() - global_step += 1 - if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break + if args.metrics_debug: + xm.master_print(met.metrics_report()) if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break @@ -204,7 +231,7 @@ def train(args, train_dataset, model, tokenizer): if args.local_rank in [-1, 0]: tb_writer.close() - return global_step, tr_loss / global_step + return global_step, loss.item() def evaluate(args, model, tokenizer, prefix=""): @@ -220,19 +247,23 @@ def evaluate(args, model, tokenizer, prefix=""): os.makedirs(eval_output_dir) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) - # Note that DistributedSampler samples randomly + # Note that DistributedSampler samples randomly. + # Also note that we don't shard for TPU Multiprocess as we don't reduce loss among client processes. eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + num_batches = len(eval_dataloader) + if args.use_tpu: + eval_dataloader = pl.ParallelLoader(eval_dataloader, [args.device]).per_device_loader(args.device) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) - logger.info(" Num examples = %d", len(eval_dataset)) + logger.info(" Num examples = %d", num_batches * args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None - for batch in tqdm(eval_dataloader, desc="Evaluating"): + for batch in tqdm(eval_dataloader, desc="Evaluating", disable=args.use_tpu): model.eval() batch = tuple(t.to(args.device) for t in batch) @@ -323,96 +354,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): return dataset -def main(): - parser = argparse.ArgumentParser() - - ## Required parameters - parser.add_argument("--data_dir", default=None, type=str, required=True, - help="The input data dir. Should contain the .tsv files (or other data files) for the task.") - parser.add_argument("--model_type", default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) - parser.add_argument("--model_name_or_path", default=None, type=str, required=True, - help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) - parser.add_argument("--task_name", default=None, type=str, required=True, - help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) - parser.add_argument("--output_dir", default=None, type=str, required=True, - help="The output directory where the model predictions and checkpoints will be written.") - - ## Other parameters - parser.add_argument("--config_name", default="", type=str, - help="Pretrained config name or path if not the same as model_name") - parser.add_argument("--tokenizer_name", default="", type=str, - help="Pretrained tokenizer name or path if not the same as model_name") - parser.add_argument("--cache_dir", default="", type=str, - help="Where do you want to store the pre-trained models downloaded from s3") - parser.add_argument("--max_seq_length", default=128, type=int, - help="The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded.") - parser.add_argument("--do_train", action='store_true', - help="Whether to run training.") - parser.add_argument("--do_eval", action='store_true', - help="Whether to run eval on the dev set.") - parser.add_argument("--evaluate_during_training", action='store_true', - help="Rul evaluation during training at each logging step.") - parser.add_argument("--do_lower_case", action='store_true', - help="Set this flag if you are using an uncased model.") - - parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, - help="Batch size per GPU/CPU for training.") - parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, - help="Batch size per GPU/CPU for evaluation.") - parser.add_argument('--gradient_accumulation_steps', type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.") - parser.add_argument("--learning_rate", default=5e-5, type=float, - help="The initial learning rate for Adam.") - parser.add_argument("--weight_decay", default=0.0, type=float, - help="Weight deay if we apply some.") - parser.add_argument("--adam_epsilon", default=1e-8, type=float, - help="Epsilon for Adam optimizer.") - parser.add_argument("--max_grad_norm", default=1.0, type=float, - help="Max gradient norm.") - parser.add_argument("--num_train_epochs", default=3.0, type=float, - help="Total number of training epochs to perform.") - parser.add_argument("--max_steps", default=-1, type=int, - help="If > 0: set total number of training steps to perform. Override num_train_epochs.") - parser.add_argument("--warmup_steps", default=0, type=int, - help="Linear warmup over warmup_steps.") - - parser.add_argument('--logging_steps', type=int, default=50, - help="Log every X updates steps.") - parser.add_argument('--save_steps', type=int, default=50, - help="Save checkpoint every X updates steps.") - parser.add_argument("--eval_all_checkpoints", action='store_true', - help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") - parser.add_argument("--no_cuda", action='store_true', - help="Avoid using CUDA when available") - parser.add_argument('--overwrite_output_dir', action='store_true', - help="Overwrite the content of the output directory") - parser.add_argument('--overwrite_cache', action='store_true', - help="Overwrite the cached training and evaluation sets") - parser.add_argument('--seed', type=int, default=42, - help="random seed for initialization") - - parser.add_argument('--tpu', action='store_true', - help="Whether to run on the TPU defined in the environment variables") - parser.add_argument('--tpu_ip_address', type=str, default='', - help="TPU IP address if none are set in the environment variables") - parser.add_argument('--tpu_name', type=str, default='', - help="TPU name if none are set in the environment variables") - parser.add_argument('--xrt_tpu_config', type=str, default='', - help="XRT TPU config if none are set in the environment variables") - - parser.add_argument('--fp16', action='store_true', - help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") - parser.add_argument('--fp16_opt_level', type=str, default='O1', - help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." - "See details at https://nvidia.github.io/apex/amp.html") - parser.add_argument("--local_rank", type=int, default=-1, - help="For distributed training: local_rank") - parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") - parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") - args = parser.parse_args() - +def main(args): if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) @@ -435,23 +377,6 @@ def main(): args.n_gpu = 1 args.device = device - if args.tpu: - if args.tpu_ip_address: - os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address - if args.tpu_name: - os.environ["TPU_NAME"] = args.tpu_name - if args.xrt_tpu_config: - os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config - - assert "TPU_IP_ADDRESS" in os.environ - assert "TPU_NAME" in os.environ - assert "XRT_TPU_CONFIG" in os.environ - - import torch_xla - import torch_xla.core.xla_model as xm - args.device = xm.xla_device() - args.xla_model = xm - # Setup logging logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -492,11 +417,17 @@ def main(): if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + # Setup for using TPUs + if args.use_tpu: + if not _TORCH_XLA_INSTALLED: + raise ImportError('Could not import torch_xla package. Please install torch_xla first:' + 'https://github.com/pytorch/xla.') + args.device = xm.xla_device() + model.to(args.device) logger.info("Training/evaluation parameters %s", args) - # Training if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) @@ -505,7 +436,7 @@ def main(): # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() - if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu: + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) @@ -514,7 +445,7 @@ def main(): # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training - model_to_save.save_pretrained(args.output_dir) + model_to_save.save_pretrained(args.output_dir, xla_device=args.use_tpu) tokenizer.save_pretrained(args.output_dir) # Good practice: save your training arguments together with the trained model @@ -548,5 +479,104 @@ def main(): return results +def get_args(): + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--data_dir", default=None, type=str, required=True, + help="The input data dir. Should contain the .tsv files (or other data files) for the task.") + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--task_name", default=None, type=str, required=True, + help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model predictions and checkpoints will be written.") + + ## Other parameters + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name") + parser.add_argument("--tokenizer_name", default="", type=str, + help="Pretrained tokenizer name or path if not the same as model_name") + parser.add_argument("--cache_dir", default="", type=str, + help="Where do you want to store the pre-trained models downloaded from s3") + parser.add_argument("--max_seq_length", default=128, type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.") + parser.add_argument("--do_train", action='store_true', + help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', + help="Whether to run eval on the dev set.") + parser.add_argument("--evaluate_during_training", action='store_true', + help="Rul evaluation during training at each logging step.") + parser.add_argument("--do_lower_case", action='store_true', + help="Set this flag if you are using an uncased model.") + + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--learning_rate", default=5e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + parser.add_argument("--num_train_epochs", default=3.0, type=float, + help="Total number of training epochs to perform.") + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.") + parser.add_argument("--warmup_steps", default=0, type=int, + help="Linear warmup over warmup_steps.") + + parser.add_argument('--logging_steps', type=int, default=50, + help="Log every X updates steps.") + parser.add_argument('--save_steps', type=int, default=50, + help="Save checkpoint every X updates steps.") + parser.add_argument("--eval_all_checkpoints", action='store_true', + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") + parser.add_argument("--no_cuda", action='store_true', + help="Avoid using CUDA when available") + parser.add_argument('--overwrite_output_dir', action='store_true', + help="Overwrite the content of the output directory") + parser.add_argument('--overwrite_cache', action='store_true', + help="Overwrite the cached training and evaluation sets") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") + parser.add_argument('--fp16_opt_level', type=str, default='O1', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + parser.add_argument("--local_rank", type=int, default=-1, + help="For distributed training: local_rank") + parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") + parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") + + parser.add_argument('--use_tpu', action='store_true', help='Whether to use TPUs.') + parser.add_argument('--num_cores', default=8, type=int, help='Number of TPU cores to use.') + parser.add_argument('--metrics_debug', action='store_true', help='Whether to print debug metrics.') + return parser.parse_args() + + +def _tpu_mp_fn(rank, args): + global logger + logger_name ='{} [{}]'.format(__name__, xm.get_ordinal(defval=-1)) + logger = logging.getLogger(logger_name) + main(args) + +def main_cli(): + args = get_args() + if args.use_tpu: + xmp.spawn(_tpu_mp_fn, args=(args,), nprocs=args.num_cores) + else: + main(args) + if __name__ == "__main__": - main() + main_cli() diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index d51eefab58fd..42ce416e86b1 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -232,7 +232,7 @@ def prune_heads(self, heads_to_prune): self.base_model._prune_heads(heads_to_prune) - def save_pretrained(self, save_directory): + def save_pretrained(self, save_directory, xla_device=False): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. """ @@ -246,7 +246,11 @@ def save_pretrained(self, save_directory): # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) - torch.save(model_to_save.state_dict(), output_model_file) + if xla_device: + import torch_xla.core.xla_model as xm + xm.save(model_to_save.state_dict(), output_model_file) + else: + torch.save(model_to_save.state_dict(), output_model_file) logger.info("Model weights saved in {}".format(output_model_file)) @classmethod