Skip to content

Commit

Permalink
add HeteroJK class (#9380)
Browse files Browse the repository at this point in the history
Hi,

this PR implements #9355, please do let me know if any additional
changes or documentation is needed.
Thank you

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
m-atalla and rusty1s authored Jun 6, 2024
1 parent 476e768 commit bb131e8
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the heterogeneous `HeteroJumpingKnowledge` module ([#9380](https://github.com/pyg-team/pytorch_geometric/pull/9380))
- Started work on GNN+LLM package ([#9350](https://github.com/pyg-team/pytorch_geometric/pull/9350))
- Added support for negative sampling in `LinkLoader` acccording to source and destination node weights ([#9316](https://github.com/pyg-team/pytorch_geometric/pull/9316))
- Added support for `EdgeIndex.unbind` ([#9298](https://github.com/pyg-team/pytorch_geometric/pull/9298))
Expand Down
54 changes: 53 additions & 1 deletion test/nn/models/test_jumping_knowledge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from torch_geometric.nn import JumpingKnowledge
from torch_geometric.nn import HeteroJumpingKnowledge, JumpingKnowledge
from torch_geometric.testing import is_full_test


Expand Down Expand Up @@ -38,3 +38,55 @@ def test_jumping_knowledge():
if is_full_test():
jit = torch.jit.script(model)
assert torch.allclose(jit(xs), out)


def test_hetero_jumping_knowledge():
num_nodes, channels, num_layers = 100, 17, 5

types = ["author", "paper"]
xs_dict = {
key: [torch.randn(num_nodes, channels) for _ in range(num_layers)]
for key in types
}

model = HeteroJumpingKnowledge(types, mode='cat')
model.reset_parameters()
assert str(model) == 'HeteroJumpingKnowledge(num_types=2, mode=cat)'

out_dict = model(xs_dict)
for out in out_dict.values():
assert out.size() == (num_nodes, channels * num_layers)

if is_full_test():
jit = torch.jit.script(model)
jit_out = jit(xs_dict)
for key in types:
assert torch.allclose(jit_out[key], out_dict[key])

model = HeteroJumpingKnowledge(types, mode='max')
assert str(model) == 'HeteroJumpingKnowledge(num_types=2, mode=max)'

out_dict = model(xs_dict)
for out in out_dict.values():
assert out.size() == (num_nodes, channels)

if is_full_test():
jit = torch.jit.script(model)
jit_out = jit(xs_dict)
for key in types:
assert torch.allclose(jit_out[key], out_dict[key])

model = HeteroJumpingKnowledge(types, mode='lstm', channels=channels,
num_layers=num_layers)
assert str(model) == (f'HeteroJumpingKnowledge(num_types=2, mode=lstm, '
f'channels={channels}, layers={num_layers})')

out_dict = model(xs_dict)
for out in out_dict.values():
assert out.size() == (num_nodes, channels)

if is_full_test():
jit = torch.jit.script(model)
jit_out = jit(xs_dict)
for key in types:
assert torch.allclose(jit_out[key], out_dict[key])
3 changes: 2 additions & 1 deletion torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .mlp import MLP
from .basic_gnn import GCN, GraphSAGE, GIN, GAT, PNA, EdgeCNN
from .jumping_knowledge import JumpingKnowledge
from .jumping_knowledge import JumpingKnowledge, HeteroJumpingKnowledge
from .meta import MetaLayer
from .node2vec import Node2Vec
from .deep_graph_infomax import DeepGraphInfomax
Expand Down Expand Up @@ -42,6 +42,7 @@
'PNA',
'EdgeCNN',
'JumpingKnowledge',
'HeteroJumpingKnowledge',
'MetaLayer',
'Node2Vec',
'DeepGraphInfomax',
Expand Down
67 changes: 63 additions & 4 deletions torch_geometric/nn/models/jumping_knowledge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -41,8 +41,12 @@ class JumpingKnowledge(torch.nn.Module):
num_layers (int, optional): The number of layers to aggregate. Needs to
be only set for LSTM-style aggregation. (default: :obj:`None`)
"""
def __init__(self, mode: str, channels: Optional[int] = None,
num_layers: Optional[int] = None):
def __init__(
self,
mode: str,
channels: Optional[int] = None,
num_layers: Optional[int] = None,
) -> None:
super().__init__()
self.mode = mode.lower()
assert self.mode in ['cat', 'max', 'lstm']
Expand All @@ -63,7 +67,7 @@ def __init__(self, mode: str, channels: Optional[int] = None,

self.reset_parameters()

def reset_parameters(self):
def reset_parameters(self) -> None:
r"""Resets all learnable parameters of the module."""
if self.lstm is not None:
self.lstm.reset_parameters()
Expand Down Expand Up @@ -94,3 +98,58 @@ def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.mode}, '
f'channels={self.channels}, layers={self.num_layers})')
return f'{self.__class__.__name__}({self.mode})'


class HeteroJumpingKnowledge(torch.nn.Module):
r"""A heterogeneous version of the :class:`JumpingKnowledge` module.
Args:
types (List[str]): The keys of the input dictionary.
mode (str): The aggregation scheme to use
(:obj:`"cat"`, :obj:`"max"` or :obj:`"lstm"`).
channels (int, optional): The number of channels per representation.
Needs to be only set for LSTM-style aggregation.
(default: :obj:`None`)
num_layers (int, optional): The number of layers to aggregate. Needs to
be only set for LSTM-style aggregation. (default: :obj:`None`)
"""
def __init__(
self,
types: List[str],
mode: str,
channels: Optional[int] = None,
num_layers: Optional[int] = None,
) -> None:
super().__init__()

self.mode = mode.lower()

self.jk_dict = torch.nn.ModuleDict({
key:
JumpingKnowledge(mode, channels, num_layers)
for key in types
})

def reset_parameters(self) -> None:
r"""Resets all learnable parameters of the module."""
for jk in self.jk_dict.values():
jk.reset_parameters()

def forward(self, xs_dict: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
r"""Forward pass.
Args:
xs_dict (Dict[str, List[torch.Tensor]]): A dictionary holding a
list of layer-wise representation for each type.
"""
return {key: jk(xs_dict[key]) for key, jk in self.jk_dict.items()}

def __repr__(self):
if self.mode == 'lstm':
jk = next(iter(self.jk_dict.values()))
return (f'{self.__class__.__name__}('
f'num_types={len(self.jk_dict)}, '
f'mode={self.mode}, channels={jk.channels}, '
f'layers={jk.num_layers})')
return (f'{self.__class__.__name__}(num_types={len(self.jk_dict)}, '
f'mode={self.mode})')

0 comments on commit bb131e8

Please sign in to comment.