Skip to content

Commit

Permalink
#902-Patch avoid moving model (#910)
Browse files Browse the repository at this point in the history
* Remove model move in `create_supervised`

- Remove the model move
- Update the docstring and include a warning note
- Updated the tests to make sure the behavior is as expected.

* Updated use of `create_supervised`

Grepped the repo and made the required changes to keep everything
in-line with the new changes to `create_supervised`.

Hopefully I didn't miss anything.
I did *not* rerun the FastaiLRFinder_MINST.ipynb.
The output would change and add unnecessary size to the rep.

* autopep8 fix

* Update test_create_supervised.py

Co-authored-by: AutoPEP8 <>
Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
kai-tub and vfdev-5 committed Apr 10, 2020
1 parent 7b0b4a2 commit 400c26a
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 30 deletions.
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, display_gpu_info
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_visdom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist_save_resume_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def run(
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
criterion = nn.NLLLoss()
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5)
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist_with_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, lo
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist_with_visdom.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
9 changes: 6 additions & 3 deletions examples/notebooks/FastaiLRFinder_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"source": [
"%matplotlib inline\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
Expand Down Expand Up @@ -121,8 +122,10 @@
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"criterion = nn.NLLLoss()\n",
"model = Net()\n",
"model.to(device) # Move model before creating optimizer\n",
"optimizer = optim.SGD(model.parameters(), lr=3e-4, momentum=0.9)"
]
},
Expand Down Expand Up @@ -344,7 +347,7 @@
}
],
"source": [
"trainer = create_supervised_trainer(model, optimizer, criterion, device=\"cuda\")\n",
"trainer = create_supervised_trainer(model, optimizer, criterion, device=device)\n",
"ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {\"batch loss\": x})\n",
"\n",
"lr_finder = FastaiLRFinder()\n",
Expand All @@ -354,7 +357,7 @@
" \n",
"trainer.run(trainloader, max_epochs=10)\n",
"\n",
"evaluator = create_supervised_evaluator(model, metrics={\"acc\": Accuracy(), \"loss\": Loss(nn.NLLLoss())}, device=\"cuda\")\n",
"evaluator = create_supervised_evaluator(model, metrics={\"acc\": Accuracy(), \"loss\": Loss(nn.NLLLoss())}, device=device)\n",
"evaluator.run(testloader)\n",
"\n",
"print(evaluator.state.metrics)"
Expand Down Expand Up @@ -674,4 +677,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
15 changes: 12 additions & 3 deletions ignite/contrib/engines/tbptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,28 @@ def create_supervised_tbptt_trainer(
tbtt_step (int): the length of time chunks (last one may be smaller).
dim (int): axis representing the time dimension.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Applies to batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU,
the copy may occur asynchronously with respect to the host. For other cases,
this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`,
`non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`.
.. warning::
The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.
For more information see:
* `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
* `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)

def _update(engine, batch):
loss_list = []
Expand Down
44 changes: 28 additions & 16 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,29 @@ def create_supervised_trainer(
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches and the model after starting the engine.
Applies to batches after starting the engine. Model *will not* be moved.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
Note:
`engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
of the processed batch by default.
.. warning::
The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.
For more information see:
* `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
* `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
Returns:
Engine: a trainer engine with supervised update function.
"""
Expand All @@ -74,12 +86,6 @@ def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[t

trainer = Engine(_update)

if device is not None:

@trainer.on(Events.STARTED)
def move_model(engine):
model.to(device)

return trainer


Expand All @@ -98,7 +104,7 @@ def create_supervised_evaluator(
model (`torch.nn.Module`): the model to train.
metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics.
device (str, optional): device type specification (default: None).
Applies to batches and the model after starting the engine.
Applies to batches after starting the engine. Model *will not* be moved.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
Expand All @@ -107,9 +113,21 @@ def create_supervised_evaluator(
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is
Note:
`engine.state.output` for this engine is defind by `output_transform` parameter and is
a tuple of `(batch_pred, batch_y)` by default.
.. warning::
The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.
For more information see:
* `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
* `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
Returns:
Engine: an evaluator engine with supervised inference function.
"""
Expand All @@ -124,12 +142,6 @@ def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tupl

evaluator = Engine(_inference)

if device is not None:

@evaluator.on(Events.STARTED)
def move_model(engine):
model.to(device)

for name, metric in metrics.items():
metric.attach(evaluator, name)

Expand Down
1 change: 1 addition & 0 deletions tests/ignite/contrib/engines/test_tbptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_create_supervised_tbptt_trainer_callcounts(mock_detach_hidden):
def _test_create_supervised_tbptt_trainer(device):
# Defining dummy recurrent model with zero weights
model = nn.RNN(1, 1, bias=False)
model.to(device) # Move model before creating optimizer
for p in model.parameters():
p.data.zero_()

Expand Down
22 changes: 14 additions & 8 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ def test_create_supervised_trainer_traced_with_cpu():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda():
device = "cuda"
model = Linear(1, 1)
model.to(device)
model.weight.data.zero_()
model.bias.data.zero_()
optimizer = SGD(model.parameters(), 0.1)
trainer = create_supervised_trainer(model, optimizer, mse_loss, device="cuda")
trainer = create_supervised_trainer(model, optimizer, mse_loss, device=device)

x = torch.tensor([[1.0], [2.0]])
y = torch.tensor([[3.0], [5.0]])
Expand All @@ -102,8 +104,9 @@ def test_create_supervised_trainer_on_cuda():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_with_model_on_cpu_after_init():
def test_create_supervised_trainer_on_cuda_with_model_on_cpu():
model = Linear(1, 1)
# Not moving model to cuda!
model.weight.data.zero_()
model.bias.data.zero_()
optimizer = SGD(model.parameters(), 0.1)
Expand All @@ -114,8 +117,8 @@ def test_create_supervised_trainer_on_cuda_with_model_on_cpu_after_init():
y = torch.tensor([[3.0], [5.0]])
data = [(x, y)]

model.to("cpu")
trainer.run(data)
with pytest.raises(RuntimeError, match=r"device type"):
trainer.run(data)


def test_create_supervised_evaluator():
Expand Down Expand Up @@ -192,11 +195,13 @@ def test_create_supervised_evaluator_traced_on_cpu():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_evaluator_on_cuda():
device = "cuda"
model = Linear(1, 1)
model.to(device)
model.weight.data.zero_()
model.bias.data.zero_()

evaluator = create_supervised_evaluator(model, device="cuda")
evaluator = create_supervised_evaluator(model, device=device)

x = torch.tensor([[1.0], [2.0]])
y = torch.tensor([[3.0], [5.0]])
Expand All @@ -215,8 +220,9 @@ def test_create_supervised_evaluator_on_cuda():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_evaluator_on_cuda_with_model_on_cpu_after_init():
def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
model = Linear(1, 1)
# Not moving model to cuda!
model.weight.data.zero_()
model.bias.data.zero_()

Expand All @@ -226,8 +232,8 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu_after_init():
y = torch.tensor([[3.0], [5.0]])
data = [(x, y)]

model.to("cpu")
evaluator.run(data)
with pytest.raises(RuntimeError, match=r"device type"):
evaluator.run(data)


def test_create_supervised_evaluator_with_metrics():
Expand Down

0 comments on commit 400c26a

Please sign in to comment.