Skip to content

Commit

Permalink
data.LightningData module (#3613)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update doc

* update

* update

* update

* rename

* typo

* fix doc

* update

* update

* update

* typo

* update

* typo

* update

* typo

* update

* typo

* typo
  • Loading branch information
rusty1s committed Dec 3, 2021
1 parent e8915ad commit 8012326
Show file tree
Hide file tree
Showing 11 changed files with 773 additions and 339 deletions.
1 change: 1 addition & 0 deletions .github/workflows/testing.yml
Expand Up @@ -50,6 +50,7 @@ jobs:
pip install h5py
pip install numba
pip install tabulate
pip install matplotlib
pip install git+https://github.com/the-virtual-brain/tvb-geodesic.git
pip install pytorch-memlab
Expand Down
93 changes: 0 additions & 93 deletions test/data/test_lightning_data_module.py

This file was deleted.

267 changes: 267 additions & 0 deletions test/data/test_lightning_datamodule.py
@@ -0,0 +1,267 @@
import sys
import math
import random
import shutil
import os.path as osp
import warnings
import pytest

import torch
import torch.nn.functional as F

from torch_geometric.nn import global_mean_pool
from torch_geometric.datasets import TUDataset, Planetoid, DBLP
from torch_geometric.data import LightningDataset, LightningNodeData

try:
from pytorch_lightning import LightningModule
no_pytorch_lightning = False
except (ImportError, ModuleNotFoundError):
LightningModule = torch.nn.Module
no_pytorch_lightning = True


class LinearGraphModule(LightningModule):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
from torchmetrics import Accuracy

self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
self.lin2 = torch.nn.Linear(hidden_channels, out_channels)

self.train_acc = Accuracy()
self.val_acc = Accuracy()

def forward(self, x, batch):
# Basic test to ensure that the dataset is not replicated:
self.trainer.datamodule.train_dataset.data.x.add_(1)

x = self.lin1(x).relu()
x = global_mean_pool(x, batch)
x = self.lin2(x)
return x

def training_step(self, data, batch_idx):
y_hat = self(data.x, data.batch)
loss = F.cross_entropy(y_hat, data.y)
self.train_acc(y_hat.softmax(dim=-1), data.y)
self.log('loss', loss, batch_size=data.num_graphs)
self.log('train_acc', self.train_acc, batch_size=data.num_graphs)
return loss

def validation_step(self, data, batch_idx):
y_hat = self(data.x, data.batch)
self.val_acc(y_hat.softmax(dim=-1), data.y)
self.log('val_acc', self.val_acc, batch_size=data.num_graphs)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@pytest.mark.skipif(no_pytorch_lightning, reason='PL not available')
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('strategy', [None, 'ddp_spawn'])
def test_lightning_dataset(strategy):
import pytorch_lightning as pl

root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
dataset = TUDataset(root, name='MUTAG').shuffle()
train_dataset = dataset[:50]
val_dataset = dataset[50:80]
test_dataset = dataset[80:90]
shutil.rmtree(root)

gpus = 1 if strategy is None else torch.cuda.device_count()
if strategy == 'ddp_spawn':
strategy = pl.plugins.DDPSpawnPlugin(find_unused_parameters=False)

model = LinearGraphModule(dataset.num_features, 64, dataset.num_classes)

trainer = pl.Trainer(strategy=strategy, gpus=gpus, max_epochs=1,
log_every_n_steps=1)
datamodule = LightningDataset(train_dataset, val_dataset, test_dataset,
batch_size=5, num_workers=3)
old_x = train_dataset.data.x.clone()
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)')
trainer.fit(model, datamodule)
new_x = train_dataset.data.x
offset = 10 + 6 + 2 * gpus # `train_steps` + `val_steps` + `sanity`
assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data.
assert trainer._data_connector._val_dataloader_source.is_defined()
assert trainer._data_connector._test_dataloader_source.is_defined()

# Test with `val_dataset=None` and `test_dataset=None`:
warnings.filterwarnings('ignore', '.*Skipping val loop.*')
trainer = pl.Trainer(strategy=strategy, gpus=gpus, max_epochs=1,
log_every_n_steps=1)
datamodule = LightningDataset(train_dataset, batch_size=5, num_workers=3)
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=3, '
'pin_memory=True, persistent_workers=True)')
trainer.fit(model, datamodule)
assert not trainer._data_connector._val_dataloader_source.is_defined()
assert not trainer._data_connector._test_dataloader_source.is_defined()


class LinearNodeModule(LightningModule):
def __init__(self, in_channels, out_channels):
super().__init__()
from torchmetrics import Accuracy

self.lin = torch.nn.Linear(in_channels, out_channels)

self.train_acc = Accuracy()
self.val_acc = Accuracy()

def forward(self, x):
# Basic test to ensure that the dataset is not replicated:
self.trainer.datamodule.data.x.add_(1)

return self.lin(x)

def training_step(self, data, batch_idx):
y_hat = self(data.x)[data.train_mask]
y = data.y[data.train_mask]
loss = F.cross_entropy(y_hat, y)
self.train_acc(y_hat.softmax(dim=-1), y)
self.log('loss', loss, batch_size=y.size(0))
self.log('train_acc', self.train_acc, batch_size=y.size(0))
return loss

def validation_step(self, data, batch_idx):
y_hat = self(data.x)[data.val_mask]
y = data.y[data.val_mask]
self.val_acc(y_hat.softmax(dim=-1), y)
self.log('val_acc', self.val_acc, batch_size=y.size(0))

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@pytest.mark.skipif(no_pytorch_lightning, reason='PL not available')
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('loader', ['full', 'neighbor'])
@pytest.mark.parametrize('strategy', [None, 'ddp_spawn'])
def test_lightning_node_data(strategy, loader):
import pytorch_lightning as pl

root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
dataset = Planetoid(root, name='Cora')
data = dataset[0]
data_repr = ('Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], '
'train_mask=[2708], val_mask=[2708], test_mask=[2708])')
shutil.rmtree(root)

model = LinearNodeModule(dataset.num_features, dataset.num_classes)

if strategy is None or loader == 'full':
gpus = 1
else:
gpus = torch.cuda.device_count()

if strategy == 'ddp_spawn' and loader == 'full':
data = data.cuda() # This is necessary to test sharing of data.

if strategy == 'ddp_spawn':
strategy = pl.plugins.DDPSpawnPlugin(find_unused_parameters=False)

batch_size = 1 if loader == 'full' else 32
num_workers = 0 if loader == 'full' else 3
kwargs, kwargs_repr = {}, ''
if loader == 'neighbor':
kwargs['num_neighbors'] = [5]
kwargs_repr += 'num_neighbors=[5], '

trainer = pl.Trainer(strategy=strategy, gpus=gpus, max_epochs=5,
log_every_n_steps=1)
datamodule = LightningNodeData(data, loader=loader, batch_size=batch_size,
num_workers=num_workers, **kwargs)
old_x = data.x.clone().cpu()
assert str(datamodule) == (f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, {kwargs_repr}'
f'pin_memory={loader != "full"}, '
f'persistent_workers={loader != "full"})')
trainer.fit(model, datamodule)
new_x = data.x.cpu()
if loader == 'full':
offset = 5 + 5 + 1 # `train_steps` + `val_steps` + `sanity`
else:
offset = 0
offset += gpus * 2 # `sanity`
offset += 5 * gpus * math.ceil(140 / (gpus * batch_size)) # `train`
offset += 5 * gpus * math.ceil(500 / (gpus * batch_size)) # `val`
assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data.
assert trainer._data_connector._val_dataloader_source.is_defined()
assert trainer._data_connector._test_dataloader_source.is_defined()


class LinearHeteroNodeModule(LightningModule):
def __init__(self, in_channels, out_channels):
super().__init__()
from torchmetrics import Accuracy

self.lin = torch.nn.Linear(in_channels, out_channels)

self.train_acc = Accuracy()
self.val_acc = Accuracy()

def forward(self, x):
# Basic test to ensure that the dataset is not replicated:
self.trainer.datamodule.data['author'].x.add_(1)

return self.lin(x)

def training_step(self, data, batch_idx):
y_hat = self(data['author'].x)[data['author'].train_mask]
y = data['author'].y[data['author'].train_mask]
loss = F.cross_entropy(y_hat, y)
self.train_acc(y_hat.softmax(dim=-1), y)
self.log('loss', loss, batch_size=y.size(0))
self.log('train_acc', self.train_acc, batch_size=y.size(0))
return loss

def validation_step(self, data, batch_idx):
y_hat = self(data['author'].x)[data['author'].val_mask]
y = data['author'].y[data['author'].val_mask]
self.val_acc(y_hat.softmax(dim=-1), y)
self.log('val_acc', self.val_acc, batch_size=y.size(0))

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)


@pytest.mark.skipif(no_pytorch_lightning, reason='PL not available')
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
def test_lightning_hetero_node_data():
import pytorch_lightning as pl

root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
dataset = DBLP(root)
data = dataset[0]
shutil.rmtree(root)

model = LinearHeteroNodeModule(data['author'].num_features,
int(data['author'].y.max()) + 1)

gpus = torch.cuda.device_count()
strategy = pl.plugins.DDPSpawnPlugin(find_unused_parameters=False)

trainer = pl.Trainer(strategy=strategy, gpus=gpus, max_epochs=5,
log_every_n_steps=1)
datamodule = LightningNodeData(data, loader='neighbor', num_neighbors=[5],
batch_size=32, num_workers=3)
old_x = data['author'].x.clone()
trainer.fit(model, datamodule)
new_x = data['author'].x
offset = 0
offset += gpus * 2 # `sanity`
offset += 5 * gpus * math.ceil(400 / (gpus * 32)) # `train`
offset += 5 * gpus * math.ceil(400 / (gpus * 32)) # `val`
assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data.
assert trainer._data_connector._val_dataloader_source.is_defined()
assert trainer._data_connector._test_dataloader_source.is_defined()
3 changes: 2 additions & 1 deletion torch_geometric/data/__init__.py
Expand Up @@ -4,7 +4,7 @@
from .batch import Batch
from .dataset import Dataset
from .in_memory_dataset import InMemoryDataset
from .lightning_data_module import LightningDataset
from .lightning_datamodule import LightningDataset, LightningNodeData
from .download import download_url
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz

Expand All @@ -16,6 +16,7 @@
'Dataset',
'InMemoryDataset',
'LightningDataset',
'LightningNodeData',
'download_url',
'extract_tar',
'extract_zip',
Expand Down

0 comments on commit 8012326

Please sign in to comment.