diff --git a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb index 9c76a6181a..69947e80b0 100644 --- a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb +++ b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb @@ -418,7 +418,7 @@ " # perform a single optimization step (parameter update)\n", " optimizer.step()\n", " # update running training loss\n", - " train_loss += loss.item()*data.size(0)\n", + " train_loss += loss.item()\n", " \n", " ###################### \n", " # validate the model #\n", @@ -430,12 +430,12 @@ " # calculate the loss\n", " loss = criterion(output, target)\n", " # update running validation loss \n", - " valid_loss += loss.item()*data.size(0)\n", + " valid_loss += loss.item()\n", " \n", " # print training/validation statistics \n", " # calculate average loss over an epoch\n", - " train_loss = train_loss/len(train_loader.dataset)\n", - " valid_loss = valid_loss/len(valid_loader.dataset)\n", + " train_loss = train_loss/len(train_loader)\n", + " valid_loss = valid_loss/len(valid_loader)\n", " \n", " print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n", " epoch+1, \n",