diff --git a/mnist/main.py b/mnist/main.py index 3f71e3cdbe..61a747d9cc 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -65,7 +65,7 @@ def forward(self, x): x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) - return F.log_softmax(x) + return F.log_softmax(x, dim=1) model = Net() if args.cuda: diff --git a/mnist_hogwild/main.py b/mnist_hogwild/main.py index 2ad9217ff8..be80047996 100644 --- a/mnist_hogwild/main.py +++ b/mnist_hogwild/main.py @@ -42,7 +42,7 @@ def forward(self, x): x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) - return F.log_softmax(x) + return F.log_softmax(x, dim=1) if __name__ == '__main__': args = parser.parse_args()