Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#902-Patch avoid moving model #910

Merged
merged 6 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, display_gpu_info
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/mnist/mnist_with_visdom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist_save_resume_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def run(
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
criterion = nn.NLLLoss()
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5)
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist_with_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, lo
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
1 change: 1 addition & 0 deletions examples/mnist/mnist_with_visdom.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
if torch.cuda.is_available():
device = "cuda"

model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
Expand Down
9 changes: 6 additions & 3 deletions examples/notebooks/FastaiLRFinder_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"source": [
"%matplotlib inline\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
Expand Down Expand Up @@ -121,8 +122,10 @@
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"criterion = nn.NLLLoss()\n",
"model = Net()\n",
"model.to(device) # Move model before creating optimizer\n",
"optimizer = optim.SGD(model.parameters(), lr=3e-4, momentum=0.9)"
]
},
Expand Down Expand Up @@ -344,7 +347,7 @@
}
],
"source": [
"trainer = create_supervised_trainer(model, optimizer, criterion, device=\"cuda\")\n",
"trainer = create_supervised_trainer(model, optimizer, criterion, device=device)\n",
"ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {\"batch loss\": x})\n",
"\n",
"lr_finder = FastaiLRFinder()\n",
Expand All @@ -354,7 +357,7 @@
" \n",
"trainer.run(trainloader, max_epochs=10)\n",
"\n",
"evaluator = create_supervised_evaluator(model, metrics={\"acc\": Accuracy(), \"loss\": Loss(nn.NLLLoss())}, device=\"cuda\")\n",
"evaluator = create_supervised_evaluator(model, metrics={\"acc\": Accuracy(), \"loss\": Loss(nn.NLLLoss())}, device=device)\n",
"evaluator.run(testloader)\n",
"\n",
"print(evaluator.state.metrics)"
Expand Down Expand Up @@ -674,4 +677,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
15 changes: 12 additions & 3 deletions ignite/contrib/engines/tbptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,28 @@ def create_supervised_tbptt_trainer(
tbtt_step (int): the length of time chunks (last one may be smaller).
dim (int): axis representing the time dimension.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Applies to batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU,
the copy may occur asynchronously with respect to the host. For other cases,
this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`,
`non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`.

.. warning::

The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.

For more information see:

* `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
* `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

Returns:
Engine: a trainer engine with supervised update function.

"""
if device:
model.to(device)

def _update(engine, batch):
loss_list = []
Expand Down
44 changes: 28 additions & 16 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,29 @@ def create_supervised_trainer(
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches and the model after starting the engine.
Applies to batches after starting the engine. Model *will not* be moved.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.

Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
Note:
`engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
of the processed batch by default.

.. warning::

The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.

For more information see:

* `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
* `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

Returns:
Engine: a trainer engine with supervised update function.
"""
Expand All @@ -74,12 +86,6 @@ def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[t

trainer = Engine(_update)

if device is not None:

@trainer.on(Events.STARTED)
def move_model(engine):
model.to(device)

return trainer


Expand All @@ -98,7 +104,7 @@ def create_supervised_evaluator(
model (`torch.nn.Module`): the model to train.
metrics (dict of str - :class:`~ignite.metrics.Metric`): a map of metric names to Metrics.
device (str, optional): device type specification (default: None).
Applies to batches and the model after starting the engine.
Applies to batches after starting the engine. Model *will not* be moved.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
Expand All @@ -107,9 +113,21 @@ def create_supervised_evaluator(
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.

Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is
Note:
`engine.state.output` for this engine is defind by `output_transform` parameter and is
a tuple of `(batch_pred, batch_y)` by default.

.. warning::

The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.

For more information see:

* `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
* `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

Returns:
Engine: an evaluator engine with supervised inference function.
"""
Expand All @@ -124,12 +142,6 @@ def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tupl

evaluator = Engine(_inference)

if device is not None:

@evaluator.on(Events.STARTED)
def move_model(engine):
model.to(device)

for name, metric in metrics.items():
metric.attach(evaluator, name)

Expand Down
1 change: 1 addition & 0 deletions tests/ignite/contrib/engines/test_tbptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_create_supervised_tbptt_trainer_callcounts(mock_detach_hidden):
def _test_create_supervised_tbptt_trainer(device):
# Defining dummy recurrent model with zero weights
model = nn.RNN(1, 1, bias=False)
model.to(device) # Move model before creating optimizer
for p in model.parameters():
p.data.zero_()

Expand Down
22 changes: 14 additions & 8 deletions tests/ignite/engine/test_create_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ def test_create_supervised_trainer_traced_with_cpu():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda():
device = "cuda"
model = Linear(1, 1)
model.to(device)
model.weight.data.zero_()
model.bias.data.zero_()
optimizer = SGD(model.parameters(), 0.1)
trainer = create_supervised_trainer(model, optimizer, mse_loss, device="cuda")
trainer = create_supervised_trainer(model, optimizer, mse_loss, device=device)

x = torch.tensor([[1.0], [2.0]])
y = torch.tensor([[3.0], [5.0]])
Expand All @@ -102,8 +104,9 @@ def test_create_supervised_trainer_on_cuda():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_trainer_on_cuda_with_model_on_cpu_after_init():
def test_create_supervised_trainer_on_cuda_with_model_on_cpu():
model = Linear(1, 1)
# Not moving model to cuda!
model.weight.data.zero_()
model.bias.data.zero_()
optimizer = SGD(model.parameters(), 0.1)
Expand All @@ -114,8 +117,8 @@ def test_create_supervised_trainer_on_cuda_with_model_on_cpu_after_init():
y = torch.tensor([[3.0], [5.0]])
data = [(x, y)]

model.to("cpu")
trainer.run(data)
with pytest.raises(RuntimeError, match=r"device type"):
trainer.run(data)


def test_create_supervised_evaluator():
Expand Down Expand Up @@ -192,11 +195,13 @@ def test_create_supervised_evaluator_traced_on_cpu():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_evaluator_on_cuda():
device = "cuda"
model = Linear(1, 1)
model.to(device)
model.weight.data.zero_()
model.bias.data.zero_()

evaluator = create_supervised_evaluator(model, device="cuda")
evaluator = create_supervised_evaluator(model, device=device)

x = torch.tensor([[1.0], [2.0]])
y = torch.tensor([[3.0], [5.0]])
Expand All @@ -215,8 +220,9 @@ def test_create_supervised_evaluator_on_cuda():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_create_supervised_evaluator_on_cuda_with_model_on_cpu_after_init():
def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
model = Linear(1, 1)
# Not moving model to cuda!
model.weight.data.zero_()
model.bias.data.zero_()

Expand All @@ -226,8 +232,8 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu_after_init():
y = torch.tensor([[3.0], [5.0]])
data = [(x, y)]

model.to("cpu")
evaluator.run(data)
with pytest.raises(RuntimeError, match=r"device type"):
evaluator.run(data)


def test_create_supervised_evaluator_with_metrics():
Expand Down