From 9a1f71d77eb16cc8a416d1432246aada42c70e93 Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Tue, 1 Oct 2024 17:07:14 +0530 Subject: [PATCH 1/2] Add dataloader_kwargs support in DataConfig --- src/pytorch_tabular/config/config.py | 9 ++++++++- src/pytorch_tabular/tabular_datamodule.py | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index b0a4af35..e8fabc06 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -95,7 +95,9 @@ class DataConfig: handle_missing_values (bool): Whether to handle missing values in categorical columns as unknown - + + dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See + https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader """ target: Optional[List[str]] = field( @@ -175,6 +177,11 @@ class DataConfig: default=True, metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"}, ) + + dataloader_kwargs: Dict[str, Any] = field( + default_factory=dict, + metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."}, + ) def __post_init__(self): assert ( diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 50849ae4..96a49c41 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -805,6 +805,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: num_workers=self.config.num_workers, sampler=self.train_sampler, pin_memory=self.config.pin_memory, + **self.config.dataloader_kwargs ) def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: @@ -823,6 +824,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: shuffle=False, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, + **self.config.dataloader_kwargs ) def _prepare_inference_data(self, df: DataFrame) -> DataFrame: @@ -865,6 +867,7 @@ def prepare_inference_dataloader( batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, + **self.config.dataloader_kwargs ) def save_dataloader(self, path: Union[str, Path]) -> None: From a34c6f6f7755bdc14b9593038619c0c1eae996e9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:06:27 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_tabular/config/config.py | 5 +++-- src/pytorch_tabular/tabular_datamodule.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index e8fabc06..2e410170 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -95,9 +95,10 @@ class DataConfig: handle_missing_values (bool): Whether to handle missing values in categorical columns as unknown - + dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader + """ target: Optional[List[str]] = field( @@ -177,7 +178,7 @@ class DataConfig: default=True, metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"}, ) - + dataloader_kwargs: Dict[str, Any] = field( default_factory=dict, metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."}, diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 96a49c41..8cdcdaef 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -805,7 +805,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: num_workers=self.config.num_workers, sampler=self.train_sampler, pin_memory=self.config.pin_memory, - **self.config.dataloader_kwargs + **self.config.dataloader_kwargs, ) def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: @@ -824,7 +824,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: shuffle=False, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, - **self.config.dataloader_kwargs + **self.config.dataloader_kwargs, ) def _prepare_inference_data(self, df: DataFrame) -> DataFrame: @@ -867,7 +867,7 @@ def prepare_inference_dataloader( batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, - **self.config.dataloader_kwargs + **self.config.dataloader_kwargs, ) def save_dataloader(self, path: Union[str, Path]) -> None: