diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 0eb0b15821b2..ba8aae0465d1 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -69,6 +69,17 @@ def supervised_training_step( Returns: Callable: update function. + Example:: + + from ignite.engine import Engine, supervised_training_step + + model = ... + optimizer = ... + loss_fn = ... + + update_fn = supervised_training_step(model, optimizer, loss_fn, 'cuda') + trainer = Engine(update_fn) + .. versionadded:: 0.5.0 """ @@ -115,6 +126,18 @@ def supervised_training_step_amp( Returns: Callable: update function + Example:: + + from ignite.engine import Engine, supervised_training_step_amp + + model = ... + optimizer = ... + loss_fn = ... + scaler = torch.cuda.amp.GradScaler(2**10) + + update_fn = supervised_training_step_amp(model, optimizer, loss_fn, 'cuda', scaler=scaler) + trainer = Engine(update_fn) + .. versionadded:: 0.5.0 """ @@ -170,6 +193,17 @@ def supervised_training_step_apex( Returns: Callable: update function. + Example:: + + from ignite.engine import Engine, supervised_training_step_apex + + model = ... + optimizer = ... + loss_fn = ... + + update_fn = supervised_training_step_apex(model, optimizer, loss_fn, 'cuda') + trainer = Engine(update_fn) + .. versionadded:: 0.5.0 """ @@ -220,6 +254,17 @@ def supervised_training_step_tpu( Returns: Callable: update function. + Example:: + + from ignite.engine import Engine, supervised_training_step_tpu + + model = ... + optimizer = ... + loss_fn = ... + + update_fn = supervised_training_step_tpu(model, optimizer, loss_fn, 'xla') + trainer = Engine(update_fn) + .. versionadded:: 0.5.0 """ try: