Skip to content
Permalink
Browse files

torch nightly

  • Loading branch information...
rusty1s committed Aug 11, 2019
1 parent 6219117 commit eef181b2b9da63c88281ba74d186ebca86b4125b
Showing with 9 additions and 7 deletions.
  1. +2 −1 .travis.yml
  2. +7 −6 torch_geometric/nn/conv/message_passing.py
@@ -17,7 +17,7 @@ before_install:
- export CXX="g++-4.9"
install:
- pip install numpy
- pip install -q torch
- pip install -q torch -f https://download.pytorch.org/whl/nightly/cpu/torch.html
- pip install torch-scatter torch-sparse torch-cluster torch-spline-conv
- pip install cython && pip install gdist
- pip install pycodestyle
@@ -26,6 +26,7 @@ install:
- pip install sphinx
- pip install sphinx_rtd_theme
script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
- flake8 .
- python setup.py install
@@ -56,7 +56,7 @@ def __init__(self, aggr='add', flow='source_to_target'):
]
self.__update_args__ = getargspec(self.update)[0][2:]

def propagate(self, edge_index, size=None, **kwargs):
def propagate(self, edge_index, size=None, dim=0, **kwargs):
r"""The initial call to start propagating messages.
Args:
@@ -66,6 +66,7 @@ def propagate(self, edge_index, size=None, **kwargs):
size (list or tuple, optional): The size :obj:`[N, M]` of the
assignment matrix. If set to :obj:`None`, the size is tried to
get automatically inferred. (default: :obj:`None`)
dim ()
**kwargs: Any additional data which is needed to construct messages
and to update node embeddings.
"""
@@ -88,20 +89,20 @@ def propagate(self, edge_index, size=None, **kwargs):
assert len(tmp) == 2
if tmp[1 - idx] is not None:
if size[1 - idx] is None:
size[1 - idx] = tmp[1 - idx].size(0)
if size[1 - idx] != tmp[1 - idx].size(0):
size[1 - idx] = tmp[1 - idx].size(dim)
if size[1 - idx] != tmp[1 - idx].size(dim):
raise ValueError(__size_error_msg__)
tmp = tmp[idx]

if tmp is None:
message_args.append(tmp)
else:
if size[idx] is None:
size[idx] = tmp.size(0)
if size[idx] != tmp.size(0):
size[idx] = tmp.size(dim)
if size[idx] != tmp.size(dim):
raise ValueError(__size_error_msg__)

tmp = torch.index_select(tmp, 0, edge_index[idx])
tmp = torch.index_select(tmp, dim, edge_index[idx])
message_args.append(tmp)
else:
message_args.append(kwargs.get(arg, None))

0 comments on commit eef181b

Please sign in to comment.
You can’t perform that action at this time.