Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero_grad place resulting in zero logs #2555

Merged
merged 48 commits into from
May 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3484a59
Remove unnecessary code in BaseOutputHandler
sadra-barikbin Jan 22, 2022
2c8eed9
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 3, 2022
ccf2364
Add ReduceLROnPlateauScheduler
sadra-barikbin Feb 3, 2022
7f7dae6
Fix indentation issue
sadra-barikbin Feb 3, 2022
896e482
Fix another indentation issue
sadra-barikbin Feb 3, 2022
cbc8d04
Fix PEP8 related issues
sadra-barikbin Feb 3, 2022
47b0622
Fix other PEP8 related issues
sadra-barikbin Feb 3, 2022
91d058e
Fix hopefully the last PEP8 related issue
sadra-barikbin Feb 3, 2022
9fd7d61
Fix hopefully the last PEP8 related issue
sadra-barikbin Feb 3, 2022
b7dc921
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 3, 2022
e0644e3
Merge branch 'master' of https://github.com/sadra-barikbin/ignite
sadra-barikbin Feb 3, 2022
c95a2be
Remove ReduceLROnPlateau's specific params and add link to it
sadra-barikbin Feb 3, 2022
96554d0
Fix state_dict bug and add a test
sadra-barikbin Feb 5, 2022
145dabc
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 9, 2022
0aee28a
Update docs
sadra-barikbin Feb 10, 2022
307803c
Merge branch 'master' into master
vfdev-5 Feb 14, 2022
0129572
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 14, 2022
7dbf8f5
Fix gradients flushed at the end for 'supervised_train_step' (#2459)
egaznep Feb 16, 2022
a17a5b2
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 19, 2022
b3ea962
Add doctest and fix typo
sadra-barikbin Feb 20, 2022
e2e6831
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 20, 2022
b88c9e1
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 20, 2022
8d0ae3c
Merge branch 'master' of https://github.com/sadra-barikbin/ignite
sadra-barikbin Feb 20, 2022
408b271
Merge branch 'master' into master
vfdev-5 Feb 20, 2022
d4b513e
Merge branch 'master' into fix_2459
vfdev-5 Feb 21, 2022
102ba25
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 22, 2022
b715475
Fix gradient loggability for AMP, APEX and TPU (#2459)
egaznep Feb 24, 2022
310ba5a
Merge branch 'fix_2459' of github.com:egaznep/ignite into fix_2459
egaznep Feb 24, 2022
25065d0
Merge branch 'master' into fix_2459
vfdev-5 Mar 8, 2022
a5d9bab
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Mar 12, 2022
8ad4113
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Mar 22, 2022
f4ba590
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Mar 29, 2022
c594fd7
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Mar 29, 2022
52dbb5c
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Apr 10, 2022
93ce080
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Apr 13, 2022
68bada5
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Apr 14, 2022
ceb69a7
Merge branch 'pytorch:master' into master
sadra-barikbin Apr 15, 2022
0f0fca7
Merge branch 'master' of https://github.com/sadra-barikbin/ignite
sadra-barikbin Apr 17, 2022
b8715c6
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Apr 17, 2022
530e69a
Fix zero_grad place in trainer step
sadra-barikbin Apr 17, 2022
48ea731
Improve tests and fix bug
sadra-barikbin Apr 19, 2022
ee21110
Remove redundant stmts after pytest parametrize
sadra-barikbin Apr 20, 2022
0902534
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Apr 28, 2022
091f5ac
Merge branch 'master' into fix_2459
sadra-barikbin Apr 28, 2022
3cde27f
Refactor tests
sadra-barikbin Apr 28, 2022
d331109
autopep8 fix
sadra-barikbin Apr 28, 2022
fcbb154
Improvement
sadra-barikbin Apr 29, 2022
1d00d09
Fix bug
sadra-barikbin Apr 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 8 additions & 5 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def supervised_training_step(
)

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
if (engine.state.iteration - 1) % gradient_accumulation_steps == 0:
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
Expand All @@ -104,7 +106,6 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
loss.backward()
if engine.state.iteration % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return output_transform(x, y, y_pred, loss)

return update
Expand Down Expand Up @@ -173,6 +174,8 @@ def supervised_training_step_amp(
)

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
if (engine.state.iteration - 1) % gradient_accumulation_steps == 0:
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
Expand All @@ -185,12 +188,10 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
if engine.state.iteration % gradient_accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
loss.backward()
if engine.state.iteration % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return output_transform(x, y, y_pred, loss)

return update
Expand Down Expand Up @@ -256,6 +257,8 @@ def supervised_training_step_apex(
)

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
if (engine.state.iteration - 1) % gradient_accumulation_steps == 0:
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
Expand All @@ -266,7 +269,6 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
scaled_loss.backward()
if engine.state.iteration % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return output_transform(x, y, y_pred, loss)

return update
Expand Down Expand Up @@ -331,6 +333,8 @@ def supervised_training_step_tpu(
)

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
if (engine.state.iteration - 1) % gradient_accumulation_steps == 0:
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
Expand All @@ -340,7 +344,6 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
loss.backward()
if engine.state.iteration % gradient_accumulation_steps == 0:
xm.optimizer_step(optimizer, barrier=True)
optimizer.zero_grad()
return output_transform(x, y, y_pred, loss)

return update
Expand Down
274 changes: 142 additions & 132 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,35 @@ def _():
trainer.run(data)


def _test_create_supervised_trainer_have_grad_after_iteration(
model_device: Optional[str] = None,
trainer_device: Optional[str] = None,
trace: bool = False,
amp_mode: str = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
gradient_accumulation_steps: int = 1,
):

trainer, model = _default_create_supervised_trainer(
gradient_accumulation_steps=gradient_accumulation_steps,
model_device=model_device,
trainer_device=trainer_device,
trace=trace,
amp_mode=amp_mode,
scaler=scaler,
)

x = torch.tensor([[1.0], [1.0], [1.0], [1.0], [1.0]])
y = torch.tensor([[2.0], [3.0], [4.0], [5.0], [6.0]])
data = [(_x, _y) for _x, _y in zip(x, y)]

@trainer.on(Events.ITERATION_COMPLETED)
def _():
assert model.weight.grad != 0

trainer.run(data)


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
def test_create_supervised_training_scalar_assignment():

Expand Down Expand Up @@ -340,25 +369,119 @@ def _test_create_evaluation_step(
assert output_transform_mock.called


def test_create_supervised_trainer():
_test_create_supervised_trainer_wrong_accumulation()
_test_create_supervised_trainer(gradient_accumulation_steps=1)
_test_create_supervised_trainer(gradient_accumulation_steps=3)
_test_create_mocked_supervised_trainer()


def test_create_supervised_trainer_with_cpu():
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu")
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu")
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu")
_test_create_mocked_supervised_trainer(trainer_device="cpu")


def test_create_supervised_trainer_traced_with_cpu():
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu")
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu", trace=True)
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu", trace=True)
_test_create_mocked_supervised_trainer(trainer_device="cpu", trace=True)
@pytest.mark.parametrize(
("trainer_device", "model_device", "trace", "amp_mode", "scaler"),
[
pytest.param(None, None, False, None, False, id="default"),
pytest.param("cpu", None, False, None, False, id="cpu"),
pytest.param("cpu", None, True, None, False, id="traced_with_cpu"),
pytest.param(
"cuda",
"cuda",
False,
None,
False,
marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"),
id="cuda",
),
pytest.param(
"cuda",
"cuda",
False,
"amp",
False,
marks=[
pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"),
pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0"),
],
id="cuda_amp",
),
pytest.param(
"cuda",
"cuda",
False,
"amp",
True,
marks=[
pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"),
pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0"),
],
id="cuda_amp_scaler",
),
pytest.param(
"cuda",
"cuda",
False,
"amp",
torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available()),
marks=[
pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"),
pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0"),
],
id="cuda_amp_gradscaler",
),
pytest.param(
"cuda",
"cuda",
False,
"apex",
False,
marks=[
pytest.mark.skip(reason="Temporarily disabled, as it fails because of an issue from apex side"),
# pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"),
# pytest.mark.skipif(not find_spec("apex"), reason="Skip if no APEX")
],
id="cuda_apex",
),
pytest.param(
"xla",
"xla",
False,
None,
False,
marks=[
pytest.mark.tpu,
pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars"),
pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package"),
],
id="tpu",
),
pytest.param(
"cuda",
None,
False,
None,
False,
marks=[pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")],
id="cuda_with_model_on_cpu",
),
],
)
def test_create_supervised_trainer(trainer_device, model_device, trace, amp_mode, scaler):
_test_create_supervised_trainer_wrong_accumulation(model_device, trainer_device, amp_mode)
_test_create_supervised_trainer(
gradient_accumulation_steps=1,
model_device=model_device,
trainer_device=trainer_device,
trace=trace,
amp_mode=amp_mode,
scaler=scaler,
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3,
model_device=model_device,
trainer_device=trainer_device,
trace=trace,
amp_mode=amp_mode,
scaler=scaler,
)
_test_create_mocked_supervised_trainer(model_device, trainer_device, trace, amp_mode, scaler)
_test_create_supervised_trainer_have_grad_after_iteration(
model_device, trainer_device, trace, amp_mode, scaler, gradient_accumulation_steps=1
)
_test_create_supervised_trainer_have_grad_after_iteration(
model_device, trainer_device, trace, amp_mode, scaler, gradient_accumulation_steps=3
)


@pytest.mark.skipif(find_spec("apex"), reason="Skip if APEX")
Expand Down Expand Up @@ -405,96 +528,6 @@ def test_create_supervised_trainer_scaler_not_amp():
_test_create_supervised_trainer(amp_mode="apex", scaler=scaler)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda():
model_device = trainer_device = "cuda"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_amp():
model_device = trainer_device = "cuda"
_test_create_supervised_trainer_wrong_accumulation(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="amp")


@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_amp_scaler():
model_device = trainer_device = "cuda"
_test_create_supervised_trainer_wrong_accumulation(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp"
)
_test_create_supervised_trainer(
gradient_accumulation_steps=1,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=True,
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=True,
)
_test_create_mocked_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=True
)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
_test_create_supervised_trainer(
gradient_accumulation_steps=1,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=scaler,
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3,
model_device=model_device,
trainer_device=trainer_device,
amp_mode="amp",
scaler=scaler,
)
_test_create_mocked_supervised_trainer(
model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=scaler
)


# @pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
# @pytest.mark.skipif(not find_spec("apex"), reason="Skip if no APEX")
@pytest.mark.skip(reason="Temporarily disabled, as it fails because of an issue from apex side")
def test_create_supervised_trainer_on_cuda_apex():
model_device = trainer_device = "cuda"
_test_create_supervised_trainer_wrong_accumulation(
model_device=model_device, trainer_device=trainer_device, amp_mode="apex"
)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="apex"
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="apex"
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="apex")


@pytest.mark.skipif(idist.has_xla_support, reason="Skip if has PyTorch XLA package")
def test_supervised_training_step_tpu_no_xla():
with pytest.raises(ModuleNotFoundError, match="torch_xla cannot be imported, please install PyTorch XLA."):
Expand All @@ -509,21 +542,6 @@ def test_create_supervised_trainer_on_tpu_no_xla():
_test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_create_supervised_trainer_on_tpu():
model_device = trainer_device = "xla"
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
_test_create_supervised_trainer(
gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device
)
_test_create_supervised_trainer(
gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device
)
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)


@pytest.mark.tpu
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_create_supervised_trainer_on_tpu_amp():
Expand All @@ -532,14 +550,6 @@ def test_create_supervised_trainer_on_tpu_amp():
_test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="amp")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_with_model_on_cpu():
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cuda")
_test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cuda")
_test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cuda")
_test_create_mocked_supervised_trainer(trainer_device="cuda")


def test_create_supervised_evaluator():
_test_create_supervised_evaluator()
_test_mocked_supervised_evaluator()
Expand Down