New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement narrow from a regular tensor to jagged tensor #112770
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112770
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 03a94ee with merge base d4c810c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: ec6d15238467d273092f99f675298e4aa8c42daa Pull Request resolved: #112770
torch/nested/__init__.py
Outdated
|
||
|
||
|
||
def narrow(tensor: Tensor, dim: int, start: Union[SymInt, Tensor], length: Union[SymInt, Tensor], layout=torch.jagged): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should add a layout argument here. It doesn't follow the narrow syntax for regular tensors. If the user wants to change the layout they can use to
. Since this is a view operation it shouldn't affect performance. Furthermore, layout change is likely not support when this is view only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should add a layout argument here. It doesn't follow the narrow syntax for regular tensors.
This is true, but we're also not limited to the original narrow()
syntax since this is a new function in the torch.nested
namespace.
So I sort of disagree. It's possible to create a view over a dense buffer with both the expanded jagged layout and the strided layout, and I think the layout
arg is a natural extension for indicating which type of NT + metadata is to be constructed.
Also note that nested narrow for jagged is only supported on dim=1
due to representative restrictions, and it's potentially useful to take a ragged view on other dimensions, using the more general strided layout. This function would be the way to do it AFAICT.
If the user wants to change the layout they can use to.
Current semantics are such that this will copy, no? I don't quite see how we can guarantee torch.nested.narrow(...).to(layout=...)
is non-copying if it's the difference between a C++ NT and a python subclass jagged layout NT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.nested.narrow(...).to(layout=...)
can be made non-copying via torch.compile and pattern matching.
But you're right that it's a new function. Maybe in the future we can extend torch.narrow with similar functionality.
Alright, I'm convinced.
raise RuntimeError( | ||
"When constructing a jagged nested tensor using narrow(), " | ||
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1" | ||
) from e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use is_expandable_to
in torch/_prims_common/__init__.py
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need the expanded tensors, so I figured I could kill two birds with one stone, but I agree the explicit check is probably better
torch/nested/_internal/ops.py
Outdated
@@ -36,7 +36,8 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None: | |||
|
|||
arg_type_check_fns = { | |||
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), | |||
"jt": lambda x: isinstance(x, NestedTensor), | |||
"jt": lambda x: isinstance(x, NestedTensor) and x._lengths is None, | |||
"jt_nc": lambda x: isinstance(x, NestedTensor), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new type won't be too useful today as you won't actually be able to register two funcs to the same aten op overload even if those ops have different schemas.
For now, what you probably want to do is just to branch inside whatever op you are trying to implement.
Perhaps in the future we want to go the route of writing a general dispatching mechanism, I'm not sure. Or its also entirely possible that we'd revert the t vs jt distinction as well. We don't have many ops implemented, so I think its too early to decide today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah... I was trying to get out of adding an if statement to every single function we implement, but instead I will change the register_func call to accept a parameter of whether non-contiguous tensors are allowed or not (given most functions won't accept them at all)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and then add the two paths for the ones that allow them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on second thought, "jt_nc" might not be the best name, but it is the one that allows both the contiguous and noncont. versions to pass though to the actual kernel code vs "jt" only letting the contiguous ones pass to maintain compatibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I'm okay with your initial approach. I did something similar to that locally when I was playing around with this and it minimized the changes needed to existing op impls
torch/nested/__init__.py
Outdated
if isinstance(length, SymInt): | ||
length = torch.tensor([int(length)]) | ||
|
||
nt = jagged_from_tensor_and_lengths(tensor, start, length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You actually returned 3 things from this function, but the name suggests that you only expected to return one thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually want to return all 3 to the user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I'll change it, I was trying to keep it similar to jagged_from_list
which returns its offsets and is semantically a very similar function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think its okay for the helper function jagged_from_tensor_and_lengths
to also return 3, but we don't necessarily want to return all 3 in the public API.
…sor" [ghstack-poisoned]
ghstack-source-id: 7df806f7755432a844b1c3435f723b160303e54f Pull Request resolved: #112770
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I'm curious to see how this looks with SDPA :)
Do you mind adding a few tests into test/test_nestedtensor.py
? I think we can check that things nicely error out for all ops when they encounter a non-contiguous jagged NT.
Along that line: I think we want to expand the definition of contiguous()
to take lengths
into account + add a corresponding test.
torch/nested/__init__.py
Outdated
|
||
Keyword arguments: | ||
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor. | ||
Only strided and jagged layouts are supported. Default: if None, the jagged layout. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: None default should indicate strided, as is consistent with the behavior in other places (e.g. nested_tensor()
/ as_nested_tensor()
).
also is None being handled? I didn't see it but I may have just missed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is being handled with a RuntimeError
…sor" [ghstack-poisoned]
ghstack-source-id: f712cd7556e787a7b93770f40188e08f4e058654 Pull Request resolved: #112770
…sor" [ghstack-poisoned]
ghstack-source-id: 1075ff48d892296e1cb96e92de4864afa19f1ba1 Pull Request resolved: #112770
…sor" [ghstack-poisoned]
ghstack-source-id: bbcaec97e0ced5496675e725bfef03a66371c0e7 Pull Request resolved: #112770
…sor" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: 94d57818c238208d051ea966a78a5062ead0f883 Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: 2bf3fe5a469f26ed06d95dbcb24028c47355f103 Pull Request resolved: #112770
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really nice, thanks for all your hard work!
I'm just concerned about one thing: the format of offsets
. Traditionally, the jagged layout defines it to be of shape B + 1
for batch size B
, and it includes the final offset (e.g. offsets=[0, 3, 5, 7]
with B=3 and a values buffer of shape [7, D]
). For the contiguous case, this allows for lengths to be computed with offsets.diff()
.
From what I can tell, this PR redefines offsets
to leave off the final offset, resulting in offsets.shape[0] == lengths.shape[0] == B
.
I don't think we can change this for the contiguous case, as it's well-established at this point. For consistency, I'd also argue that the offsets shape should be the same across non-contiguous and contiguous cases. I admit this is a little strange, as the final offset doesn't have a ton of utility when lengths
are defined separately.
@soulitzer / @cpuhrsch / @drisspg: what are your opinions on this?
ragged_size = get_tensor_id(offsets, coeff=1) | ||
else: | ||
ragged_size = get_tensor_id(lengths, coeff=1) | ||
B = offsets.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm I don't think this is correct. The original (contiguous) jagged layout has offsets of shape B + 1
, making lengths
computable via offsets.diff()
. For the non-contiguous case, I think we should maintain this shape, including the final offset even if it isn't strictly necessary.
@@ -232,6 +250,17 @@ def backward(ctx, gO: NestedTensor): # type: ignore[override] | |||
return gO.values(), None, None | |||
|
|||
|
|||
# Not actually a view! | |||
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: this will go away when we introduce proper dense -> jagged views, which I'm working on. No action needed for this PR
@@ -278,12 +307,50 @@ def jagged_from_list( | |||
offsets = torch.cat( | |||
[ | |||
torch.zeros(1, dtype=torch.int64, device=values.device), | |||
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0), | |||
torch.tensor([s[0] for s in sizes[:-1]], device=values.device).cumsum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as mentioned above, we shouldn't change the offsets
format; we should keep this as shape B + 1
with the final offset included
else: | ||
raise RuntimeError( | ||
"When constructing a jagged nested tensor using narrow(), " | ||
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1" | |
"start and length must be Tensors that broadcast to input.shape[0]" |
(I believe we removed the need for x 1
in the logic)
0, batch_size, dtype=torch.int64, device=tensor.device | ||
) | ||
# Jagged layout specifies that offsets are stored as int64 on the same device as values. | ||
offsets = start_list + offset_lengths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should add the "final offset" to match the old jagged offsets format with shape B + 1
torch/nested/_internal/ops.py
Outdated
split_offsets = torch.cat( | ||
( | ||
offsets, | ||
torch.tensor( | ||
[values.shape[0]], device=offsets.device, dtype=offsets.dtype | ||
), | ||
) | ||
) | ||
return torch.split(values, split_offsets.diff().tolist()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we maintain the traditional jagged offsets
format with shape B + 1
, we can just use diff()
torch/nested/__init__.py
Outdated
r""" | ||
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows | ||
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor | ||
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations | |
(maybe view) shows only the elements in the interval `[start, start+length)`. As nested representations |
extreme nit: exclusive upper bound for interval
torch/nested/__init__.py
Outdated
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows | ||
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor | ||
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wording / formatting nit: rather than mentioning views in-line here, I think it'd be clearer to talk about view / copy behavior for each layout in a separate paragraph after this one
from offline discussion: we're in agreement that |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
I've updated all the offsets, hopefully I didn't miss any, but I can't seem to figure out the issue with the recompile happening on dynamo that breaks CI |
ghstack-source-id: 20ea5dde599fc8326f9fff1faf742ff2cd62753e Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: 952a3948709ec3a3654da4c0b7866d290f042814 Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: b95f8e47918dff02a9ac86a08ec3d790be2fbea4 Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: 7517fb544cf352a8810067b2ad85ebfdd484893f Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: d55799b6027b641eee4815ed9e0a25a7bf9fcf99 Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: 2fbbc0e0cd2a59eb6c4dde4ba1d6cc83ffb52b42 Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: df56044c88154d07ff6850b347ac6be215c3ad26 Pull Request resolved: #112770
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#112770 Approved by: https://github.com/cpuhrsch
@@ -187,3 +188,69 @@ def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires | |||
return nt | |||
else: | |||
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}") | |||
|
|||
|
|||
def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, did the documentation for this make it onto the website?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have forgotten to commit the changes to the docs adding this function, as I don't see it in the commit
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng