diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 35c71bd2cf1..50e9af882bc 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -181,7 +181,7 @@ def __load_folds(self, folds: Optional[int]) -> None: self.root, self.base_folder, self.folds_list_file) with open(path_to_folds, 'r') as f: str_idx = f.read().splitlines()[folds] - list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ') + list_idx = np.fromstring(str_idx, dtype=np.int64, sep=' ') self.data = self.data[list_idx, :, :, :] if self.labels is not None: self.labels = self.labels[list_idx]