In [1]:
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1

env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1


In [2]:
import torch
import torch.nn as nn
import torch_xla



In [3]:
from torch_xla.experimental.scan import scan

In [4]:
weird_tensor = torch.tensor([0.0, 0.0], device=torch_xla.device())

def step_fn(carry, x):
  new_carry = carry + x
  weird_tensor.add_(1.0)
  y = new_carry + weird_tensor
  return new_carry, y

init = torch.tensor([0.0, 0.0], device=torch_xla.device())
xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device=torch_xla.device())

# Bug: because we only trace the DAG subgraph rooted at `step_fn` outputs, mutations
# to `weird_tensor` aren't captured. In PyTorch/XLA, mutations are supported via a
# Copy-on-Write mechanism, where we update the reference inside `weird_tensor` to an
# updated tensor. To fix the bug, we need to verify that the HLO from each step are
# the same. This way, the next trace of `step_fn` will use the mutated tensor and
# collect a larger and distinct graph, which will catch in-place mutations. In summary,
# just like how JAX requires `step_fn` to be a pure function, we also need to prevent
# side-effects in order to extract a single shared HLO computation.
scan(step_fn, init, xs)

AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.add.Tensor(tensor([...], device='xla:0', size=(2,)), 1.0)

In [None]:
import torch
from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree


def loopy_scan(fn, init, xs):
  """A simple scan implemented with for loops serving as reference
  implementation."""
  carry = init
  ys = []
  xs_len = len(next(iter(tree_iter(xs))))
  for i in range(xs_len):
    carry, y = fn(carry, tree_map(lambda x: x[i], xs))
    ys.append(y)
  ys = tree_map(lambda *x: torch.stack(x), *ys)
  return carry, ys


weird_tensor = torch.tensor([0.0, 0.0], device=torch_xla.device())

def step_fn(carry, x):
  new_carry = carry + x
  weird_tensor.add_(1.0)
  y = new_carry + weird_tensor
  return new_carry, y

init = torch.tensor([0.0, 0.0], device=torch_xla.device())
xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device=torch_xla.device())

loopy_scan(step_fn, init, xs)

(tensor([0., 0.], device='xla:0'),
 tensor([[1., 1.],
         [2., 2.],
         [3., 3.]], device='xla:0'))