Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch Compile is slightly slower than eager mode. #98441

Closed
AdamLouly opened this issue Apr 5, 2023 · 5 comments
Closed

Torch Compile is slightly slower than eager mode. #98441

AdamLouly opened this issue Apr 5, 2023 · 5 comments
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@AdamLouly
Copy link
Contributor

AdamLouly commented Apr 5, 2023

🐛 Describe the bug

When running some models on Torch, I have noticed that the torch.compile mode is slightly slower than the eager mode.

It may or may not be related to this issue : #98102

one example is : microsoft-deberta-base

To reproduce:

go to this folder transformers/examples/pytorch/language-modeling/ and run:

eager mode:
python run_mlm.py --model_name_or_path microsoft/deberta-v3-base --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 1 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --max_train_samples 1000

torch.compile:
python run_mlm.py --model_name_or_path microsoft/deberta-v3-base --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --num_train_epochs 1 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --do_train --do_eval --overwrite_output_dir --output_dir ./outputs/ --seed 1137 --fp16 --report_to none --max_train_samples 1000 --torch_compile

results :

Metric Eager TorchCompile
Avg of 2nd half 72.44162 ms 102.73143 ms
Train loss 5.995 5.9397
Train runtime 0:03:09.17 0:04:38.75
Train samples 1000 1000
Train samples per second 5.286 3.587
Train steps per second 5.286 3.587
Eval accuracy 0.3637 0.3657
Eval loss 4.8822 4.8525
Eval runtime 0:00:10.11 0:00:32.71
Eval samples 230 230
Eval samples per second 22.746 7.031
Eval steps per second 22.746 7.031
Perplexity 131.92 128.0628

Ran on a Single Tesla V100 16GB GPU.

Versions

[conda] numpy 1.24.1 pypi_0 pypi
[conda] pytorch-triton 2.1.0+46672772b4 pypi_0 pypi
[conda] torch 2.1.0.dev20230404+cu117 pypi_0 pypi
[conda] torch-ort 1.14.0 pypi_0 pypi
[conda] torchaudio 2.0.0.dev20230313+cu117 pypi_0 pypi
[conda] torchvision 0.15.0.dev20230313+cu117 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

@ezyang
Copy link
Contributor

ezyang commented Apr 5, 2023

So, on the example from #98102 I noticed your trainer script was including compile time as part of the overall train samples per second. Is this using the same trainer? This won't tell you if the compiled code is faster or not. (Of course, it may not be good ROI to spend the time compiling if you're doing a very short run, but that's a separate question.)

@AdamLouly
Copy link
Contributor Author

@ezyang avg of second half is the average step length of the second half number of steps (500 last steps in this case) that's what we use to measure the perf.

@ezyang
Copy link
Contributor

ezyang commented Apr 5, 2023

OK, noob question, how do I get the "avg of second half" stats? Is

  epoch                    =        1.0
  train_loss               =     1.0188
  train_runtime            = 0:05:26.29
  train_samples            =       1000
  train_samples_per_second =      3.065
  train_steps_per_second   =      3.065

this? Or am I supposed to go do wandb or something

@AdamLouly
Copy link
Contributor Author

OK, noob question, how do I get the "avg of second half" stats? Is

  epoch                    =        1.0
  train_loss               =     1.0188
  train_runtime            = 0:05:26.29
  train_samples            =       1000
  train_samples_per_second =      3.065
  train_steps_per_second   =      3.065

this? Or am I supposed to go do wandb or something

change the transformers trainer.py
with this
https://gist.github.com/AdamLouly/c2d7fe70c9d8d24d24ed2b04e5f903f2

you can either change on the packages or change and build from source.
let me know if you have any questions

@ngimel
Copy link
Collaborator

ngimel commented Apr 5, 2023

Deberta is known to have lots of graph breaks, should hopefully be fixed by #98158

@desertfire desertfire added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants