Skip to content

Commit

Permalink
optional pykeops
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jan 17, 2020
1 parent c39c717 commit 4a45a71
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ source=dgmc
exclude_lines =
pragma: no cover
raise
except
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ jobs:
install:
- pip install numpy
- pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install pykeops
- pip install torch-scatter
- pip install torch-sparse
- pip install torch-cluster
Expand Down
10 changes: 7 additions & 3 deletions dgmc/models/dgmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from torch_scatter import scatter_add
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn.inits import reset
from pykeops.torch import LazyTensor

try:
from pykeops.torch import LazyTensor
except ImportError:
LazyTensor = None

EPS = 1e-8

Expand Down Expand Up @@ -78,10 +82,10 @@ def reset_parameters(self):
self.psi_2.reset_parameters()
reset(self.mlp)

def __top_k__(self, x_s, x_t):
def __top_k__(self, x_s, x_t): # pragma: no cover
r"""Memory-efficient top-k correspondence computation."""
x_s, x_t = x_s.unsqueeze(-2), x_t.unsqueeze(-3)
if self.backend != 'test': # pragma: no cover
if LazyTensor is not None:
x_s, x_t = LazyTensor(x_s), LazyTensor(x_t)
S_ij = (-x_s * x_t).sum(dim=-1)
return S_ij.argKmin(self.k, dim=2, backend=self.backend)
Expand Down
1 change: 0 additions & 1 deletion docs/requirements_1.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
numpy
torch_nightly
pykeops==1.2
sphinx
sphinx_rtd_theme
3 changes: 0 additions & 3 deletions test/models/test_dgmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_dgmc_repr():
def test_dgmc_on_single_graphs():
set_seed()
model = DGMC(psi_1, psi_2, num_steps=1)
model.backend = 'test'
x, e = data.x, data.edge_index
y = torch.arange(data.num_nodes)
y = torch.stack([y, y], dim=0)
Expand Down Expand Up @@ -68,7 +67,6 @@ def test_dgmc_on_single_graphs():
def test_dgmc_on_multiple_graphs():
set_seed()
model = DGMC(psi_1, psi_2, num_steps=1)
model.backend = 'test'

batch = Batch.from_data_list([data, data])
x, e, b = batch.x, batch.edge_index, batch.batch
Expand All @@ -88,7 +86,6 @@ def test_dgmc_on_multiple_graphs():

def test_dgmc_include_gt():
model = DGMC(psi_1, psi_2, num_steps=1)
model.backend = 'test'

S_idx = torch.tensor([[[0, 1], [1, 2]], [[1, 2], [0, 1]]])
s_mask = torch.tensor([[True, False], [True, True]])
Expand Down

0 comments on commit 4a45a71

Please sign in to comment.