Here is an AOTAutograd friendly version of einsum that won't be decomposed into
views and transposes.

In [9]:
import torch
import torch_xla
from torch import Tensor
from typing import Optional
from torch_xla.core.xla_model import XLA_LIB
from torch.library import impl, custom_op

# Custom forward op: uses einsum internally
@custom_op("xla::custom_linear_forward", schema="(Tensor input, Tensor weight, Tensor? bias) -> Tensor", mutates_args=())
def custom_linear_forward(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    product = torch_xla._XLAC._xla_einsum('...n,mn->...m', (input, weight))
    if bias is not None:
        return product + bias
    return product
  
@custom_linear_forward.register_fake
def custom_linear_forward_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    product = torch.einsum('...n,mn->...m', input, weight)
    if bias is not None:
        return product + bias
    return product

@custom_op("xla::custom_linear_backward", schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)", mutates_args=())
def custom_linear_backward(
    grad_output: Tensor,
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    needs_input_grad_input: bool,
    needs_input_grad_weight: bool,
    needs_input_grad_bias: bool
):
    grad_input = grad_weight = grad_bias = None
    
    if needs_input_grad_input:
        grad_input = torch_xla._XLAC._xla_einsum('...m,mn->...n', (grad_output, weight))
    else:
        grad_input = torch.zeros_like(input)
    
    if needs_input_grad_weight:
        grad_weight = torch_xla._XLAC._xla_einsum('...m,...n->mn', (grad_output, input))
    else:
        grad_weight = torch.zeros_like(weight)
    
    if bias is not None and needs_input_grad_bias:
        grad_bias = torch_xla._XLAC._xla_einsum('...m->m', (grad_output, ))
    else:
        grad_bias = torch.zeros((weight.size(0),), dtype=grad_output.dtype, device=grad_output.device)

    return grad_input, grad_weight, grad_bias

@custom_linear_backward.register_fake
def custom_linear_backward_fake(
    grad_output: Tensor,
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    needs_input_grad_input: bool,
    needs_input_grad_weight: bool,
    needs_input_grad_bias: bool
):
    grad_input = grad_weight = grad_bias = None
    
    if needs_input_grad_input:
        grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
    else:
        grad_input = torch.zeros_like(input)
    
    if needs_input_grad_weight:
        grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
    else:
        grad_weight = torch.zeros_like(weight)
    
    if bias is not None and needs_input_grad_bias:
        grad_bias = torch.einsum('...m->m', grad_output)
    else:
        grad_bias = torch.zeros((weight.size(0),), dtype=grad_output.dtype, device=grad_output.device)

    return grad_input, grad_weight, grad_bias

# Now define the XLAPatchedLinear function that uses the custom ops
class XLAPatchedLinear(torch.autograd.Function):
    """
    A patched version of `torch.nn.functional.linear` that uses einsum via custom ops.
    By wrapping these calls in custom ops, AOTAutograd won't decompose einsum.
    """

    @staticmethod
    def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor] = None):
        ctx.save_for_backward(input, weight, bias)
        # Call our custom forward op
        return torch.ops.xla.custom_linear_forward(input, weight, bias)

    @staticmethod
    def backward(ctx, grad_output: Tensor):
        input, weight, bias = ctx.saved_tensors
        needs_input_grad_input = ctx.needs_input_grad[0]
        needs_input_grad_weight = ctx.needs_input_grad[1]
        needs_input_grad_bias = False
        if bias is not None:
            needs_input_grad_bias = ctx.needs_input_grad[2]

        # Call our custom backward op with the boolean flags
        grad_input, grad_weight, grad_bias = torch.ops.xla.custom_linear_backward(
            grad_output, 
            input, 
            weight, 
            bias, 
            needs_input_grad_input, 
            needs_input_grad_weight, 
            needs_input_grad_bias
        )
        return grad_input, grad_weight, grad_bias



In [10]:
with torch_xla.runtime.xla_device():
  x = torch.randn(2, 3, requires_grad=True)
  w = torch.randn(4, 3, requires_grad=True)
  b = torch.randn(4, requires_grad=True)

  # Run forward
  y = XLAPatchedLinear.apply(x, w, b)
  loss = y.sum()
  # Run backward
  loss.backward()
  print(x.grad, w.grad, b.grad)

tensor([[-1.1934, -2.8848,  0.7188],
        [-1.1934, -2.8848,  0.7188]], device='xla:0') tensor([[ 1.3555, -1.8750, -0.3125],
        [ 1.3555, -1.8750, -0.3125],
        [ 1.3555, -1.8750, -0.3125],
        [ 1.3555, -1.8750, -0.3125]], device='xla:0') tensor([2., 2., 2., 2.], device='xla:0')


Building Einsum: ...n,mn->...m
Building Einsum: ...m,mn->...n
Building Einsum: ...m,...n->mn
Building Einsum: ...m->m


In [11]:
import torch
from functorch.compile import aot_function

# A custom compiler function that prints the graph.
def print_graph(gm, sample_inputs):
    # Print the FX Graph to observe the operations after decomposition
    print("=== Generated Graph ===")
    print(gm)
    return gm.forward

def my_einsum_func(x, y):
    # A simple einsum expression to test decomposition
    return XLAPatchedLinear.apply(x, y)

# Wrap the function with aot_function, using our custom compilers that print the graph
compiled_func = aot_function(
    my_einsum_func,
    fw_compiler=print_graph,
    bw_compiler=print_graph
)

# Run the compiled function with sample inputs
with torch_xla.runtime.xla_device():
  x = torch.randn(3, 3)
  y = torch.randn(3, 3)
  out = compiled_func(x, y)

print("=== Output ===")
print(out)


=== Generated Graph ===
<lambda>()



def forward(self, arg0_1, arg1_1):
    custom_linear_forward = torch.ops.xla.custom_linear_forward.default(arg0_1, arg1_1, None);  arg0_1 = arg1_1 = None
    return (custom_linear_forward,)
    
# To see more debug info, please use `graph_module.print_readable()`
=== Output ===
tensor([[ 0.6927, -1.0702,  0.1617],
        [ 3.4787, -0.7434,  0.0236],
        [ 1.6258,  0.4528, -0.1720]], device='xla:0')


Building Einsum: ...n,mn->...m


## Verify the HLO lowering of einsum

In [12]:
import torch_xla.runtime

x = torch.randn(3, 3)
y = torch.randn(3, 3)

with torch_xla.runtime.xla_device():
  x = x.to('xla')
  y = y.to('xla')
  out = compiled_func(x, y)

print(torch_xla._XLAC._get_xla_tensors_text([out]))
print(torch_xla._XLAC._get_xla_tensors_hlo([out]))

IR {
  %0 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %1 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %2 = f32[3,3]{1,0} aten::einsum(%1, %0), xla_shape=f32[3,3]{1,0}, ROOT=0
}

HloModule IrToHlo.6, entry_computation_layout={(f32[3,3]{1,0}, f32[3,3]{1,0})->(f32[3,3]{1,0})}

ENTRY %IrToHlo.6 (p0.1: f32[3,3], p1.2: f32[3,3]) -> (f32[3,3]) {
  %p1.2 = f32[3,3]{1,0} parameter(1)
  %p0.1 = f32[3,3]{1,0} parameter(0)
  %dot.3 = f32[3,3]{1,0} dot(f32[3,3]{1,0} %p1.2, f32[3,3]{1,0} %p0.1), lhs_contracting_dims={1}, rhs_contracting_dims={1}, frontend_attributes={grad_x="false",grad_y="false"}
  %transpose.4 = f32[3,3]{1,0} transpose(f32[3,3]{1,0} %dot.3), dimensions={0,1}
  ROOT %tuple.5 = (f32[3,3]{1,0}) tuple(f32[3,3]{1,0} %transpose.4)
}




Building Einsum: ...n,mn->...m


In [None]:
import torch_xla.runtime

x = torch.randn(3, 3, requires_grad=True)
y = torch.randn(3, 3, requires_grad=True)

with torch.enable_grad():
  with torch_xla.runtime.xla_device():
    x = x.to('xla').requires_grad_()
    y = y.to('xla').requires_grad_()
    out = torch.einsum('ab,bc->ab', x, y)

print(torch_xla._XLAC._get_xla_tensors_text([out]))
print(torch_xla._XLAC._get_xla_tensors_hlo([out]))

IR {
  %0 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %1 = f32[3,3,1]{2,1,0} aten::view(%0), xla_shape=f32[3,3,1]{2,1,0}
  %2 = f32[1,3,3]{0,2,1} aten::permute(%1), xla_shape=f32[1,3,3]{0,2,1}
  %3 = f32[1,3,1]{2,1,0} aten::sum(%2), xla_shape=f32[1,3,1]{2,1,0}
  %4 = f32[3,3]{1,0} xla::device_data(), xla_shape=f32[3,3]{1,0}
  %5 = f32[3,3,1]{2,1,0} aten::view(%4), xla_shape=f32[3,3,1]{2,1,0}
  %6 = f32[3,3,1]{2,1,0} aten::permute(%5), xla_shape=f32[3,3,1]{2,1,0}
  %7 = f32[3,3,1]{2,1,0} aten::mul(%6, %3), xla_shape=f32[3,3,1]{2,1,0}
  %8 = f32[3,3]{1,0} aten::view(%7), xla_shape=f32[3,3]{1,0}, ROOT=0
}

HloModule IrToHlo.20, entry_computation_layout={(f32[3,3]{1,0}, f32[3,3]{1,0})->(f32[3,3]{1,0})}

%AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
  %x.7 = f32[] parameter(0)
  %y.8 = f32[] parameter(1)
  ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
}

ENTRY %IrToHlo.20 (p0.1: f32[3,3], p1.12: f32[3,3]) -> (f32[3,3]) {
  %constant.5 = s32[] constant(3)
  %p1.12

: 