Skip to content
Permalink
Browse files

added toslic transform + more modular dataloader

  • Loading branch information
rusty1s committed Nov 13, 2019
1 parent 70e575d commit 00e5db93194c5c187ef2792ceeb782d2153251ae
@@ -117,7 +117,7 @@ You can find a complete list of all methods at :class:`torch_geometric.data.Data
Common Benchmark Datasets
-------------------------

PyTorch Geometric contains a large number of common benchmark datasets, *e.g.* all Planetoid datasets (Cora, Citeseer, Pubmed), all graph classification datasets from `http://graphkernels.cs.tu-dortmund.de/ <http://graphkernels.cs.tu-dortmund.de/>`_ and their clean versions that contain only non-isomorphic graphs from `https://raw.githubusercontent.com/nd7141/graph_datasets/master/datasets`_, the QM7 and QM9 dataset, and a handful of 3D mesh/point cloud datasets like FAUST, ModelNet10/40 and ShapeNet.
PyTorch Geometric contains a large number of common benchmark datasets, *e.g.* all Planetoid datasets (Cora, Citeseer, Pubmed), all graph classification datasets from `http://graphkernels.cs.tu-dortmund.de/ <http://graphkernels.cs.tu-dortmund.de/>`_ and their `cleaned versions <https://github.com/nd7141/graph_datasets>`_, the QM7 and QM9 dataset, and a handful of 3D mesh/point cloud datasets like FAUST, ModelNet10/40 and ShapeNet.

Initializing a dataset is straightforward.
An initialization of a dataset will automatically download its raw files and process them to the previously described ``Data`` format.
@@ -8,6 +8,7 @@
'scipy',
'networkx',
'scikit-learn',
'scikit-image',
'requests',
'plyfile',
'pandas',
@@ -0,0 +1,85 @@
import sys
import random
import os.path as osp
import shutil

import torch
from torchvision.datasets.mnist import MNIST, read_image_file, read_label_file
import torchvision.transforms as T
from torch_geometric.data import download_url, extract_gz, DataLoader
from torch_geometric.data.makedirs import makedirs
from torch_geometric.transforms import ToSLIC

resources = [
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]


def test_to_superpixels():
root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))

raw_folder = osp.join(root, 'MNIST', 'raw')
processed_folder = osp.join(root, 'MNIST', 'processed')

makedirs(raw_folder)
makedirs(processed_folder)
for resource in resources:
path = download_url(resource, raw_folder)
extract_gz(path, osp.join(root, raw_folder))

test_set = (
read_image_file(osp.join(raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(osp.join(raw_folder, 't10k-labels-idx1-ubyte')),
)

torch.save(test_set, osp.join(processed_folder, 'training.pt'))
torch.save(test_set, osp.join(processed_folder, 'test.pt'))

dataset = MNIST(root, download=False)

dataset.transform = T.Compose([T.ToTensor(), ToSLIC()])

data, y = dataset[0]
assert len(data) == 2
assert data.pos.dim() == 2 and data.pos.size(1) == 2
assert data.x.dim() == 2 and data.x.size(1) == 1
assert data.pos.size(0) == data.x.size(0)
assert y == 7

loader = DataLoader(dataset, batch_size=2, shuffle=False)
for data, y in loader:
assert len(data) == 3
assert data.pos.dim() == 2 and data.pos.size(1) == 2
assert data.x.dim() == 2 and data.x.size(1) == 1
assert data.batch.dim() == 1
assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
assert y.tolist() == [7, 2]
break

dataset.transform = T.Compose(
[T.ToTensor(), ToSLIC(add_seg=True, add_img=True)])

data, y = dataset[0]
assert len(data) == 4
assert data.pos.dim() == 2 and data.pos.size(1) == 2
assert data.x.dim() == 2 and data.x.size(1) == 1
assert data.pos.size(0) == data.x.size(0)
assert data.seg.size() == (1, 28, 28)
assert data.img.size() == (1, 1, 28, 28)
assert data.seg.max().item() + 1 == data.x.size(0)
assert y == 7

loader = DataLoader(dataset, batch_size=2, shuffle=False)
for data, y in loader:
assert len(data) == 5
assert data.pos.dim() == 2 and data.pos.size(1) == 2
assert data.x.dim() == 2 and data.x.size(1) == 1
assert data.batch.dim() == 1
assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
assert data.seg.size() == (2, 28, 28)
assert data.img.size() == (2, 1, 28, 28)
assert y.tolist() == [7, 2]
break

shutil.rmtree(root)
@@ -1,7 +1,8 @@
import torch.utils.data
from torch.utils.data.dataloader import default_collate

from torch_geometric.data import Batch
from torch_geometric.data import Data, Batch
from torch._six import container_abcs, string_classes, int_classes


class DataLoader(torch.utils.data.DataLoader):
@@ -17,20 +18,31 @@ class DataLoader(torch.utils.data.DataLoader):
follow_batch (list or tuple, optional): Creates assignment batch
vectors for each key in the list. (default: :obj:`[]`)
"""

def __init__(self,
dataset,
batch_size=1,
shuffle=False,
follow_batch=[],
def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=[],
**kwargs):
super(DataLoader, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=lambda data_list: Batch.from_data_list(
data_list, follow_batch),
**kwargs)
def collate(batch):
elem = batch[0]
if isinstance(elem, Data):
return Batch.from_data_list(batch, follow_batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return type(elem)(*(collate(s) for s in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
return [collate(s) for s in zip(*batch)]

raise TypeError('DataLoader found invalid type: {}'.format(
type(elem)))

super(DataLoader,
self).__init__(dataset, batch_size, shuffle,
collate_fn=lambda batch: collate(batch), **kwargs)


class DataListLoader(torch.utils.data.DataLoader):
@@ -49,14 +61,10 @@ class DataListLoader(torch.utils.data.DataLoader):
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch (default: :obj:`False`)
"""

def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
super(DataListLoader, self).__init__(
dataset,
batch_size,
shuffle,
collate_fn=lambda data_list: data_list,
**kwargs)
super(DataListLoader,
self).__init__(dataset, batch_size, shuffle,
collate_fn=lambda data_list: data_list, **kwargs)


class DenseDataLoader(torch.utils.data.DataLoader):
@@ -77,13 +85,13 @@ class DenseDataLoader(torch.utils.data.DataLoader):
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch (default: :obj:`False`)
"""

def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
def dense_collate(data_list):
batch = Batch()
for key in data_list[0].keys:
batch[key] = default_collate([d[key] for d in data_list])
return batch

super(DenseDataLoader, self).__init__(
dataset, batch_size, shuffle, collate_fn=dense_collate, **kwargs)
super(DenseDataLoader,
self).__init__(dataset, batch_size, shuffle,
collate_fn=dense_collate, **kwargs)
@@ -32,6 +32,7 @@
from .laplacian_lambda_max import LaplacianLambdaMax
from .generate_mesh_normals import GenerateMeshNormals
from .delaunay import Delaunay
from .to_superpixels import ToSLIC

__all__ = [
'Compose',
@@ -68,4 +69,5 @@
'LaplacianLambdaMax',
'GenerateMeshNormals',
'Delaunay',
'ToSLIC',
]
@@ -0,0 +1,63 @@
import torch
from skimage.segmentation import slic
from torch_scatter import scatter_mean
from torch_geometric.data import Data


class ToSLIC(object):
r"""Converts an image to a superpixel representation using the
:meth:`skimage.segmentation.slic` algorithm, resulting in a
:obj:`torch_geometric.data.Data` object holding the centroids of
superpixels in :obj:`pos` and their mean color in :obj:`x`.
This transform can be used with any :obj:`torchvision` dataset:
Example::
from torchvision.datasets import MNIST
from torch_geometric.transforms import ToSLIC
transform = T.Compose([T.ToTensor(), ToSLIC(n_segments=75)])
dataset = MNIST('/tmp/MNIST', download=True, transform=transform)
Args:
add_seg (bool, optional): If set to `True`, will add the segmentation
result to the data object. (default: :obj:`False`)
add_img (bool, optional): If set to `True`, will add the input image
to the data object. (default: :obj:`False`)
**kwargs (optional): Arguments to adjust the output of the SLIC
algorithm. See the `SLIC documentation
<https://scikit-image.org/docs/dev/api/skimage.segmentation.html
#skimage.segmentation.slic>`_ for an overview.
"""
def __init__(self, add_seg=False, add_img=False, **kwargs):
self.add_seg = add_seg
self.add_img = add_img
self.kwargs = kwargs

def __call__(self, img):
img = img.permute(1, 2, 0)
h, w, c = img.size()

seg = slic(img.to(torch.double).numpy(), **self.kwargs)
seg = torch.from_numpy(seg)

x = scatter_mean(img.view(h * w, c), seg.view(h * w), dim=0)

pos_y = torch.arange(h, dtype=torch.float)
pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w)
pos_x = torch.arange(w, dtype=torch.float)
pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w)

pos = torch.stack([pos_x, pos_y], dim=-1)
pos = scatter_mean(pos, seg.view(h * w), dim=0)

data = Data(x=x, pos=pos)

if self.add_seg:
data.seg = seg.view(1, h, w)

if self.add_img:
data.img = img.permute(2, 0, 1).view(1, c, h, w)

return data

0 comments on commit 00e5db9

Please sign in to comment.
You can’t perform that action at this time.