In [15]:
# For quick debugging and crash recovery
%env PJRT_DEVICE=CPU

env: PJRT_DEVICE=CPU


In [19]:
from typing import Callable, TypeVar

import torch
import torch_xla.core.xla_builder as xb
from torch._ops import HigherOrderOperator
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented

import torch_xla

Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')


class ScanOp(HigherOrderOperator):

  def __init__(self):
    super().__init__("scan")

  def __call__(
      self,
      fn: Callable[[Carry, X], tuple[Carry, Y]],
      init: Carry,
      xs: X,
      /,
  ) -> tuple[Carry, Y]:
    return super().__call__(fn, init, xs)  # type: ignore


scan_op = ScanOp()


def scan(
    fn: Callable[[Carry, X], tuple[Carry, Y]],
    init: Carry,
    xs: X,
) -> tuple[Carry, Y]:
  return scan_op(fn, init, xs)


def dynamic_update_slice(ys: xb.Op, y: xb.Op, idx: xb.Op) -> xb.Op:
  # See https://openxla.org/xla/operation_semantics#dynamicupdateslice.
  y = y.broadcast([1])
  indices = [idx]
  for _ in range(ys.shape().rank - 1):
    indices.append(idx.zeros_like())
  # TODO: This is buggy
  return ys.dynamic_update_slice(y, indices)


def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op:
  indices = [idx]
  for _ in range(xs.shape().rank - 1):
    indices.append(idx.zeros_like())
  slice_shape = list(xs.shape().sizes)
  slice_shape[0] = 1
  sliced = xs.dynamic_slice(indices, slice_shape)
  shape = list(xs.shape().sizes)
  shape = shape[1:]
  return sliced.reshape(shape)


# See https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md for
# the meaning of CompositeExplicitAutogradNonFunctional.
@scan_op.py_impl(DispatchKey.CompositeExplicitAutogradNonFunctional)
@scan_op.py_impl(DispatchKey.XLA)
def scan_dense(fn, init, xs):
  """Forward implementation of scan."""

  # Abstractly trace and lower `fn`.
  # Later we will include `fn_computation` within the while loop body.
  device = torch_xla.device()
  fake_carry = torch.empty(init.size(), dtype=init.dtype).to(device)
  fake_xs = torch.empty(xs[0].size(), dtype=xs[0].dtype).to(device)
  fn_outputs = fn(fake_carry, fake_xs)
  fn_ctx = torch_xla._XLAC.lowering.LoweringContext()
  fn_ctx.set_name_string("my_ctx")
  fn_ctx.build(list(fn_outputs))
  fn_hlo = fn_ctx.hlo()
  fn_computation = xb.computation_from_module_proto("my_fn_computation", fn_hlo)
  xs_len = xs.shape[0]

  # Figure out the shape of `ys` from the abstract tracing.
  fn_carry_shape, fn_y_shape = (v.shape for v in fn_outputs)
  assert fn_carry_shape == init.shape, f"`fn` must keep the `carry` shape unchanged. \
    Got {fn_carry_shape} but expected {init.shape}"

  def cond_fn(num_iters: xb.Op, carry, xs, ys):
    return num_iters > xb.Op.scalar(num_iters.builder(), 0, dtype=xb.Type.S64)

  def body_fn(num_iters: xb.Op, carry: xb.Op, xs: xb.Op, ys: xb.Op):
    xs_len_op = xb.Op.scalar(num_iters.builder(), xs_len, dtype=xb.Type.S64)
    one = xb.Op.scalar(num_iters.builder(), 1, dtype=xb.Type.S64)
    idx = xs_len_op - num_iters
    x = dynamic_slice(xs, idx)
    result = xb.Op.call(fn_computation, (carry, x))
    carry = result.get_tuple_element(0)
    y = result.get_tuple_element(1)
    ys = dynamic_update_slice(ys, y, idx)
    return xb.Op.tuple((num_iters - one, carry, xs, ys))

  num_iters = torch.tensor(xs_len, device=device)
  ys = torch.zeros((xs_len, *fn_y_shape), device=device)
  carry = (num_iters, init, xs, ys)
  builder = xb.create_builder('scan')
  carry_param = []
  for i, val in enumerate(carry):
    carry_param.append(xb.mkparam(builder, i, xb.tensor_shape(val)))
  res = xb.Op.mkwhile(tuple(carry_param), cond_fn, body_fn)
  computation = res.build('scan')

  _last_iter, carry, xs, ys = torch_xla._XLAC._xla_user_computation(
      'xla::scan', carry, computation)

  return carry, ys


import torch.autograd


class Scan(torch.autograd.Function):

  @staticmethod
  def forward(ctx, fn, init, xs):
    # Forward pass, save inputs for backward
    ctx._fn = fn
    with torch._C._AutoDispatchBelowAutograd():
      carry, ys = scan(fn, init, xs)
    ctx.save_for_backward(xs, carry)
    return carry, ys

  @staticmethod
  def backward(ctx, grad_carry, grad_ys):
    print("I'm in backward!!!!!!!!!!!!!!!!!")
    
    fn = ctx._fn
    xs, carry = ctx.saved_tensors

    # Define the backward step function for each step in reverse
    def backward_step_fn(grad_carry, grad_y, carry, x):
      # Compute gradients for carry and x at each step
      carry, y = fn(carry, x)
      grad_x = grad_carry * x + grad_y
      grad_carry = grad_carry * carry + grad_y
      return grad_carry, grad_x

    # Reverse loop to accumulate gradients using `scan`
    # We flip the gradients and the input tensors to simulate reverse iteration
    flipped_grad_ys = grad_ys.flip(dims=[0])
    flipped_xs = xs.flip(dims=[0])

    # Initial gradients for the carry and xs
    grad_init = grad_carry.clone()
    grad_xs = torch.zeros_like(xs)

    with torch._C._AutoDispatchBelowAutograd():
      _, reversed_grad_xs = scan(
          lambda gc, gyx: backward_step_fn(gc, gyx[0], carry, gyx[1]),
          grad_init, torch.stack([flipped_grad_ys, flipped_xs], dim=1))

    # Flip the gradients back to original order
    grad_xs = reversed_grad_xs.flip(dims=[0])

    return None, grad_init, grad_xs


# scan_op.py_impl(DispatchKey.AutogradXLA)(
#     autograd_not_implemented(scan_op, deferred_error=True))
scan_op.py_impl(DispatchKey.AutogradXLA)(Scan.apply)


@scan_op.py_functionalize_impl
def scan_func(ctx, fn, init, xs):
  unwrapped_init = ctx.unwrap_tensors(init)
  unwrapped_xs = ctx.unwrap_tensors(xs)
  with ctx.redispatch_to_next() as m:
    functional_fn = ctx.functionalize(fn)
    ret = scan_op(
        functional_fn,
        unwrapped_init,
        unwrapped_xs,
    )
    return ctx.wrap_tensors(ret)


Test the `scan` operation when input and output have the same shape

In [20]:
def cumsum(arr):

  def scan_fn(carry, x):
    return carry + x, carry + x

  _, result = scan(scan_fn, torch.tensor([0.0] * 3, device=device), arr)
  return result


device = torch_xla.device()
arr = torch.stack([
    torch.tensor([1.0, 1.0, 1.0], device=torch.device("cpu")) * i
    for i in range(5)
])
arr = arr.to(device)

cumulative_sum = cumsum(arr)
torch_xla.sync()
print("Result: ", cumulative_sum)

Result:  FunctionalTensor(lvl=0, value=\
tensor([[ 0.,  0.,  0.],
        [ 1.,  1.,  1.],
        [ 3.,  3.,  3.],
        [ 6.,  6.,  6.],
        [10., 10., 10.]], device='xla:0'))


Test the `scan` operation when input and output have different shapes

In [21]:
def explode(arr):

  def scan_fn(carry: torch.Tensor, x: torch.Tensor):
    return carry + x, (carry + x).tile(3)

  _, result = scan(scan_fn, torch.tensor([0.0] * 3, device=device), arr)
  return result


device = torch_xla.device()
arr = torch.stack([
    torch.tensor([1.0, 1.0, 1.0], device=torch.device("cpu")) * i
    for i in range(5)
])
arr = arr.to(device)

res = explode(arr)
torch_xla.sync()
print("Result: ", res)

Result:  FunctionalTensor(lvl=0, value=\
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
        [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
        [10., 10., 10., 10., 10., 10., 10., 10., 10.]], device='xla:0'))


Test the backwards of `scan`

In [22]:
# A simple function to be applied at each step of the scan
def step_fn(carry, x):
    new_carry = carry + x
    y = carry * x
    return new_carry, y

device = torch_xla.device()

# Initial carry (let's make it a scalar with requires_grad)
init_carry = torch.tensor([1.0, 1.0, 1.0], requires_grad=True, device=device)

# Example input tensor of shape (batch_size, features)
xs = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], requires_grad=True, device=device)

# Use the scan function
final_carry, ys = scan(step_fn, init_carry, xs)

# Loss for backward pass (sum of the outputs)
loss = ys.sum()
loss

tensor(249., device='xla:0', grad_fn=<SumBackward0>)

In [23]:
def print_grad_fn(grad_fn, level=0):
    if grad_fn is None:
        return

    print("  " * level + str(grad_fn))
    for next_fn in grad_fn.next_functions:
        if next_fn[0] is not None:
            print_grad_fn(next_fn[0], level + 1)

print_grad_fn(loss.grad_fn)

loss.backward()

: 

TODO: the following WIP impl supports dynamo tracing

In [14]:
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
import torch.utils._pytree as pytree
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._higher_order_ops.utils import reenter_make_fx


@scan_op.py_impl(ProxyTorchDispatchMode)
def scan_tracing(mode, fn, init, xs):

  def _trace_scan(proxy_mode, scan_op, fn, init, xs):
    body_graph = reenter_make_fx(fn)(init)

    next_name = None
    i = 0
    while not next_name:
      candidate = f"scan_graph_{i}"
      if hasattr(proxy_mode.tracer.root, candidate):
        i += 1
      else:
        next_name = candidate
    body_graph_name = next_name
    assert not hasattr(proxy_mode.tracer.root, body_graph_name)
    proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
    args = (body_graph, init, xs)
    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
    out_proxy = proxy_mode.tracer.create_proxy(
        "call_function", scan_op, proxy_args, {}, name="scan")

    # fn return output with the same pytree and tensor meta data as carried_inputs
    # so we could just return the output after one iteration.
    out = fn(init, xs)
    return track_tensor_tree(
        out, out_proxy, constant=None, tracer=proxy_mode.tracer)

  if mode.enable_tracing:
    return _trace_scan(mode, scan_op, fn, init, xs)
  else:
    return scan_op(fn, init, xs)


@scan_op.py_impl(FakeTensorMode)
def while_loop_fake_tensor_mode(mode, cond_fn, body_fn, carried_inputs,
                                additional_inputs):
  with mode:
    return body_fn(*carried_inputs, *additional_inputs)

