Skip to content

Commit

Permalink
Use last dimension for prediction by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ottonemo committed Sep 22, 2017
1 parent 847809a commit f21d41d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions inferno/net.py
Expand Up @@ -717,7 +717,7 @@ def predict(self, X):
"""
self.module_.train(False)
return self.predict_proba(X).argmax(1)
return self.predict_proba(X).argmax(-1)

def get_loss(self, y_pred, y_true, X=None, train=False):
"""Return the loss for this batch.
Expand Down Expand Up @@ -929,7 +929,7 @@ def get_loss(self, y_pred, y, X=None, train=False):
return self.criterion_(y_pred_log, y)

def predict(self, X):
return self.predict_proba(X).argmax(1)
return self.predict_proba(X).argmax(-1)

def fit(self, X, y, **fit_params):
"""See `NeuralNet.fit`.
Expand Down

0 comments on commit f21d41d

Please sign in to comment.