# Einsum (the dumb inneficient way)

In [43]:
import torch as t
import einops

## recursion and loops

In [61]:
def dumbsum(x, y, shapes):
    '''
    dumb implem for my own intuition building sake, with absolutely no value for real life use.
    not vectorized, and do not handle splitting / merging / creating extra dim.
    
    the main idea is to:
    1- generate nested loops for indexing for each dim in the output
    2- generate nexted loops for summing everything else
    e.g. 'a b c d e, a c e -> a d b'
    for a in range(x.shape[0]):
      for d in range(x.shape[3]):
        for b in range(x.shape[1]):
          tot = 0
          for c in range(x.shape[2]):
            for e in range(x.shape[4]):
              tot += x[a, b, c, d, e] * y[a, c, e]
          res[a, d, b] = tot

    in practice I initialize res to a tensor of zero, and update it in place instead of accumulating in a tot
    res[a, d, b] += x[a, b, c, d, e] * y[a, c, e]
    '''
    def split_shape(shape):
        return [x for x in shape.split(' ') if x]
    def parse(shapes):
        assert shapes.count(',') == 1
        assert shapes.count('->') == 1
        shapes, res_shape = shapes.split('->')
        x_shape, y_shape = shapes.split(',')
        x_shape, y_shape, res_shape = (split_shape(s) for s in (x_shape, y_shape, res_shape))
        sum_shape = list(set(x_shape + y_shape) - set(res_shape))
        assert set(res_shape).issubset(set(x_shape + y_shape))
        return x_shape, y_shape, res_shape, sum_shape
    def build_dim_lookup(t, t_shape, lookup=None):
        if not lookup: lookup = {}
        dims = t.shape
        for dim, letter in zip(dims, t_shape):
            assert lookup.get(letter, dim) == dim
            lookup[letter] = dim
        return lookup
    def iterate(shape, sum_shape, fn, lookup, indexes):
        if not shape:
            iterate_sum(sum_shape[:], fn, lookup, indexes)
            return
        dim = shape.pop(-1)
        # print(f'iterate over → {dim}')
        for i in range(lookup[dim]):
            indexes[dim] = i
            iterate(shape[:], sum_shape, fn, lookup, indexes)
    def iterate_sum(sum_shape, fn, lookup, indexes):
        if not sum_shape:
            fn(indexes)
            return
        dim = sum_shape.pop(-1)
        # print(f'sum over → {dim}')
        for i in range(lookup[dim]):
            indexes[dim] = i
            iterate_sum(sum_shape[:], fn, lookup, indexes)
    def ind(t_shape, indexes):
        return (indexes[dim] for dim in t_shape)
    def close_sum(x, y, res, x_shape, y_shape, res_shape):
        def fn(indexes):
            # print(f'res[{tuple(ind(res_shape, indexes))}] += x[{tuple(ind(x_shape, indexes))}] * y[{tuple(ind(y_shape, indexes))}]')
            res[*ind(res_shape, indexes)] += x[*ind(x_shape, indexes)] * y[*ind(y_shape, indexes)]
        return fn

    x_shape, y_shape, res_shape, sum_shape = parse(shapes)
    assert len(x_shape) == x.dim()
    assert len(y_shape) == y.dim()
    lookup = build_dim_lookup(x, x_shape)
    lookup = build_dim_lookup(y, y_shape, lookup=lookup)
    res = t.zeros(tuple(lookup[s] for s in res_shape))
    fn = close_sum(x, y, res, x_shape, y_shape, res_shape)
    iterate(res_shape[:], sum_shape[:], fn, lookup, {})
    return res

## vectorized

In [134]:
def dumbsum_vectorized(x, y, shapes):
    '''
    vectorize it, still do not handle splitting / merging / creating extra dim.
    my vectorized also does not handle repeated dim (e.g. 'a a b, a a c -> a a').
    
    the main idea is to:
    1- align the dimensions of x and y, completing the holes with fake `1` dimensions
    2- multiply x and y
    3- sum out the extra dims
    e.g. 'a c d e, a c e -> a d b'
    # align dims
    x = reshape('a c d e -> a 1 c d e')
    y = reshape('a c e   -> a 1 c 1 e')
    # order dims
    x = reshape('a 1 c d e -> a d 1 c e')
    y = reshape('a 1 c 1 e -> a 1 1 c e')
    # mult and sum
    res = x * y
    res = res.sum((3, 4))
    '''
    def split_shape(shape):
        return [x for x in shape.split(' ') if x]
    def parse(shapes):
        assert shapes.count(',') == 1
        assert shapes.count('->') == 1
        shapes, res_shape = shapes.split('->')
        x_shape, y_shape = shapes.split(',')
        x_shape, y_shape, res_shape = (split_shape(s) for s in (x_shape, y_shape, res_shape))
        sum_shape = list(set(x_shape + y_shape) - set(res_shape))
        assert set(res_shape).issubset(set(x_shape + y_shape))
        return x_shape, y_shape, res_shape, sum_shape
    def build_dim_pos_lookup(t_shape):
        return {letter: dim for dim, letter in enumerate(t_shape)}
    def expand(t, t_shape, merged):
        lookup = build_dim_pos_lookup(t_shape)
        ind = len(lookup)
        for dim in merged:
            if dim not in lookup:
                t = t.unsqueeze(-1)
                lookup[dim] = ind
                ind += 1
        return t, lookup
    def align(t, lookup, res_lookup):
        # rely on dict being ordered (python >= 3.7)
        permuted_dims = tuple(lookup[dim] for dim in res_lookup)
        return t.permute(permuted_dims)
    def dims_to_sum(res_shape, res_lookup):
        return tuple(range(len(res_shape), len(res_lookup)))

    x_shape, y_shape, res_shape, sum_shape = parse(shapes)
    assert len(x_shape) == x.dim()
    assert len(y_shape) == y.dim()
    merged = set(x_shape + y_shape)
    x, x_lookup = expand(x, x_shape, merged)
    y, y_lookup = expand(y, y_shape, merged)
    _, res_lookup = expand(t.zeros((0)), res_shape, merged)
    x = align(x, x_lookup, res_lookup)
    y = align(y, y_lookup, res_lookup)
    res = x * y
    dims = dims_to_sum(res_shape, res_lookup)
    if dims: res = res.sum(dims)
    return res

## tests

In [135]:
def dumb_test():
    x = t.rand((4, 5))
    y = t.rand((5, 3))
    pattern = 'a b, b c -> a c'
    a = dumbsum(x, y, pattern)
    b = dumbsum_vectorized(x, y, pattern)
    c = x @ y
    assert a.allclose(c)
    assert b.allclose(c)

dumb_test()

In [136]:
def dumb_test2():
    x = t.rand((5, 4, 3, 2))
    y = t.rand((5, 4, 3, 2))
    pattern = 'a b c d, a b c d -> a b'
    a = dumbsum(x, y, pattern)
    b = dumbsum_vectorized(x, y, pattern)
    c = (x * y).sum((-1, -2))
    assert a.allclose(c)
    assert b.allclose(c)

dumb_test2()

In [137]:
def dumb_test3():
    x = t.rand((10, 5, 2, 3))
    y = t.rand((3, 10, 5, 7))
    pattern = 'a b c d, d a b e -> a e'
    a = dumbsum(x, y, pattern)
    b = dumbsum_vectorized(x, y, pattern)
    # align the 2 tensors dimensions
    xx = x[..., None] # (a b c d 1)
    yy = y[..., None].permute((1, 2, 4, 0, 3)) # (a b 1 d e)
    # put the result dims at the start
    xx = xx.permute((0, 4, 1, 2, 3)) # (a 1 b c d)
    yy = yy.permute((0, 4, 1, 2, 3)) # (a e b 1 d)
    c = (xx * yy).sum((2, 3, 4))
    assert a.allclose(c)
    assert b.allclose(c)

dumb_test3()

In [138]:
def einops_test(x, y, pattern):
    a = dumbsum(x, y, pattern)
    b = dumbsum_vectorized(x, y, pattern)
    c = einops.einsum(x, y, pattern)
    assert a.allclose(c)
    assert b.allclose(c)

x = t.rand((10, 5, 2, 3))
y = t.rand((3, 10, 5, 7))
einops_test(x, y, 'a b c d, d a b e -> b e c')
einops_test(x, y, 'a b c d, d a b e -> a b c d e')
einops_test(x, y, 'a b c d, d a b e -> e d c b a')
einops_test(x, y, 'a b c d, d a b e -> a')
einops_test(x, y, 'a b c d, d a b e ->')
einops_test(x, y, 'a b c d, d a b e -> a e')