From ea76568f74a67ae996e417c9a78c6e897d867aa1 Mon Sep 17 00:00:00 2001 From: Aroj Hada Date: Sun, 16 Feb 2025 15:36:37 +0100 Subject: [PATCH 1/2] Update python_utils.py for torch.load --- src/pytorch_tabular/utils/python_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index d94ea714..57176fdc 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -74,7 +74,7 @@ def pl_load( """ if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similar - return torch.load(path_or_url, map_location=map_location) + return torch.load(path_or_url, map_location=map_location, weights_only=False) if str(path_or_url).startswith("http"): return torch.hub.load_state_dict_from_url( str(path_or_url), @@ -82,7 +82,7 @@ def pl_load( ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: - return torch.load(f, map_location=map_location) + return torch.load(f, map_location=map_location, weights_only=False) def check_numpy(x): From e59e21e8da0eed34a9f6ef3d001cd79d1dac03c4 Mon Sep 17 00:00:00 2001 From: Aroj Hada Date: Sun, 16 Feb 2025 15:37:09 +0100 Subject: [PATCH 2/2] Update tabular_datamodule.py for torch.load --- src/pytorch_tabular/tabular_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 3d09bb2e..7fd68dbf 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -758,7 +758,7 @@ def _load_dataset_from_cache(self, tag: str = "train"): ) elif self.cache_mode is self.CACHE_MODES.DISK: try: - dataset = torch.load(self.cache_dir / f"{tag}_dataset") + dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) except FileNotFoundError: raise FileNotFoundError( f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"