In [1]:
# See https://dev-discuss.pytorch.org/t/a-small-mps-debugging-story/769
%env TORCH_SHOW_DISPATCH_TRACE=1

env: TORCH_SHOW_DISPATCH_TRACE=1


# Einsum investigation

This notebook is a minimal reproducer of the lowering issue of einsum


First we define a utility that inspects the lowering.

In [2]:
import torch
import torch_xla
import time

X = torch.zeros(3, 3, requires_grad=False, device='xla')
Y = torch.zeros(3, 3, requires_grad=False, device='xla')
time.sleep(2)

def test_lowering(func):
  time.sleep(2)
  out = func(X, Y)
  time.sleep(2)
  ir = torch_xla._XLAC._get_xla_tensors_text([out])
  if 'einsum' not in ir:
    print("!!!!!!!!!!WRONG!!!!!!!!!!! Did not find einsum in lowering")
    print("IR:")
    print(ir)
  else:
    print("OK")

 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar], key=[CPU]
 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar], key=[CPU]
 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar], key=[CPU]
 [call] op=[aten::ones], key=[BackendSelect]
  [redispatch] op=[aten::ones], key=[CPU]
   [call] op=[aten::empty.memory_format], key=[BackendSelect]
    [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::fill_.Scalar]

Then we test this on a regular einsum function.

In [3]:
test_lowering(lambda a, b: torch.einsum('...n,mn->...m', a, b))

 [call] op=[aten::einsum], key=[AutogradXLA]


OK


Next we define a custom op that wraps said einsum.

In [4]:
import torch
from torch import Tensor
from torch.library import custom_op

@custom_op("xla::custom_linear_forward123", schema="(Tensor input, Tensor weight) -> Tensor", mutates_args=())
def custom_linear_forward123(input: Tensor, weight: Tensor):
    return torch.einsum('...n,mn->...m', input, weight)


Let's test the custom op

In [5]:
test_lowering(lambda a, b: custom_linear_forward123(a, b))

 [callBoxed] op=[xla::custom_linear_forward123], key=[AutogradXLA]
  [redispatchBoxed] op=[xla::custom_linear_forward123], key=[Functionalize]
   [callBoxed] op=[xla::custom_linear_forward123], key=[XLA]
    [call] op=[aten::einsum], key=[XLA]
     [call] op=[aten::unsqueeze], key=[ADInplaceOrView]
      [redispatch] op=[aten::unsqueeze], key=[XLA]
       [call] op=[aten::as_strided], key=[XLA]
     [call] op=[aten::permute], key=[ADInplaceOrView]
      [redispatch] op=[aten::permute], key=[XLA]
       [redispatchBoxed] op=[aten::permute], key=[Meta]
        [call] op=[aten::as_strided], key=[Functionalize]
         [call] op=[aten::as_strided], key=[Meta]
         [call] op=[aten::as_strided_copy], key=[XLA]
     [call] op=[aten::unsqueeze], key=[ADInplaceOrView]
      [redispatch] op=[aten::unsqueeze], key=[XLA]
       [call] op=[aten::as_strided], key=[XLA]
     [call] op=[aten::permute], key=[ADInplaceOrView]
      [redispatch] op=[aten::permute], key=[XLA]
       [redispatchBoxed]

!!!!!!!!!!WRONG!!!!!!!!!!! Did not find einsum in lowering
IR:
IR {
  %0 = f32[] prim::Constant(), xla_shape=f32[]
  %1 = f32[3,3]{1,0} aten::expand(%0), xla_shape=f32[3,3]{1,0}
  %2 = f32[3,3,1]{2,1,0} aten::as_strided(%1), xla_shape=f32[3,3,1]{2,1,0}
  %3 = f32[3,3,1]{2,1,0} aten::as_strided(%2), xla_shape=f32[3,3,1]{2,1,0}
  %4 = f32[1,3,3]{2,1,0} aten::view(%3), xla_shape=f32[1,3,3]{2,1,0}
  %5 = f32[] prim::Constant(), xla_shape=f32[]
  %6 = f32[3,3]{1,0} aten::expand(%5), xla_shape=f32[3,3]{1,0}
  %7 = f32[3,3,1]{2,1,0} aten::as_strided(%6), xla_shape=f32[3,3,1]{2,1,0}
  %8 = f32[3,3,1]{2,1,0} aten::as_strided(%7), xla_shape=f32[3,3,1]{2,1,0}
  %9 = f32[1,3,3]{2,1,0} aten::view(%8), xla_shape=f32[1,3,3]{2,1,0}
  %10 = f32[1,3,3]{2,1,0} aten::matmul(%9, %4), xla_shape=f32[1,3,3]{2,1,0}
  %11 = f32[3,1,3]{2,1,0} aten::view(%10), xla_shape=f32[3,1,3]{2,1,0}
  %12 = f32[3,3,1]{2,1,0} aten::as_strided(%11), xla_shape=f32[3,3,1]{2,1,0}
  %13 = f32[3,3]{1,0} aten::view(%12), xla_shape=f

What's different between these two traces?

The first one has

```
 [call] op=[aten::einsum], key=[AutogradXLA]
```

while the second one has

```
 [call] op=[aten::einsum], key=[XLA]
```

followed by a whole bunch of decomposed aten operations.

This suggests that when calling `torch.einsum` with the `XLA` dispatch key,
our registered lowerings are bypassed. Instead, some other code in PyTorch
handles it and turns the einsum into a bunch of permutes.

## Try to enable the AutogradXLA dispatch key

In [8]:
import torch
from torch import Tensor
from torch.library import custom_op

@custom_op("xla::custom_linear_forward_autograd", schema="(Tensor input, Tensor weight) -> Tensor", mutates_args=())
def custom_linear_forward_autograd(input: Tensor, weight: Tensor):
  keyset = torch._C._dispatch_keys(input)
  print(f"These dispatch keys were enabled originally: {keyset}")
  keyset = keyset.add(torch._C.DispatchKey.AutogradXLA)
  exclude = torch._C._dispatch_tls_local_exclude_set()
  print(f"These dispatch keys were excluded originally: {exclude}")
  exclude = exclude.remove(torch._C.DispatchKey.AutogradXLA)
  with torch._C._ForceDispatchKeyGuard(keyset, exclude):
    return torch.einsum('...n,mn->...m', input, weight)


In [9]:
test_lowering(lambda a, b: custom_linear_forward_autograd(a, b))

 [callBoxed] op=[xla::custom_linear_forward_autograd], key=[AutogradXLA]
  [redispatchBoxed] op=[xla::custom_linear_forward_autograd], key=[Functionalize]
   [callBoxed] op=[xla::custom_linear_forward_autograd], key=[XLA]
    [call] op=[aten::einsum], key=[AutogradXLA]


These dispatch keys were enabled originally: DispatchKeySet(XLA, ADInplaceOrView, AutogradXLA, AutocastXLA)
These dispatch keys were excluded originally: DispatchKeySet(Functionalize, AutogradOther, AutogradNestedTensor, AutocastCPU, AutocastMTIA, AutocastXPU, AutocastIPU, AutocastHPU, AutocastXLA, AutocastMPS, AutocastCUDA, AutocastPrivateUse1)
OK
