# Einsum investigation

This notebook is a minimal reproducer of the lowering issue of einsum


First we define a utility that inspects the lowering.

In [9]:
import torch
import torch_xla

def test_lowering(func):
  x = torch.randn(3, 3, requires_grad=False)
  y = torch.randn(3, 3, requires_grad=False)

  x = x.to('xla')
  y = y.to('xla')
  out = func(x, y)
  ir = torch_xla._XLAC._get_xla_tensors_text([out])
  if 'einsum' not in ir:
    print("!!!!!!!!!!WRONG!!!!!!!!!!! Did not find einsum in lowering")
    print("IR:")
    print(ir)
  else:
    print("OK")

Then we test this on a regular einsum function.

In [10]:
test_lowering(lambda a, b: torch.einsum('...n,mn->...m', a, b))

OK


Next we define a custom op that wraps said einsum.

In [13]:
import torch
from torch import Tensor
from torch.library import custom_op

@custom_op("xla::custom_linear_forward123", schema="(Tensor input, Tensor weight) -> Tensor", mutates_args=())
def custom_linear_forward123(input: Tensor, weight: Tensor):
    return torch.einsum('...n,mn->...m', input, weight)


Let's test the custom op

In [14]:
test_lowering(lambda a, b: custom_linear_forward123(a, b))

!!!!!!!!!!WRONG!!!!!!!!!!! Did not find einsum in lowering
IR:
IR {
  %0 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %1 = f32[3,3,1]{2,1,0} aten::as_strided(%0), xla_shape=f32[3,3,1]{2,1,0}
  %2 = f32[3,3,1]{2,1,0} aten::as_strided(%1), xla_shape=f32[3,3,1]{2,1,0}
  %3 = f32[1,3,3]{2,1,0} aten::view(%2), xla_shape=f32[1,3,3]{2,1,0}
  %4 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %5 = f32[3,3,1]{2,1,0} aten::as_strided(%4), xla_shape=f32[3,3,1]{2,1,0}
  %6 = f32[3,3,1]{2,1,0} aten::as_strided(%5), xla_shape=f32[3,3,1]{2,1,0}
  %7 = f32[1,3,3]{2,1,0} aten::view(%6), xla_shape=f32[1,3,3]{2,1,0}
  %8 = f32[1,3,3]{2,1,0} aten::matmul(%7, %3), xla_shape=f32[1,3,3]{2,1,0}
  %9 = f32[3,1,3]{2,1,0} aten::view(%8), xla_shape=f32[3,1,3]{2,1,0}
  %10 = f32[3,3,1]{2,1,0} aten::as_strided(%9), xla_shape=f32[3,3,1]{2,1,0}
  %11 = f32[3,3]{1,0} aten::view(%10), xla_shape=f32[3,3]{1,0}, ROOT=0
}

