Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training support for dynamo+torchxla integration #88449

Closed
wants to merge 1 commit into from

Conversation

shunting314
Copy link
Contributor

@shunting314 shunting314 commented Nov 3, 2022

We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla.

Training is trickier than inference and we may not expect much perf gains since

  1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in torchxla_trace_once bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations.
  2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency.
  3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces.

But we still want to add training support to dynamo/torchxla to make the work complete.

We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above.

Results:

With '--iterations-per-run' equals to 1, here are the perf numbers:

+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |             0.91   |                0.959    |
+-------------------------+--------------------+-------------------------+
| resnet50                |             0.917  |                0.932    |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |             0.912  |                0.905    |
+-------------------------+--------------------+-------------------------+
| alexnet                 |             1.038  |                0.974    |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |             0.881  |                0.835    |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |             0.903  |                0.931    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |             0.914  |                0.967    |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |             1.359  |                0.84     |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |             1.288  |                0.893    |
+-------------------------+--------------------+-------------------------+
| geomean                 |             1.0006 |                0.913794 |
+-------------------------+--------------------+-------------------------+

Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run'

NOTE:
In torchbench.py I added the following code to do a few workaround:

from myscripts import workaround # TODO will remove this line before landing

Here are the content of workaround.py:

import torch
from torch import nn
import os

# override max_pool2d with avg_pool2d
if os.environ.get("REPLACE_MAXPOOL", "0") == "1":
    torch.nn.MaxPool2d = torch.nn.AvgPool2d

It work around a few issues we found

  1. MaxPool2d does not work for training in dynamo/torchxla: Dynamo can not optimize a model with MaxPool2d on XLA devices torchdynamo#1837 . WIP fix from Brian in codegen fixes to fix tracing XLA autograd ops #90226 , https://github.com/pytorch/xla/pull/4276/files (WIP)
  2. recent change ( this PR Rectify native_batch_norm schema by splitting it into two legit schemas #88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in Lower _native_batch_norm_legit xla#4282 (comment) . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd)
  3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: api to grab base seed as device data xla#4293 (confirmed the fix)

Example command:

REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 3, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88449

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 82b8827:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 5, 2022

I think the DeviceData that miss the info is the random seed. Random seed does not have a tensor associated with it hence it will never have a tensor_id nor info. I guess we can work around this special case in the bridge.

torch/_dynamo/utils.py Outdated Show resolved Hide resolved
@shunting314
Copy link
Contributor Author

shunting314 commented Nov 9, 2022

With some discussion, we found it's too stretching to do the following 2 at the same time

  1. instead of moving data across device in the dynamo compiler function, do that at the top level once before calling the model.
  2. enable AOTAutograd for training

There are quite a few weird issues we need to resolve. Here are some of them

We feel it would be a good strategy to separate the concern and handle these two items separately. We'll do item 2 first and follow up on item 1 after that. I'll suspend on this PR and work on a new PR to specially handle item 2.

cc @JackCaoG @wconstab @jansel

@shunting314
Copy link
Contributor Author

shunting314 commented Nov 14, 2022

@JackCaoG there is one more issue about FakeTensor. AOTAutograd uses FakeTensors during compilation. It looks like FakeTensor on XLA has some issue and I get the following error stacks

Traceback (most recent call last):
  File "/pytorch/torch/_dynamo/optimizations/backends.py", line 53, in inner
    return fn(model, **kwargs)
  File "/pytorch/torch/_dynamo/optimizations/backends.py", line 800, in torchxla_trace_once
    return integration.extract_compiled_graph(model, example_inputs)
  File "/pytorch/torch/_dynamo/optimizations/torchxla_integration.py", line 172, in extract_compiled_graph
    torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args
  File "/pytorch/torch/_dynamo/optimizations/torchxla_integration.py", line 172, in <listcomp>
    torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args
RuntimeError: /pytorch/xla/torch_xla/csrc/aten_xla_bridge.cpp:73 : Check failed: xtensor != nullptr

I can unblock myself for now by force AOTAutograd using eager tensors, but this is something we eventually should look into (FakeTensor is mandated if we want support dynamic shape )

Here is the command to reproduce:

USE_FAKE_TENSOR=1 override_model=linear GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --
trace-on-xla --training --only resnet18 --backend=aot_torchxla_trace_once

cc @Chillee as we discussed a bit about the FakeTensor usage in AOTAutograd.

@shunting314 shunting314 force-pushed the dynamo_torchxla_training branch 2 times, most recently from 21c191e to 8036941 Compare November 16, 2022 00:25
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2022
In #87741 we added the inference support for dynamo/torchxla integration. Later on in #88449 we attempt to add the training support. That attempt is not smooth because
- we try 2 things together
   1. let dynamo trace the model on xla rather than eager
   2. enable training
- It turns out neither of these two tasks are trivial enough.

Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync.

This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training.

Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x.
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.38    |                 1.008   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.227   |                 0.998   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.544   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.085   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            2.028   |                 1.013   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.516   |                 0.995   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            0.868   |                 1.01    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.099   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            3.26    |                 1.027   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            2.182   |                 1.015   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.50389 |                 1.01261 |
+-------------------------+--------------------+-------------------------+
```

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once
```

Pull Request resolved: #88904
Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
@shunting314
Copy link
Contributor Author

shunting314 commented Dec 3, 2022

Code is cleaned up. The 10 models we tested for inference can mostly be run now (one fail because of fallback). But I just found an issue with Dropout which cause incorrect result: #90097 . cc @JackCaoG

@shunting314
Copy link
Contributor Author

To workaround the dropout issue, I force torch.nn.Dropout to be a NopModule, now the correctness check for:
resnet18
resnet50
resnext50_32x4d
BERT_pytorch
timm_vision_transformer
mnasnet1_0
mobilenet_v2

All pass. But we still fail the correctness check for vgg16 and alexnet.

@shunting314
Copy link
Contributor Author

shunting314 commented Dec 6, 2022

@JackCaoG and I figured out why we fail the correctness check for vgg16 and alexnet. It's because of different set of optimizations applied by XLA on different graphs.

Here are the things we do to verify the point:

  1. in timed method we print the sum for prediction tensor before and after the mark_step call as follows
            # result[0] is the prediction
            print(f"result sum {result[0].sum()}");
            xm.mark_step()
            print(f"result after ms sum {result[0].sum()}");

and found XLA gives different results..

result sum -0.019893646240234375
result after ms sum -0.019596099853515625

mark_step cause a larger graph being compiled and optimized more aggressively. The optimizations XLA applied may cause numerical unstable.

  1. considering the baseline compile a single large graph, while 'trace_once' optimization compile 1 graph for forward, 1 graph for backward and 1 graph for the optimizer, we can add a mark_step manually in def forward_and_backward_pass after the forward pass. This way, we force the baseline to do the similar graph breaks as trace_once. We verified that both vgg16 and alexnet pass the correctness check

  2. Jack also suggested to use XLA_GET_TENSORS_OPBYOP=1 XLA_SYNC_TENSORS_OPBYOP=1 to disable all xla optimizations. But unfortunately trace_once fail because the compiled graph is not found. Maybe the flags cause some change in graph hash computation.

But I think 1 and 2 are strong enough to convince me the root cause is because of different graph size.

Jack suggested to use 1e-3 as the tolerence to do correctness check.

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2023

This PR needs to be approved by an authorized maintainer before merge.

Copy link
Contributor

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2023

This PR needs to be approved by an authorized maintainer before merge.

@malfet
Copy link
Contributor

malfet commented Jan 5, 2023

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2023

This PR needs to be approved by an authorized maintainer before merge.

@malfet
Copy link
Contributor

malfet commented Jan 5, 2023

@pytorchbot help

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'help' (choose from 'merge', 'revert', 'rebase', 'label', 'drci')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@malfet
Copy link
Contributor

malfet commented Jan 5, 2023

@pytorchbot help

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'help' (choose from 'merge', 'revert', 'rebase', 'label', 'drci')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@shunting314
Copy link
Contributor Author

Try dismiss review from Jason to merge the PR since jason is on PTO

@shunting314 shunting314 dismissed jansel’s stale review January 5, 2023 19:50

To merge the PR since there is some issue to merge the PR

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2023

This PR needs to be approved by an authorized maintainer before merge.

@huydhn huydhn removed the request for review from jansel January 5, 2023 19:52
@malfet malfet added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 5, 2023
@huydhn
Copy link
Contributor

huydhn commented Jan 5, 2023

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2023

This PR needs to be approved by an authorized maintainer before merge.

@malfet
Copy link
Contributor

malfet commented Jan 5, 2023

@pytorchbot merge -f "Apologies for the inconvenience"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jan 25, 2023
This is a follow up from the previous PR: #88449 , to move the dynamo/TorchXLA bridge from pytorch repo to xla repo.

Overall the dynamo/TorchXLA integration has the following four layers of code
- pybind layer: This is the bottom layer containing various pybind APIs as the foundation. This part resident in xla repo
- bridge layer: build upon the pybind layer to implement the trace once functionality. This layer and it's corresponding unit test are in pytorch repro previously. This PR (and the corresponding xla pr pytorch/xla#4476 ) moves them to the xla repo.
- dynamo backend registration: this a thin layer registers 4 dynamo backends (training/inference/trace_once/trace_everytime). It remains in pytorch repo.
- benchmark script: the torchbench.py script in dynamo is adapted so it can be used in dynamo/TorchXLA integration. This one remains in pytorch repo.

We think the new code organization is cleaner.

I'll wait for the xla PR in first before trying to merge this one.

Tests
1. run the unit tests moved to the xla repo
2. Test for inference:  `GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trace_once --only resnet18`
3. Test for training: `GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only resnet18 --collect-outputs`

Pull Request resolved: #92601
Approved by: https://github.com/wconstab
@github-actions github-actions bot deleted the dynamo_torchxla_training branch July 6, 2024 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants