Skip to content
Permalink
Browse files

jit fix

  • Loading branch information
rusty1s committed Dec 3, 2019
1 parent eacb7d3 commit 261fa093028dc4673612931189ab7132f4dbc624
Showing with 11 additions and 5 deletions.
  1. +0 −4 .travis.yml
  2. +11 −1 torch_geometric/transforms/gdc.py
@@ -28,10 +28,6 @@ jobs:
- python3 -m pip install --upgrade virtualenv
- virtualenv -p python3 --system-site-packages "$HOME/venv"
- source "$HOME/venv/bin/activate"
- export NUMBA_CACHE_DIR="$HOME/numba_cache"
- mkdir $NUMBA_CACHE_DIR
- chmod 777 $NUMBA_CACHE_DIR
- echo $NUMBA_CACHE_DIR
env:
- CC=clang
- CXX=clang++
@@ -7,6 +7,16 @@
from torch_scatter import scatter_add


def jit():
def decorator(func):
try:
return numba.jit(cache=True)(func)
except RuntimeError:
return numba.jit(cache=False)(func)

return decorator


class GDC(object):
r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
`"Diffusion Improves Graph Learning" <https://www.kdd.in.tum.de/gdc>`_
@@ -485,7 +495,7 @@ def __neighbors_to_graph__(self, neighbors, neighbor_weights,
return edge_index, edge_weight

@staticmethod
@numba.njit(cache=True)
@jit()
def __calc_ppr__(indptr, indices, out_degree, alpha, eps):
r"""Calculate the personalized PageRank vector for all nodes
using a variant of the Andersen algorithm

0 comments on commit 261fa09

Please sign in to comment.
You can’t perform that action at this time.