New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Needs someone to complete] torch.einsum() #2008

wants to merge 1 commit into
base: master
Jump to file or symbol
Failed to load files and symbols.
+100 −1
Diff settings


Just for now

@@ -4,7 +4,7 @@
from functools import reduce
__all__ = [
'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul',
'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul', 'einsum'
@@ -245,3 +245,102 @@ def maybeSqueeze(tensor):
raise ValueError("both arguments to __matmul__ need to be at least 1D, "
"but they are {}D and {}D".format(dim_tensor1, dim_tensor2))
def _prod(xs):
res = 1
for x in xs:
res *= x
return res
def _einsum_reduce(t1, t1_indices, t2, t2_indices, dummy_indices):
preserved = set(t1_indices) & set(t2_indices) - dummy_indices
t1_broadcast = ''.join(set(t1_indices) - preserved - dummy_indices)
t2_broadcast = ''.join(set(t2_indices) - preserved - dummy_indices)
preserved = ''.join(preserved)
dummy_indices = ''.join(dummy_indices)
t1_indices = ''.join(t1_indices)
t2_indices = ''.join(t2_indices)
n_preserved = len(preserved)
t1_trans = [t1_indices.find(char) for char in preserved + t1_broadcast + dummy_indices]
t2_trans = [t2_indices.find(char) for char in preserved + dummy_indices + t2_broadcast]
t1 = t1.permute(*t1_trans).contiguous()
t2 = t2.permute(*t2_trans).contiguous()
s1 = t1.size()
s2 = t2.size()
preserved_dims = list(s1[:n_preserved])
t1_broadcast_dims = list(s1[n_preserved:n_preserved + len(t1_broadcast)])
t2_broadcast_dims = list(s2[n_preserved + len(dummy_indices):])
result = torch.bmm(
t1.view(_prod(preserved_dims), _prod(t1_broadcast_dims), -1),
t2.view(_prod(preserved_dims), -1, _prod(t2_broadcast_dims)),
return result.view(*(preserved_dims + t1_broadcast_dims + t2_broadcast_dims)), preserved + t1_broadcast + t2_broadcast
def _reduce_sum(input, axis):
for ax in sorted(axis, reverse=True):
input = input.sum(ax)
return input
def einsum(equation, *inputs):
match = re.match('([a-z,]+)(->[a-z]*)?', equation)
assert '...' not in equation, 'ellpisis'
assert match, 'wrong eq'
input_axis_labels =',')
assert len(inputs) == len(input_axis_labels), 'wrong inputs'
axis_labels = set(''.join(input_axis_labels))
output_axis_labels =[2:]
indices = ''.join(sorted(axis_labels))
counts = {ax: 0 for ax in indices}
for axes_ in input_axis_labels:
for ax in axes_:
counts[ax] += 1
output_axis_labels = ''.join(sorted(
ax for ax in indices
if counts[ax] == 1
for a in axis_labels:
input_count = sum(1 for s in input_axis_labels if a in s)
assert not (input_count > 2 and a not in output_axis_labels), 'exp space'
temp = inputs[0]
temp_axis_labels = input_axis_labels[0]
for i in range(len(inputs)-1):
axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1])
- set(output_axis_labels))
temp, temp_axis_labels = _einsum_reduce(temp,
missing_indices = set(temp_axis_labels) - set(output_axis_labels)
if missing_indices:
reduction_indices = [i for i, a in enumerate(temp_axis_labels)
if a not in output_axis_labels]
temp = _reduce_sum(temp, reduction_indices)
temp_axis_labels = ''.join(a for a in temp_axis_labels
if a in output_axis_labels)
assert sorted(temp_axis_labels) == sorted(output_axis_labels), 'wrong eq or inputs'
perm = [temp_axis_labels.index(a) for a in output_axis_labels]
return temp.permute(*perm)
ProTip! Use n and p to navigate between commits in a pull request.