Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PyTorch/XLA support #2182

Merged
merged 6 commits into from
Apr 11, 2023
Merged

feat: add PyTorch/XLA support #2182

merged 6 commits into from
Apr 11, 2023

Conversation

morgandu
Copy link
Contributor

Description

This PR is to add PyTorch/XLA support in TorchServe backend base handler.

Type of change

  • New feature (non-breaking change which adds functionality)

@morgandu morgandu marked this pull request as draft March 16, 2023 23:34
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! Left some minor questions on the PR directly

One question I had was whether this is the right way to uses torch/xla nowadays or whether users are recommended to pass in an XLA backend to torch.compile()

Since most of machines are running on AWS in CI it's unlikely we'll get a TPU available to fuly test this but I'm assuming this should work just fine on GPU as well, in which case a quick test would also be super helpful

@@ -278,6 +303,9 @@ def inference(self, data, *args, **kwargs):
with torch.no_grad():
marshalled_data = data.to(self.device)
results = self.model(marshalled_data, *args, **kwargs)
if torch_xla_enabled:
xm.mark_step()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super familiar with xla internals but what does this line do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the xm.mark_step() cause this is essential for training, optional for inferencing. In short, the value /calculation is upon either a xm.mark_step() or when it gets retrieved. In our case it's the latter one.

@@ -59,6 +59,24 @@ def check_pt2_enabled():
)


def check_torch_xla_enabled() -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lxning another good candidate for your new config change, it might be possible that a user has xla installed but doesnt want to necessarily comile the model with XLA

Copy link
Collaborator

@lxning lxning Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim yes. the model yaml config can make this much easier. I'll send the PR early next week to unblock this PR.

Copy link
Contributor Author

@morgandu morgandu Mar 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lxning another good candidate for your new config change, it might be possible that a user has xla installed but doesnt want to necessarily comile the model with XLA

IIUC, the above mentioned scenario applies to gpu. Though, I have torch.cuda.is_available() and properties.get("gpu_id") is not None: as the prioritized condition. For accelerator type the require torch_xla, users do have option to choose to compile the torchxla_trace_once, which is an experimental backend for Dynamo.

@yyetim
Copy link

yyetim commented Mar 17, 2023

torch.compile() is a good point. I'm guessing we'll need both this version to support pytorch <2.0, and another change to support pytorch 2.0 models.

@msaroufim
Copy link
Member

msaroufim commented Mar 18, 2023

So we do actually already support torch.compile #1960 and you can pass in a custom backend via a compile.json

I don't think supporting both workflows is a huge deal but curious which one would you prefer people use assuming people have 2.0 installed

@morgandu
Copy link
Contributor Author

torch.compile() is a good point. I'm guessing we'll need both this version to support pytorch <2.0, and another change to support pytorch 2.0 models.

As discussed, we decided to prioritize pytorch/xla 2.0 and above.

@morgandu
Copy link
Contributor Author

morgandu commented Mar 25, 2023

So we do actually already support torch.compile #1960 and you can pass in a custom backend via a compile.json

I don't think supporting both workflows is a huge deal but curious which one would you prefer people use assuming people have 2.0 installed

Added torchxla_trace_once backend

@morgandu morgandu requested a review from msaroufim March 27, 2023 17:37
@morgandu morgandu marked this pull request as ready for review March 27, 2023 17:38
@codecov
Copy link

codecov bot commented Mar 28, 2023

Codecov Report

Merging #2182 (c3a6b93) into master (c37da18) will increase coverage by 0.10%.
The diff coverage is 86.66%.

❗ Current head c3a6b93 differs from pull request most recent head 4537fbb. Consider uploading reports for the commit 4537fbb to get more accurate results

@@            Coverage Diff             @@
##           master    #2182      +/-   ##
==========================================
+ Coverage   71.31%   71.41%   +0.10%     
==========================================
  Files          73       73              
  Lines        3336     3348      +12     
  Branches       57       57              
==========================================
+ Hits         2379     2391      +12     
  Misses        954      954              
  Partials        3        3              
Impacted Files Coverage Δ
ts/torch_handler/base_handler.py 54.97% <85.71%> (+1.97%) ⬆️
ts/utils/util.py 71.79% <100.00%> (+0.36%) ⬆️

... and 1 file with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@pytorch pytorch deleted a comment from morgandu Mar 28, 2023
@cloud-tpu-inference-github-bot cloud-tpu-inference-github-bot added kokoro:run Triggers a kokoro presubmit on a new pull request kokoro:force-run Triggers a kokoro presubmit on a pull request and removed kokoro:run Triggers a kokoro presubmit on a new pull request kokoro:force-run Triggers a kokoro presubmit on a pull request labels Mar 30, 2023
@cloud-tpu-inference-github-bot cloud-tpu-inference-github-bot added kokoro:force-run Triggers a kokoro presubmit on a pull request and removed kokoro:force-run Triggers a kokoro presubmit on a pull request labels Mar 31, 2023
@msaroufim
Copy link
Member

PR looks good but I was hoping we could have the test you're running checked in and only run it if a TPU is found

@morgandu
Copy link
Contributor Author

morgandu commented Apr 5, 2023

PR looks good but I was hoping we could have the test you're running checked in and only run it if a TPU is found

Added test, PTAL

@msaroufim
Copy link
Member

LGTM thank you, as FYI we're killing the compile.json in the next release but I'll make the change and test out the kokoro CI directly

@msaroufim msaroufim requested a review from lxning April 6, 2023 05:09
@morgandu
Copy link
Contributor Author

morgandu commented Apr 6, 2023

Thanks for the heads up!

@cloud-tpu-inference-github-bot cloud-tpu-inference-github-bot added kokoro:force-run Triggers a kokoro presubmit on a pull request and removed kokoro:force-run Triggers a kokoro presubmit on a pull request labels Apr 6, 2023
@morgandu
Copy link
Contributor Author

@lxning , follow up for review request

@lxning lxning merged commit 4ea172d into pytorch:master Apr 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants