Skip to content

Conversation

jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Oct 13, 2023

Stack from ghstack (oldest at bottom):

With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

cc @cpuhrsch @bhosmer @drisspg @soulitzer @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111253

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 659d5d7 with merge base 3eb5cae (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: 585b4a7
Pull Request resolved: #111253
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: 7111393
Pull Request resolved: #111253
output = func(*t_args, **t_kwargs)
return NestedTensor(output, **extract_kwargs(args[0]))
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan here with torch compile?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I suspect tests would fail)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question, I need to look into this. For the purposes of getting SAM running without torch.compile, this is fine, but not great for the immediate next step

With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Oct 16, 2023
ghstack-source-id: 25d262e
Pull Request resolved: #111253
With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Oct 17, 2023
ghstack-source-id: f680f61
Pull Request resolved: #111253
Comment on lines +745 to +748
// assume NTs are always batched
if (input.is_nested()) {
return std::make_tuple(input, true);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context: conv_transpose2d is composite implicit, so I can't override this behavior. It's easiest to add this hack here to avoid messing with shapes for the NT case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to give an alternative: You could override conv2d_transpose for the AutogradNestedTensor key using py_impl. But maybe this is more of a pain, since it would require you to re-implement more of conv2d_transpose in python?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah thanks for the suggestion! Indeed this would be more work, as I'd have to implement more of conv2d_transpose, alongside the other 1d/2d/3d transposed / non-transposed variants.

if bias is not None:
new_values += bias
return NestedTensor(new_values, **extract_kwargs(inp))
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer FYI this change fixes linear to expect weight in the form of (out_channels, in_channels). Backward and tests have to change to accommodate this as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, thanks!

output = func(*t_args, **t_kwargs)
return NestedTensor(output, **extract_kwargs(args[0]))
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question, I need to look into this. For the purposes of getting SAM running without torch.compile, this is fine, but not great for the immediate next step

@jbschlosser jbschlosser added topic: improvements topic category release notes: nested tensor Changes that have a direct impact on nested tensors module: nestedtensor NestedTensor tag see issue #25032 and removed module: cpu CPU specific problem (e.g., perf, algorithm) labels Oct 18, 2023
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 18, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

With this PR, we have full op support for SAM without needing to unwrap subclass into jagged buffer -> run ops -> rewrap manually. Specifically, this was previously happening in the MaskDecoder.

cc cpuhrsch bhosmer drisspg soulitzer jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
jbschlosser added a commit that referenced this pull request Oct 18, 2023
ghstack-source-id: 003bd81
Pull Request resolved: #111253
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Oct 18, 2023
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge -f "ignore spurious failure"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake, then you can re trigger it through pytorch-bot.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@huydhn
Copy link
Contributor

huydhn commented Oct 18, 2023

@pytorchbot drci

(please ignore this, I'm testing Dr.CI)

pytorchmergebot pushed a commit that referenced this pull request Oct 19, 2023
This is the final part of #110054.  The broken trunk classification has been done on Dr.CI, so we can just check for that in trymerge for consistency when ghstack is used.

* [x] #110054
* [x] #110133
* [x] This PR to clean up the broken trunk logic.

One important change is that `get_classifications` doesn't need to query the jobs from Rockset for the head and merge base SHA anymore, saving a query there.  The function looks a lot simpler now.

### Testing

#111253 had 1 broken trunk failure as detected by Dr.CI from the base commit https://hud.pytorch.org/pytorch/pytorch/commit/3eb5cae3af1207ac58f77c5ac78669e276824cb9 (valid) while trymerge didn't detect that because ghstack base commit https://hud.pytorch.org/pytorch/pytorch/commit/be8e51717411e09d8e4343c055848d434964dfb5 didn't have the same failure (miss).

Pull Request resolved: #111520
Approved by: https://github.com/clee2000
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…11520)

This is the final part of pytorch#110054.  The broken trunk classification has been done on Dr.CI, so we can just check for that in trymerge for consistency when ghstack is used.

* [x] pytorch#110054
* [x] pytorch#110133
* [x] This PR to clean up the broken trunk logic.

One important change is that `get_classifications` doesn't need to query the jobs from Rockset for the head and merge base SHA anymore, saving a query there.  The function looks a lot simpler now.

### Testing

pytorch#111253 had 1 broken trunk failure as detected by Dr.CI from the base commit https://hud.pytorch.org/pytorch/pytorch/commit/3eb5cae3af1207ac58f77c5ac78669e276824cb9 (valid) while trymerge didn't detect that because ghstack base commit https://hud.pytorch.org/pytorch/pytorch/commit/be8e51717411e09d8e4343c055848d434964dfb5 didn't have the same failure (miss).

Pull Request resolved: pytorch#111520
Approved by: https://github.com/clee2000
@facebook-github-bot facebook-github-bot deleted the gh/jbschlosser/94/head branch November 18, 2023 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: nestedtensor NestedTensor tag see issue #25032 release notes: nested tensor Changes that have a direct impact on nested tensors topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants