From b9eaba9bb2b7eba3493ead0063496faf568fa184 Mon Sep 17 00:00:00 2001 From: sanshi Date: Tue, 21 Oct 2025 20:56:41 +0800 Subject: [PATCH] Refactor(examples/mnist): Clarify loss reduction and correct dataset length This commit addresses two potential points of confusion and error in the MNIST example, as detailed in issue #623. 1. Explicit Loss Reduction: The train() function's use of F.nll_loss implicitly defaults to reduction='mean', whereas the test() function uses reduction='sum'. This change makes the `reduction='mean'` explicit in the train() function. This improves code clarity. 2. Correct Dataset Size with Samplers: Using `len(loader.dataset)` to get the number of samples is incorrect when a Sampler (e.g., SubsetRandomSampler for a validation split) is used. It incorrectly reports the full dataset size, not the subset size. The logic is updated to first check `len(loader.sampler)`. If a sampler exists, its length is used. Otherwise, it falls back to `len(loader.dataset)`. This ensures the correct number of samples is used for logging and calculations. Fixes #623 --- mnist/main.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mnist/main.py b/mnist/main.py index dee5a384cb..e7580cd98d 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -35,16 +35,20 @@ def forward(self, x): def train(args, model, device, train_loader, optimizer, epoch): model.train() + # Get the correct number of samples for logging. + # Use len(train_loader.sampler) if a sampler is provided (e.g., SubsetRandomSampler), + # otherwise, use the full dataset length. + data_len = len(train_loader.sampler) if train_loader.sampler is not None else len(train_loader.dataset) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) - loss = F.nll_loss(output, target) + loss = F.nll_loss(output, target,reduction='mean') # get batch average loss loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), + epoch, batch_idx * len(data), data_len, 100. * batch_idx / len(train_loader), loss.item())) if args.dry_run: break @@ -54,6 +58,7 @@ def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 + data_len = len(test_loader.sampler) if test_loader.sampler is not None else len(test_loader.dataset) with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) @@ -62,11 +67,11 @@ def test(model, device, test_loader): pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() - test_loss /= len(test_loader.dataset) + test_loss /= data_len # get average loss in test_set print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + test_loss, correct, data_len, + 100. * correct / data_len)) def main():