Skip to content

Commit

Permalink
Initialize iterator only once per fit call (#835)
Browse files Browse the repository at this point in the history
At the moment, we initialize the iterators for the training and
validation datasets once per epoch. However, this is unnecessary and
creates an (ever so small) overhead.

With this PR, the iterators are created once per fit call only.

A provision was added to keep backwards compatibility, but a
DeprecationWarning will be raised with instructions on how to change
the code.

To test the change in terms of performance, I ran the MNIST benchmark.
(While doing so, I fixed a few issues with the script.). The difference
of this PR on this benchmark is not noticeable. I would only expect a
performance difference on very small datasets. Still, I believe the
benefits outweigh the costs.

Side note:

The test_pickle_load failed for me locally when cuda_available was set
to False. I'm not exactly sure what the reason is, it could be that the
way we patch torch.cuda.is_available breaks with some recent changes in
PyTorch.
  • Loading branch information
BenjaminBossan committed Feb 13, 2022
1 parent c58ae67 commit f5bb1a5
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training

### Changed
- Initialize data loaders for training and validation dataset once per fit call instead of once per epoch ([migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-11-to-0-12))

### Fixed

Expand Down
39 changes: 39 additions & 0 deletions docs/user/FAQ.rst
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,42 @@ this is how to make the transition:
...
The same goes for the other three methods.

Migration from 0.11 to 0.12
^^^^^^^^^^^^^^^^^^^^^^^^^^^

In skorch 0.12, we made a change regarding the training step. Now, we initialize
the :class:`torch.utils.data.DataLoader` only once per fit call instead of once
per epoch. This is accomplished by calling
:py:meth:`skorch.net.NeuralNet.get_iterator` only once at the beginning of the
training process. For the majority of the users, this should make no difference
in practice.

However, you might be affected if you wrote a custom
:py:meth:`skorch.net.NeuralNet.run_single_epoch`. The first argument to this
method is now the initialized ``DataLoader`` instead of a ``Dataset``.
Therefore, this method should no longer call
:py:meth:`skorch.net.NeuralNet.get_iterator`. You only need to change a few
lines of code to accomplish this, as shown below:

.. code:: python
# before
def run_single_epoch(self, dataset, ...):
...
for batch in self.get_iterator(dataset, training=training):
...
# after
def run_single_epoch(self, iterator, ...):
...
for batch in iterator:
...
Your old code should still work for the time being but will give a
``DeprecationWarning``. Starting from skorch v0.13, old code will raise an error
instead.

If it is necessary to have access to the ``Dataset`` inside of
``run_single_epoch``, you can access it on the ``DataLoader`` object using
``iterator.dataset``.
81 changes: 37 additions & 44 deletions examples/benchmarks/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"""

import argparse
import os
import time

import numpy as np
Expand All @@ -34,11 +33,12 @@
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
import torch
from torch import nn

from skorch.utils import to_device
from skorch import NeuralNetClassifier
from skorch.callbacks import EpochScoring
import torch
from torch import nn


BATCH_SIZE = 128
Expand All @@ -49,8 +49,8 @@
def get_data(num_samples):
mnist = fetch_openml('mnist_784')
torch.manual_seed(0)
X = mnist.data.astype('float32').reshape(-1, 1, 28, 28)
y = mnist.target.astype('int64')
X = mnist.data.values.astype('float32').reshape(-1, 1, 28, 28)
y = mnist.target.values.astype('int64')
X, y = shuffle(X, y)
X, y = X[:num_samples], y[:num_samples]
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
Expand All @@ -61,7 +61,7 @@ def get_data(num_samples):

class ClassifierModule(nn.Module):
def __init__(self):
super(ClassifierModule, self).__init__()
super().__init__()

self.cnn = nn.Sequential(
nn.Conv2d(1, 32, (3, 3)),
Expand Down Expand Up @@ -134,36 +134,35 @@ def report(losses, batch_sizes, y, y_proba, epoch, time, training=True):
def train_torch(
model,
X,
X_test,
y,
y_test,
batch_size,
device,
lr,
max_epochs,
):
model = to_device(model, device)

idx_train, idx_valid = next(iter(StratifiedKFold(
5, random_state=0).split(np.arange(len(X)), y)))
idx_train, idx_valid = next(iter(StratifiedKFold(5).split(np.arange(len(X)), y)))
X_train, X_valid, y_train, y_valid = (
X[idx_train], X[idx_valid], y[idx_train], y[idx_valid])
dataset_train = torch.utils.data.TensorDataset(
torch.tensor(X_train),
torch.tensor(y_train),
)
iterator_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size)
dataset_valid = torch.utils.data.TensorDataset(
torch.tensor(X_valid),
torch.tensor(y_valid),
)
iterator_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size)

optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)
criterion = nn.NLLLoss()

for epoch in range(max_epochs):
train_out = train_step(
model,
dataset_train,
iterator_train,
batch_size=batch_size,
device=device,
criterion=criterion,
Expand All @@ -173,7 +172,7 @@ def train_torch(

valid_out = valid_step(
model,
dataset_valid,
iterator_valid,
batch_size=batch_size,
device=device,
criterion=criterion,
Expand All @@ -185,13 +184,13 @@ def train_torch(
return model


def train_step(model, dataset, device, criterion, batch_size, optimizer):
def train_step(model, iterator, device, criterion, batch_size, optimizer):
model.train()
y_preds = []
losses = []
batch_sizes = []
tic = time.time()
for Xi, yi in torch.utils.data.DataLoader(dataset, batch_size=batch_size):
for Xi, yi in iterator:
Xi, yi = to_device(Xi, device), to_device(yi, device)
optimizer.zero_grad()
y_pred = model(Xi)
Expand All @@ -212,16 +211,14 @@ def train_step(model, dataset, device, criterion, batch_size, optimizer):
}


def valid_step(model, dataset, device, criterion, batch_size):
def valid_step(model, iterator, device, criterion, batch_size):
model.eval()
y_preds = []
losses = []
batch_sizes = []
tic = time.time()
with torch.no_grad():
for Xi, yi in torch.utils.data.DataLoader(
dataset, batch_size=batch_size,
):
for Xi, yi in iterator:
Xi, yi = to_device(Xi, device), to_device(yi, device)
y_pred = model(Xi)
y_pred = torch.log(y_pred)
Expand Down Expand Up @@ -255,13 +252,11 @@ def performance_torch(
model = train_torch(
model,
X_train,
X_test,
y_train,
y_test,
batch_size=batch_size,
device=device,
max_epochs=max_epochs,
lr=0.1,
lr=lr,
)

X_test = torch.tensor(X_test).to(device)
Expand All @@ -275,29 +270,27 @@ def main(device, num_samples):
# trigger potential cuda call overhead
torch.zeros(1).to(device)

if True:
print("\nTesting skorch performance")
tic = time.time()
score_skorch = performance_skorch(
*data,
batch_size=BATCH_SIZE,
max_epochs=MAX_EPOCHS,
lr=LEARNING_RATE,
device=device,
)
time_skorch = time.time() - tic

if True:
print("\nTesting pure torch performance")
tic = time.time()
score_torch = performance_torch(
*data,
batch_size=BATCH_SIZE,
max_epochs=MAX_EPOCHS,
lr=LEARNING_RATE,
device=device,
)
time_torch = time.time() - tic
print("\nTesting skorch performance")
tic = time.time()
score_skorch = performance_skorch(
*data,
batch_size=BATCH_SIZE,
max_epochs=MAX_EPOCHS,
lr=LEARNING_RATE,
device=device,
)
time_skorch = time.time() - tic

print("\nTesting pure torch performance")
tic = time.time()
score_torch = performance_torch(
*data,
batch_size=BATCH_SIZE,
max_epochs=MAX_EPOCHS,
lr=LEARNING_RATE,
device=device,
)
time_torch = time.time() - tic

print("time skorch: {:.4f}, time torch: {:.4f}".format(
time_skorch, time_torch))
Expand Down
32 changes: 25 additions & 7 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,26 +1080,30 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
'dataset_train': dataset_train,
'dataset_valid': dataset_valid,
}
iterator_train = self.get_iterator(dataset_train, training=True)
iterator_valid = None
if dataset_valid is not None:
iterator_valid = self.get_iterator(dataset_valid, training=False)

for _ in range(epochs):
self.notify('on_epoch_begin', **on_epoch_kwargs)

self.run_single_epoch(dataset_train, training=True, prefix="train",
self.run_single_epoch(iterator_train, training=True, prefix="train",
step_fn=self.train_step, **fit_params)

self.run_single_epoch(dataset_valid, training=False, prefix="valid",
self.run_single_epoch(iterator_valid, training=False, prefix="valid",
step_fn=self.validation_step, **fit_params)

self.notify("on_epoch_end", **on_epoch_kwargs)
return self

def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params):
def run_single_epoch(self, iterator, training, prefix, step_fn, **fit_params):
"""Compute a single epoch of train or validation.
Parameters
----------
dataset : torch Dataset or None
The initialized dataset to loop over. If None, skip this step.
iterator : torch DataLoader or None
The initialized ``DataLoader`` to loop over. If None, skip this step.
training : bool
Whether to set the module to train mode or not.
Expand All @@ -1112,12 +1116,13 @@ def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params):
**fit_params : dict
Additional parameters passed to the ``step_fn``.
"""
if dataset is None:
if iterator is None:
return

batch_count = 0
for batch in self.get_iterator(dataset, training=training):
for batch in iterator:
self.notify("on_batch_begin", batch=batch, training=training)
step = step_fn(batch, **fit_params)
self.history.record_batch(prefix + "_loss", step["loss"].item())
Expand Down Expand Up @@ -1632,6 +1637,19 @@ def get_iterator(self, dataset, training=False):
mini-batches.
"""
# TODO: remove in skorch v0.13, see #835
if isinstance(dataset, DataLoader):
msg = (
"get_iterator was called with a DataLoader instance but it should be "
"called with a Dataset instead. Probably, you implemented a custom "
"run_single_epoch method. Its first argument is now a DataLoader, "
"not a Dataset. For more information, look here: "
"https://skorch.readthedocs.io/en/latest/user/FAQ.html"
"#migration-from-0-11-to-0-12. This will raise an error in skorch v0.13"
)
warnings.warn(msg, DeprecationWarning)
return dataset

if training:
kwargs = self.get_params_for('iterator_train')
iterator = self.iterator_train
Expand Down
11 changes: 6 additions & 5 deletions skorch/tests/callbacks/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def test_multiple_scorings_with_dict(
with pytest.raises(ValueError, match=msg):
net.fit(*data)

@pytest.mark.parametrize('use_caching, count', [(False, 1), (True, 0)])
@pytest.mark.parametrize('use_caching, count', [(False, 5), (True, 2)])
def test_with_caching_get_iterator_not_called(
self, net_cls, module_cls, train_split, caching_scoring_cls, data,
use_caching, count,
Expand All @@ -499,10 +499,11 @@ def test_with_caching_get_iterator_not_called(
net.fit(*data)

# expected count should be:
# max_epochs * (1 (train) + 1 (valid) + 0 or 1 (from scoring,
# depending on caching))
count_expected = max_epochs * (1 + 1 + count)
assert net.get_iterator.call_count == count_expected
# fit loop: 1 (train) + 1 (valid) = 2
# scoring:
# without cahching: 0
# with caching: 1 per epoch = 3
assert net.get_iterator.call_count == count

def test_subclassing_epoch_scoring(
self, classifier_module, classifier_data):
Expand Down
37 changes: 37 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3609,6 +3609,43 @@ def evaluation_step(self, batch, training=False):
y_pred = net.predict(X)
assert y_pred.shape == (100, 2)

# TODO: remove in skorch v0.13
def test_net_with_custom_run_single_epoch(self, net_cls, module_cls, data):
# See #835. We changed the API to initialize the DataLoader only once
# per epoch. This test is to make sure that code that overrides
# run_single_epoch still works for the time being.
from skorch.dataset import get_len

class MyNet(net_cls):
def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params):
# code as in skorch<=0.11
# first argument should now be an iterator, not a dataset
if dataset is None:
return

# make sure that the "dataset" (really the DataLoader) can still
# access the Dataset if needed
assert hasattr(dataset, 'dataset')

batch_count = 0
for batch in self.get_iterator(dataset, training=training):
self.notify("on_batch_begin", batch=batch, training=training)
step = step_fn(batch, **fit_params)
self.history.record_batch(prefix + "_loss", step["loss"].item())
batch_size = (get_len(batch[0]) if isinstance(batch, (tuple, list))
else get_len(batch))
self.history.record_batch(prefix + "_batch_size", batch_size)
self.notify("on_batch_end", batch=batch, training=training, **step)
batch_count += 1

self.history.record(prefix + "_batch_count", batch_count)

net = MyNet(module_cls, max_epochs=2)
X, y = data
with pytest.deprecated_call():
net.fit(X, y)
# does not raise
net.predict(X)

class TestNetSparseInput:
@pytest.fixture(scope='module')
Expand Down

0 comments on commit f5bb1a5

Please sign in to comment.