diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index f3596ed0dc0..e5ddc35221c 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -44,7 +44,7 @@ def __init__( os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform ) os.makedirs(self.root, exist_ok=True) - if not isinstance(target_type, list): + if isinstance(target_type, str): target_type = [target_type] self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]