Skip to content

Commit

Permalink
Add README section for TPU and address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
jysohn23 committed Nov 18, 2019
1 parent 5a44823 commit 837fac2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
43 changes: 43 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ similar API between the different models.
| Section | Description |
|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [TensorFlow 2.0 models on GLUE](#TensorFlow-2.0-Bert-models-on-GLUE) | Examples running BERT TensorFlow 2.0 model on the GLUE tasks.
| [Running on TPUs](#running-on-tpus) | Examples on running fine-tuning tasks on Google TPUs to accelerate workloads. |
| [Language Model fine-tuning](#language-model-fine-tuning) | Fine-tuning the library models for language modeling on a text dataset. Causal language modeling for GPT/GPT-2, masked language modeling for BERT/RoBERTa. |
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
Expand Down Expand Up @@ -36,6 +37,48 @@ Quick benchmarks from the script (no other modifications):

Mixed precision (AMP) reduces the training time considerably for the same hardware and hyper-parameters (same batch size was used).

## Running on TPUs

You can accelerate your workloads on Google's TPUs. For information on how to setup your TPU environment refer to this
[README](https://github.com/pytorch/xla/blob/master/README.md).

The following are some examples of running the `*_tpu.py` finetuning scripts on TPUs. All steps for data preparation are
identical to your normal GPU + Huggingface setup.

### GLUE

Before running anyone of these GLUE tasks you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.

For running your GLUE task on MNLI dataset you can run something like the following:

```
export GLUE_DIR=/path/to/glue
export TASK_NAME=MNLI
python run_glue_tpu.py \
--model_type bert \
--model_name_or_path bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/$TASK_NAME \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 3e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/$TASK_NAME \
--overwrite_output_dir \
--logging_steps 50 \
--save_steps 200 \
--num_cores=8 \
--only_log_master
```


## Language model fine-tuning

Based on the script [`run_lm_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_lm_finetuning.py).
Expand Down
30 changes: 18 additions & 12 deletions examples/run_glue_tpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding=utf-8
# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,8 +78,8 @@ def set_seed(args):
def get_sampler(dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset,
num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
return DistributedSampler(
dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


def train(args, train_dataset, model, tokenizer, disable_logging=False):
Expand All @@ -97,8 +98,14 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
{
'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay,
},
{
'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
Expand Down Expand Up @@ -129,8 +136,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
logger.info("Saving model checkpoint to %s", output_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir, xla_device=True)
model.save_pretrained(output_dir, xla_device=True)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))

model.train()
Expand All @@ -144,6 +150,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
loss = outputs[0] # model outputs are always tuple in transformers (see doc)

if args.gradient_accumulation_steps > 1:
xm.mark_step() # Mark step to evaluate graph so far or else graph will grow too big and OOM.
loss = loss / args.gradient_accumulation_steps

loss.backward()
Expand Down Expand Up @@ -350,25 +357,24 @@ def main(args):

logger.info("Training/evaluation parameters %s", args)

# Training
if args.do_train:
# Train the model.
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer, disable_logging=disable_logging)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

# Save trained model.
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
output_dir = os.path.join(args.output_dir, 'final-xla{}'.format(xm.get_ordinal()))

# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
output_dir = os.path.join(args.output_dir, 'final-xla{}'.format(xm.get_ordinal()))
if args.do_train:
# Create output directory if needed
if not os.path.exists(output_dir):
os.makedirs(output_dir)

logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir, xla_device=True)
model.save_pretrained(output_dir, xla_device=True)
tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained.
Expand Down

0 comments on commit 837fac2

Please sign in to comment.