Skip to content

Commit

Permalink
to_hetero fix for leaf modules (e.g., `torch_geometric.nn.BatchNorm…
Browse files Browse the repository at this point in the history
…`) (#4027)

* add batch_norm fx fix

* linting
  • Loading branch information
rusty1s committed Feb 8, 2022
1 parent a945389 commit ea135bf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
18 changes: 17 additions & 1 deletion test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear, ReLU, Sequential
from torch_sparse import SparseTensor

from torch_geometric.nn import GINEConv
from torch_geometric.nn import BatchNorm, GINEConv
from torch_geometric.nn import Linear as LazyLinear
from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero

Expand Down Expand Up @@ -113,6 +113,15 @@ def forward(self, x: Tensor) -> Tensor:
return x


class Net9(torch.nn.Module):
def __init__(self):
super().__init__()
self.batch_norm = BatchNorm(16)

def forward(self, x: Tensor) -> Tensor:
return self.batch_norm(x)


def test_to_hetero():
metadata = (['paper', 'author'], [('paper', 'cites', 'paper'),
('paper', 'written_by', 'author'),
Expand Down Expand Up @@ -195,6 +204,13 @@ def test_to_hetero():
assert out['paper'].size() == (4, 32)
assert out['author'].size() == (8, 32)

model = Net9()
model = to_hetero(model, metadata, debug=False)
out = model({'paper': torch.randn(4, 16), 'author': torch.randn(8, 16)})
assert isinstance(out, dict) and len(out) == 2
assert out['paper'].size() == (4, 16)
assert out['author'].size() == (8, 16)


class GraphConv(MessagePassing):
def __init__(self, in_channels, out_channels):
Expand Down
8 changes: 2 additions & 6 deletions torch_geometric/nn/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import Module, ModuleDict, ModuleList, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear

try:
from torch.fx import Graph, GraphModule, Node
Expand Down Expand Up @@ -236,11 +235,8 @@ def symbolic_trace(
concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
class Tracer(torch.fx.Tracer):
def is_leaf_module(self, module: Module, *args, **kwargs) -> bool:
# We don't want to trace inside `MessagePassing` and lazy `Linear`
# modules, so we mark them as leaf modules.
return (isinstance(module, MessagePassing)
or isinstance(module, Linear)
or super().is_leaf_module(module, *args, **kwargs))
# TODO We currently only trace top-level modules.
return not isinstance(module, torch.nn.Sequential)

return GraphModule(module, Tracer().trace(module, concrete_args))

Expand Down

0 comments on commit ea135bf

Please sign in to comment.