Skip to content

Commit

Permalink
feat: Added model_fn to support non-standard model function in create…
Browse files Browse the repository at this point in the history
…_trainer (#3055) (#3074)

* feat: Added model_fn to support non-standard model function in create_trainer (#3055)

* Added versionchanged docstring and tests for model_fn (#3074)

* fix style formatting

* Update test_create_supervised.py

* Apply suggestions from code review

---------

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
invoker-bot and vfdev-5 committed Oct 3, 2023
1 parent b92ad52 commit b8751f2
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 18 deletions.
51 changes: 45 additions & 6 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def supervised_training_step(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training.
Expand All @@ -71,6 +72,8 @@ def supervised_training_step(
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
Callable: update function.
Expand All @@ -91,6 +94,8 @@ def supervised_training_step(
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

if gradient_accumulation_steps <= 0:
Expand All @@ -104,7 +109,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand All @@ -128,6 +133,7 @@ def supervised_training_step_amp(
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using ``torch.cuda.amp``.
Expand All @@ -149,6 +155,7 @@ def supervised_training_step_amp(
scaler: GradScaler instance for gradient scaling. (default: None)
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
Callable: update function
Expand All @@ -171,6 +178,8 @@ def supervised_training_step_amp(
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

try:
Expand All @@ -190,7 +199,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand Down Expand Up @@ -219,6 +228,7 @@ def supervised_training_step_apex(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using apex.
Expand All @@ -239,6 +249,7 @@ def supervised_training_step_apex(
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
Callable: update function.
Expand All @@ -260,6 +271,8 @@ def supervised_training_step_apex(
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

try:
Expand All @@ -278,7 +291,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand All @@ -302,6 +315,7 @@ def supervised_training_step_tpu(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using ``torch_xla``.
Expand All @@ -322,6 +336,7 @@ def supervised_training_step_tpu(
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
Callable: update function.
Expand All @@ -343,6 +358,8 @@ def supervised_training_step_tpu(
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
try:
import torch_xla.core.xla_model as xm
Expand All @@ -360,7 +377,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand Down Expand Up @@ -414,6 +431,7 @@ def create_supervised_trainer(
amp_mode: Optional[str] = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Engine:
"""Factory function for creating a trainer for supervised models.
Expand Down Expand Up @@ -444,6 +462,7 @@ def create_supervised_trainer(
(default: False)
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
a trainer engine with supervised update function.
Expand Down Expand Up @@ -525,6 +544,8 @@ def output_transform_fn(x, y, y_pred, loss):
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

device_type = device.type if isinstance(device, torch.device) else device
Expand All @@ -543,6 +564,7 @@ def output_transform_fn(x, y, y_pred, loss):
output_transform,
_scaler,
gradient_accumulation_steps,
model_fn,
)
elif mode == "apex":
_update = supervised_training_step_apex(
Expand All @@ -555,6 +577,7 @@ def output_transform_fn(x, y, y_pred, loss):
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)
elif mode == "tpu":
_update = supervised_training_step_tpu(
Expand All @@ -567,6 +590,7 @@ def output_transform_fn(x, y, y_pred, loss):
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)
else:
_update = supervised_training_step(
Expand All @@ -579,6 +603,7 @@ def output_transform_fn(x, y, y_pred, loss):
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)

trainer = Engine(_update) if not deterministic else DeterministicEngine(_update)
Expand All @@ -595,6 +620,7 @@ def supervised_evaluation_step(
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""
Factory function for supervised evaluation.
Expand All @@ -612,6 +638,7 @@ def supervised_evaluation_step(
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
Inference function.
Expand All @@ -629,13 +656,15 @@ def supervised_evaluation_step(
.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

Expand All @@ -649,6 +678,7 @@ def supervised_evaluation_step_amp(
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""
Factory function for supervised evaluation using ``torch.cuda.amp``.
Expand All @@ -666,6 +696,7 @@ def supervised_evaluation_step_amp(
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
Inference function.
Expand All @@ -683,6 +714,8 @@ def supervised_evaluation_step_amp(
.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
try:
from torch.cuda.amp import autocast
Expand All @@ -694,7 +727,7 @@ def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, T
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

Expand All @@ -710,6 +743,7 @@ def create_supervised_evaluator(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
amp_mode: Optional[str] = None,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Engine:
"""
Factory function for creating an evaluator for supervised models.
Expand All @@ -730,6 +764,7 @@ def create_supervised_evaluator(
output expected by metrics. If you change it you should use `output_transform` in metrics.
amp_mode: can be ``amp``, model will be casted to float16 using
`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
Returns:
an evaluator engine with supervised inference function.
Expand All @@ -754,6 +789,8 @@ def create_supervised_evaluator(
Added ``amp_mode`` argument for automatic mixed precision.
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
Expand All @@ -768,6 +805,7 @@ def create_supervised_evaluator(
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
model_fn=model_fn,
)
else:
evaluate_step = supervised_evaluation_step(
Expand All @@ -777,6 +815,7 @@ def create_supervised_evaluator(
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
model_fn=model_fn,
)

evaluator = Engine(evaluate_step)
Expand Down

0 comments on commit b8751f2

Please sign in to comment.