diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index e165b957c67..1612e03063c 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -81,7 +81,7 @@ def __init__(self, db_path, classes='train', classes = [classes] else: classes = [c + '_' + classes for c in categories] - if type(classes) == list: + elif type(classes) == list: for c in classes: c_short = c.split('_') c_short.pop(len(c_short) - 1)