Skip to content

Commit

Permalink
Add einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Dec 14, 2021
1 parent 9b5fa6d commit 6f4e598
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -378,6 +378,7 @@ jit_to_numpy(a) # Caches results for `B.jit`.
```
transpose(a, perm=None) (alias: t, T)
matmul(a, b, tr_a=False, tr_b=False) (alias: mm, dot)
einsum(equation, *elements)
trace(a, axis1=0, axis2=1)
kron(a, b)
svd(a, compute_uv=True)
Expand Down
5 changes: 5 additions & 0 deletions lab/autograd/linear_algebra.py
Expand Up @@ -22,6 +22,11 @@ def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False):
return anp.matmul(a, b)


@dispatch
def einsum(equation: str, *elements: Numeric):
return anp.einsum(equation, *elements)


@dispatch
def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None):
# Correctly handle special cases.
Expand Down
5 changes: 5 additions & 0 deletions lab/jax/linear_algebra.py
Expand Up @@ -35,6 +35,11 @@ def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False):
return jnp.matmul(a, b)


@dispatch
def einsum(equation: str, *elements: Numeric):
return jnp.einsum(equation, *elements)


@dispatch
def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None):
# Correctly handle special cases.
Expand Down
17 changes: 16 additions & 1 deletion lab/linear_algebra.py
Expand Up @@ -13,6 +13,7 @@
"matmul",
"mm",
"dot",
"einsum",
"kron",
"trace",
"svd",
Expand Down Expand Up @@ -102,7 +103,21 @@ def matmul(a, b, tr_a: bool = False, tr_b: bool = False): # pragma: no cover


@dispatch
@abstract(promote=None)
@abstract(promote_from=1)
def einsum(equation: str, *elements: Numeric): # pragma: no cover
"""Tensor contraction via Einstein summation.
Args:
equation (str): Equation.
*elements (tensor): Tensors to contract.
Returns:
tensor: Contraction.
"""


@dispatch
@abstract()
def trace(a: Numeric, axis1: Int = -2, axis2: Int = -1): # pragma: no cover
"""Compute the trace of a tensor.
Expand Down
5 changes: 5 additions & 0 deletions lab/numpy/linear_algebra.py
Expand Up @@ -22,6 +22,11 @@ def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False):
return np.matmul(a, b)


@dispatch
def einsum(equation: str, *elements: Numeric):
return np.einsum(equation, *elements)


@dispatch
def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None):
# Correctly handle special cases.
Expand Down
5 changes: 5 additions & 0 deletions lab/tensorflow/linear_algebra.py
Expand Up @@ -17,6 +17,11 @@ def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False):
return tf.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b)


@dispatch
def einsum(equation: str, *elements: Numeric):
return tf.einsum(equation, *elements)


@dispatch
def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None):
# Correctly handle special cases.
Expand Down
5 changes: 5 additions & 0 deletions lab/torch/linear_algebra.py
Expand Up @@ -19,6 +19,11 @@ def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False):
return torch.matmul(a, b)


@dispatch
def einsum(equation: str, *elements: Numeric):
return torch.einsum(equation, *elements)


@dispatch
def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None):
# Correctly handle special cases.
Expand Down
24 changes: 19 additions & 5 deletions lab/util.py
Expand Up @@ -156,17 +156,28 @@ def batch_computation(f, xs, ranks):
return B.reshape(res, *(batch_shape + B.shape(res)[1:]))


def abstract(promote=None):
def abstract(promote=None, promote_from=None):
"""Create a decorator for an abstract function.
Args:
promote (int, optional): Number of arguments to promote. Set to `-1`
to promote all arguments, and set to `None` or `0` to promote no
arguments. Defaults to `None`.
promote (int, optional): Number of arguments to promote. Set to `-1` to promote
all arguments, and set to `None` or `0` to promote no arguments. Defaults to
`None`. Cannot be specified in conjunction with `promote_from`.
promote_from (int, optional): Index from which to promote argument. Set to `-1`
or `None` to promote no arguments, and set to `0` to promote all arguments.
Defaults to `None`. Cannot be specified in conjunction with `promote`.
Returns:
function: Decorator.
"""
if promote is not None and promote_from is not None:
raise ValueError("Specify either `promote` or `promote_from`.")

# If `promote` isn't given, we can safely give it the value of
# `promote_from`: either `promote_from` is given, which is fine; or
# `promote_from` isn't given, so `promote` remains at `None`.
if promote is None:
promote = promote_from

def decorator(f):
@wraps(f)
Expand All @@ -183,7 +194,10 @@ def wrapper(*args, **kw_args):
types_before = tuple(plum.type_of(arg) for arg in args)

# Promote.
args = plum.promote(*args[:promote_index]) + args[promote_index:]
if promote_from is None:
args = plum.promote(*args[:promote_index]) + args[promote_index:]
else:
args = args[:promote_index] + plum.promote(*args[promote_index:])

# Enforce a change in types. Otherwise, the call will recurse, which
# means that an implementation is not available.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_linear_algebra.py
Expand Up @@ -65,6 +65,13 @@ def test_matmul(f, check_lazy_shapes):
)


def test_einsum(check_lazy_shapes):
for eq in ["ij,ij->", "ij,jk->ik", "ii,ii->"]:
check_function(B.einsum, (Value(eq), Tensor(3, 3), Tensor(3, 3)))
for eq in ["...ij,...ij->...", "...ij,...jk->...ik", "...ii,...ii->..."]:
check_function(B.einsum, (Value(eq), Tensor(4, 3, 3), Tensor(4, 3, 3)))


def test_trace(check_lazy_shapes):
# Check default call.
check_function(
Expand Down
69 changes: 63 additions & 6 deletions tests/test_util.py
Expand Up @@ -120,6 +120,10 @@ def test_metadata(check_lazy_shapes):


def test_abstract(check_lazy_shapes):
# Test that `promote` and `promote_from` cannot be specified at the same time.
with pytest.raises(ValueError):
abstract(promote=1, promote_from=1)(lambda: None)

class General:
pass

Expand All @@ -134,10 +138,11 @@ class Specific:
plum.promote = lambda *args: (b,) * len(args)

# Define some abstract functions.

@B.dispatch
@abstract()
def f1(*args: General):
return args
pass

@B.dispatch
def f1(*args: Specific):
Expand All @@ -146,7 +151,7 @@ def f1(*args: Specific):
@B.dispatch
@abstract(promote=None)
def f2(*args: General):
return args
pass

@B.dispatch
def f2(*args: Specific):
Expand All @@ -155,59 +160,111 @@ def f2(*args: Specific):
@B.dispatch
@abstract(promote=-1)
def f3(*args: General):
return args
pass

@B.dispatch
def f3(*args: Specific):
return args

@B.dispatch
@abstract(promote_from=-1)
def f3_from(*args: General):
pass

@B.dispatch
def f3_from(*args: Specific):
return args

@B.dispatch
@abstract(promote=0)
def f4(*args: General):
return args
pass

@B.dispatch
def f4(*args: Specific):
return args

@B.dispatch
@abstract(promote_from=0)
def f4_from(*args: General):
pass

@B.dispatch
def f4_from(*args: Specific):
return args

@B.dispatch
@abstract(promote=1)
def f5(*args: General):
return args
pass

@B.dispatch
def f5(arg: Specific, *args: General):
return (arg,) + args

@B.dispatch
@abstract(promote_from=1)
def f5_from(*args: General):
pass

@B.dispatch
def f5_from(arg: General, *args: Specific):
return (arg,) + args

@B.dispatch
@abstract(promote=2)
def f6(*args: General):
return args
pass

@B.dispatch
def f6(arg1: Specific, arg2: Specific, *args: General):
return (arg1, arg2) + args

@B.dispatch
@abstract(promote_from=2)
def f6_from(*args: General):
pass

@B.dispatch
def f6_from(arg1: General, arg2: General, *args: Specific):
return (arg1, arg2) + args

# Register methods.
B.f1 = f1
B.f2 = f2
B.f3 = f3
B.f3_from = f3_from
B.f4 = f4
B.f4_from = f4_from
B.f5 = f5
B.f5_from = f5_from
B.f6 = f6
B.f6_from = f6_from

# Test promotion.
with pytest.raises(NotFoundLookupError):
f1(a, a, a)

with pytest.raises(NotFoundLookupError):
f2(a, a, a)

assert f3(a, a, a) == (b, b, b)
with pytest.raises(NotFoundLookupError):
f3_from(a, a, a)

with pytest.raises(NotFoundLookupError):
f4(a, a, a)
assert f4_from(a, a, a) == (b, b, b)

assert f5(a, a, a) == (b, a, a)
assert f5(a) == (b,)
assert f5_from(a, a, a) == (a, b, b)
assert f5_from(a, a) == (a, b)

assert f6(a, a, a) == (b, b, a)
assert f6(a, a) == (b, b)
assert f6_from(a, a, a, a) == (a, a, b, b)
assert f6_from(a, a, a) == (a, a, b)

# Put back promotion function.
plum.promote = plum_promote

0 comments on commit 6f4e598

Please sign in to comment.