diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 521056b8ae2..32252c55309 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -41,9 +41,14 @@ class STL10(CIFAR10): ['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'], ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e'] ] + splits = ('train', 'train+unlabeled', 'unlabeled', 'test') def __init__(self, root, split='train', transform=None, target_transform=None, download=False): + if split not in self.splits: + raise ValueError('Split "{}" not found. Valid splits are: {}'.format( + split, ', '.join(self.splits), + )) self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform