-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multiple node type sampling in NeighborLoader (V2) #5521
base: master
Are you sure you want to change the base?
Conversation
@mananshah99 I was trying to figure out how to cleanly add this, and the easiest seemed to be to do something like what I've done here: Wrapping the 'input nodes' up into a class. I know the code is probably not in the right place, and I've not actually used it to support multiple nodes yet, but I wanted to get feedback on this approach. |
Sorry, just saw this in my inbox. Will have a review in shortly |
hey just a bump @mananshah99 :-) |
I'll try to get this merged today. Sorry for the slowness on our end :) |
thanks! but actually its no where near ready for merging - just looking for feedback on the approach before I implement it more thoroughly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments. Overall, I think this design makes sense, summarizing my understanding below (feel free to correct me if I am mistaken):
- We will be supporting dictionaries of input nodes to a loader
- Internally, the loader converts this to a
SamplingInputNodes
representation, which has methods to convert between this dict representation and a flattened list that we need for the PyTorch dataloader to define the list of samples to batch properly (I wonder if we can override that to work on dicts instead of needing to flatten here...) - Within
collate_fn
, we will re-convert back to a dict representation containing potentially multiple node types that we then pass to the sampling implementation (which needs to support this now).
def node_types(self) -> Tuple[Optional[str]]: | ||
return tuple(self.input_nodes.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: return a list instead of a tuple?
@property | ||
def as_list(self) -> Tuple[Tuple[str, int]]: | ||
return tuple([(node_type, int(i)) | ||
for node_type, tensor in self.input_nodes.items() | ||
for i in tensor]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we need this (and the below function) so that we can pass the right data to the PyTorch DataLoader constructor; perhaps we can leave a note here explaining that?
super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) | ||
|
||
def collate_fn(self, index: NodeSamplerInput) -> Any: | ||
r"""Samples a subgraph from a batch of input nodes.""" | ||
input_data: NodeSamplerInput = self.input_data[index] | ||
input_nodes: NodeSamplerInput = self.input_nodes[index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I correct in understanding that, in the final version, InputData
will be responsible for producing a Dict[str, NodeSamplerInput]
(e.g. using from_list
), and we will then sample on this dictionary by providing support for this in the sampler implementation?
@dataclass | ||
class SamplingInputNodes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few thoughts:
SamplingInputNodes
andInputNodes
as types have different functionalities, but almost identical names. Thoughts on changing this name (e.g. to something likeSamplingInput
) to increase the edit distance between them?- Would it be better to have
input_nodes: InputNodes
? I am not sure I understand the need for a separateDict[Optional[str], Sequence
here. That would also make the relationship between this class andInputNodes
more clear. - Am I correct in my understanding that we will pass
sampling_input_nodes.as_list()
to the DataLoader constructor, and once batches are selected by the torch DataLoader, we will convert themfrom_list
back to a dict representation that we then pass to the sampler?
Your understanding is correct. Thanks for the feedback. I will finish the implementation taking your comments into account 👍 |
Hi @Padarn, thanks for your contributioooon!! Would you still be interested on finishing it? It would be amazing to have this feature |
This PR adds functionality to allow for multiple node types to be sampled in
NeighbourLoader
.The interface looks as was discussed in the roadmap (#4765):
Internally, it converts this to a list of tuples.
This is not very efficient, but benchmarks #4765 (comment) showed it to be acceptable.
TODO:
Addresses #4765
cc @mananshah99