diff --git a/othello/pytorch/OthelloNNet.py b/othello/pytorch/OthelloNNet.py index b6c1a11a1..7429b6edc 100644 --- a/othello/pytorch/OthelloNNet.py +++ b/othello/pytorch/OthelloNNet.py @@ -50,7 +50,7 @@ def forward(self, s): s = F.dropout(F.relu(self.fc_bn1(self.fc1(s))), p=self.args.dropout, training=self.training) # batch_size x 1024 s = F.dropout(F.relu(self.fc_bn2(self.fc2(s))), p=self.args.dropout, training=self.training) # batch_size x 512 - pi = self.fc3(s) # batch_size x 512 - v = self.fc4(s) # batch_size x 512 + pi = self.fc3(s) # batch_size x action_size + v = self.fc4(s) # batch_size x 1 - return F.log_softmax(pi), F.tanh(v) + return F.log_softmax(pi, dim=1), F.tanh(v)