Skip to content

Commit

Permalink
to_data_list functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 4, 2019
1 parent 4e43734 commit 8b7ef55
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 26 deletions.
9 changes: 9 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,13 @@ def test_batch():
assert data.batch.tolist() == [0, 0, 0, 1, 1]
assert data.num_graphs == 2

data_list = data.to_data_list()
assert len(data_list) == 2
assert len(data_list[0]) == 2
assert data_list[0].x.tolist() == [1, 2, 3]
assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
assert len(data_list[1]) == 2
assert data_list[1].x.tolist() == [1, 2]
assert data_list[1].edge_index.tolist() == [[0, 1], [1, 0]]

torch_geometric.set_debug(True)
49 changes: 39 additions & 10 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ class Batch(Data):

def __init__(self, batch=None, **kwargs):
super(Batch, self).__init__(**kwargs)

self.batch = batch
self.__data_class__ = Data
self.__slices__ = None

@staticmethod
def from_data_list(data_list, follow_batch=[]):
Expand All @@ -28,26 +31,26 @@ def from_data_list(data_list, follow_batch=[]):
assert 'batch' not in keys

batch = Batch()
batch.__data_class__ = data_list[0].__class__
batch.__slices__ = {key: [0] for key in keys}

for key in keys:
batch[key] = []

for key in follow_batch:
batch['{}_batch'.format(key)] = []

cumsum = {}
cumsum = {key: 0 for key in keys}
batch.batch = []
for i, data in enumerate(data_list):
for key in data.keys:
item = data[key] + cumsum.get(key, 0)
if key in cumsum:
cumsum[key] += data.__inc__(key, item)
else:
cumsum[key] = data.__inc__(key, item)
item = data[key] + cumsum[key]
size = data[key].size(data.__cat_dim__(key, data[key]))
batch.__slices__[key].append(size + batch.__slices__[key][-1])
cumsum[key] += data.__inc__(key, item)
batch[key].append(item)

for key in follow_batch:
size = data[key].size(data.__cat_dim__(key, data[key]))
item = torch.full((size, ), i, dtype=torch.long)
batch['{}_batch'.format(key)].append(item)

Expand All @@ -62,12 +65,12 @@ def from_data_list(data_list, follow_batch=[]):
for key in batch.keys:
item = batch[key][0]
if torch.is_tensor(item):
batch[key] = torch.cat(
batch[key], dim=data_list[0].__cat_dim__(key, item))
batch[key] = torch.cat(batch[key],
dim=data_list[0].__cat_dim__(key, item))
elif isinstance(item, int) or isinstance(item, float):
batch[key] = torch.tensor(batch[key])
else:
raise ValueError('Unsupported attribute type.')
raise ValueError('Unsupported attribute type')

# Copy custom data functions to batch (does not work yet):
# if data_list.__class__ != Data:
Expand All @@ -82,6 +85,32 @@ def from_data_list(data_list, follow_batch=[]):

return batch.contiguous()

def to_data_list(self):
r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
from the batch object.
The batch object must have been created via :meth:`from_data_list` in
order to be able reconstruct the initial objects."""

if self.__slices__ is None:
raise RuntimeError(
('Cannot reconstruct data list from batch because the batch '
'object was not created using Batch.from_data_list()'))

keys = [key for key in self.keys if key[-5:] != 'batch']
cumsum = {key: 0 for key in keys}
data_list = []
for i in range(len(self.__slices__[keys[0]]) - 1):
data = self.__data_class__()
for key in keys:
data[key] = self[key].narrow(
data.__cat_dim__(key, self[key]), self.__slices__[key][i],
self.__slices__[key][i + 1] - self.__slices__[key][i])
data[key] = data[key] - cumsum[key]
cumsum[key] += data.__inc__(key, data[key])
data_list.append(data)

return data_list

@property
def num_graphs(self):
"""Returns the number of graphs in the batch."""
Expand Down
14 changes: 3 additions & 11 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,8 @@ class Data(object):
data.test_mask = torch.tensor([...], dtype=torch.uint8)
"""

def __init__(self,
x=None,
edge_index=None,
edge_attr=None,
y=None,
pos=None,
norm=None,
face=None,
**kwargs):
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
pos=None, norm=None, face=None, **kwargs):
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
Expand Down Expand Up @@ -107,8 +100,7 @@ def __setitem__(self, key, value):
def keys(self):
r"""Returns all names of graph attributes."""
keys = [key for key in self.__dict__.keys() if self[key] is not None]
if '__num_nodes__' in keys:
keys.remove('__num_nodes__')
keys = [key for key in keys if key[:2] != '__' and key[-2:] != '__']
return keys

def __len__(self):
Expand Down
7 changes: 2 additions & 5 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def process(self):
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError

def __init__(self,
root,
transform=None,
pre_transform=None,
def __init__(self, root, transform=None, pre_transform=None,
pre_filter=None):
super(InMemoryDataset, self).__init__(root, transform, pre_transform,
pre_filter)
Expand Down Expand Up @@ -136,7 +133,7 @@ def collate(self, data_list):
elif isinstance(item[key], int) or isinstance(item[key], float):
s = slices[key][-1] + 1
else:
raise ValueError('Unsupported attribute type.')
raise ValueError('Unsupported attribute type')
slices[key].append(s)

if hasattr(data_list[0], '__num_nodes__'):
Expand Down

0 comments on commit 8b7ef55

Please sign in to comment.