# Scan op reference implementation

We present one Python based for-loop impl, and one JAX impl, to compare correctness.

In [1]:
%env PJRT_DEVICE=CPU

env: PJRT_DEVICE=CPU


In [2]:
import torch_xla
import torch
from typing import Callable, TypeVar, Tuple

Carry = TypeVar('Carry')
X = torch.Tensor
Y = torch.Tensor


def scan(fn: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry,
         xs: X) -> Tuple[Carry, Y]:
  carry = init
  ys = []

  for i in range(xs.size(0)):
    carry, y = fn(carry, xs[i])
    ys.append(y)

  # Stack the list of outputs into a single tensor
  ys = torch.stack(ys)

  return carry, ys


# Test Function
def step_fn(carry, x):
  new_carry = carry + x
  y = carry * x
  return new_carry, y


# Test the simplified scan implementation
device = torch_xla.device()

# Initial carry (let's make it a scalar with requires_grad)
init_carry = torch.tensor([1.0, 1.0, 1.0], requires_grad=True, device=device)

# Example input tensor of shape (batch_size, features)
xs = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
                  requires_grad=True,
                  device=device)

# Use the scan function
final_carry, ys = scan(step_fn, init_carry, xs)

# Loss for backward pass (sum of the outputs)
loss = ys.sum()
print("Loss:", loss)

loss.backward()

# Output gradients
print("init_carry grad:", init_carry.grad)
print("xs grad:", xs.grad)


I0000 00:00:1723077440.079486 1033945 cpu_client.cc:466] TfrtCpuClient created.


Loss: tensor(249., device='xla:0', grad_fn=<SumBackward0>)
init_carry grad: tensor([12., 15., 18.], device='xla:0')
xs grad: tensor([[12., 14., 16.],
        [ 9., 11., 13.],
        [ 6.,  8., 10.]], device='xla:0')


In [3]:
def print_grad_fn(grad_fn, level=0):
    if grad_fn is None:
        return

    print("  " * level + str(grad_fn))
    for next_fn in grad_fn.next_functions:
        if next_fn[0] is not None:
            print_grad_fn(next_fn[0], level + 1)

print_grad_fn(loss.grad_fn)


<SumBackward0 object at 0x7fe2075cef50>
  <StackBackward0 object at 0x7fe2075cf490>
    <MulBackward0 object at 0x7fe2075cf520>
      <AccumulateGrad object at 0x7fe2075cf0d0>
      <SelectBackward0 object at 0x7fe2075cf010>
        <AccumulateGrad object at 0x7fe2075cf070>
    <MulBackward0 object at 0x7fe2075cf850>
      <AddBackward0 object at 0x7fe2075cf010>
        <AccumulateGrad object at 0x7fe2075cf070>
        <SelectBackward0 object at 0x7fe2075cf040>
          <AccumulateGrad object at 0x7fe2075cebc0>
      <SelectBackward0 object at 0x7fe2075cf0a0>
        <AccumulateGrad object at 0x7fe2075cf040>
    <MulBackward0 object at 0x7fe2075cf340>
      <AddBackward0 object at 0x7fe2075cf0a0>
        <AddBackward0 object at 0x7fe2075cf040>
          <AccumulateGrad object at 0x7fe2075cebc0>
          <SelectBackward0 object at 0x7fe2075ceb90>
            <AccumulateGrad object at 0x7fe2075cef20>
        <SelectBackward0 object at 0x7fe2075cebf0>
          <AccumulateGrad object at

Here we capture the IR graph of the backward part of the loss

In [4]:
device = torch_xla.device()

init_carry = torch.tensor([1.0, 1.0, 1.0], requires_grad=True, device=device)

xs = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
                  requires_grad=True,
                  device=device)

# Use the scan function
final_carry, ys = scan(step_fn, init_carry, xs)

# Loss for backward pass (sum of the outputs)
loss = ys.sum()
print("Loss:", loss)

# Cut the graph off at forward.
torch_xla.sync()
_ = f"{init_carry}"
_ = f"{xs}"
_ = f"{loss}"

# Now trace the backwards.
loss.backward()

tensors = [init_carry.grad, xs.grad]

print("HLO graph:")
print(torch_xla._XLAC._get_xla_tensors_dot(tensors))

# Output gradients
print("init_carry grad:", init_carry.grad)
print("xs grad:", xs.grad)


Loss: tensor(249., device='xla:0', grad_fn=<SumBackward0>)
HLO graph:
digraph G {
  node0 [label="prim::Constant\nf32[]\nxla_shape=f32[]"]
  node1 [label="prim::Constant\nf32[]\nxla_shape=f32[]"]
  node2 [label="xla::device_data\nf32[3]{0}\nxla_shape=f32[3]{0}"]
  node3 [label="prim::Constant\nf32[]\nxla_shape=f32[]"]
  node4 [label="aten::expand\nf32[]\nxla_shape=f32[]"]
  node5 [label="aten::expand\nf32[3,3]{1,0}\nxla_shape=f32[3,3]{1,0}"]
  node6 [label="xla::generic_slice\nf32[1,3]{1,0}\nxla_shape=f32[1,3]{1,0}"]
  node7 [label="aten::view\nf32[3]{0}\nxla_shape=f32[3]{0}"]
  node8 [label="aten::mul\nf32[3]{0}\nxla_shape=f32[3]{0}"]
  node9 [label="xla::device_data\nf32[3]{0}\nxla_shape=f32[3]{0}"]
  node10 [label="xla::generic_slice\nf32[1,3]{1,0}\nxla_shape=f32[1,3]{1,0}"]
  node11 [label="aten::view\nf32[3]{0}\nxla_shape=f32[3]{0}"]
  node12 [label="aten::mul\nf32[3]{0}\nxla_shape=f32[3]{0}"]
  node13 [label="aten::add\nf32[3]{0}\nxla_shape=f32[3]{0}"]
  node14 [label="xla::devic

# Equivalent example in JAX

In [5]:
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad
from jax import lax

# Define the step function
def step_fn(carry, x):
    new_carry = carry + x
    y = carry * x
    return new_carry, y

# Initial carry (same as init_carry in PyTorch)
init_carry = jnp.array([1.0, 1.0, 1.0])

# Example input tensor of shape (batch_size, features)
xs = jnp.array([[1.0, 2.0, 3.0], 
                [4.0, 5.0, 6.0], 
                [7.0, 8.0, 9.0]])

# Use jax.lax.scan to apply the step function
final_carry, ys = lax.scan(step_fn, init_carry, xs)

# Define a function to compute the loss
def compute_loss(init_carry, xs):
    _, ys = lax.scan(step_fn, init_carry, xs)
    return jnp.sum(ys)

# Compute the gradients
loss_value, grads = value_and_grad(compute_loss, argnums=(0, 1))(init_carry, xs)

# Print the results
print("Loss:", loss_value)
print("Gradients of init_carry:", grads[0])
print("Gradients of xs:", grads[1])


Loss: 249.0
Gradients of init_carry: [12. 15. 18.]
Gradients of xs: [[12. 14. 16.]
 [ 9. 11. 13.]
 [ 6.  8. 10.]]
