From 27931ba0a844d04455e5119f9812c084aea344b1 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 23 Aug 2019 08:22:30 +0200 Subject: [PATCH] Fixes #583 --- docs/source/faq.rst | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/faq.rst b/docs/source/faq.rst index 5de46611e8fb..0e1491ae5947 100644 --- a/docs/source/faq.rst +++ b/docs/source/faq.rst @@ -100,9 +100,6 @@ do this, the most simple is the following: def update_fn(engine, batch): model.train() - if engine.state.iteration % accumulation_steps == 0: - optimizer.zero_grad() - x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) loss = criterion(y_pred, y) / accumulation_steps @@ -110,6 +107,7 @@ do this, the most simple is the following: if engine.state.iteration % accumulation_steps == 0: optimizer.step() + optimizer.zero_grad() return loss.item()