From ec87c23e52a27c3ca5a0c381dca55fdce2381dbb Mon Sep 17 00:00:00 2001 From: Abhinav R Nayak <38727070+abhinavnayak11@users.noreply.github.com> Date: Sat, 3 Jul 2021 11:33:05 +0530 Subject: [PATCH] Updated calculation of train_loss and valid_loss Using len(train_loader.dataset) and len(valid_loader.dataset) will result in wrong values as they return size of 'train_data' (i.e 60000 images). This is the reason for such difference in train_loss and valid_loss. 'train_loader' actually has size of 0.8*train_data (i.e 48000 images) and valid_loader actually has size of 0.2*train_data (i.e 12000 images). Hence calculate avg batch loss in each batch and divide by total batches. This works because all the batches in both the dataloaders are of 20(reason: 48000%20=0 and 12000%20=0) --- .../mnist-mlp/mnist_mlp_solution_with_validation.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb index 9c76a6181a..69947e80b0 100644 --- a/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb +++ b/convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb @@ -418,7 +418,7 @@ " # perform a single optimization step (parameter update)\n", " optimizer.step()\n", " # update running training loss\n", - " train_loss += loss.item()*data.size(0)\n", + " train_loss += loss.item()\n", " \n", " ###################### \n", " # validate the model #\n", @@ -430,12 +430,12 @@ " # calculate the loss\n", " loss = criterion(output, target)\n", " # update running validation loss \n", - " valid_loss += loss.item()*data.size(0)\n", + " valid_loss += loss.item()\n", " \n", " # print training/validation statistics \n", " # calculate average loss over an epoch\n", - " train_loss = train_loss/len(train_loader.dataset)\n", - " valid_loss = valid_loss/len(valid_loader.dataset)\n", + " train_loss = train_loss/len(train_loader)\n", + " valid_loss = valid_loss/len(valid_loader)\n", " \n", " print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n", " epoch+1, \n",