diff --git a/docs/source/features/scan.md b/docs/source/features/scan.md new file mode 100644 index 000000000000..d3959e6cae5d --- /dev/null +++ b/docs/source/features/scan.md @@ -0,0 +1,202 @@ +# Guide for using `scan` and `scan_layers` + +This is a guide for using `scan` and `scan_layers` in PyTorch/XLA. + +## When should you use this + +You should consider using [`scan_layers`][scan_layers] if you have a model with +many homogenous (same shape, same logic) layers, for example LLMs. These models +can be slow to compile. `scan_layers` is a drop-in replacement for a for loop over +homogenous layers, such as a bunch of decoder layers. `scan_layers` traces the +first layer and reuses the compiled result for all subsequent layers, significantly +reducing the model compile time. + +[`scan`][scan] on the other hand is a lower level higher-order-op modeled after +[`jax.lax.scan`][jax-lax-scan]. Its primary purpose is to help implement +`scan_layers` under the hood. However, you may find it useful if you would like +to program some sort of loop logic where the loop itself has a first-class +representation in the compiler (specifically, an XLA `While` op). + +## `scan_layers` example + +Typically, a transformer model passes the input embedding through a sequence of +homogenous decoder layers like the following: + +```python +def run_decoder_layers(self, hidden_states): + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states) + return hidden_states +``` + +When this function is lowered into an HLO graph, the for loop is unrolled into a +flat list of operations, resulting in long compile times. To reduce compile +times, you can replace the for loop with a call to `scan_layers`, as shown in +[`decoder_with_scan.py`][decoder_with_scan]: + +```python +def run_decoder_layers(self, hidden_states): + from torch_xla.experimental.scan_layers import scan_layers + return scan_layers(self.layers, hidden_states) +``` + +You can train this decoder model by running the following command from the root +directory of a `pytorch/xla` source checkout. + +```sh +python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan +``` + +## `scan` example + +[`scan`][scan] takes a combine function and applies that function over the leading +dimension of tensors while carrying along state: + +```python +def scan( + fn: Callable[[Carry, X], tuple[Carry, Y]], + init: Carry, + xs: X, +) -> tuple[Carry, Y]: + ... +``` + +You can use it to loop over the leading dimension of tensors efficiently. If `xs` +is a single tensor, this function is roughly equal to the following Python code: + +```python +def scan(fn, init, xs): + ys = [] + carry = init + for i in len(range(xs.size(0))): + carry, y = fn(carry, xs[i]) + ys.append(y) + return carry, torch.stack(ys, dim=0) +``` + +Under the hood, `scan` is implemented much more efficiently by lowering the loop +into an XLA `While` operation. This ensures that only one iteration of the loop +is compiled by XLA. + +[`scan_examples.py`][scan_examples] contains some example code showing how to use +`scan`. In that file, `scan_example_cumsum` uses `scan` to implement a cumulative +sum. `scan_example_pytree` demonstrates how to pass PyTrees to `scan`. + +You can run the examples with: + +```sh +python3 examples/scan/scan_examples.py +``` + +The output should look something like the following: + +``` +Running example: scan_example_cumsum +Final sum: tensor([6.], device='xla:0') +History of sums tensor([[1.], + [3.], + [6.]], device='xla:0') + + +Running example: scan_example_pytree +Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')} +Means over time: tensor([[1.0000], + [1.5000], + [2.0000], + [2.5000], + [3.0000]], device='xla:0') +``` + +## Limitations + +### AOTAutograd compatibility requirement + +The functions/modules passed to `scan` and `scan_layers` must be AOTAutograd +traceable. In particular, as of PyTorch/XLA 2.6, `scan` and `scan_layers` cannot +trace functions with custom Pallas kernels. That means if your decoder uses, +for example flash attention, then it's incompatible with `scan`. We are working on +[supporting this important use case][flash-attn-issue] in nightly and the next +releases. + +### AOTAutograd overhead + +Because `scan` uses AOTAutograd to figure out the backward pass of the input +function/module on every iteration, it's easy to become tracing bound compared to +a for loop implementation. In fact, the `train_decoder_only_base.py` example runs +slower under `scan` than with for loop as of PyTorch/XLA 2.6 due to this overhead. +We are working on [improving tracing speed][retracing-issue]. This is less of a +problem when your model is very large or has many layers, which are the situations +you would want to use `scan` anyways. + +## Compile time experiments + +To demonstrate the compile time savings, we'll train a simple decoder with many +layers on a single TPU chip with for loops vs with `scan_layers`. + +- Run the for loop implementation: + +```sh +❯ python3 examples/train_decoder_only_base.py \ + --hidden-size 256 \ + --num-layers 50 \ + --num-attention-heads 4 \ + --num-key-value-heads 2 \ + --intermediate-size 2048 \ + --num-steps 5 \ + --print-metrics + +... + +Metric: CompileTime + TotalSamples: 3 + Accumulator: 02m57s694ms418.595us + ValueRate: 02s112ms586.097us / second + Rate: 0.054285 / second + Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us; + 99%=01m03s028ms571.841us +``` + +- Run the `scan_layers` implementation: + +```sh +❯ python3 examples/train_decoder_only_base.py \ + scan.decoder_with_scan.DecoderWithScan \ + --hidden-size 256 \ + --num-layers 50 \ + --num-attention-heads 4 \ + --num-key-value-heads 2 \ + --intermediate-size 2048 \ + --num-steps 5 \ + --print-metrics + +... + +Metric: CompileTime + TotalSamples: 3 + Accumulator: 29s996ms941.409us + ValueRate: 02s529ms591.388us / second + Rate: 0.158152 / second + Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us; + 99%=18s995ms301.667us +``` + +We can see that the maximum compile time dropped from `1m03s` to `19s` by +switching to `scan_layers`. + +## References + +See https://github.com/pytorch/xla/issues/7253 for the design of `scan` and +`scan_layers` itself. + +See the function doc comments of [`scan`][scan] and [`scan_layers`][scan_layers] +for details on how to use them. + + + +[scan]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py +[scan_layers]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py +[flash-attn-issue]: https://github.com/pytorch/xla/issues/8633 +[retracing-issue]: https://github.com/pytorch/xla/issues/8632 +[jax-lax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html +[decoder_with_scan]: /examples/scan/decoder_with_scan.py +[scan_examples]: /examples/scan/scan_examples.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 7b03724cc31a..0faa31c75bc6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,6 +40,7 @@ PyTorch/XLA is a Python package that uses the XLA deep learning compiler to conn features/pallas.md features/stablehlo.md features/triton.md + features/scan.md .. toctree:: :glob: diff --git a/examples/decoder_only_model.py b/examples/decoder_only_model.py index 79040e5d24d2..77cd120966a4 100644 --- a/examples/decoder_only_model.py +++ b/examples/decoder_only_model.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import math import torch @@ -201,7 +200,7 @@ def forward( class DecoderOnlyModel(nn.Module): def __init__(self, config: DecoderOnlyConfig): - super(DecoderOnlyModel, self).__init__() + super().__init__() self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( @@ -211,7 +210,7 @@ def __init__(self, config: DecoderOnlyConfig): def forward( self, - input_ids: torch.LongTensor = None, + input_ids: torch.LongTensor, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) @@ -219,10 +218,14 @@ def forward( hidden_states = inputs_embeds # decoder layers - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer(hidden_states,) - hidden_states = layer_outputs + hidden_states = self.run_decoder_layers(hidden_states) hidden_states = self.norm(hidden_states) + # [B, S, H] -> [B, S, V] return self.output(hidden_states) + + def run_decoder_layers(self, hidden_states): + for decoder_layer in self.layers: + hidden_states = decoder_layer(hidden_states) + return hidden_states diff --git a/examples/scan/README.md b/examples/scan/README.md new file mode 120000 index 000000000000..78548adcd700 --- /dev/null +++ b/examples/scan/README.md @@ -0,0 +1 @@ +../../docs/source/features/scan.md \ No newline at end of file diff --git a/examples/scan/decoder_with_scan.py b/examples/scan/decoder_with_scan.py new file mode 100644 index 000000000000..8e938430340e --- /dev/null +++ b/examples/scan/decoder_with_scan.py @@ -0,0 +1,13 @@ +from typing_extensions import override +from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel + + +class DecoderWithScan(DecoderOnlyModel): + + def __init__(self, config: DecoderOnlyConfig): + super().__init__(config) + + @override + def run_decoder_layers(self, hidden_states): + from torch_xla.experimental.scan_layers import scan_layers + return scan_layers(self.layers, hidden_states) diff --git a/examples/scan/scan_examples.py b/examples/scan/scan_examples.py new file mode 100644 index 000000000000..5a4097d029ee --- /dev/null +++ b/examples/scan/scan_examples.py @@ -0,0 +1,85 @@ +import torch +import torch_xla + +from torch_xla.experimental.scan import scan + + +def scan_example_cumsum(): + """ + This example uses the `scan` function to compute the cumulative sum of a tensor. + """ + + # 1) Define a combine function that takes in the accumulated sum and the next element, + # and returns the new accumulated sum. We return two values, one is the "carry" that + # will be passed to the next iteration of this function call, and the other is the + # "output" that will be stacked into the final result. + def cumsum(accumulated, element): + accumulated += element + return accumulated, accumulated + + # 2) Define an initial carry and the input tensor. + init_sum = torch.tensor([0.0], device=torch_xla.device()) + xs = torch.tensor([1.0, 2.0, 3.0], device=torch_xla.device()) + torch_xla.sync() + + # 3) Call `scan` with our combine function, initial carry, and input tensor. + final, result = scan(cumsum, init_sum, xs) + torch_xla.sync() + + print("Final sum:", final) + print("History of sums", result) + + +def scan_example_pytree(): + """ + This example uses the `scan` function to compute a running mean. + + It demonstrates using PyTrees as inputs and outputs, in particular, dictionaries. + """ + # 1) Define an initial carry as a dictionary with two leaves: + # - 'sum' to accumulate the sum of all seen values + # - 'count' to count how many values have been seen + carry = { + 'sum': torch.tensor([0.0], device=torch_xla.device()), + 'count': torch.tensor([0.0], device=torch_xla.device()) + } + + # 2) Define our input PyTree, which in this case is just a dictionary with one leaf: + # - 'values' is a 1D tensor representing data points we want to scan over. + xs = { + 'values': + torch.arange(1, 6, dtype=torch.float32, device=torch_xla.device()) + } + + # Here, xs['values'] has shape [5]. The `scan` function will automatically slice + # out one element (shape []) each iteration. + + # 3) Define our function (akin to a "step" function in jax.lax.scan). It: + # - takes in the current carry and the current slice of xs, + # - updates the sum/count in the carry, + # - computes a new output (the running mean), + # - returns the updated carry and that output. + def fn(carry_dict, x_dict): + new_sum = carry_dict['sum'] + x_dict['values'] + new_count = carry_dict['count'] + 1.0 + new_carry = {'sum': new_sum, 'count': new_count} + running_mean = new_sum / new_count + return new_carry, running_mean + + # 4) Call `scan` with our step function, initial carry, and input dictionary. + final_carry, means_over_time = scan(fn, carry, xs) + + # 5) `final_carry` contains the final sum/count, while `means_over_time` is + # a 1D tensor with the running mean at each step. + print("Final carry:", final_carry) + print("Means over time:", means_over_time) + + +if __name__ == "__main__": + for example in [ + scan_example_cumsum, + scan_example_pytree, + ]: + print(f"\nRunning example: {example.__name__}", flush=True) + example() + print(flush=True) diff --git a/examples/train_decoder_only_base.py b/examples/train_decoder_only_base.py index a55d8e399892..256d6c8fed40 100644 --- a/examples/train_decoder_only_base.py +++ b/examples/train_decoder_only_base.py @@ -5,25 +5,28 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl +import argparse import time import itertools import torch import torch_xla -import torch.optim as optim import torch.nn as nn -class TrainDecoderOnlyBase(): +class TrainDecoderOnlyBase: - def __init__(self): - self.config = DecoderOnlyConfig() + def __init__(self, + decoder_cls=DecoderOnlyModel, + num_steps: int = 200, + config=DecoderOnlyConfig()): + self.config = config if xr.device_type() == 'NEURON': self.batch_size = 4 else: self.batch_size = 16 self.seq_len = 512 - self.num_steps = 200 + self.num_steps = num_steps self.num_epochs = 1 self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. # For the purpose of this example, we are going to use fake data. @@ -34,7 +37,7 @@ def __init__(self): self.device = torch_xla.device() self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) - self.model = DecoderOnlyModel(self.config).to(self.device) + self.model = decoder_cls(self.config).to(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001) self.loss_fn = nn.CrossEntropyLoss() # Compile the step fn @@ -43,6 +46,7 @@ def __init__(self): def _train_update(self, step, loss, tracker, epoch): print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') + assert not torch.isnan(loss).item(), "Loss became NaN!" def run_optimizer(self): self.optimizer.step() @@ -79,5 +83,80 @@ def start_training(self): if __name__ == '__main__': - base = TrainDecoderOnlyBase() + parser = argparse.ArgumentParser("Train a decoder only model") + parser.add_argument( + "cls_name", + type=str, + nargs="?", + default=None, + help="The decoder model to train, as fully qualified Python class. \ + Defauls to decoder_only_model.DecoderOnlyModel") + parser.add_argument( + "--num-steps", + type=int, + default=200, + help="Number of steps to train the model for") + parser.add_argument( + "--hidden-size", + type=int, + default=1024, + help="Hidden size of the model, aka the embedding size") + parser.add_argument( + "--num-layers", + type=int, + default=2, + help="Number of decoder layers in the model", + ) + parser.add_argument( + "--num-attention-heads", + type=int, + default=8, + help="Number of attention heads in the model", + ) + parser.add_argument( + "--num-key-value-heads", + type=int, + default=4, + help="Number of key value heads in the model", + ) + parser.add_argument( + "--intermediate-size", + type=int, + default=32 * 1024, + help="Intermediate size of the model, aka the up-projection output size", + ) + parser.add_argument( + "--print-metrics", + action="store_true", + help="Print torch_xla metrics at the end of the training", + ) + args = parser.parse_args() + + # Seed the RNG for deterministic results + torch.manual_seed(42) + torch_xla.manual_seed(42) + + # Figure out the decoder model to use + decoder_cls = None + if args.cls_name is not None: + xm.master_print(f'Using decoder class: {args.cls_name}') + module, cls_name = args.cls_name.rsplit('.', 1) + decoder_cls = getattr(__import__(module, fromlist=[cls_name]), cls_name) + + # Initialize config + config = DecoderOnlyConfig( + hidden_size=args.hidden_size, + num_hidden_layers=args.num_layers, + num_attention_heads=args.num_attention_heads, + num_key_value_heads=args.num_key_value_heads, + intermediate_size=args.intermediate_size, + ) + + params = [] + if decoder_cls is not None: + params.append(decoder_cls) + base = TrainDecoderOnlyBase(*params, num_steps=args.num_steps, config=config) base.start_training() + + if args.print_metrics: + print(torch_xla._XLAC._xla_metrics_report()) diff --git a/test/run_tests.sh b/test/run_tests.sh index ba87a3ce3653..874f80a26746 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -243,6 +243,9 @@ function run_xla_op_tests3 { PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py" run_test "$CDIR/test_pallas.py" run_xla_ir_hlo_debug run_test "$CDIR/test_user_computation_debug_cache.py" + + # Test examples + run_test "$CDIR/../examples/scan/scan_examples.py" # CUDA tests if [ -x "$(command -v nvidia-smi)" ]; then diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 76959214300a..5c20c683fe72 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -52,6 +52,9 @@ python3 "$TEST_CDIR/test_data_type.py" python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py" python3 "$TEST_CDIR/../examples/fsdp/train_decoder_only_fsdp_v2.py" python3 "$TEST_CDIR/../examples/train_resnet_amp.py" +python3 "$TEST_CDIR/../examples/train_decoder_only_base.py" +python3 "$TEST_CDIR/../examples/train_decoder_only_base.py" scan.decoder_with_scan.DecoderWithScan \ + --num-steps 30 # TODO(https://github.com/pytorch/xla/issues/8632): Reduce scan tracing overhead # HACK: don't confuse local `torch_xla` folder with installed package # Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559