Skip to content

Commit

Permalink
Make MessagePassing interface thread-safe (#9001)
Browse files Browse the repository at this point in the history
Fixes #8994
  • Loading branch information
rusty1s committed Mar 1, 2024
1 parent 8d625c4 commit 849ca0a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001))
- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))

Expand Down
2 changes: 1 addition & 1 deletion benchmark/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
```bash
pip install ogb
```
1. Install `autoconf` required for `jemalloc` setup
1. Install `autoconf` required for `jemalloc` setup:
```bash
sudo apt-get install autoconf
```
Expand Down
39 changes: 23 additions & 16 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
fuse=self.fuse,
)

# Cache to potentially disable later on:
self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = module.propagate

Expand All @@ -197,22 +196,30 @@ def __init__(

# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
if (self.inspector.implements('edge_update')
and not self.edge_updater.__module__.startswith(jinja_prefix)
and self.inspector.can_read_source):
module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='edge_collect',
signature=self._get_edge_updater_signature(),
collect_param_dict=self.inspector.get_param_dict(
'edge_update'),
)
and not self.edge_updater.__module__.startswith(jinja_prefix)):
if self.inspector.can_read_source:

module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='edge_collect',
signature=self._get_edge_updater_signature(),
collect_param_dict=self.inspector.get_param_dict(
'edge_update'),
)

self.__class__.edge_updater = module.edge_updater
self.__class__.edge_collect = module.edge_collect
self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = module.edge_updater

self.__class__.edge_updater = module.edge_updater
self.__class__.edge_collect = module.edge_collect
else:
self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = (
self.__class__.edge_updater)

# Explainability:
self._explain: Optional[bool] = None
Expand Down
18 changes: 9 additions & 9 deletions torch_geometric/template.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import importlib
import os
import os.path as osp
import sys
import tempfile
from typing import Any

from jinja2 import Environment, FileSystemLoader

from torch_geometric import get_home_dir


def module_from_template(
module_name: str,
Expand All @@ -23,13 +21,15 @@ def module_from_template(
template = env.get_template(osp.basename(template_path))
module_repr = template.render(**kwargs)

instance_dir = osp.join(get_home_dir(), tmp_dirname)
os.makedirs(instance_dir, exist_ok=True)
instance_path = osp.join(instance_dir, f'{module_name}.py')
with open(instance_path, 'w') as f:
f.write(module_repr)
with tempfile.NamedTemporaryFile(
mode='w',
prefix=f'{module_name}_',
suffix='.py',
delete=False,
) as tmp:
tmp.write(module_repr)

spec = importlib.util.spec_from_file_location(module_name, instance_path)
spec = importlib.util.spec_from_file_location(module_name, tmp.name)
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand Down

0 comments on commit 849ca0a

Please sign in to comment.