-
-
Notifications
You must be signed in to change notification settings - Fork 655
Closed
Description
Link: https://pytorch.org/ignite/faq.html
It should be
accumulation_steps = 4
def update_fn(engine, batch):
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = criterion(y_pred, y) / accumulation_steps
loss.backward()
if engine.state.iteration % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return loss.item()
trainer = Engine(update_fn)
instead of
accumulation_steps = 4
def update_fn(engine, batch):
model.train()
if engine.state.iteration % accumulation_steps == 0:
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = criterion(y_pred, y) / accumulation_steps
loss.backward()
if engine.state.iteration % accumulation_steps == 0:
optimizer.step()
return loss.item()
trainer = Engine(update_fn)
vfdev-5 and liepieshov
Metadata
Metadata
Assignees
Labels
No labels