Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple node type sampling in NeighborLoader (V2) #5521

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

Padarn
Copy link
Contributor

@Padarn Padarn commented Sep 24, 2022

This PR adds functionality to allow for multiple node types to be sampled in NeighbourLoader.

The interface looks as was discussed in the roadmap (#4765):

NeighbourLoader(
   input_nodes=dict(
     paper=torch.LongTensor([0,1,2]), 
     author=torch.LongTensor([0,1,2])
    )
  ...
)

Internally, it converts this to a list of tuples.

[('paper', 0), ('paper', 1),....]

This is not very efficient, but benchmarks #4765 (comment) showed it to be acceptable.

TODO:

  • Add tests

Addresses #4765

cc @mananshah99

@Padarn
Copy link
Contributor Author

Padarn commented Sep 24, 2022

@mananshah99 I was trying to figure out how to cleanly add this, and the easiest seemed to be to do something like what I've done here: Wrapping the 'input nodes' up into a class.

I know the code is probably not in the right place, and I've not actually used it to support multiple nodes yet, but I wanted to get feedback on this approach.

@mananshah99
Copy link
Contributor

Sorry, just saw this in my inbox. Will have a review in shortly

@Padarn
Copy link
Contributor Author

Padarn commented Oct 23, 2022

hey just a bump @mananshah99 :-)

@rusty1s
Copy link
Member

rusty1s commented Oct 24, 2022

I'll try to get this merged today. Sorry for the slowness on our end :)

@Padarn
Copy link
Contributor Author

Padarn commented Oct 25, 2022

thanks! but actually its no where near ready for merging - just looking for feedback on the approach before I implement it more thoroughly

Copy link
Contributor

@mananshah99 mananshah99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments. Overall, I think this design makes sense, summarizing my understanding below (feel free to correct me if I am mistaken):

  • We will be supporting dictionaries of input nodes to a loader
  • Internally, the loader converts this to a SamplingInputNodes representation, which has methods to convert between this dict representation and a flattened list that we need for the PyTorch dataloader to define the list of samples to batch properly (I wonder if we can override that to work on dicts instead of needing to flatten here...)
  • Within collate_fn, we will re-convert back to a dict representation containing potentially multiple node types that we then pass to the sampling implementation (which needs to support this now).

Comment on lines +215 to +216
def node_types(self) -> Tuple[Optional[str]]:
return tuple(self.input_nodes.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return a list instead of a tuple?

Comment on lines +229 to +233
@property
def as_list(self) -> Tuple[Tuple[str, int]]:
return tuple([(node_type, int(i))
for node_type, tensor in self.input_nodes.items()
for i in tensor])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need this (and the below function) so that we can pass the right data to the PyTorch DataLoader constructor; perhaps we can leave a note here explaining that?

super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)

def collate_fn(self, index: NodeSamplerInput) -> Any:
r"""Samples a subgraph from a batch of input nodes."""
input_data: NodeSamplerInput = self.input_data[index]
input_nodes: NodeSamplerInput = self.input_nodes[index]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct in understanding that, in the final version, InputData will be responsible for producing a Dict[str, NodeSamplerInput] (e.g. using from_list), and we will then sample on this dictionary by providing support for this in the sampler implementation?

Comment on lines +210 to +211
@dataclass
class SamplingInputNodes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few thoughts:

  • SamplingInputNodes and InputNodes as types have different functionalities, but almost identical names. Thoughts on changing this name (e.g. to something like SamplingInput) to increase the edit distance between them?
  • Would it be better to have input_nodes: InputNodes? I am not sure I understand the need for a separate Dict[Optional[str], Sequence here. That would also make the relationship between this class and InputNodes more clear.
  • Am I correct in my understanding that we will pass sampling_input_nodes.as_list() to the DataLoader constructor, and once batches are selected by the torch DataLoader, we will convert them from_list back to a dict representation that we then pass to the sampler?

@Padarn
Copy link
Contributor Author

Padarn commented Nov 8, 2022

Your understanding is correct. Thanks for the feedback.

I will finish the implementation taking your comments into account 👍

@denadai2
Copy link
Contributor

Your understanding is correct. Thanks for the feedback.

I will finish the implementation taking your comments into account 👍

Hi @Padarn, thanks for your contributioooon!! Would you still be interested on finishing it? It would be amazing to have this feature

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants