-
Notifications
You must be signed in to change notification settings - Fork 559
Publish dynamic shape doc. #5631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
ca8b965
add bounded dynamic shape doc
vanbasten23 4a0caba
got more done
vanbasten23 290b788
Publish dynamic shape doc.
vanbasten23 1f0d849
remove unwanted file
vanbasten23 54a61cd
fix typo
vanbasten23 c0c7dc9
add rfc to the doc
vanbasten23 8acd5b7
fix a typo
vanbasten23 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Dynamic shape | ||
|
|
||
| Dynamic shape refers to the variable nature of a tensor shape where its shape depends on the value of another upstream tensor. For example: | ||
| ``` | ||
| >>> import torch, torch_xla | ||
| >>> in_tensor = torch.randint(low=0, high=2, size=(5,5), device='xla:0') | ||
| >>> out_tensor = torch.nonzero(in_tensor) | ||
| ``` | ||
| the shape of `out_tensor` depends on the value of `in_tensor` and is bounded by the shape of `in_tensor`. In other words, if you do | ||
| ``` | ||
| >>> print(out_tensor.shape) | ||
| torch.Size([<=25, 2]) | ||
| ``` | ||
| you can see the first dimension depends on the value of `in_tensor` and its maximum value is 25. We call the first dimension the dynamic dimension. The second dimension does not depend on any upstream tensors so we call it the static dimension. | ||
|
|
||
| Dynamic shape can be further categorized into bounded dynamic shape and unbounded dynamic shape. | ||
| - bounded dynamic shape: refers to a shape whose dynamic dimensions are bounded by static values. It works for accelerators that require static memory allocation (e.g. TPU). | ||
| - unbounded dynamic shape: refers to a shape whose dynamic dimensions can be infinitely large. It works for accelerators that don’t require static memory allocation (e.g. GPU). | ||
|
|
||
| Today, only the bounded dynamic shape is supported and it is in the experimental phase. | ||
|
|
||
| ## Bounded dynamic shape | ||
|
|
||
| Currently, we support multi-layer perceptron models (MLP) with dynamic size input on TPU. | ||
|
|
||
| This feature is controlled by a flag `XLA_EXPERIMENTAL="nonzero:masked_select"`. To run a model with the feature enabled, you can do: | ||
| ``` | ||
| XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" python your_scripts.py | ||
| ``` | ||
|
|
||
| Here are some numbers we get when we run the MLP model for 100 iterations: | ||
|
|
||
| | | No dynamic shape | With dynamic shape | | ||
| | :--- | :----: | ---: | | ||
| | End-to-end training time | 29.49 | 20.03 | | ||
| | Number of compilations | 102 | 49 | | ||
| | Compilation cache hit | 198 | 1953 | | ||
|
|
||
|  | ||
|
|
||
| One of the motivations of the dynamic shape is to reduce the number of excessive recompilation when the shape keeps changing between iterations. From the figure above, you can see the number of compilations reduced by half which results in the drop of the training time. | ||
|
|
||
| To try it out, run | ||
| ``` | ||
| XLA_EXPERIMENTAL="nonzero:masked_select" PJRT_DEVICE=TPU python3 pytorch/xla/test/ds/test_dynamic_shape_models.py TestDynamicShapeModels.test_backward_pass_with_dynamic_input | ||
| ``` | ||
| For more details on how we plan to expand the dynamic shape support on PyTorch/XLA in the future, feel free to review our [RFC](https://github.com/pytorch/xla/issues/3884). | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.