diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 6f7495141a0..10acef91a5c 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -81,9 +81,9 @@ def __getitem__(self, index): def __len__(self): if self.train: - return 60000 + return len(self.train_data) else: - return 10000 + return len(self.test_data) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \