Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the interplay between TUDataset and pre_transform that modify node features #4669

Merged
merged 3 commits into from May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed the interplay between `TUDataset` and `pre_transform` that modify node features ([#4669](https://github.com/pyg-team/pytorch_geometric/pull/4669))
- Make use of the `pyg_sphinx_theme` documentation template ([#4664](https://github.com/pyg-team/pyg-lib/pull/4664), [#4667](https://github.com/pyg-team/pyg-lib/pull/4667))
- Refactored reading molecular positions from sdf file for qm9 datasets ([4654](https://github.com/pyg-team/pytorch_geometric/pull/4654))
- Fixed `MLP.jittable()` bug in case `return_emb=True` ([#4645](https://github.com/pyg-team/pytorch_geometric/pull/4645), [#4648](https://github.com/pyg-team/pytorch_geometric/pull/4648))
Expand Down
38 changes: 16 additions & 22 deletions torch_geometric/datasets/tu_dataset.py
Expand Up @@ -121,7 +121,16 @@ def __init__(self, root: str, name: str,
self.name = name
self.cleaned = cleaned
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])

out = torch.load(self.processed_paths[0])
if not isinstance(out, tuple) and len(out) != 3:
raise RuntimeError(
"The 'data' object was created by an older version of PyG. "
"If this error occurred while loading an already existing "
"dataset, remove the 'processed/' directory in the dataset's "
"root folder and try again.")
self.data, self.slices, self.sizes = out

if self.data.x is not None and not use_node_attr:
num_node_attributes = self.num_node_attributes
self.data.x = self.data.x[:, num_node_attributes:]
Expand All @@ -141,34 +150,19 @@ def processed_dir(self) -> str:

@property
def num_node_labels(self) -> int:
if self.data.x is None:
return 0
for i in range(self.data.x.size(1)):
x = self.data.x[:, i:]
if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
return self.data.x.size(1) - i
return 0
return self.sizes['num_node_labels']

@property
def num_node_attributes(self) -> int:
if self.data.x is None:
return 0
return self.data.x.size(1) - self.num_node_labels
return self.sizes['num_node_attributes']

@property
def num_edge_labels(self) -> int:
if self.data.edge_attr is None:
return 0
for i in range(self.data.edge_attr.size(1)):
if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0):
return self.data.edge_attr.size(1) - i
return 0
return self.sizes['num_edge_labels']

@property
def num_edge_attributes(self) -> int:
if self.data.edge_attr is None:
return 0
return self.data.edge_attr.size(1) - self.num_edge_labels
return self.sizes['num_edge_attributes']

@property
def raw_file_names(self) -> List[str]:
Expand All @@ -189,7 +183,7 @@ def download(self):
os.rename(osp.join(folder, self.name), self.raw_dir)

def process(self):
self.data, self.slices = read_tu_data(self.raw_dir, self.name)
self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)

if self.pre_filter is not None:
data_list = [self.get(idx) for idx in range(len(self))]
Expand All @@ -201,7 +195,7 @@ def process(self):
data_list = [self.pre_transform(data) for data in data_list]
self.data, self.slices = self.collate(data_list)

torch.save((self.data, self.slices), self.processed_paths[0])
torch.save((self.data, self.slices, sizes), self.processed_paths[0])

def __repr__(self) -> str:
return f'{self.name}({len(self)})'
21 changes: 17 additions & 4 deletions torch_geometric/io/tu.py
Expand Up @@ -24,9 +24,11 @@ def read_tu_data(folder, prefix):
edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1
batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1

node_attributes = node_labels = None
node_attributes = torch.empty((batch.size(0), 0))
if 'node_attributes' in names:
node_attributes = read_file(folder, prefix, 'node_attributes')

node_labels = torch.empty((batch.size(0), 0))
if 'node_labels' in names:
node_labels = read_file(folder, prefix, 'node_labels', torch.long)
if node_labels.dim() == 1:
Expand All @@ -35,11 +37,12 @@ def read_tu_data(folder, prefix):
node_labels = node_labels.unbind(dim=-1)
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
x = cat([node_attributes, node_labels])

edge_attributes, edge_labels = None, None
edge_attributes = torch.empty((edge_index.size(1), 0))
if 'edge_attributes' in names:
edge_attributes = read_file(folder, prefix, 'edge_attributes')

edge_labels = torch.empty((edge_index.size(1), 0))
if 'edge_labels' in names:
edge_labels = read_file(folder, prefix, 'edge_labels', torch.long)
if edge_labels.dim() == 1:
Expand All @@ -48,6 +51,8 @@ def read_tu_data(folder, prefix):
edge_labels = edge_labels.unbind(dim=-1)
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)

x = cat([node_attributes, node_labels])
edge_attr = cat([edge_attributes, edge_labels])

y = None
Expand All @@ -65,7 +70,14 @@ def read_tu_data(folder, prefix):
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data, slices = split(data, batch)

return data, slices
sizes = {
'num_node_attributes': node_attributes.size(-1),
'num_node_labels': node_labels.size(-1),
'num_edge_attributes': edge_attributes.size(-1),
'num_edge_labels': edge_labels.size(-1),
}

return data, slices, sizes


def read_file(folder, prefix, name, dtype=None):
Expand All @@ -75,6 +87,7 @@ def read_file(folder, prefix, name, dtype=None):

def cat(seq):
seq = [item for item in seq if item is not None]
seq = [item for item in seq if item.numel() > 0]
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
return torch.cat(seq, dim=-1) if len(seq) > 0 else None

Expand Down