Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List node features can not be sliced with subgraph #6162

Open
euerka opened this issue Dec 7, 2022 — with Slack · 13 comments
Open

List node features can not be sliced with subgraph #6162

euerka opened this issue Dec 7, 2022 — with Slack · 13 comments

Comments

Copy link

euerka commented Dec 7, 2022

How to pass subgraph Data as a node feature into main graph Data?
test_data=Data(edge_index=edge_index,pos=torch.tensor(np.array(pos)), graph=subgraphList)
subgraphList is as follows:

 Data(x=[1, 3], edge_index=[2, 0], pos=[1, 2]),
 ...
 Data(x=[10, 3], edge_index=[2, 18])]```
The issue is that when I tried to mask some nodes:
`mask_test_data=test_data.subgraph(mask)`
mask_test_data.graph is the same as test_data.graph, while the size of mask_test_data.pos is reduced from test_data.pos.

After testing, I found the issue is caused by the type of list. May I know whether there is another type of container in which I can store the subgraph Data, and it can auto-align with  `test_data.subgraph(mask)`?

[Slack Message](https://torchgeometricco.slack.com/archives/C01DN0B3B1N/p1670346485240969?thread_ts=1670346485.240969&cid=C01DN0B3B1N)
@rusty1s
Copy link
Member

rusty1s commented Dec 7, 2022

Thanks for the issue.

@cemunds
Copy link
Contributor

cemunds commented Feb 1, 2023

Hi, I think I have a similar problem. Each node in my graph has a whole point cloud associated with it. So, for example, one graph would look like this:

Data(edge_index=[2, 2234], y=[1561], pos=[1561, 3], point_cloud=[1561],  x=[1561, 51])

Where point_cloud is itself a list of Data objects representing the point cloud for each node. Obviously, this gets too big to load into my GPU memory for bigger graphs, so I am testing some sampling strategies. When experimenting with RandomNodeLoader for example, in my training loop, I get a batch like this:

Data(edge_index=[2, 120], y=[65], pos=[65, 3], point_cloud=[1561], x=[65, 51], batch=[65], ptr=[8])

The point_cloud attribute still contains all point clouds in my graph instead of containing only the point clouds for the sampled subgraph. I looked into the RandomNodeLoader (which uses the subgraph function of the Data object) and noticed that only Tensor objects are considered when overwriting the Data object's fields according to the sampled subset. So, since the point_cloud field in my Data object is a vanilla python list of Data objects, it is not processed.

It would be very helpful to also have vanilla lists be considered or have a special data structure for fields in a Data object that itself is a list of Data objects.

@rusty1s
Copy link
Member

rusty1s commented Feb 2, 2023

I think this is a good suggestion - we currently don't do this. Let me see what I can do.

@cemunds
Copy link
Contributor

cemunds commented Feb 3, 2023

Thanks for the quick response. As this feature is also somewhat crucial to my own research, I would be glad to help with implementing this once it's decided whether you want to go for a custom data structure, a vanilla list, or something else.

@rusty1s
Copy link
Member

rusty1s commented Feb 4, 2023

I think all we need to do is filter out the elements of a list whenever we encounter a list with length being equal to the number of nodes, e.g., here.

@cemunds
Copy link
Contributor

cemunds commented Feb 6, 2023

I experimented a bit more with my dataset and I'm preparing a pull request for this. However, I think to fully support list fields and get the behavior that I would expect, I would also need to make adjustments to torch_geometric.data.collate and consquently to torch_geometric.data.separate.

@rusty1s
Copy link
Member

rusty1s commented Feb 7, 2023

Do you think so? I think they should be already working as expected (as they don't do any filtering).

@cemunds
Copy link
Contributor

cemunds commented Feb 7, 2023

Only adjusting the filtering works if the dataset consists of just one graph, e.g:

Data(edge_index=[2, 2234], y=[1561], pos=[1561, 3], point_cloud=[1561],  x=[1561, 51])

If I adjust the filtering in Data.subgraph and pass this graph to a RandomNodeLoader for example, I would get the expected result:

Data(edge_index=[2, 56], y=[63], pos=[63, 3], point_cloud=[63], x=[63, 51], batch=[63], ptr=[3])

However, the dataset that I am working with consists of multiple such graphs. I am using collate at the end of my dataset's process function before saving the data to disk. Without any adjustments to collate, I get the following:

dataset
> MyDataset(15)
dataset[0]
> Data(edge_index=[2, 2234], y=[1561], pos=[1561, 3], point_cloud=[1561], x=[1561, 51])
dataset.data
> Data(edge_index=[2, 105427], y=[70641], pos=[70641, 3], point_cloud=[15])

The point_cloud field gets collated into a list of lists of Data objects. Even with adjustments to Data.subgraph, when passing this Data object to RandomNodeLoader, the point_cloud field would not get filtered correctly, because the length does not correspond to the number of nodes.

I know that passing this big Data object with multiple independent graphs into a RandomNodeLoader might not make much sense from a training perspective in the first place. I'm still experimenting with how to best split up my graphs in order to train on them. I'm just using RandomNodeLoader to illustrate the point here, as it uses Data.subgraph internally.

@cemunds
Copy link
Contributor

cemunds commented Feb 7, 2023

I'm testing ClusterData and ClusterLoader right now to break down my graphs into mini-batches. There are some places where the filtering would need to be adjusted, e.g. here and here. However, I feel like adding something along the lines of

if isinstance(value, list):
    out[key] = [value[i] for i in node_idx]

in all of these places does not seem very elegant. I suppose torch_geometric.loader.utils.index_select could be extended for the case of list fields and then used in those places.

However, ClusterData selects items by range instead of index in its __getitem__ method here. So that would require something like

if isinstance(item, list) and len(item) == N:
    data[key] = item[start:start+length]

Essentially we would have to add if clauses in a bunch of places to treat lists sperately from tensor fields or abstract this away somehow. What are your thoughts on this?

@rusty1s
Copy link
Member

rusty1s commented Feb 8, 2023

Yes, this is tricky. For collate, I assume you are right - in this case it seems more natural to also extend the list instead of merging it into a new one.

To support this more elegantly, I am afraid we need to add some kind of "index_select" interface that also supports lists. Extending index_select sounds like a good option, but it will currently only get called on Tensor values and that is hidden behind a lot of places in PyG (see is_node_attr references).

@cemunds
Copy link
Contributor

cemunds commented Feb 9, 2023

Yes, I also think some kind of interface that supports lists would be a better option. I made some quick and dirty adjustments in a few places in my own copy of PyG in order to run my experiments and it's working at the moment, but I only adjusted the places that I currently need, such as the mentioned Data.subgraph, collate, separate, ClusterData and ClusterLoader, and it doesn't look very nice without a proper abstraction. What would you propose as the course of action to implement this feature? I'd love to help with this, but I'm not sure if I'm familiar enough with the inner workings of PyG yet to implement this interface.

@rusty1s
Copy link
Member

rusty1s commented Feb 10, 2023

Personally, I would prefer to just start with filtering, and modify the places accordingly, i.e., subgraph, ClusterData, etc speak through a unified interface to perform filtering of node and edge attributes, e.g., in utils/`:

def filter_node_attributes(data, node_index_or_mask):
    ...

which is then called in each of these functions. WDYT?

@cemunds
Copy link
Contributor

cemunds commented Feb 15, 2023

Sounds good to me. I'll see what I can do and prepare a pull request.

rusty1s added a commit that referenced this issue Feb 28, 2023
This adds a unified interface to perform filtering of node attributes in
order to support attributes of type `list` inside `Data` objects. For
now, this is only used in `Data.subgraph`, `ClusterData` and
`ClusterLoader`, but should probably be extended in the future, as
discussed in issue #6162.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants