From e1d1add04fef2386442d1abd0e531438b243e910 Mon Sep 17 00:00:00 2001 From: jmarin Date: Thu, 14 Mar 2024 21:28:01 +0100 Subject: [PATCH] Reset the gradients before compute them in loss.backward() --- beginner_source/basics/quickstart_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/basics/quickstart_tutorial.py b/beginner_source/basics/quickstart_tutorial.py index 07a1be517d..2bb4622d4e 100644 --- a/beginner_source/basics/quickstart_tutorial.py +++ b/beginner_source/basics/quickstart_tutorial.py @@ -152,9 +152,9 @@ def train(dataloader, model, loss_fn, optimizer): loss = loss_fn(pred, y) # Backpropagation + optimizer.zero_grad() loss.backward() optimizer.step() - optimizer.zero_grad() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X)