Skip to content

Commit

Permalink
Merge pull request #110 from dnouri/feature/no-optimizer-in-fit-loop
Browse files Browse the repository at this point in the history
Remove optimizer parameter from train_step
  • Loading branch information
benjamin-work committed Nov 15, 2017
2 parents d0b3f4b + 71bb97e commit 6deb46c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def validation_step(self, Xi, yi):
y_pred = self.infer(Xi)
return self.get_loss(y_pred, yi, X=Xi, train=False)

def train_step(self, Xi, yi, optimizer):
def train_step(self, Xi, yi):
"""Perform a forward step using batched data, update module
parameters, and return the loss.
Expand All @@ -459,7 +459,7 @@ def train_step(self, Xi, yi, optimizer):
"""
self.module_.train()
optimizer.zero_grad()
self.optimizer_.zero_grad()
y_pred = self.infer(Xi)
loss = self.get_loss(y_pred, yi, X=Xi, train=True)
loss.backward()
Expand All @@ -470,7 +470,7 @@ def train_step(self, Xi, yi, optimizer):
self.gradient_clip_value,
norm_type=self.gradient_clip_norm_type)

optimizer.step()
self.optimizer_.step()
return loss

def evaluation_step(self, Xi, training=False):
Expand Down Expand Up @@ -538,7 +538,7 @@ def fit_loop(self, X, y=None, epochs=None):

for Xi, yi in self.get_iterator(dataset_train, train=True):
self.notify('on_batch_begin', X=Xi, y=yi, train=True)
loss = self.train_step(Xi, yi, self.optimizer_)
loss = self.train_step(Xi, yi)
self.history.record_batch('train_loss', loss.data[0])
self.history.record_batch('train_batch_size', len(Xi))
self.notify('on_batch_end', X=Xi, y=yi, train=True)
Expand Down

0 comments on commit 6deb46c

Please sign in to comment.