Skip to content

Commit

Permalink
fix the T.neq not giving int8 type issue #29
Browse files Browse the repository at this point in the history
  • Loading branch information
hma02 committed Mar 28, 2017
1 parent 23440c6 commit db33b0f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion lib/layers.py
Expand Up @@ -309,6 +309,6 @@ def errors_top_x(self, y, num_top=5):
# represents a mistake in prediction
y_pred_top_x = T.argsort(self.p_y_given_x, axis=1)[:, -num_top:]
y_top_x = y.reshape((y.shape[0], 1)).repeat(num_top, axis=1)
return T.mean(T.min(T.neq(y_pred_top_x, y_top_x), axis=1))
return T.mean(T.min(T.neq(y_pred_top_x, y_top_x).astype('int8'), axis=1))
else:
raise NotImplementedError()

0 comments on commit db33b0f

Please sign in to comment.