Skip to content

Commit

Permalink
fix sanity check in case val_dataset is undefined
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Dec 1, 2021
1 parent d54c3ff commit e8915ad
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
11 changes: 10 additions & 1 deletion test/data/test_lightning_data_module.py
@@ -1,8 +1,8 @@
import sys
import random
import shutil
import pytest
import os.path as osp
import pytest

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -80,5 +80,14 @@ def test_lightning_dataset():
strategy=pl.plugins.DDPSpawnPlugin(find_unused_parameters=False),
)
trainer.fit(model, data_module)
assert trainer._data_connector._val_dataloader_source.is_defined()
assert trainer._data_connector._test_dataloader_source.is_defined()

data_module = LightningDataset(train_dataset, batch_size=5, num_workers=2)
assert str(data_module) == ('LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5)')
trainer.fit(model, data_module)
assert not trainer._data_connector._val_dataloader_source.is_defined()
assert not trainer._data_connector._test_dataloader_source.is_defined()

shutil.rmtree(root)
34 changes: 25 additions & 9 deletions torch_geometric/data/lightning_data_module.py
Expand Up @@ -33,7 +33,8 @@ class LightningDataset(LightningDataModule):
Args:
train_dataset: (Dataset) The training dataset.
val_dataset: (Dataset) The validation dataset.
val_dataset: (Dataset, optional) The validation dataset.
(default: :obj:`None`)
test_dataset: (Dataset, optional) The test dataset.
(default: :obj:`None`)
batch_size (int, optional): How many samples per batch to load.
Expand All @@ -44,9 +45,15 @@ class LightningDataset(LightningDataModule):
**kwargs (optional): Additional arguments of
:class:`torch_geometric..loader.DataLoader`.
"""
def __init__(self, train_dataset: Dataset, val_dataset: Dataset,
test_dataset: Optional[Dataset] = None, batch_size: int = 1,
num_workers: int = 0, **kwargs):
def __init__(
self,
train_dataset: Dataset,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
batch_size: int = 1,
num_workers: int = 0,
**kwargs,
):
super().__init__()

if no_pytorch_lightning:
Expand Down Expand Up @@ -80,6 +87,12 @@ def __init__(self, train_dataset: Dataset, val_dataset: Dataset,
else:
self.persistent_workers = num_workers > 0

if self.val_dataset is None:
self.val_dataloader = None

if self.test_dataset is None:
self.test_dataloader = None

def setup(self, stage: Optional[str] = None):
from pytorch_lightning.plugins import DDPSpawnPlugin
if not isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin):
Expand Down Expand Up @@ -110,8 +123,11 @@ def test_dataloader(self) -> DataLoader:
return self.dataloader('test_dataset', shuffle=False)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'train_dataset={self.train_dataset}, '
f'val_dataset={self.val_dataset}, '
f'test_dataset={self.test_dataset}, '
f'batch_size={self.batch_size})')
args_repr = [f'train_dataset={self.train_dataset}']
if self.val_dataset is not None:
args_repr += [f'val_dataset={self.val_dataset}']
if self.test_dataset is not None:
args_repr += [f'test_dataset={self.test_dataset}']
args_repr += [f'batch_size={self.batch_size}']
args_repr = ', '.join(args_repr)
return f'{self.__class__.__name__}({args_repr})'

0 comments on commit e8915ad

Please sign in to comment.