Here is an AOTAutograd friendly version of einsum that won't be decomposed into
views and transposes.

In [54]:
import torch
import torch_xla
from torch import Tensor
from typing import Optional
from torch_xla.core.xla_model import XLA_LIB
from torch.library import impl, custom_op

# Custom forward op: uses einsum internally
@custom_op("xla::custom_linear_forward", schema="(Tensor input, Tensor weight, Tensor? bias) -> Tensor", mutates_args=())
def custom_linear_forward(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    product = torch.einsum('...n,mn->...m', input, weight)
    if bias is not None:
        return product + bias
    return product

@custom_linear_forward.register_fake
def custom_linear_forward_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    product = torch.einsum('...n,mn->...m', input, weight)
    if bias is not None:
        return product + bias
    return product

@custom_op("xla::custom_linear_backward1", schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)", mutates_args=())
def custom_linear_backward1(
    grad_output: Tensor,
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    needs_input_grad_input: bool,
    needs_input_grad_weight: bool,
    needs_input_grad_bias: bool
):
    grad_input = grad_weight = grad_bias = None

    if needs_input_grad_input:
        # This doesn't work
        grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
        print(f"grad_input = grad_output @ weight")
        print(f"grad_output = {grad_output}")
        print(f"weight = {weight}")
        print(f"grad_input = {grad_input}")

        # This works
        grad_input = grad_output @ weight
    else:
        grad_input = torch.zeros_like(input)

    if needs_input_grad_weight:
        grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
    else:
        grad_weight = torch.zeros_like(weight)

    if bias is not None and needs_input_grad_bias:
        grad_bias = torch.einsum('...m->m', grad_output)
    else:
        grad_bias = torch.zeros((weight.size(0),), dtype=grad_output.dtype, device=grad_output.device)

    return grad_input, grad_weight, grad_bias

@custom_linear_backward1.register_fake
def custom_linear_backward_fake(
    grad_output: Tensor,
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    needs_input_grad_input: bool,
    needs_input_grad_weight: bool,
    needs_input_grad_bias: bool
):
    grad_input = grad_weight = grad_bias = None

    if needs_input_grad_input:
        grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
        print(f"grad_input = grad_output @ weight")
        print(f"grad_output = {grad_output}")
        print(f"weight = {weight}")
        print(f"grad_input = {grad_input}")
    else:
        grad_input = torch.zeros_like(input)
    
    if needs_input_grad_weight:
        grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
    else:
        grad_weight = torch.zeros_like(weight)

    if bias is not None and needs_input_grad_bias:
        grad_bias = torch.einsum('...m->m', grad_output)
    else:
        grad_bias = torch.zeros((weight.size(0),), dtype=grad_output.dtype, device=grad_output.device)

    return grad_input, grad_weight, grad_bias

# Now define the XLAPatchedLinear function that uses the custom ops
class XLAPatchedLinear(torch.autograd.Function):
    """
    A patched version of `torch.nn.functional.linear` that uses einsum via custom ops.
    By wrapping these calls in custom ops, AOTAutograd won't decompose einsum.
    """

    @staticmethod
    def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor] = None):
        ctx.save_for_backward(input, weight, bias)
        # Call our custom forward op
        return torch.ops.xla.custom_linear_forward(input, weight, bias)

    @staticmethod
    def backward(ctx, grad_output: Tensor):
        input, weight, bias = ctx.saved_tensors
        needs_input_grad_input = ctx.needs_input_grad[0]
        needs_input_grad_weight = ctx.needs_input_grad[1]
        needs_input_grad_bias = False
        if bias is not None:
            needs_input_grad_bias = ctx.needs_input_grad[2]

        print(f"grad_output = {grad_output}")
        print(f"input = {input}")
        print(f"weight = {weight}")
        print(f"bias = {bias}")

        # Call our custom backward op with the boolean flags
        grad_input, grad_weight, grad_bias = torch.ops.xla.custom_linear_backward1(
            grad_output,
            input,
            weight,
            bias,
            needs_input_grad_input,
            needs_input_grad_weight,
            needs_input_grad_bias
        )
        return grad_input, grad_weight, grad_bias


In [55]:
with torch_xla.runtime.xla_device():
  x0 = torch.randn(2, 3, requires_grad=True)
  w0 = torch.randn(4, 3, requires_grad=True)
  b0 = torch.randn(4, requires_grad=True)
  torch_xla.sync()

print("Inputs:")
print(x0, w0, b0)

x = x0.clone().detach().requires_grad_()
w = w0.clone().detach().requires_grad_()
b = b0.clone().detach().requires_grad_()

# Run forward
y = XLAPatchedLinear.apply(x, w, b)
loss = y.sum()
# Run backward
loss.backward()

print("Outputs:")
print(y, x.grad, w.grad, b.grad)

y0, xg0, wg0, bg0 = y.clone().detach(), x.grad.clone().detach(), w.grad.clone().detach(), b.grad.clone().detach()

Inputs:
tensor([[ 1.3808,  0.8378,  0.0665],
        [ 0.1580,  0.9724, -0.1232]], device='xla:0', requires_grad=True) tensor([[-1.8650, -0.2681,  1.7381],
        [-0.6676,  0.2368, -0.7860],
        [-0.2759, -0.5643, -0.3005],
        [-1.5966,  1.1333, -0.5190]], device='xla:0', requires_grad=True) tensor([ 0.5979, -0.8140,  0.1973, -1.6715], device='xla:0',
       requires_grad=True)
grad_output = tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]], device='xla:0')
input = tensor([[ 1.3808,  0.8378,  0.0665],
        [ 0.1580,  0.9724, -0.1232]], device='xla:0', requires_grad=True)
weight = tensor([[-1.8650, -0.2681,  1.7381],
        [-0.6676,  0.2368, -0.7860],
        [-0.2759, -0.5643, -0.3005],
        [-1.5966,  1.1333, -0.5190]], device='xla:0', requires_grad=True)
bias = tensor([ 0.5979, -0.8140,  0.1973, -1.6715], device='xla:0',
       requires_grad=True)
grad_input = grad_output @ weight
grad_output = tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]], device='xla:0')


In [56]:
from torch_xla.distributed.spmd.xla_sharding import XLAPatchedLinear as OGXLAPatchedLinear
import torch_xla.runtime

print("Inputs:")
print(x0, w0, b0)

x = x0.clone().detach().requires_grad_()
w = w0.clone().detach().requires_grad_()
b = b0.clone().detach().requires_grad_()

# Run forward
y = OGXLAPatchedLinear.apply(x, w, b)
loss = y.sum()
# Run backward
loss.backward()

print("Outputs:")
print(y, x.grad, w.grad, b.grad)

y1, xg1, wg1, bg1 = y.clone().detach(), x.grad.clone().detach(), w.grad.clone().detach(), b.grad.clone().detach()
torch.testing.assert_close(y0, y1)
torch.testing.assert_close(xg0, xg1)

Inputs:
tensor([[ 1.3808,  0.8378,  0.0665],
        [ 0.1580,  0.9724, -0.1232]], device='xla:0', requires_grad=True) tensor([[-1.8650, -0.2681,  1.7381],
        [-0.6676,  0.2368, -0.7860],
        [-0.2759, -0.5643, -0.3005],
        [-1.5966,  1.1333, -0.5190]], device='xla:0', requires_grad=True) tensor([ 0.5979, -0.8140,  0.1973, -1.6715], device='xla:0',
       requires_grad=True)
Outputs:
tensor([[-2.0926, -1.5915, -0.6737, -2.9629],
        [-0.1712, -0.5923, -0.3564, -0.7579]], device='xla:0',
       grad_fn=<XLAPatchedLinearBackward>) tensor([[-4.4043,  0.5400,  0.1289],
        [-4.4043,  0.5400,  0.1289]], device='xla:0') tensor([[ 1.5410,  1.8086, -0.0566],
        [ 1.5410,  1.8086, -0.0566],
        [ 1.5410,  1.8086, -0.0566],
        [ 1.5410,  1.8086, -0.0566]], device='xla:0') tensor([2., 2., 2., 2.], device='xla:0')


In [57]:
import torch_xla.runtime
import torch.nn.functional as F

print("Inputs:")
print(x0, w0, b0)

x = x0.clone().detach().requires_grad_()
w = w0.clone().detach().requires_grad_()
b = b0.clone().detach().requires_grad_()

# Run forward
y = F.linear(x, w, b)
loss = y.sum()
# Run backward
loss.backward()

print("Outputs:")
print(y, x.grad, w.grad, b.grad)

y2, xg2, wg2, bg2 = y.clone().detach(), x.grad.clone().detach(), w.grad.clone().detach(), b.grad.clone().detach()
torch.testing.assert_close(y0, y2)
torch.testing.assert_close(xg1, xg2)
torch.testing.assert_close(xg0, xg2)

Inputs:
tensor([[ 1.3808,  0.8378,  0.0665],
        [ 0.1580,  0.9724, -0.1232]], device='xla:0', requires_grad=True) tensor([[-1.8650, -0.2681,  1.7381],
        [-0.6676,  0.2368, -0.7860],
        [-0.2759, -0.5643, -0.3005],
        [-1.5966,  1.1333, -0.5190]], device='xla:0', requires_grad=True) tensor([ 0.5979, -0.8140,  0.1973, -1.6715], device='xla:0',
       requires_grad=True)
Outputs:
tensor([[-2.0926, -1.5915, -0.6737, -2.9629],
        [-0.1712, -0.5923, -0.3564, -0.7579]], device='xla:0',
       grad_fn=<AddmmBackward0>) tensor([[-4.4043,  0.5400,  0.1289],
        [-4.4043,  0.5400,  0.1289]], device='xla:0') tensor([[ 1.5410,  1.8086, -0.0566],
        [ 1.5410,  1.8086, -0.0566],
        [ 1.5410,  1.8086, -0.0566],
        [ 1.5410,  1.8086, -0.0566]], device='xla:0') tensor([2., 2., 2., 2.], device='xla:0')


In [58]:
import torch
from functorch.compile import aot_function

# A custom compiler function that prints the graph.
def print_graph(gm, sample_inputs):
    # Print the FX Graph to observe the operations after decomposition
    print("=== Generated Graph ===")
    print(gm)
    return gm.forward

def my_einsum_func(x, y):
    # A simple einsum expression to test decomposition
    return XLAPatchedLinear.apply(x, y)

# Wrap the function with aot_function, using our custom compilers that print the graph
compiled_func = aot_function(
    my_einsum_func,
    fw_compiler=print_graph,
    bw_compiler=print_graph
)

# Run the compiled function with sample inputs
with torch_xla.runtime.xla_device():
  x = torch.randn(3, 3)
  y = torch.randn(3, 3)
  out = compiled_func(x, y)

print("=== Output ===")
print(out)


=== Generated Graph ===
<lambda>()



def forward(self, arg0_1, arg1_1):
    custom_linear_forward = torch.ops.xla.custom_linear_forward.default(arg0_1, arg1_1, None);  arg0_1 = arg1_1 = None
    return (custom_linear_forward,)
    
# To see more debug info, please use `graph_module.print_readable()`
=== Output ===
tensor([[ 1.3572, -1.4180,  0.3893],
        [-1.2944, -0.0214,  0.0524],
        [-1.0056, -0.4093, -1.5266]], device='xla:0')


## Verify the HLO lowering of einsum

In [59]:
import torch_xla.runtime

x = torch.randn(3, 3)
y = torch.randn(3, 3)

with torch_xla.runtime.xla_device():
  x = x.to('xla')
  y = y.to('xla')
  out = compiled_func(x, y)

print(torch_xla._XLAC._get_xla_tensors_text([out]))
print(torch_xla._XLAC._get_xla_tensors_hlo([out]))

IR {
  %0 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %1 = f32[3,3,1]{2,1,0} aten::as_strided(%0), xla_shape=f32[3,3,1]{2,1,0}
  %2 = f32[3,3,1]{2,1,0} aten::as_strided(%1), xla_shape=f32[3,3,1]{2,1,0}
  %3 = f32[1,3,3]{2,1,0} aten::view(%2), xla_shape=f32[1,3,3]{2,1,0}
  %4 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %5 = f32[3,3,1]{2,1,0} aten::as_strided(%4), xla_shape=f32[3,3,1]{2,1,0}
  %6 = f32[3,3,1]{2,1,0} aten::as_strided(%5), xla_shape=f32[3,3,1]{2,1,0}
  %7 = f32[1,3,3]{2,1,0} aten::view(%6), xla_shape=f32[1,3,3]{2,1,0}
  %8 = f32[1,3,3]{2,1,0} aten::matmul(%7, %3), xla_shape=f32[1,3,3]{2,1,0}
  %9 = f32[3,1,3]{2,1,0} aten::view(%8), xla_shape=f32[3,1,3]{2,1,0}
  %10 = f32[3,3,1]{2,1,0} aten::as_strided(%9), xla_shape=f32[3,3,1]{2,1,0}
  %11 = f32[3,3]{1,0} aten::view(%10), xla_shape=f32[3,3]{1,0}, ROOT=0
}

HloModule IrToHlo.16, entry_computation_layout={(f32[3,3]{1,0}, f32[3,3]{1,0})->(f32[3,3]{1,0})}

ENTRY %IrToHlo.16 (p0.1: f32[3,3], 