diff --git a/test/data/test_lightning_data_module.py b/test/data/test_lightning_data_module.py index 2a263003bc89..a86c7e370827 100644 --- a/test/data/test_lightning_data_module.py +++ b/test/data/test_lightning_data_module.py @@ -1,8 +1,8 @@ import sys import random import shutil -import pytest import os.path as osp +import pytest import torch import torch.nn.functional as F @@ -80,5 +80,14 @@ def test_lightning_dataset(): strategy=pl.plugins.DDPSpawnPlugin(find_unused_parameters=False), ) trainer.fit(model, data_module) + assert trainer._data_connector._val_dataloader_source.is_defined() + assert trainer._data_connector._test_dataloader_source.is_defined() + + data_module = LightningDataset(train_dataset, batch_size=5, num_workers=2) + assert str(data_module) == ('LightningDataset(train_dataset=MUTAG(50), ' + 'batch_size=5)') + trainer.fit(model, data_module) + assert not trainer._data_connector._val_dataloader_source.is_defined() + assert not trainer._data_connector._test_dataloader_source.is_defined() shutil.rmtree(root) diff --git a/torch_geometric/data/lightning_data_module.py b/torch_geometric/data/lightning_data_module.py index 8f79f819c3f1..d29c6c6a0ce3 100644 --- a/torch_geometric/data/lightning_data_module.py +++ b/torch_geometric/data/lightning_data_module.py @@ -33,7 +33,8 @@ class LightningDataset(LightningDataModule): Args: train_dataset: (Dataset) The training dataset. - val_dataset: (Dataset) The validation dataset. + val_dataset: (Dataset, optional) The validation dataset. + (default: :obj:`None`) test_dataset: (Dataset, optional) The test dataset. (default: :obj:`None`) batch_size (int, optional): How many samples per batch to load. @@ -44,9 +45,15 @@ class LightningDataset(LightningDataModule): **kwargs (optional): Additional arguments of :class:`torch_geometric..loader.DataLoader`. """ - def __init__(self, train_dataset: Dataset, val_dataset: Dataset, - test_dataset: Optional[Dataset] = None, batch_size: int = 1, - num_workers: int = 0, **kwargs): + def __init__( + self, + train_dataset: Dataset, + val_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None, + batch_size: int = 1, + num_workers: int = 0, + **kwargs, + ): super().__init__() if no_pytorch_lightning: @@ -80,6 +87,12 @@ def __init__(self, train_dataset: Dataset, val_dataset: Dataset, else: self.persistent_workers = num_workers > 0 + if self.val_dataset is None: + self.val_dataloader = None + + if self.test_dataset is None: + self.test_dataloader = None + def setup(self, stage: Optional[str] = None): from pytorch_lightning.plugins import DDPSpawnPlugin if not isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin): @@ -110,8 +123,11 @@ def test_dataloader(self) -> DataLoader: return self.dataloader('test_dataset', shuffle=False) def __repr__(self) -> str: - return (f'{self.__class__.__name__}(' - f'train_dataset={self.train_dataset}, ' - f'val_dataset={self.val_dataset}, ' - f'test_dataset={self.test_dataset}, ' - f'batch_size={self.batch_size})') + args_repr = [f'train_dataset={self.train_dataset}'] + if self.val_dataset is not None: + args_repr += [f'val_dataset={self.val_dataset}'] + if self.test_dataset is not None: + args_repr += [f'test_dataset={self.test_dataset}'] + args_repr += [f'batch_size={self.batch_size}'] + args_repr = ', '.join(args_repr) + return f'{self.__class__.__name__}({args_repr})'