diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index b0a4af35..2e410170 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -96,6 +96,9 @@ class DataConfig: handle_missing_values (bool): Whether to handle missing values in categorical columns as unknown + dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See + https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader + """ target: Optional[List[str]] = field( @@ -176,6 +179,11 @@ class DataConfig: metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"}, ) + dataloader_kwargs: Dict[str, Any] = field( + default_factory=dict, + metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."}, + ) + def __post_init__(self): assert ( len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0 diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 50849ae4..8cdcdaef 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -805,6 +805,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: num_workers=self.config.num_workers, sampler=self.train_sampler, pin_memory=self.config.pin_memory, + **self.config.dataloader_kwargs, ) def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: @@ -823,6 +824,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: shuffle=False, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, + **self.config.dataloader_kwargs, ) def _prepare_inference_data(self, df: DataFrame) -> DataFrame: @@ -865,6 +867,7 @@ def prepare_inference_dataloader( batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, + **self.config.dataloader_kwargs, ) def save_dataloader(self, path: Union[str, Path]) -> None: