diff --git a/convolutional-neural-networks/mnist-mlp/mnist_mlp_exercise.ipynb b/convolutional-neural-networks/mnist-mlp/mnist_mlp_exercise.ipynb index 895dfc4f4d..8203501f36 100644 --- a/convolutional-neural-networks/mnist-mlp/mnist_mlp_exercise.ipynb +++ b/convolutional-neural-networks/mnist-mlp/mnist_mlp_exercise.ipynb @@ -309,7 +309,7 @@ " # compare predictions to true label\n", " correct = np.squeeze(pred.eq(target.data.view_as(pred)))\n", " # calculate test accuracy for each object class\n", - " for i in range(batch_size):\n", + " for i in range(len(target)):\n", " label = target.data[i]\n", " class_correct[label] += correct[i].item()\n", " class_total[label] += 1\n", diff --git a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution.ipynb b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution.ipynb index eed4e46e98..5657e2ab2d 100644 --- a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution.ipynb +++ b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution.ipynb @@ -424,7 +424,7 @@ " # compare predictions to true label\n", " correct = np.squeeze(pred.eq(target.data.view_as(pred)))\n", " # calculate test accuracy for each object class\n", - " for i in range(batch_size):\n", + " for i in range(len(target)):\n", " label = target.data[i]\n", " class_correct[label] += correct[i].item()\n", " class_total[label] += 1\n", diff --git a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb index 47def3170a..50b322946c 100755 --- a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb +++ b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb @@ -509,7 +509,7 @@ " # compare predictions to true label\n", " correct = np.squeeze(pred.eq(target.data.view_as(pred)))\n", " # calculate test accuracy for each object class\n", - " for i in range(batch_size):\n", + " for i in range(len(target)):\n", " label = target.data[i]\n", " class_correct[label] += correct[i].item()\n", " class_total[label] += 1\n",