Skip to content

Commit

Permalink
Fixed node mapping in RCDD dataset (#9234)
Browse files Browse the repository at this point in the history
Fix #9212
  • Loading branch information
EdisonLeeeee committed Apr 26, 2024
1 parent 212c4ce commit 2567552
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fix node mapping bug in `RCDD` dataset ([#9234](https://github.com/pyg-team/pytorch_geometric/pull/9234))
- Fixed incorrect treatment of `edge_label` and `edge_label_index` in `ToSparseTensor` transform ([#9199](https://github.com/pyg-team/pytorch_geometric/pull/9199))
- Fixed `EgoData` processing in `SnapDataset` in case filenames are unsorted ([#9195](https://github.com/pyg-team/pytorch_geometric/pull/9195))
- Fixed empty graph and isolated node handling in `to_dgl` ([#9188](https://github.com/pyg-team/pytorch_geometric/pull/9188))
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/datasets/rcdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def process(self) -> None:
mapping = torch.empty(len(node_df), dtype=torch.long)
for node_type in node_df['node_type'].unique():
mask = node_df['node_type'] == node_type
mask = torch.from_numpy(mask.values)
num_nodes = int(mask.sum())
mapping[mask] = torch.arange(num_nodes)
node_id = torch.from_numpy(node_df['node_id'][mask].values)
num_nodes = mask.sum()
mapping[node_id] = torch.arange(num_nodes)
data[node_type].num_nodes = num_nodes
x = np.vstack([
np.asarray(f.split(':'), dtype=np.float32)
for f in node_df['node_feat'][mask.numpy()]
for f in node_df['node_feat'][mask]
])
data[node_type].x = torch.from_numpy(x)

Expand Down

0 comments on commit 2567552

Please sign in to comment.