In [1]:
# Debugging flags
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env PJRT_DEVICE=TPU

env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: PJRT_DEVICE=TPU


## Problem 1: place_to_host, place_to_device are not recognized by torch AOTAutograd

You can't call `place_to_host` on a torch functional tensor.

In [2]:
import logging
import torch_xla
import torch_xla.runtime
from torch_xla.experimental.stablehlo_custom_call import place_to_host, place_to_device

import torch
from functorch.compile import aot_function

import time

device = torch_xla.device()

def fn(a):
  """The identity function but moves the input to host then to device."""
  print("a:", type(a), a.shape)
  time.sleep(1)
  a = place_to_host(a)
  a = place_to_device(a)
  return a

def compiler_fn(m: torch.fx.GraphModule, _):
  print(m.code)
  return m

a, b, c, d = [torch.randn(4, 4, 4, 4, requires_grad=True, device=device) for _ in range(4)]
torch_xla.sync()
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
try:
  res = aot_print_fn(cloned_a)
except RuntimeError as e:
  logging.exception(e)



a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4, 4, 4])


ERROR:root:torch_xla/csrc/aten_xla_bridge.cpp:105 : Check failed: xtensor 
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	torch_xla::bridge::GetXlaTensor(at::Tensor const&)
	torch_xla::bridge::GetXlaTensors(c10::IListRef<at::Tensor> const&)
	
	
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	PyEval_EvalCode
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	
	
	
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_E

## Problem 2: AOTAutograd ignores saved_tensors_hooks

In [3]:
import torch.nn as nn
from torch.autograd.graph import saved_tensors_hooks
from torch_xla.experimental.stablehlo_custom_call import (
  place_to_host, place_to_device
)
from functorch.compile import aot_module

class OffloadingModule(torch.nn.Module):
  def __init__(self, m):
    super().__init__()
    self.m = m

  def forward(self, *args, **kwargs):
    def pack(x):
      print(f"Packing {type(x)} {x.shape}")
      return place_to_host(x)

    def unpack(x):
      print(f"Unpacking {type(x)} {x.shape}")
      return place_to_device(x)

    with saved_tensors_hooks(pack, unpack):
      return self.m(*args, **kwargs)
    
class Layer(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.l = nn.Linear(4, 4)
    
  def forward(self, x):
    x = self.l(x)
    x = torch.sin(x)
    return x

with torch_xla.runtime.xla_device():
  layer = Layer()
  layer = OffloadingModule(layer)

a = torch.randn(4, requires_grad=True, device=device)
torch_xla.sync()
aot_print_fn = aot_module(layer, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a)




def forward(self, primals_1, primals_2, primals_3):
    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
    unsqueeze = torch.ops.aten.unsqueeze.default(primals_3, 0);  primals_3 = None
    mm = torch.ops.aten.mm.default(unsqueeze, t)
    squeeze = torch.ops.aten.squeeze.dim(mm, 0);  mm = None
    add = torch.ops.aten.add.Tensor(squeeze, primals_2);  squeeze = primals_2 = None
    sin = torch.ops.aten.sin.default(add)
    return (sin, t, unsqueeze, add)
    


## Solution: trace a function into fw,bw, then wrap the activations

In [2]:
import torch
import torch.autograd
from torch.library import impl, register_fake
from torch_xla.core.xla_model import XLA_LIB
from torch_xla.experimental.stablehlo_custom_call import place_to_host, place_to_device

@torch.library.custom_op("xla::place_to_host", mutates_args=())
def to_host(t: torch.Tensor) -> torch.Tensor:
  return place_to_host(t)

@to_host.register_fake
def _(t: torch.Tensor) -> torch.Tensor:
  return torch.empty_like(t)

def to_host_backward(ctx, grad):
    return grad

to_host.register_autograd(to_host_backward)

@torch.library.custom_op("xla::place_to_device", mutates_args=())
def to_device(t: torch.Tensor) -> torch.Tensor:
  return place_to_device(t)

@to_device.register_fake
def _(t: torch.Tensor) -> torch.Tensor:
  return torch.empty_like(t)

def to_device_backward(ctx, grad):
    return grad

to_device.register_autograd(to_device_backward)




In [None]:
import torch
import torch_xla
from functorch.compile import aot_function, make_boxed_func  # type:ignore
from functools import partial

device = torch_xla.device()

a = torch.randn(4, requires_grad=True, device=device)

def my_layer(t):
  t = t + 456
  host_t = torch.ops.xla.place_to_host(t)  # type:ignore
  device_t = torch.ops.xla.place_to_device(host_t)   # type:ignore
  return device_t + 123

def compiler_fn(name: str, m: torch.fx.GraphModule, _):
  print(f"Captured {name} graph:")
  print(m.code)
  import time
  time.sleep(3)
  return make_boxed_func(m)

a = torch.randn(4, requires_grad=True, device=device)
torch_xla.sync()
aot_print_fn = aot_function(
  my_layer,
  fw_compiler=partial(compiler_fn, "forward"),
  bw_compiler=partial(compiler_fn, "backward")
)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a)
torch_xla.sync()
res.sum().backward()
torch_xla.sync()
print(f"Res: {res}, cloned a grad: {cloned_a.grad}")

Captured forward graph:



def forward(self, primals_1):
    add = torch.ops.aten.add.Tensor(primals_1, 456);  primals_1 = None
    place_to_host = torch.ops.xla.place_to_host.default(add);  add = None
    place_to_device = torch.ops.xla.place_to_device.default(place_to_host);  place_to_host = None
    add_1 = torch.ops.aten.add.Tensor(place_to_device, 123);  place_to_device = None
    return (add_1,)
    
Captured backward graph:



def forward(self, tangents_1):
    return (tangents_1,)
    
Res: tensor([578.8382, 579.8059, 579.2380, 580.3248], device='xla:0',
       grad_fn=<CompiledFunctionBackward>), cloned a grad: tensor([1., 1., 1., 1.], device='xla:0')
