Skip to content

Commit

Permalink
xi -> Xi
Browse files Browse the repository at this point in the history
  • Loading branch information
ottonemo committed Oct 13, 2017
1 parent aeafd7e commit 155df7b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions skorch/dataset.py
Expand Up @@ -203,9 +203,9 @@ def __getitem__(self, i):
if is_pandas_ndframe(X):
X = {k: X[k].values.reshape(-1, 1) for k in X}

xi = multi_indexing(X, i)
Xi = multi_indexing(X, i)
yi = y if y is None else multi_indexing(y, i)
return self.transform(xi, yi)
return self.transform(Xi, yi)


class CVSplit(object):
Expand Down
40 changes: 20 additions & 20 deletions skorch/net.py
Expand Up @@ -410,7 +410,7 @@ def initialize(self):
def check_data(self, X, y=None):
pass

def validation_step(self, xi, yi):
def validation_step(self, Xi, yi):
"""Perform a forward step using batched data and return the
resulting loss.
Expand All @@ -419,10 +419,10 @@ def validation_step(self, xi, yi):
"""
self.module_.eval()
y_pred = self.infer(xi)
return self.get_loss(y_pred, yi, X=xi, train=False)
y_pred = self.infer(Xi)
return self.get_loss(y_pred, yi, X=Xi, train=False)

def train_step(self, xi, yi, optimizer):
def train_step(self, Xi, yi, optimizer):
"""Perform a forward step using batched data, update module
parameters, and return the loss.
Expand All @@ -432,8 +432,8 @@ def train_step(self, xi, yi, optimizer):
"""
self.module_.train()
optimizer.zero_grad()
y_pred = self.infer(xi)
loss = self.get_loss(y_pred, yi, X=xi, train=True)
y_pred = self.infer(Xi)
loss = self.get_loss(y_pred, yi, X=Xi, train=True)
loss.backward()

if self.gradient_clip_value is not None:
Expand All @@ -445,7 +445,7 @@ def train_step(self, xi, yi, optimizer):
optimizer.step()
return loss

def evaluation_step(self, xi, training=False):
def evaluation_step(self, Xi, training=False):
"""Perform a forward step to produce the output used for
prediction and scoring.
Expand All @@ -455,7 +455,7 @@ def evaluation_step(self, xi, training=False):
"""
self.module_.train(training)
return self.infer(xi)
return self.infer(Xi)

def fit_loop(self, X, y=None, epochs=None):
"""The proper fit loop.
Expand Down Expand Up @@ -491,23 +491,23 @@ def fit_loop(self, X, y=None, epochs=None):
for _ in range(epochs):
self.notify('on_epoch_begin', X=X, y=y)

for xi, yi in self.get_iterator(dataset_train, train=True):
self.notify('on_batch_begin', X=xi, y=yi, train=True)
loss = self.train_step(xi, yi, self.optimizer_)
for Xi, yi in self.get_iterator(dataset_train, train=True):
self.notify('on_batch_begin', X=Xi, y=yi, train=True)
loss = self.train_step(Xi, yi, self.optimizer_)
self.history.record_batch('train_loss', loss.data[0])
self.history.record_batch('train_batch_size', len(xi))
self.notify('on_batch_end', X=xi, y=yi, train=True)
self.history.record_batch('train_batch_size', len(Xi))
self.notify('on_batch_end', X=Xi, y=yi, train=True)

if X_valid is None:
self.notify('on_epoch_end', X=X, y=y)
continue

for xi, yi in self.get_iterator(dataset_valid, train=False):
self.notify('on_batch_begin', X=xi, y=yi, train=False)
loss = self.validation_step(xi, yi)
for Xi, yi in self.get_iterator(dataset_valid, train=False):
self.notify('on_batch_begin', X=Xi, y=yi, train=False)
loss = self.validation_step(Xi, yi)
self.history.record_batch('valid_loss', loss.data[0])
self.history.record_batch('valid_batch_size', len(xi))
self.notify('on_batch_end', X=xi, y=yi, train=False)
self.history.record_batch('valid_batch_size', len(Xi))
self.notify('on_batch_end', X=Xi, y=yi, train=False)

self.notify('on_epoch_end', X=X, y=y)
return self
Expand Down Expand Up @@ -586,9 +586,9 @@ def forward(self, X, training=False):
dataset = self.dataset(X, use_cuda=self.use_cuda)
iterator = self.get_iterator(dataset, train=training)
y_infer = []
for xi, _ in iterator:
for Xi, _ in iterator:
y_infer.append(
self.evaluation_step(xi, training=training))
self.evaluation_step(Xi, training=training))
return torch.cat(y_infer, dim=0)

def infer(self, x):
Expand Down

0 comments on commit 155df7b

Please sign in to comment.