diff --git a/experimental/torch_xla2/docs/support_a_new_model.md b/experimental/torch_xla2/docs/support_a_new_model.md new file mode 100644 index 000000000000..07a2be41e748 --- /dev/null +++ b/experimental/torch_xla2/docs/support_a_new_model.md @@ -0,0 +1,157 @@ +# Run a model under torch_xla2 + +Supporting a new model in torch_xla2 means +having this model run using torch_xla2 and succeeds. + +A model usually consists of executing a list of torch ops +on a set of tensors (i.e. the parameters and inputs) and +produce a new tensor(s). These ops should just work. + +However, there are cases that the model doesn't run on +torch_xla2, because: + +1. Some op it needs is not implemented. +2. Some op it needs is implemented incorrectly +3. There are some non-torch-op code that interacts with torch_xla2 in a non-friendly matter. + +Here we present few steps to attempt to fix the related issues. + +# Step 1. Attempt to run the model + +To run a model under torch_xla2, the first step is to +instantiate the model and run it under normal torch. +This usually means eager mode torch CPU. (NOTE: for large + models, it's recommended to make a model of equal architecture but smaller, by setting fewer layers / dim sizes; OR, use GPU +so that it can run reasonably fast). + +In this example, we will use `BERT_pytorch` model from +torchbench. + +## Install torchbench and instantiate a the model + +```bash +git clone https://github.com/pytorch/benchmark.git torchbench +cd torchbench +pip install torchvision torchaudio +pip install -e . +``` +Now, torchbench is installed, now we need to download +the model. + +``` +python install.py BERT_pytorch +``` + +NOTE: if you run `python install.py` without positional args +it will download ALL the 100+ models which can take sometime. + +Now, let's verify that the model is there by importing it in python. + +```python +import torchbenchmark.models.BERT_pytorch + +model, sample_inputs = torchbenchmark.models.BERT_pytorch.Model( + test='eval', device='cpu' +) + +print(model(*sample_inputs)) +``` + +If the above succeeds, then the model is ready. + +# Attempt to run the model in torchxla2 + +To run the model in torch_xla2, we need to do 2 things: +1. Move the model's weight to XLA device (i.e. XLA tensors) +2. Move the sample_inputs to XLA device (i.e. XLA tensors) + +The API for the above is the `to_xla` method on `Environment` class. +To get the current environment, one can use `torch_xla2.default_env()`. + +i.e. + +```python +xla_env = torch_xla2.default_env() +model2 = xla_env.to_xla(model) +sample_inputs = xla_env.to_xla(sample_inputs) +with xla_env: + print(model2(*sample_inputs)) +``` + +You might get something like this: +```bash +Traceback (most recent call last): + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py", line 13, in + benchmark = benchmark_cls(test="eval", device = "cpu") # test = train or eval device = cuda or cpu + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/util/model.py", line 39, in __call__ + obj = type.__call__(cls, *args, **kwargs) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/__init__.py", line 174, in __init__ + bert = BERT( + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/bert.py", line 30, in __init__ + self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/bert.py", line 24, in __init__ + self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/token.py", line 6, in __init__ + super().__init__(vocab_size, embed_size, padding_idx=0) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 145, in __init__ + self.reset_parameters() + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 154, in reset_parameters + init.normal_(self.weight) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 172, in normal_ + return torch.overrides.handle_torch_function( + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/overrides.py", line 1619, in handle_torch_function + result = mode.__torch_function__(public_api, types, args, kwargs) + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 210, in __torch_function__ + return func(*args, **(kwargs or {})) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 175, in normal_ + return _no_grad_normal_(tensor, mean, std, generator) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 20, in _no_grad_normal_ + return tensor.normal_(mean, std, generator=generator) + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 224, in __torch_dispatch__ + return self.env.dispatch(func, types, args, kwargs) + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 297, in dispatch + raise OperatorNotFound( +torch_xla2.tensor.OperatorNotFound: Operator with name aten::normal_ has no lowering +``` +if the issue is with operators. + +Sometimes it's helpful to see how did this operator is called. +Note that, many times, an operator being called can also be +unnexpected. + +We can turn on logging with +`xla_env.config.debug_print_each_op` and it will print each operator that is being run. + +The logs looks like this: + +``` +2024-06-16 15:03:13,726 - root - DEBUG - FUNCTION: aten::view +2024-06-16 15:03:13,726 - root - DEBUG - FUNCTION: aten::gelu +2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: aten::view +2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: aten::t +2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: transpose +2024-06-16 15:03:13,729 - root - DEBUG - DISPATCH: aten::transpose.int +2024-06-16 15:03:13,730 - root - DEBUG - FUNCTION: permute +2024-06-16 15:03:13,730 - root - DEBUG - DISPATCH: aten::permute +2024-06-16 15:03:13,731 - root - DEBUG - FUNCTION: aten::addmm +2024-06-16 15:03:13,737 - root - DEBUG - FUNCTION: aten::view +2024-06-16 15:03:13,739 - root - DEBUG - FUNCTION: aten::add.Tensor +2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::slice.Tensor +2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::select.int +2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::t +2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: transpose +2024-06-16 15:03:13,740 - root - DEBUG - DISPATCH: aten::transpose.int +2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: permute +2024-06-16 15:03:13,740 - root - DEBUG - DISPATCH: aten::permute +2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::addmm +2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::_log_softmax +2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::view +2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::t +2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: transpose +2024-06-16 15:03:13,741 - root - DEBUG - DISPATCH: aten::transpose.int +2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: permute +2024-06-16 15:03:13,741 - root - DEBUG - DISPATCH: aten::permute +2024-06-16 15:03:13,764 - root - DEBUG - FUNCTION: aten::addmm +2024-06-16 15:03:13,788 - root - DEBUG - FUNCTION: aten::view +2024-06-16 15:03:13,790 - root - DEBUG - FUNCTION: aten::_log_softmax +``` \ No newline at end of file diff --git a/experimental/torch_xla2/examples/eager_mode.py b/experimental/torch_xla2/examples/eager_mode.py index 755f24b0d2b9..a824001a4844 100644 --- a/experimental/torch_xla2/examples/eager_mode.py +++ b/experimental/torch_xla2/examples/eager_mode.py @@ -38,4 +38,11 @@ def model_func(param, inputs): print(model_func(m.state_dict(), inputs)) +print('---=====') +with xla_env: + m2 = MyModel() + inputs = (torch.randn(3, 3, 28, 28), ) + print(m2(*inputs)) + + diff --git a/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py b/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py new file mode 100644 index 000000000000..fc0b4653d6c6 --- /dev/null +++ b/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py @@ -0,0 +1,49 @@ +import torch +import time +import torch_xla2 +import torch_xla2.interop +import os +import importlib +import sys +import logging +import sys + +root = logging.getLogger() +root.setLevel(logging.DEBUG) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +root.addHandler(handler) + +# NOTE: replace this patch below with your installation +TORCH_BENCH_PATH = os.path.expanduser('~/git/qihqi/benchmark') +# If your directory looks like this_file.py, benchmark/ +sys.path.append(TORCH_BENCH_PATH) +model_name = "torchbenchmark.models.BERT_pytorch" # replace this by the name of the model you're working on +module = importlib.import_module(model_name) +benchmark_cls = getattr(module, "Model", None) +benchmark = benchmark_cls(test="eval", device = "cpu") # test = train or eval device = cuda or cpu + +model, example = benchmark.get_module() + +env = torch_xla2.default_env() +env.config.debug_print_each_op = False +model = env.to_xla(model) +example = env.to_xla(example) +with env: + start = time.perf_counter() + print(model(*example)) + end = time.perf_counter() + print('Eager mode time', end - start) + + +def func_call(state, example): + return torch.func.functional_call(model, state, example, tie_weights=False) + +jitted = torch_xla2.interop.jax_jit(func_call) +start = time.perf_counter() +print(func_call(model.state_dict(), example)) +end = time.perf_counter() +print('Jitted mode time', end - start) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 1a0baf07af33..e7c4ad99954e 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -2,7 +2,6 @@ import contextlib from typing import Optional import jax -from jax import dlpack as jaxdl import jax.numpy as jnp import numpy import torch @@ -10,7 +9,6 @@ import torch.utils._mode_utils as mode_utils import torch.utils._python_dispatch as torch_dispatch import torch.utils._pytree as torch_pytree -import torch.utils.dlpack as torchdl from torch_xla2 import config from torch_xla2.ops import mappings @@ -46,6 +44,16 @@ def j2t_dtype(dtype): return mappings.j2t_dtype(dtype) +@contextlib.contextmanager +def log_nested(message): + logging.debug((' ' * log_nested.level) + message) + log_nested.level += 1 + yield + log_nested.level -= 1 + +log_nested.level = 0 + + class XLATensor2(torch.Tensor): @staticmethod @@ -208,10 +216,11 @@ def __torch_function__(self, types, args=(), kwargs=None) -> torch.Tensor: - try: - return self.env.dispatch(func, types, args, kwargs) - except OperatorNotFound: - return func(*args, **(kwargs or {})) + with log_nested(f'FUNCTION: {_name_of_func(func)}'): + try: + return self.env.dispatch(func, types, args, kwargs) + except OperatorNotFound: + return func(*args, **(kwargs or {})) class XLADispatchMode(torch_dispatch.TorchDispatchMode): @@ -220,13 +229,13 @@ def __init__(self, env): self.env = env def __torch_dispatch__(self, func, types, args=(), kwargs=None): - self.env.maybe_log(f'__torch_dispatch__: {_name_of_func(func)}') - if isinstance(func, torch._ops.OpOverloadPacket): - with self: + with log_nested(f'DISPATCH: {_name_of_func(func)}'): + if isinstance(func, torch._ops.OpOverloadPacket): + with self: + return func(*args, **kwargs) + if func.namespace != 'aten': return func(*args, **kwargs) - if func.namespace != 'aten': - return func(*args, **kwargs) - return self.env.dispatch(func, types, args, kwargs) + return self.env.dispatch(func, types, args, kwargs) def _name_of_func(func): if hasattr(func, 'name'): @@ -357,7 +366,3 @@ def j2t_iso(self, jaxarray): def j2t_copy(self, args): pass - - def maybe_log(self, log): - if self.config.debug_print_each_op: - logging.info(log)