Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added model_fn to support non-standard model function in create_trainer (#3055) #3074

Merged
merged 6 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 45 additions & 6 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training.

Expand All @@ -71,6 +72,8 @@
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Callable: update function.

Expand All @@ -91,6 +94,8 @@
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

if gradient_accumulation_steps <= 0:
Expand All @@ -104,7 +109,7 @@
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand All @@ -128,6 +133,7 @@
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using ``torch.cuda.amp``.

Expand All @@ -149,6 +155,7 @@
scaler: GradScaler instance for gradient scaling. (default: None)
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Callable: update function
Expand All @@ -171,6 +178,8 @@
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

try:
Expand All @@ -190,7 +199,7 @@
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
output = model(x)
output = model_fn(model, x)

Check warning on line 202 in ignite/engine/__init__.py

View check run for this annotation

Codecov / codecov/patch

ignite/engine/__init__.py#L202

Added line #L202 was not covered by tests
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand Down Expand Up @@ -219,6 +228,7 @@
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using apex.

Expand All @@ -239,6 +249,7 @@
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Callable: update function.
Expand All @@ -260,6 +271,8 @@
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

try:
Expand All @@ -278,7 +291,7 @@
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)

Check warning on line 294 in ignite/engine/__init__.py

View check run for this annotation

Codecov / codecov/patch

ignite/engine/__init__.py#L294

Added line #L294 was not covered by tests
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand All @@ -302,6 +315,7 @@
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using ``torch_xla``.

Expand All @@ -322,6 +336,7 @@
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Callable: update function.
Expand All @@ -343,6 +358,8 @@
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
try:
import torch_xla.core.xla_model as xm
Expand All @@ -360,7 +377,7 @@
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand Down Expand Up @@ -414,6 +431,7 @@
amp_mode: Optional[str] = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Engine:
"""Factory function for creating a trainer for supervised models.

Expand Down Expand Up @@ -444,6 +462,7 @@
(default: False)
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
a trainer engine with supervised update function.
Expand Down Expand Up @@ -525,6 +544,8 @@
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

device_type = device.type if isinstance(device, torch.device) else device
Expand All @@ -543,6 +564,7 @@
output_transform,
_scaler,
gradient_accumulation_steps,
model_fn,
)
elif mode == "apex":
_update = supervised_training_step_apex(
Expand All @@ -555,6 +577,7 @@
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)
elif mode == "tpu":
_update = supervised_training_step_tpu(
Expand All @@ -567,6 +590,7 @@
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)
else:
_update = supervised_training_step(
Expand All @@ -579,6 +603,7 @@
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)

trainer = Engine(_update) if not deterministic else DeterministicEngine(_update)
Expand All @@ -595,6 +620,7 @@
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""
Factory function for supervised evaluation.
Expand All @@ -612,6 +638,7 @@
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Inference function.
Expand All @@ -629,13 +656,15 @@
.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

Expand All @@ -649,6 +678,7 @@
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""
Factory function for supervised evaluation using ``torch.cuda.amp``.
Expand All @@ -666,6 +696,7 @@
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Inference function.
Expand All @@ -683,6 +714,8 @@
.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
try:
from torch.cuda.amp import autocast
Expand All @@ -694,7 +727,7 @@
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

Expand All @@ -710,6 +743,7 @@
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
amp_mode: Optional[str] = None,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Engine:
"""
Factory function for creating an evaluator for supervised models.
Expand All @@ -730,6 +764,7 @@
output expected by metrics. If you change it you should use `output_transform` in metrics.
amp_mode: can be ``amp``, model will be casted to float16 using
`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
an evaluator engine with supervised inference function.
Expand All @@ -754,6 +789,8 @@
Added ``amp_mode`` argument for automatic mixed precision.
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
Expand All @@ -768,6 +805,7 @@
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
model_fn=model_fn,
)
else:
evaluate_step = supervised_evaluation_step(
Expand All @@ -777,6 +815,7 @@
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
model_fn=model_fn,
)

evaluator = Engine(evaluate_step)
Expand Down