Skip to content

Commit

Permalink
Merge pull request #123 from dnouri/feature/train-vs-training
Browse files Browse the repository at this point in the history
Feature/train vs training
  • Loading branch information
benjamin-work committed Dec 1, 2017
2 parents 5184dd3 + 6605ba9 commit 6818c00
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions docs/user/neuralnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ the loss is determined. Below we show an example of overriding
super().__init__(*args, **kwargs)
self.lambda1 = lambda1
def get_loss(self, y_pred, y_true, X=None, train=False):
loss = super().get_loss(y_pred, y_true, X=X, train=train)
def get_loss(self, y_pred, y_true, X=None, training=False):
loss = super().get_loss(y_pred, y_true, X=X, training=training)
loss += self.lambda1 * sum([w.abs().sum() for w in self.module_.parameters()])
return loss
Expand Down
2 changes: 1 addition & 1 deletion examples/word_language_model/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def sample_n(self, num_words, input, temperature=1., hidden=None):

def score(self, X, y=None):
ds = self.get_dataset(X)
target_iterator = self.get_iterator(ds, train=False)
target_iterator = self.get_iterator(ds, training=False)

y_true = np.concatenate([skorch.utils.to_numpy(y) for _, y in target_iterator])
y_pred = self.predict(X)
Expand Down
4 changes: 2 additions & 2 deletions skorch/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ class BatchScoring(ScoringBase):
"""
# pylint: disable=unused-argument,arguments-differ
def on_batch_end(self, net, X, y, train, **kwargs):
if train != self.on_train:
def on_batch_end(self, net, X, y, training, **kwargs):
if training != self.on_train:
return

y = self.target_extractor(y)
Expand Down
30 changes: 15 additions & 15 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def on_epoch_end(self, net, **kwargs):
pass

# pylint: disable=unused-argument
def on_batch_begin(self, net, train=False, **kwargs):
def on_batch_begin(self, net, training=False, **kwargs):
self.history.new_batch()

def on_batch_end(self, net, **kwargs):
Expand Down Expand Up @@ -449,7 +449,7 @@ def validation_step(self, Xi, yi):
"""
self.module_.eval()
y_pred = self.infer(Xi)
return self.get_loss(y_pred, yi, X=Xi, train=False)
return self.get_loss(y_pred, yi, X=Xi, training=False)

def train_step(self, Xi, yi):
"""Perform a forward step using batched data, update module
Expand All @@ -462,7 +462,7 @@ def train_step(self, Xi, yi):
self.module_.train()
self.optimizer_.zero_grad()
y_pred = self.infer(Xi)
loss = self.get_loss(y_pred, yi, X=Xi, train=True)
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
loss.backward()

if self.gradient_clip_value is not None:
Expand Down Expand Up @@ -537,23 +537,23 @@ def fit_loop(self, X, y=None, epochs=None):
for _ in range(epochs):
self.notify('on_epoch_begin', **on_epoch_kwargs)

for Xi, yi in self.get_iterator(dataset_train, train=True):
self.notify('on_batch_begin', X=Xi, y=yi, train=True)
for Xi, yi in self.get_iterator(dataset_train, training=True):
self.notify('on_batch_begin', X=Xi, y=yi, training=True)
loss = self.train_step(Xi, yi)
self.history.record_batch('train_loss', loss.data[0])
self.history.record_batch('train_batch_size', len(Xi))
self.notify('on_batch_end', X=Xi, y=yi, train=True)
self.notify('on_batch_end', X=Xi, y=yi, training=True)

if X_valid is None:
self.notify('on_epoch_end', **on_epoch_kwargs)
continue

for Xi, yi in self.get_iterator(dataset_valid, train=False):
self.notify('on_batch_begin', X=Xi, y=yi, train=False)
for Xi, yi in self.get_iterator(dataset_valid, training=False):
self.notify('on_batch_begin', X=Xi, y=yi, training=False)
loss = self.validation_step(Xi, yi)
self.history.record_batch('valid_loss', loss.data[0])
self.history.record_batch('valid_batch_size', len(Xi))
self.notify('on_batch_end', X=Xi, y=yi, train=False)
self.notify('on_batch_end', X=Xi, y=yi, training=False)

self.notify('on_epoch_end', **on_epoch_kwargs)
return self
Expand Down Expand Up @@ -661,7 +661,7 @@ def forward_iter(self, X, training=False):
self.module_.train(training)

dataset = self.get_dataset(X)
iterator = self.get_iterator(dataset, train=training)
iterator = self.get_iterator(dataset, training=training)
for Xi, _ in iterator:
yp = self.evaluation_step(Xi, training=training)
yield yp
Expand Down Expand Up @@ -760,7 +760,7 @@ def predict(self, X):
return y_pred

# pylint: disable=unused-argument
def get_loss(self, y_pred, y_true, X=None, train=False):
def get_loss(self, y_pred, y_true, X=None, training=False):
"""Return the loss for this batch.
Parameters
Expand Down Expand Up @@ -840,7 +840,7 @@ def get_dataset(self, X, y=None):

return dataset(X, y, **kwargs)

def get_iterator(self, dataset, train=False):
def get_iterator(self, dataset, training=False):
"""Get an iterator that allows to loop over the batches of the
given data.
Expand All @@ -854,7 +854,7 @@ def get_iterator(self, dataset, train=False):
Usually, ``self.dataset``, initialized with the corresponding
data, is passed to ``get_iterator``.
train : bool (default=False)
training : bool (default=False)
Whether to use ``iterator_train`` or ``iterator_test``.
Returns
Expand All @@ -864,7 +864,7 @@ def get_iterator(self, dataset, train=False):
mini-batches.
"""
if train:
if training:
kwargs = self._get_params_for('iterator_train')
iterator = self.iterator_train
else:
Expand Down Expand Up @@ -1123,7 +1123,7 @@ def _prepare_target_for_loss(self, y):
# pass, even though it will fail with NLLLoss
return y

def get_loss(self, y_pred, y_true, X=None, train=False):
def get_loss(self, y_pred, y_true, X=None, training=False):
y_true = to_var(y_true, use_cuda=self.use_cuda)
y_pred_log = torch.log(y_pred)
return self.criterion_(
Expand Down
4 changes: 2 additions & 2 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,11 @@ def test_change_get_loss(self, net_cls, module_cls, data):

class MyNet(net_cls):
# pylint: disable=unused-argument
def get_loss(self, y_pred, y_true, X=None, train=False):
def get_loss(self, y_pred, y_true, X=None, training=False):
y_true = to_var(y_true, use_cuda=False)
loss_a = torch.abs(y_true.float() - y_pred[:, 1]).mean()
loss_b = ((y_true.float() - y_pred[:, 1]) ** 2).mean()
if train:
if training:
self.history.record_batch('loss_a', to_numpy(loss_a)[0])
self.history.record_batch('loss_b', to_numpy(loss_b)[0])
return loss_a + loss_b
Expand Down

0 comments on commit 6818c00

Please sign in to comment.