diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 3d09bb2e..7fd68dbf 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -758,7 +758,7 @@ def _load_dataset_from_cache(self, tag: str = "train"): ) elif self.cache_mode is self.CACHE_MODES.DISK: try: - dataset = torch.load(self.cache_dir / f"{tag}_dataset") + dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) except FileNotFoundError: raise FileNotFoundError( f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader" diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index d94ea714..57176fdc 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -74,7 +74,7 @@ def pl_load( """ if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similar - return torch.load(path_or_url, map_location=map_location) + return torch.load(path_or_url, map_location=map_location, weights_only=False) if str(path_or_url).startswith("http"): return torch.hub.load_state_dict_from_url( str(path_or_url), @@ -82,7 +82,7 @@ def pl_load( ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: - return torch.load(f, map_location=map_location) + return torch.load(f, map_location=map_location, weights_only=False) def check_numpy(x):