Skip to content

Commit

Permalink
use matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jun 7, 2020
1 parent be1c4c7 commit 3a2eff6
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions dgmc/models/dgmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,16 @@ def reset_parameters(self):

def __top_k__(self, x_s, x_t): # pragma: no cover
r"""Memory-efficient top-k correspondence computation."""
x_s, x_t = x_s.unsqueeze(-2), x_t.unsqueeze(-3)
if LazyTensor is not None:
x_s = x_s.unsqueeze(-2) # [..., n_s, 1, d]
x_t = x_t.unsqueeze(-3) # [..., 1, n_t, d]
x_s, x_t = LazyTensor(x_s), LazyTensor(x_t)
S_ij = (-x_s * x_t).sum(dim=-1)
return S_ij.argKmin(self.k, dim=2, backend=self.backend)
else:
S_ij = (x_s * x_t).sum(dim=-1)
x_s = x_s # [..., n_s, d]
x_t = x_t.transpose(-1, -2) # [..., d, n_t]
S_ij = x_s @ x_t
return S_ij.topk(self.k, dim=2)[1]

def __include_gt__(self, S_idx, s_mask, y):
Expand Down

0 comments on commit 3a2eff6

Please sign in to comment.