Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile() support #1960

Merged
merged 29 commits into from
Dec 7, 2022
Merged

torch.compile() support #1960

merged 29 commits into from
Dec 7, 2022

Conversation

msaroufim
Copy link
Member

@msaroufim msaroufim commented Nov 9, 2022

Caveats

  1. Models will take longer to initialize
  2. Changes in batch sizes will trigger a recompilation, so batch size choice needs to be sufficiently small otherwise recompilations will outweigh any benefits
  3. TensorRT still has better performance than Inductor for inference, it's not clear to me yet whether users should do the tensorRT conversion via pytorch/tensorrt or via dynamo

scope of this PR

Just went ahead and added torch.compile support - to make it happen I made a few changes

  1. Added a tutorial
  2. In the base handler check if _dynamo is present and if yes enable pt 2.0
  3. If pt 2.0 is enabled and a user passed in a compile.json file to specify the backend then model compilation is enabled
  4. Enabled some reasonable defaults for mode that reduces overhead - aka cuda graphs
  5. For convenience I also changed install_dependencies.py to include a mode that installs nightly torch

Logs

2022-12-05T23:56:53,168 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - Listening on port: /tmp/.ts.sock.9000
2022-12-05T23:56:53,172 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - Successfully loaded /home/ubuntu/serve/ts/configs/metrics.yaml.
2022-12-05T23:56:53,172 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - [PID]15049
2022-12-05T23:56:53,172 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - Torch worker started.
2022-12-05T23:56:53,173 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - Python runtime: 3.8.13
2022-12-05T23:56:53,187 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - Connection accepted: /tmp/.ts.sock.9000.
2022-12-05T23:56:53,228 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - model_name: densenet161, batchSize: 1
2022-12-05T23:56:57,592 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - Compiled model with backend inductor

@msaroufim msaroufim mentioned this pull request Nov 14, 2022
@msaroufim msaroufim changed the title Experimental torchdynamo support torch.compile support Dec 5, 2022
examples/pt2/README.md Show resolved Hide resolved
ts/torch_handler/base_handler.py Show resolved Hide resolved
ts/torch_handler/base_handler.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Dec 6, 2022

Codecov Report

Merging #1960 (2e5c215) into master (c59c6ac) will decrease coverage by 0.44%.
The diff coverage is 20.96%.

@@            Coverage Diff             @@
##           master    #1960      +/-   ##
==========================================
- Coverage   53.80%   53.35%   -0.45%     
==========================================
  Files          70       70              
  Lines        3169     3220      +51     
  Branches       56       56              
==========================================
+ Hits         1705     1718      +13     
- Misses       1464     1502      +38     
Impacted Files Coverage Δ
ts/torch_handler/base_handler.py 0.00% <0.00%> (ø)
ts/utils/util.py 43.66% <56.52%> (+6.16%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

examples/pt2/README.md Outdated Show resolved Hide resolved
ts/utils/util.py Outdated Show resolved Hide resolved
@msaroufim msaroufim added enhancement New feature or request perf Performance issue p0 high priority labels Dec 6, 2022
ts_scripts/install_dependencies.py Outdated Show resolved Hide resolved
examples/pt2/README.md Show resolved Hide resolved
ts/torch_handler/base_handler.py Show resolved Hide resolved
@msaroufim msaroufim requested a review from lxning December 6, 2022 20:17
@msaroufim msaroufim changed the title torch.compile support torch.compile() support Dec 7, 2022
@lxning lxning merged commit e22bce0 into master Dec 7, 2022
@msaroufim msaroufim deleted the experimental_dynamo branch December 7, 2022 00:34
@msaroufim msaroufim mentioned this pull request Mar 18, 2023
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request p0 high priority perf Performance issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants