-
Notifications
You must be signed in to change notification settings - Fork 24.6k
Permutation extended #76563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Permutation extended #76563
Conversation
Extended permutation support in integration (See more details on pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ```
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit b35e67b (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
Cherry-picked from csarofeen#1614 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -158,20 +174,110 @@ struct MemoryFormat { | |||
// storing stride_order in `permuted_order` for a simpler life, so we don't | |||
// have to decode `permutation_` when we want to apply/restore permutation_. | |||
permuted_order_ = stride_order; | |||
bool has_permutation_ = false; | |||
bool has_permutation = false; | |||
for (const auto i : c10::irange(rank)) { | |||
permutation_ = permutation_ * 10 + stride_order[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if someone calls setPermutation
twice? With permutation_
as a class member this will lead to weird results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤕 I'll reset permutation_ to 0.
Bumping for review 🙇 |
Bumping up for review again 🙇 |
@pytorchbot merge this |
Hey @jjsjann123. |
Summary: Extended permutation support in integration (See more details on csarofeen#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: #76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d23619b030444e2a77daab6aaa60988b765ba471 Reviewed By: malfet Differential Revision: D36101858 Pulled By: malfet fbshipit-source-id: 17662c68d7f1b448d72b270d6cfa6b8aea463df6
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: pytorch/pytorch#76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: pytorch/pytorch#76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
Extended permutation support in integration (See more details on csarofeen#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.
The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario:
output = binaryOp(input0, input1)
input0
andinput1
come with the same rank & permutation order, our output would preserve the same permutation;input0
andinput1
come with different ranks but with compatible permutation, the tensor with the higher rank dictates the permutation of the output;input0
andinput1
come with different ranks but with in-compatible permutation, this is where permutation propagation fails and the output tensor will be contiguous.By compatible permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in
MemoryFormat::broadcastToRank(int lower_rank)
.Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):