Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions docs/source/features/scan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Guide for using `scan` and `scan_layers`

This is a guide for using `scan` and `scan_layers` in PyTorch/XLA.

## When should you use this

You should consider using [`scan_layers`][scan_layers] if you have a model with
many homogenous (same shape, same logic) layers, for example LLMs. These models
can be slow to compile. `scan_layers` is a drop-in replacement for a for loop over
homogenous layers, such as a bunch of decoder layers. `scan_layers` traces the
first layer and reuses the compiled result for all subsequent layers, significantly
reducing the model compile time.

[`scan`][scan] on the other hand is a lower level higher-order-op modeled after
[`jax.lax.scan`][jax-lax-scan]. Its primary purpose is to help implement
`scan_layers` under the hood. However, you may find it useful if you would like
to program some sort of loop logic where the loop itself has a first-class
representation in the compiler (specifically, an XLA `While` op).

## `scan_layers` example

Typically, a transformer model passes the input embedding through a sequence of
homogenous decoder layers like the following:

```python
def run_decoder_layers(self, hidden_states):
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
return hidden_states
```

When this function is lowered into an HLO graph, the for loop is unrolled into a
flat list of operations, resulting in long compile times. To reduce compile
times, you can replace the for loop with a call to `scan_layers`, as shown in
[`decoder_with_scan.py`][decoder_with_scan]:

```python
def run_decoder_layers(self, hidden_states):
from torch_xla.experimental.scan_layers import scan_layers
return scan_layers(self.layers, hidden_states)
```

You can train this decoder model by running the following command from the root
directory of a `pytorch/xla` source checkout.

```sh
python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
```

## `scan` example

[`scan`][scan] takes a combine function and applies that function over the leading
dimension of tensors while carrying along state:

```python
def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
...
```

You can use it to loop over the leading dimension of tensors efficiently. If `xs`
is a single tensor, this function is roughly equal to the following Python code:

```python
def scan(fn, init, xs):
ys = []
carry = init
for i in len(range(xs.size(0))):
carry, y = fn(carry, xs[i])
ys.append(y)
return carry, torch.stack(ys, dim=0)
```

Under the hood, `scan` is implemented much more efficiently by lowering the loop
into an XLA `While` operation. This ensures that only one iteration of the loop
is compiled by XLA.

[`scan_examples.py`][scan_examples] contains some example code showing how to use
`scan`. In that file, `scan_example_cumsum` uses `scan` to implement a cumulative
sum. `scan_example_pytree` demonstrates how to pass PyTrees to `scan`.

You can run the examples with:

```sh
python3 examples/scan/scan_examples.py
```

The output should look something like the following:

```
Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
[3.],
[6.]], device='xla:0')


Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
[1.5000],
[2.0000],
[2.5000],
[3.0000]], device='xla:0')
```

## Limitations

### AOTAutograd compatibility requirement

The functions/modules passed to `scan` and `scan_layers` must be AOTAutograd
traceable. In particular, as of PyTorch/XLA 2.6, `scan` and `scan_layers` cannot
trace functions with custom Pallas kernels. That means if your decoder uses,
for example flash attention, then it's incompatible with `scan`. We are working on
[supporting this important use case][flash-attn-issue] in nightly and the next
releases.

### AOTAutograd overhead

Because `scan` uses AOTAutograd to figure out the backward pass of the input
function/module on every iteration, it's easy to become tracing bound compared to
a for loop implementation. In fact, the `train_decoder_only_base.py` example runs
slower under `scan` than with for loop as of PyTorch/XLA 2.6 due to this overhead.
We are working on [improving tracing speed][retracing-issue]. This is less of a
problem when your model is very large or has many layers, which are the situations
you would want to use `scan` anyways.

## Compile time experiments

To demonstrate the compile time savings, we'll train a simple decoder with many
layers on a single TPU chip with for loops vs with `scan_layers`.

- Run the for loop implementation:

```sh
❯ python3 examples/train_decoder_only_base.py \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics

...

Metric: CompileTime
TotalSamples: 3
Accumulator: 02m57s694ms418.595us
ValueRate: 02s112ms586.097us / second
Rate: 0.054285 / second
Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
99%=01m03s028ms571.841us
```

- Run the `scan_layers` implementation:

```sh
❯ python3 examples/train_decoder_only_base.py \
scan.decoder_with_scan.DecoderWithScan \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics

...

Metric: CompileTime
TotalSamples: 3
Accumulator: 29s996ms941.409us
ValueRate: 02s529ms591.388us / second
Rate: 0.158152 / second
Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
99%=18s995ms301.667us
```

We can see that the maximum compile time dropped from `1m03s` to `19s` by
switching to `scan_layers`.

## References

See https://github.com/pytorch/xla/issues/7253 for the design of `scan` and
`scan_layers` itself.

See the function doc comments of [`scan`][scan] and [`scan_layers`][scan_layers]
for details on how to use them.

<!-- xrefs -->

[scan]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py
[scan_layers]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py
[flash-attn-issue]: https://github.com/pytorch/xla/issues/8633
[retracing-issue]: https://github.com/pytorch/xla/issues/8632
[jax-lax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
[decoder_with_scan]: /examples/scan/decoder_with_scan.py
[scan_examples]: /examples/scan/scan_examples.py
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ PyTorch/XLA is a Python package that uses the XLA deep learning compiler to conn
features/pallas.md
features/stablehlo.md
features/triton.md
features/scan.md

.. toctree::
:glob:
Expand Down
15 changes: 9 additions & 6 deletions examples/decoder_only_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional
import math

import torch
Expand Down Expand Up @@ -201,7 +200,7 @@ def forward(
class DecoderOnlyModel(nn.Module):

def __init__(self, config: DecoderOnlyConfig):
super(DecoderOnlyModel, self).__init__()
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
Expand All @@ -211,18 +210,22 @@ def __init__(self, config: DecoderOnlyConfig):

def forward(
self,
input_ids: torch.LongTensor = None,
input_ids: torch.LongTensor,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)

# embed positions
hidden_states = inputs_embeds

# decoder layers
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(hidden_states,)
hidden_states = layer_outputs
hidden_states = self.run_decoder_layers(hidden_states)

hidden_states = self.norm(hidden_states)

# [B, S, H] -> [B, S, V]
return self.output(hidden_states)

def run_decoder_layers(self, hidden_states):
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
return hidden_states
1 change: 1 addition & 0 deletions examples/scan/README.md
13 changes: 13 additions & 0 deletions examples/scan/decoder_with_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing_extensions import override
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel


class DecoderWithScan(DecoderOnlyModel):

def __init__(self, config: DecoderOnlyConfig):
super().__init__(config)

@override
def run_decoder_layers(self, hidden_states):
from torch_xla.experimental.scan_layers import scan_layers
return scan_layers(self.layers, hidden_states)
85 changes: 85 additions & 0 deletions examples/scan/scan_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
import torch_xla

from torch_xla.experimental.scan import scan


def scan_example_cumsum():
"""
This example uses the `scan` function to compute the cumulative sum of a tensor.
"""

# 1) Define a combine function that takes in the accumulated sum and the next element,
# and returns the new accumulated sum. We return two values, one is the "carry" that
# will be passed to the next iteration of this function call, and the other is the
# "output" that will be stacked into the final result.
def cumsum(accumulated, element):
accumulated += element
return accumulated, accumulated

# 2) Define an initial carry and the input tensor.
init_sum = torch.tensor([0.0], device=torch_xla.device())
xs = torch.tensor([1.0, 2.0, 3.0], device=torch_xla.device())
torch_xla.sync()

# 3) Call `scan` with our combine function, initial carry, and input tensor.
final, result = scan(cumsum, init_sum, xs)
torch_xla.sync()

print("Final sum:", final)
print("History of sums", result)


def scan_example_pytree():
"""
This example uses the `scan` function to compute a running mean.

It demonstrates using PyTrees as inputs and outputs, in particular, dictionaries.
"""
# 1) Define an initial carry as a dictionary with two leaves:
# - 'sum' to accumulate the sum of all seen values
# - 'count' to count how many values have been seen
carry = {
'sum': torch.tensor([0.0], device=torch_xla.device()),
'count': torch.tensor([0.0], device=torch_xla.device())
}

# 2) Define our input PyTree, which in this case is just a dictionary with one leaf:
# - 'values' is a 1D tensor representing data points we want to scan over.
xs = {
'values':
torch.arange(1, 6, dtype=torch.float32, device=torch_xla.device())
}

# Here, xs['values'] has shape [5]. The `scan` function will automatically slice
# out one element (shape []) each iteration.

# 3) Define our function (akin to a "step" function in jax.lax.scan). It:
# - takes in the current carry and the current slice of xs,
# - updates the sum/count in the carry,
# - computes a new output (the running mean),
# - returns the updated carry and that output.
def fn(carry_dict, x_dict):
new_sum = carry_dict['sum'] + x_dict['values']
new_count = carry_dict['count'] + 1.0
new_carry = {'sum': new_sum, 'count': new_count}
running_mean = new_sum / new_count
return new_carry, running_mean

# 4) Call `scan` with our step function, initial carry, and input dictionary.
final_carry, means_over_time = scan(fn, carry, xs)

# 5) `final_carry` contains the final sum/count, while `means_over_time` is
# a 1D tensor with the running mean at each step.
print("Final carry:", final_carry)
print("Means over time:", means_over_time)


if __name__ == "__main__":
for example in [
scan_example_cumsum,
scan_example_pytree,
]:
print(f"\nRunning example: {example.__name__}", flush=True)
example()
print(flush=True)
Loading
Loading