-
Notifications
You must be signed in to change notification settings - Fork 25.6k
More NT subclass op support for SAM #111253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111253
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 659d5d7 with merge base 3eb5cae ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
output = func(*t_args, **t_kwargs) | ||
return NestedTensor(output, **extract_kwargs(args[0])) | ||
with torch._C.DisableTorchFunctionSubclass(): | ||
return func(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the plan here with torch compile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I suspect tests would fail)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good question, I need to look into this. For the purposes of getting SAM running without torch.compile, this is fine, but not great for the immediate next step
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
// assume NTs are always batched | ||
if (input.is_nested()) { | ||
return std::make_tuple(input, true); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context: conv_transpose2d
is composite implicit, so I can't override this behavior. It's easiest to add this hack here to avoid messing with shapes for the NT case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to give an alternative: You could override conv2d_transpose
for the AutogradNestedTensor
key using py_impl
. But maybe this is more of a pain, since it would require you to re-implement more of conv2d_transpose
in python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah thanks for the suggestion! Indeed this would be more work, as I'd have to implement more of conv2d_transpose
, alongside the other 1d/2d/3d transposed / non-transposed variants.
if bias is not None: | ||
new_values += bias | ||
return NestedTensor(new_values, **extract_kwargs(inp)) | ||
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soulitzer FYI this change fixes linear to expect weight
in the form of (out_channels, in_channels)
. Backward and tests have to change to accommodate this as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops, thanks!
output = func(*t_args, **t_kwargs) | ||
return NestedTensor(output, **extract_kwargs(args[0])) | ||
with torch._C.DisableTorchFunctionSubclass(): | ||
return func(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good question, I need to look into this. For the purposes of getting SAM running without torch.compile, this is fine, but not great for the immediate next step
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder. cc cpuhrsch bhosmer drisspg soulitzer jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge -f "ignore spurious failure" |
The merge job was canceled. If you believe this is a mistake, then you can re trigger it through pytorch-bot. |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot drci (please ignore this, I'm testing Dr.CI) |
This is the final part of #110054. The broken trunk classification has been done on Dr.CI, so we can just check for that in trymerge for consistency when ghstack is used. * [x] #110054 * [x] #110133 * [x] This PR to clean up the broken trunk logic. One important change is that `get_classifications` doesn't need to query the jobs from Rockset for the head and merge base SHA anymore, saving a query there. The function looks a lot simpler now. ### Testing #111253 had 1 broken trunk failure as detected by Dr.CI from the base commit https://hud.pytorch.org/pytorch/pytorch/commit/3eb5cae3af1207ac58f77c5ac78669e276824cb9 (valid) while trymerge didn't detect that because ghstack base commit https://hud.pytorch.org/pytorch/pytorch/commit/be8e51717411e09d8e4343c055848d434964dfb5 didn't have the same failure (miss). Pull Request resolved: #111520 Approved by: https://github.com/clee2000
…11520) This is the final part of pytorch#110054. The broken trunk classification has been done on Dr.CI, so we can just check for that in trymerge for consistency when ghstack is used. * [x] pytorch#110054 * [x] pytorch#110133 * [x] This PR to clean up the broken trunk logic. One important change is that `get_classifications` doesn't need to query the jobs from Rockset for the head and merge base SHA anymore, saving a query there. The function looks a lot simpler now. ### Testing pytorch#111253 had 1 broken trunk failure as detected by Dr.CI from the base commit https://hud.pytorch.org/pytorch/pytorch/commit/3eb5cae3af1207ac58f77c5ac78669e276824cb9 (valid) while trymerge didn't detect that because ghstack base commit https://hud.pytorch.org/pytorch/pytorch/commit/be8e51717411e09d8e4343c055848d434964dfb5 didn't have the same failure (miss). Pull Request resolved: pytorch#111520 Approved by: https://github.com/clee2000
Stack from ghstack (oldest at bottom):
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.
cc @cpuhrsch @bhosmer @drisspg @soulitzer @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10