In [1]:
import torch
import torch_xla
from torch import Tensor
from typing import Optional
from torch_xla.core.xla_model import XLA_LIB
from torch.library import impl, custom_op

# Custom forward op: uses einsum internally
@custom_op("xla::custom_linear_forward", schema="(Tensor input, Tensor weight, Tensor? bias) -> Tensor", mutates_args=())
def custom_linear_forward(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    product = torch.einsum('...n,mn->...m', input, weight)
    if bias is not None:
        return product + bias
    return product
  
@custom_linear_forward.register_fake
def custom_linear_forward_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    product = torch.einsum('...n,mn->...m', input, weight)
    if bias is not None:
        return product + bias
    return product

@custom_op("xla::custom_linear_backward", schema="(Tensor grad_output, Tensor input, Tensor weight, Tensor? bias, bool needs_input_grad_input, bool needs_input_grad_weight, bool needs_input_grad_bias) -> (Tensor, Tensor, Tensor)", mutates_args=())
def custom_linear_backward(
    grad_output: Tensor,
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    needs_input_grad_input: bool,
    needs_input_grad_weight: bool,
    needs_input_grad_bias: bool
):
    grad_input = grad_weight = grad_bias = None
    
    if needs_input_grad_input:
        grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
    else:
        grad_input = torch.zeros_like(input)
    
    if needs_input_grad_weight:
        grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
    else:
        grad_weight = torch.zeros_like(weight)
    
    if bias is not None and needs_input_grad_bias:
        grad_bias = torch.einsum('...m->m', grad_output)
    else:
        grad_bias = torch.zeros((weight.size(0),), dtype=grad_output.dtype, device=grad_output.device)

    return grad_input, grad_weight, grad_bias

@custom_linear_backward.register_fake
def custom_linear_backward_fake(
    grad_output: Tensor,
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    needs_input_grad_input: bool,
    needs_input_grad_weight: bool,
    needs_input_grad_bias: bool
):
    grad_input = grad_weight = grad_bias = None
    
    if needs_input_grad_input:
        grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
    else:
        grad_input = torch.zeros_like(input)
    
    if needs_input_grad_weight:
        grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
    else:
        grad_weight = torch.zeros_like(weight)
    
    if bias is not None and needs_input_grad_bias:
        grad_bias = torch.einsum('...m->m', grad_output)
    else:
        grad_bias = torch.zeros((weight.size(0),), dtype=grad_output.dtype, device=grad_output.device)

    return grad_input, grad_weight, grad_bias

# Now define the XLAPatchedLinear function that uses the custom ops
class XLAPatchedLinear(torch.autograd.Function):
    """
    A patched version of `torch.nn.functional.linear` that uses einsum via custom ops.
    By wrapping these calls in custom ops, AOTAutograd won't decompose einsum.
    """

    @staticmethod
    def forward(ctx, input: Tensor, weight: Tensor, bias: Optional[Tensor] = None):
        ctx.save_for_backward(input, weight, bias)
        # Call our custom forward op
        return torch.ops.xla.custom_linear_forward(input, weight, bias)

    @staticmethod
    def backward(ctx, grad_output: Tensor):
        input, weight, bias = ctx.saved_tensors
        needs_input_grad_input = ctx.needs_input_grad[0]
        needs_input_grad_weight = ctx.needs_input_grad[1]
        needs_input_grad_bias = False
        if bias is not None:
            needs_input_grad_bias = ctx.needs_input_grad[2]

        # Call our custom backward op with the boolean flags
        grad_input, grad_weight, grad_bias = torch.ops.xla.custom_linear_backward(
            grad_output, 
            input, 
            weight, 
            bias, 
            needs_input_grad_input, 
            needs_input_grad_weight, 
            needs_input_grad_bias
        )
        return grad_input, grad_weight, grad_bias





In [3]:
x = torch.randn(2, 3, requires_grad=True)
w = torch.randn(4, 3, requires_grad=True)
b = torch.randn(4, requires_grad=True)

# Run forward
y = XLAPatchedLinear.apply(x, w, b)
loss = y.sum()
# Run backward
loss.backward()
print(x.grad, w.grad, b.grad)

tensor([[-0.4687, -2.7780,  0.8191],
        [-0.4687, -2.7780,  0.8191]]) tensor([[ 1.8054, -0.7140, -0.1606],
        [ 1.8054, -0.7140, -0.1606],
        [ 1.8054, -0.7140, -0.1606],
        [ 1.8054, -0.7140, -0.1606]]) tensor([2., 2., 2., 2.])


In [4]:
import torch
from functorch.compile import aot_function

# A custom compiler function that prints the graph.
def print_graph(gm, sample_inputs):
    # Print the FX Graph to observe the operations after decomposition
    print("=== Generated Graph ===")
    print(gm)
    return gm.forward

def my_einsum_func(x, y):
    # A simple einsum expression to test decomposition
    return XLAPatchedLinear.apply(x, y)

# Wrap the function with aot_function, using our custom compilers that print the graph
compiled_func = aot_function(
    my_einsum_func,
    fw_compiler=print_graph,
    bw_compiler=print_graph
)

# Run the compiled function with sample inputs
x = torch.randn(3, 3)
y = torch.randn(3, 3)
out = compiled_func(x, y)

print("=== Output ===")
print(out)


=== Generated Graph ===
<lambda>()



def forward(self, arg0_1, arg1_1):
    custom_linear_forward = torch.ops.xla.custom_linear_forward.default(arg0_1, arg1_1, None);  arg0_1 = arg1_1 = None
    return (custom_linear_forward,)
    
# To see more debug info, please use `graph_module.print_readable()`
=== Output ===
tensor([[-3.4562e-01, -6.5829e-01, -2.4595e+00],
        [ 9.8354e-01, -4.9411e-01, -3.9749e+00],
        [-1.0354e+00, -1.8021e+00, -8.4746e-04]])
