diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 1248ab4f17b..7c206f55edb 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -249,13 +249,13 @@ class EMNIST(MNIST): md5 = "58c8d27c78d21e728a6bc7b3cc06412e" splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') # Merged Classes assumes Same structure for both uppercase and lowercase version - _merged_classes = set(['C', 'I', 'J', 'K', 'L', 'M', 'O', 'P', 'S', 'U', 'V', 'W', 'X', 'Y', 'Z']) - _all_classes = set(list(string.digits + string.ascii_letters)) + _merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'} + _all_classes = set(string.digits + string.ascii_letters) classes_split_dict = { - 'byclass': list(_all_classes), + 'byclass': sorted(list(_all_classes)), 'bymerge': sorted(list(_all_classes - _merged_classes)), 'balanced': sorted(list(_all_classes - _merged_classes)), - 'letters': list(string.ascii_lowercase), + 'letters': ['N/A'] + list(string.ascii_lowercase), 'digits': list(string.digits), 'mnist': list(string.digits), }