Skip to content

Commit

Permalink
No need to call xm.mark_step() explicitly (#4)
Browse files Browse the repository at this point in the history
Since for gradient accumulation we're accumulating on batches from
`ParallelLoader` instance which on next() marks the step itself.
  • Loading branch information
jysohn23 committed Nov 21, 2019
1 parent 6ef1edd commit 3129ad3
Showing 1 changed file with 0 additions and 1 deletion.
1 change: 0 additions & 1 deletion examples/run_glue_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
loss = outputs[0] # model outputs are always tuple in transformers (see doc)

if args.gradient_accumulation_steps > 1:
xm.mark_step() # Mark step to evaluate graph so far or else graph will grow too big and OOM.
loss = loss / args.gradient_accumulation_steps

loss.backward()
Expand Down

0 comments on commit 3129ad3

Please sign in to comment.