Skip to content

Commit

Permalink
馃拑 馃 Extract functional forms (#238)
Browse files Browse the repository at this point in the history
Closes #239
  • Loading branch information
cthoyt committed Jan 21, 2021
1 parent a2f9ee8 commit e216414
Show file tree
Hide file tree
Showing 14 changed files with 2,067 additions and 68 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ PyKEEN
reference/lookup
reference/sealant
reference/constants
reference/nn/index

.. toctree::
:caption: Appendix
Expand Down
10 changes: 0 additions & 10 deletions docs/source/reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,3 @@ Base Classes
:no-inheritance-diagram:
:no-heading:
:headings: ~~

Initialization
--------------
.. automodule:: pykeen.nn.init
:members:

Extra Modules
-------------
.. automodule:: pykeen.nn
:members:
3 changes: 3 additions & 0 deletions docs/source/reference/nn/functional.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Functional
==========
.. automodapi:: pykeen.nn.functional
11 changes: 11 additions & 0 deletions docs/source/reference/nn/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
:mod:`pykeen.nn`
================
.. automodule:: pykeen.nn

.. toctree::
:caption: nn
:name: nn

functional
similarity
init
4 changes: 4 additions & 0 deletions docs/source/reference/nn/init.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Initialization
--------------
.. automodule:: pykeen.nn.init
:members:
4 changes: 4 additions & 0 deletions docs/source/reference/nn/similarity.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Similarity
==========
.. automodule:: pykeen.nn.sim
:members:
177 changes: 177 additions & 0 deletions src/pykeen/nn/compute_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-

"""Compute kernels for common sub-tasks."""

import numpy
import torch

from ..utils import extended_einsum, split_complex, tensor_product, view_complex


def _batched_dot_manual(
a: torch.FloatTensor,
b: torch.FloatTensor,
) -> torch.FloatTensor:
return (a * b).sum(dim=-1)


# TODO benchmark
def _batched_dot_matmul(
a: torch.FloatTensor,
b: torch.FloatTensor,
) -> torch.FloatTensor:
return (a.unsqueeze(dim=-2) @ b.unsqueeze(dim=-1)).view(a.shape[:-1])


# TODO benchmark
def _batched_dot_einsum(
a: torch.FloatTensor,
b: torch.FloatTensor,
) -> torch.FloatTensor:
return torch.einsum("...i,...i->...", a, b)


def batched_dot(
a: torch.FloatTensor,
b: torch.FloatTensor,
) -> torch.FloatTensor:
"""Compute "element-wise" dot-product between batched vectors."""
return _batched_dot_manual(a, b)


# TODO benchmark
def _complex_broadcast_optimized(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Manually split into real/imag, and used optimized broadcasted combination."""
(h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)]
return sum(
factor * tensor_product(hh, rr, tt).sum(dim=-1)
for factor, hh, rr, tt in [
(+1, h_re, r_re, t_re),
(+1, h_re, r_im, t_im),
(+1, h_im, r_re, t_im),
(-1, h_im, r_im, t_re),
]
)


# TODO benchmark
def _complex_direct(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Manually split into real/imag, and directly evaluate interaction."""
(h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)]
return (
(h_re * r_re * t_re).sum(dim=-1)
+ (h_re * r_im * t_im).sum(dim=-1)
+ (h_im * r_re * t_im).sum(dim=-1)
- (h_im * r_im * t_re).sum(dim=-1)
)


# TODO benchmark
def _complex_einsum(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Use einsum."""
x = h.new_zeros(2, 2, 2)
x[0, 0, 0] = 1
x[0, 1, 1] = 1
x[1, 0, 1] = 1
x[1, 1, 0] = -1
return extended_einsum(
"ijk,bhdi,brdj,btdk->bhrt",
x,
h.view(*h.shape[:-1], -1, 2),
r.view(*r.shape[:-1], -1, 2),
t.view(*t.shape[:-1], -1, 2),
)


def _complex_native_complex(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Use torch built-ins for computation with complex numbers."""
h, r, t = [view_complex(x=x) for x in (h, r, t)]
return torch.real(tensor_product(h, r, torch.conj(t)).sum(dim=-1))


# TODO benchmark
def _complex_native_complex_select(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Use torch built-ins for computation with complex numbers and select whether to combine hr or ht first."""
h, r, t = [view_complex(x=x) for x in (h, r, t)]
hr_cost = numpy.prod([max(hs, rs) for hs, rs in zip(h.shape, r.shape)])
rt_cost = numpy.prod([max(ts, rs) for ts, rs in zip(t.shape, r.shape)])
t = torch.conj(t)
if hr_cost < rt_cost:
h = h * r
else:
t = r * t
return torch.real((h * t).sum(dim=-1))


# TODO benchmark
def _complex_select(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Decide based on result shape whether to combine hr or ht first."""
hr_cost = numpy.prod([max(hs, rs) for hs, rs in zip(h.shape, r.shape)])
rt_cost = numpy.prod([max(ts, rs) for ts, rs in zip(t.shape, r.shape)])
(h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)]
if hr_cost < rt_cost:
h_re, h_im = (h_re * r_re - h_im * r_im), (h_re * r_im + h_im * r_re)
else:
t_re, t_im = (t_re * r_re - t_im * r_im), (t_re * r_im + t_im * r_re)
return h_re @ t_re.transpose(-2, -1) - h_im @ t_im.transpose(-2, -1)


def _complex_to_stacked(h, r, t):
(r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (r, t)]
h = torch.cat([h, h], dim=-1) # re im re im
r = torch.cat([r_re, r_re, r_im, r_im], dim=-1) # re re im im
t = torch.cat([t_re, t_im, t_im, t_re], dim=-1) # re im im re
return h, r, t


# TODO benchmark
def _complex_stacked(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Stack vectors."""
h, r, t = _complex_to_stacked(h, r, t)
return (h * r * t).sum(dim=-1)


# TODO benchmark
def _complex_stacked_select(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Stack vectors and select order."""
h, r, t = _complex_to_stacked(h, r, t)
hr_cost = numpy.prod([max(hs, rs) for hs, rs in zip(h.shape, r.shape)])
rt_cost = numpy.prod([max(ts, rs) for ts, rs in zip(t.shape, r.shape)])
if hr_cost < rt_cost:
# h = h_re, -h_im
h = h * r
else:
t = r * t
return h @ t.transpose(-2, -1)

0 comments on commit e216414

Please sign in to comment.