Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 12, 2024
1 parent 1b195a0 commit 74ed43b
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 60 deletions.
14 changes: 14 additions & 0 deletions test/nn/conv/test_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,17 @@ def test_graph_conv():
assert torch.allclose(jit((x1, None), adj3.t()), out2, atol=1e-6)
assert torch.allclose(jit((x1, x2), adj4.t()), out3, atol=1e-6)
assert torch.allclose(jit((x1, None), adj4.t()), out4, atol=1e-6)


class EdgeGraphConv(GraphConv):
def message(self, x_j, edge_weight):
return edge_weight.view(-1, 1) * x_j


def test_inheritance():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
edge_weight = torch.rand(4)

conv = EdgeGraphConv(8, 16)
assert conv(x, edge_index, edge_weight).size() == (4, 16)
21 changes: 20 additions & 1 deletion test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class MyConvWithSelfLoops(MessagePassing):
def __init__(self, aggr: str = 'add'):
super().__init__(aggr=aggr)

def forward(self, x: Tensor, edge_index: torch.Tensor) -> Tensor:
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
edge_index, _ = add_self_loops(edge_index)

# propagate_type: (x: Tensor)
Expand Down Expand Up @@ -144,6 +144,25 @@ def test_my_conv_out_of_bounds():
conv(x, edge_index, value)


class MyCommentedConv(MessagePassing):
r"""This layer calls `self.propagate()` internally."""
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
# `self.propagate()` is used here to propagate messages.
return self.propagate(edge_index, x=x)


def test_my_commented_conv():
# Check that `self.propagate` occurences in comments are correctly ignored.
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

conv = MyCommentedConv()
conv(x, edge_index)

jit = torch.jit.script(conv)
jit(x, edge_index)


def test_my_conv_jit():
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
Expand Down
5 changes: 1 addition & 4 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import re
import sys
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import (
Any,
Expand All @@ -27,7 +26,7 @@
MISSING = '???'


class Dataset(torch.utils.data.Dataset, ABC):
class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating graph datasets.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
create_dataset.html>`__ for the accompanying tutorial.
Expand Down Expand Up @@ -79,12 +78,10 @@ def process(self) -> None:
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError

@abstractmethod
def len(self) -> int:
r"""Returns the number of data objects stored in the dataset."""
raise NotImplementedError

@abstractmethod
def get(self, idx: int) -> BaseData:
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import os.path as osp
import warnings
from abc import ABC
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -30,7 +29,7 @@
from torch_geometric.io import fs


class InMemoryDataset(Dataset, ABC):
class InMemoryDataset(Dataset):
r"""Dataset base class for creating graph datasets which easily fit
into CPU memory.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
Expand Down
42 changes: 30 additions & 12 deletions torch_geometric/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union

import torch
from torch import Tensor


Expand Down Expand Up @@ -32,9 +33,27 @@ def __init__(self, cls: Type):
self._signature_dict: Dict[str, Signature] = {}
self._source_dict: Dict[str, str] = {}

def _get_modules(self, cls: Type) -> List[str]:
from torch_geometric.nn import MessagePassing

modules: List[str] = []
for base_cls in cls.__bases__:
if base_cls not in {object, torch.nn.Module, MessagePassing}:
modules.extend(self._get_modules(base_cls))

modules.append(cls.__module__)
return modules

@property
def _modules(self) -> List[str]:
return self._get_modules(self._cls)

@property
def _globals(self) -> Dict[str, Any]:
return sys.modules[self._cls.__module__].__dict__
out: Dict[str, Any] = {}
for module in self._modules:
out.update(sys.modules[module].__dict__)
return out

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self._cls.__name__})'
Expand Down Expand Up @@ -301,17 +320,6 @@ def collect_param_data(

# Inspecting Method Bodies ################################################

@property
def can_read_source(self) -> bool:
r"""Returns :obj:`True` if able to read the source file of the
inspected class.
"""
try:
inspect.getfile(self._cls)
return True
except Exception:
return False

def get_source(self, cls: Optional[Type] = None) -> str:
r"""Returns the source code of :obj:`cls`."""
cls = cls or self._cls
Expand Down Expand Up @@ -388,6 +396,7 @@ def get_params_from_method_call(
# (3) Parse the function call:
for cls in self._cls.__mro__:
source = self.get_source(cls)
source = remove_comments(source)
match = find_parenthesis_content(source, f'self.{func_name}')
if match is not None:
for i, kwarg in enumerate(split(match, sep=',')):
Expand Down Expand Up @@ -515,3 +524,12 @@ def split(content: str, sep: str) -> List[str]:
if start != len(content): # Respect dangling `sep`:
outs.append(content[start:].strip())
return outs


def remove_comments(content: str) -> str:
content = re.sub(r'\s*#.*', '', content)
content = re.sub(re.compile(r'r"""(.*?)"""', re.DOTALL), '', content)
content = re.sub(re.compile(r'"""(.*?)"""', re.DOTALL), '', content)
content = re.sub(re.compile(r"r'''(.*?)'''", re.DOTALL), '', content)
content = re.sub(re.compile(r"'''(.*?)'''", re.DOTALL), '', content)
return content
5 changes: 3 additions & 2 deletions torch_geometric/io/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,10 @@ def cp(
if use_cache and clear_cache and cache_dir is not None:
try:
rm(cache_dir)
except PermissionError: # FIXME
except Exception: # FIXME
# Windows test yield "PermissionError: The process cannot access
# the file because it is being used by another process"
# the file because it is being used by another process".
# Users may also observe "OSError: Directory not empty".
# This is a quick workaround until we figure out the deeper issue.
pass

Expand Down
4 changes: 1 addition & 3 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Union

import torch
Expand All @@ -15,7 +14,7 @@
BaseMetric = torch.nn.Module # type: ignore


class LinkPredMetric(BaseMetric, ABC):
class LinkPredMetric(BaseMetric):
r"""An abstract class for computing link prediction retrieval metrics.
Args:
Expand Down Expand Up @@ -117,7 +116,6 @@ def reset(self) -> None:
self.accum.zero_()
self.total.zero_()

@abstractmethod
def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
r"""Compute the specific metric.
To be implemented separately for each metric class.
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/edge_updater.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse
from torch_geometric.typing import Size, SparseTensor

{% for module in modules %}
from {{module}} import *
{%- endfor %}


{% include "collect.jinja" %}
Expand Down
44 changes: 25 additions & 19 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ def __init__(
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
# Optimize `propagate()` via `*.jinja` templates:
if not self.propagate.__module__.startswith(jinja_prefix):
if self.inspector.can_read_source:
try:
module = module_from_template(
module_name=f'{jinja_prefix}_propagate',
template_path=osp.join(root_dir, 'propagate.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
module=self.__module__,
module=self.inspector._modules,
collect_name='collect',
signature=self._get_propagate_signature(),
collect_param_dict=self.inspector.get_flat_param_dict(
Expand All @@ -185,34 +185,40 @@ def __init__(
fuse=self.fuse,
)

# Cache to potentially disable later on:
self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = module.propagate

self.__class__.propagate = module.propagate
self.__class__.collect = module.collect
else:
except Exception: # pragma: no cover
self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = self.__class__.propagate

# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
if (self.inspector.implements('edge_update')
and not self.edge_updater.__module__.startswith(jinja_prefix)
and self.inspector.can_read_source):
module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
module=self.__module__,
collect_name='edge_collect',
signature=self._get_edge_updater_signature(),
collect_param_dict=self.inspector.get_param_dict(
'edge_update'),
)
and not self.edge_updater.__module__.startswith(jinja_prefix)):
try:
module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='edge_collect',
signature=self._get_edge_updater_signature(),
collect_param_dict=self.inspector.get_param_dict(
'edge_update'),
)

self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = module.edge_updater

self.__class__.edge_updater = module.edge_updater
self.__class__.edge_collect = module.edge_collect
self.__class__.edge_updater = module.edge_updater
self.__class__.edge_collect = module.edge_collect
except Exception: # pragma: no cover
self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = (
self.__class__.edge_updater)

# Explainability:
self._explain: Optional[bool] = None
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/propagate.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse
from torch_geometric.typing import Size, SparseTensor

{% for module in modules %}
from {{module}} import *
{%- endfor %}


{% include "collect.jinja" %}
Expand Down
9 changes: 3 additions & 6 deletions torch_geometric/nn/models/rev_gnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -145,7 +144,7 @@ def backward(ctx, *grad_outputs):
return (None, None, None, None) + gradients


class InvertibleModule(torch.nn.Module, ABC):
class InvertibleModule(torch.nn.Module):
r"""An abstract class for implementing invertible modules.
Args:
Expand All @@ -168,13 +167,11 @@ def forward(self, *args):
def inverse(self, *args):
return self._fn_apply(args, self._inverse, self._forward)

@abstractmethod
def _forward(self):
pass
raise NotImplementedError

@abstractmethod
def _inverse(self):
pass
raise NotImplementedError

def _fn_apply(self, args, fn, fn_inverse):
if not self.disable:
Expand Down
18 changes: 9 additions & 9 deletions torch_geometric/template.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import importlib
import os
import os.path as osp
import sys
import tempfile
from typing import Any

from jinja2 import Environment, FileSystemLoader

from torch_geometric import get_home_dir


def module_from_template(
module_name: str,
Expand All @@ -23,13 +21,15 @@ def module_from_template(
template = env.get_template(osp.basename(template_path))
module_repr = template.render(**kwargs)

instance_dir = osp.join(get_home_dir(), tmp_dirname)
os.makedirs(instance_dir, exist_ok=True)
instance_path = osp.join(instance_dir, f'{module_name}.py')
with open(instance_path, 'w') as f:
f.write(module_repr)
with tempfile.NamedTemporaryFile(
mode='w',
prefix=f'{module_name}_',
suffix='.py',
delete=False,
) as tmp:
tmp.write(module_repr)

spec = importlib.util.spec_from_file_location(module_name, instance_path)
spec = importlib.util.spec_from_file_location(module_name, tmp.name)
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand Down

0 comments on commit 74ed43b

Please sign in to comment.