# Einsum (the dumb inneficient way)

In [43]:
import torch as t
import einops

In [61]:
def dumbsum(x, y, shapes):
    '''
    dumb implem for my own intuition building sake, with absolutely no value for real life use.
    not vectorized, and do not handle splitting / merging / creating extra dim.
    
    the main idea is to:
    1- generate nested loops for indexing for each dim in the output
    2- generate nexted loops for summing everything else
    e.g. 'a b c d e, a c e -> a d b'
    for a in range(x.shape[0]):
      for d in range(x.shape[3]):
        for b in range(x.shape[1]):
          tot = 0
          for c in range(x.shape[2]):
            for e in range(x.shape[4]):
              tot += x[a, b, c, d, e] * y[a, c, e]
          res[a, d, b] = tot

    in practice I initialize res to a tensor of zero, and update it in place instead of accumulating in a tot
    res[a, d, b] += x[a, b, c, d, e] * y[a, c, e]
    '''
    def split_shape(shape):
        return [x for x in shape.split(' ') if x]
    def parse(shapes):
        assert shapes.count(',') == 1
        assert shapes.count('->') == 1
        shapes, res_shape = shapes.split('->')
        x_shape, y_shape = shapes.split(',')
        x_shape, y_shape, res_shape = (split_shape(s) for s in (x_shape, y_shape, res_shape))
        sum_shape = list(set(x_shape + y_shape) - set(res_shape))
        assert set(res_shape).issubset(set(x_shape + y_shape))
        return x_shape, y_shape, res_shape, sum_shape
    def build_dim_lookup(t, t_shape, lookup=None):
        if not lookup: lookup = {}
        dims = t.shape
        for dim, letter in zip(dims, t_shape):
            assert lookup.get(letter, dim) == dim
            lookup[letter] = dim
        return lookup
    def iterate(shape, sum_shape, fn, lookup, indexes):
        if not shape:
            iterate_sum(sum_shape[:], fn, lookup, indexes)
            return
        dim = shape.pop(-1)
        # print(f'iterate over → {dim}')
        for i in range(lookup[dim]):
            indexes[dim] = i
            iterate(shape[:], sum_shape, fn, lookup, indexes)
    def iterate_sum(sum_shape, fn, lookup, indexes):
        if not sum_shape:
            fn(indexes)
            return
        dim = sum_shape.pop(-1)
        # print(f'sum over → {dim}')
        for i in range(lookup[dim]):
            indexes[dim] = i
            iterate_sum(sum_shape[:], fn, lookup, indexes)
    def ind(t_shape, indexes):
        return (indexes[dim] for dim in t_shape)
    def close_sum(x, y, res, x_shape, y_shape, res_shape):
        def fn(indexes):
            # print(f'res[{tuple(ind(res_shape, indexes))}] += x[{tuple(ind(x_shape, indexes))}] * y[{tuple(ind(y_shape, indexes))}]')
            res[*ind(res_shape, indexes)] += x[*ind(x_shape, indexes)] * y[*ind(y_shape, indexes)]
        return fn

    x_shape, y_shape, res_shape, sum_shape = parse(shapes)
    assert len(x_shape) == x.dim()
    assert len(y_shape) == y.dim()
    lookup = build_dim_lookup(x, x_shape)
    lookup = build_dim_lookup(y, y_shape, lookup=lookup)
    res = t.zeros(tuple(lookup[s] for s in res_shape))
    fn = close_sum(x, y, res, x_shape, y_shape, res_shape)
    iterate(res_shape[:], sum_shape[:], fn, lookup, {})
    return res

In [63]:
def dumb_test():
    x = t.rand((4, 5))
    y = t.rand((5, 3))
    a = dumbsum(x, y, 'a b, b c -> a c')
    b = x @ y
    assert a.allclose(b)

dumb_test()

In [72]:
def einops_test(x, y, pattern):
    a = dumbsum(x, y, pattern)
    b = einops.einsum(x, y, pattern)
    assert a.allclose(b)

x = t.rand((10, 5, 2, 3))
y = t.rand((3, 10, 5, 7))
einops_test(x, y, 'a b c d, d a b e -> b e c')
einops_test(x, y, 'a b c d, d a b e -> a b c d e')
einops_test(x, y, 'a b c d, d a b e -> e d c b a')
einops_test(x, y, 'a b c d, d a b e -> a')
einops_test(x, y, 'a b c d, d a b e ->')
einops_test(x, y, 'a b c d, d a b e -> a e')