Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement narrow from a regular tensor to jagged tensor #112770

Closed
wants to merge 14 commits into from

Conversation

ani300
Copy link
Collaborator

@ani300 ani300 commented Nov 2, 2023

Copy link

pytorch-bot bot commented Nov 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112770

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 03a94ee with merge base d4c810c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ani300 added a commit that referenced this pull request Nov 2, 2023
ghstack-source-id: ec6d15238467d273092f99f675298e4aa8c42daa
Pull Request resolved: #112770
@ani300 ani300 changed the title Implement narrow from a regular tensor to jagged tensor [RFC] Implement narrow from a regular tensor to jagged tensor Nov 2, 2023



def narrow(tensor: Tensor, dim: int, start: Union[SymInt, Tensor], length: Union[SymInt, Tensor], layout=torch.jagged):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should add a layout argument here. It doesn't follow the narrow syntax for regular tensors. If the user wants to change the layout they can use to. Since this is a view operation it shouldn't affect performance. Furthermore, layout change is likely not support when this is view only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should add a layout argument here. It doesn't follow the narrow syntax for regular tensors.

This is true, but we're also not limited to the original narrow() syntax since this is a new function in the torch.nested namespace.

So I sort of disagree. It's possible to create a view over a dense buffer with both the expanded jagged layout and the strided layout, and I think the layout arg is a natural extension for indicating which type of NT + metadata is to be constructed.

Also note that nested narrow for jagged is only supported on dim=1 due to representative restrictions, and it's potentially useful to take a ragged view on other dimensions, using the more general strided layout. This function would be the way to do it AFAICT.

If the user wants to change the layout they can use to.

Current semantics are such that this will copy, no? I don't quite see how we can guarantee torch.nested.narrow(...).to(layout=...) is non-copying if it's the difference between a C++ NT and a python subclass jagged layout NT.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.nested.narrow(...).to(layout=...) can be made non-copying via torch.compile and pattern matching.

But you're right that it's a new function. Maybe in the future we can extend torch.narrow with similar functionality.

Alright, I'm convinced.

raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1"
) from e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use is_expandable_to in torch/_prims_common/__init__.py instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need the expanded tensors, so I figured I could kill two birds with one stone, but I agree the explicit check is probably better

torch/nested/_internal/nested_tensor.py Outdated Show resolved Hide resolved
torch/nested/_internal/nested_tensor.py Show resolved Hide resolved
@@ -36,7 +36,8 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None:

arg_type_check_fns = {
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor) and x._lengths is None,
"jt_nc": lambda x: isinstance(x, NestedTensor),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new type won't be too useful today as you won't actually be able to register two funcs to the same aten op overload even if those ops have different schemas.

For now, what you probably want to do is just to branch inside whatever op you are trying to implement.

Perhaps in the future we want to go the route of writing a general dispatching mechanism, I'm not sure. Or its also entirely possible that we'd revert the t vs jt distinction as well. We don't have many ops implemented, so I think its too early to decide today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah... I was trying to get out of adding an if statement to every single function we implement, but instead I will change the register_func call to accept a parameter of whether non-contiguous tensors are allowed or not (given most functions won't accept them at all)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then add the two paths for the ones that allow them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on second thought, "jt_nc" might not be the best name, but it is the one that allows both the contiguous and noncont. versions to pass though to the actual kernel code vs "jt" only letting the contiguous ones pass to maintain compatibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I'm okay with your initial approach. I did something similar to that locally when I was playing around with this and it minimized the changes needed to existing op impls

if isinstance(length, SymInt):
length = torch.tensor([int(length)])

nt = jagged_from_tensor_and_lengths(tensor, start, length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You actually returned 3 things from this function, but the name suggests that you only expected to return one thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually want to return all 3 to the user?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I'll change it, I was trying to keep it similar to jagged_from_list which returns its offsets and is semantically a very similar function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think its okay for the helper function jagged_from_tensor_and_lengths to also return 3, but we don't necessarily want to return all 3 in the public API.

torch/nested/_internal/nested_tensor.py Show resolved Hide resolved
ani300 added a commit that referenced this pull request Nov 3, 2023
ghstack-source-id: 7df806f7755432a844b1c3435f723b160303e54f
Pull Request resolved: #112770
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I'm curious to see how this looks with SDPA :)

Do you mind adding a few tests into test/test_nestedtensor.py? I think we can check that things nicely error out for all ops when they encounter a non-contiguous jagged NT.

Along that line: I think we want to expand the definition of contiguous() to take lengths into account + add a corresponding test.

torch/nested/__init__.py Show resolved Hide resolved
torch/nested/_internal/nested_tensor.py Outdated Show resolved Hide resolved
torch/nested/__init__.py Outdated Show resolved Hide resolved
torch/nested/__init__.py Show resolved Hide resolved
torch/nested/__init__.py Outdated Show resolved Hide resolved
torch/nested/__init__.py Outdated Show resolved Hide resolved
torch/nested/__init__.py Outdated Show resolved Hide resolved
torch/nested/__init__.py Outdated Show resolved Hide resolved
torch/nested/_internal/nested_tensor.py Outdated Show resolved Hide resolved

Keyword arguments:
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the jagged layout.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: None default should indicate strided, as is consistent with the behavior in other places (e.g. nested_tensor() / as_nested_tensor()).

also is None being handled? I didn't see it but I may have just missed it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is being handled with a RuntimeError

ani300 added a commit that referenced this pull request Nov 3, 2023
ghstack-source-id: f712cd7556e787a7b93770f40188e08f4e058654
Pull Request resolved: #112770
ani300 added a commit that referenced this pull request Nov 6, 2023
ghstack-source-id: 1075ff48d892296e1cb96e92de4864afa19f1ba1
Pull Request resolved: #112770
ani300 added a commit that referenced this pull request Nov 7, 2023
ghstack-source-id: bbcaec97e0ced5496675e725bfef03a66371c0e7
Pull Request resolved: #112770
…sor"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 7, 2023
ghstack-source-id: 94d57818c238208d051ea966a78a5062ead0f883
Pull Request resolved: #112770
@ani300 ani300 changed the title [RFC] Implement narrow from a regular tensor to jagged tensor Implement narrow from a regular tensor to jagged tensor Nov 7, 2023
@ani300 ani300 added the release notes: nested tensor Changes that have a direct impact on nested tensors label Nov 7, 2023
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 7, 2023
ghstack-source-id: 2bf3fe5a469f26ed06d95dbcb24028c47355f103
Pull Request resolved: #112770
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really nice, thanks for all your hard work!

I'm just concerned about one thing: the format of offsets. Traditionally, the jagged layout defines it to be of shape B + 1 for batch size B, and it includes the final offset (e.g. offsets=[0, 3, 5, 7] with B=3 and a values buffer of shape [7, D]). For the contiguous case, this allows for lengths to be computed with offsets.diff().

From what I can tell, this PR redefines offsets to leave off the final offset, resulting in offsets.shape[0] == lengths.shape[0] == B.

I don't think we can change this for the contiguous case, as it's well-established at this point. For consistency, I'd also argue that the offsets shape should be the same across non-contiguous and contiguous cases. I admit this is a little strange, as the final offset doesn't have a ton of utility when lengths are defined separately.

@soulitzer / @cpuhrsch / @drisspg: what are your opinions on this?

ragged_size = get_tensor_id(offsets, coeff=1)
else:
ragged_size = get_tensor_id(lengths, coeff=1)
B = offsets.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I don't think this is correct. The original (contiguous) jagged layout has offsets of shape B + 1, making lengths computable via offsets.diff(). For the non-contiguous case, I think we should maintain this shape, including the final offset even if it isn't strictly necessary.

@@ -232,6 +250,17 @@ def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None


# Not actually a view!
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this will go away when we introduce proper dense -> jagged views, which I'm working on. No action needed for this PR

@@ -278,12 +307,50 @@ def jagged_from_list(
offsets = torch.cat(
[
torch.zeros(1, dtype=torch.int64, device=values.device),
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
torch.tensor([s[0] for s in sizes[:-1]], device=values.device).cumsum(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as mentioned above, we shouldn't change the offsets format; we should keep this as shape B + 1 with the final offset included

else:
raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1"
"start and length must be Tensors that broadcast to input.shape[0]"

(I believe we removed the need for x 1 in the logic)

0, batch_size, dtype=torch.int64, device=tensor.device
)
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = start_list + offset_lengths
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add the "final offset" to match the old jagged offsets format with shape B + 1

Comment on lines 482 to 490
split_offsets = torch.cat(
(
offsets,
torch.tensor(
[values.shape[0]], device=offsets.device, dtype=offsets.dtype
),
)
)
return torch.split(values, split_offsets.diff().tolist())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we maintain the traditional jagged offsets format with shape B + 1, we can just use diff()

r"""
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations
(maybe view) shows only the elements in the interval `[start, start+length)`. As nested representations

extreme nit: exclusive upper bound for interval

Comment on lines 194 to 196
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wording / formatting nit: rather than mentioning views in-line here, I think it'd be clearer to talk about view / copy behavior for each layout in a separate paragraph after this one

@jbschlosser
Copy link
Contributor

from offline discussion: we're in agreement that offsets should consistently be in the old format: shape B + 1

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
@ani300
Copy link
Collaborator Author

ani300 commented Nov 8, 2023

I've updated all the offsets, hopefully I didn't miss any, but I can't seem to figure out the issue with the recompile happening on dynamo that breaks CI

ani300 added a commit that referenced this pull request Nov 8, 2023
ghstack-source-id: 20ea5dde599fc8326f9fff1faf742ff2cd62753e
Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 8, 2023
ghstack-source-id: 952a3948709ec3a3654da4c0b7866d290f042814
Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 8, 2023
ghstack-source-id: b95f8e47918dff02a9ac86a08ec3d790be2fbea4
Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 8, 2023
ghstack-source-id: 7517fb544cf352a8810067b2ad85ebfdd484893f
Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 8, 2023
ghstack-source-id: d55799b6027b641eee4815ed9e0a25a7bf9fcf99
Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 10, 2023
ghstack-source-id: 2fbbc0e0cd2a59eb6c4dde4ba1d6cc83ffb52b42
Pull Request resolved: #112770
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
ani300 added a commit that referenced this pull request Nov 10, 2023
ghstack-source-id: df56044c88154d07ff6850b347ac6be215c3ad26
Pull Request resolved: #112770
@ani300
Copy link
Collaborator Author

ani300 commented Nov 13, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 13, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
@@ -187,3 +188,69 @@ def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")


def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, did the documentation for this make it onto the website?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have forgotten to commit the changes to the docs adding this function, as I don't see it in the commit

@facebook-github-bot facebook-github-bot deleted the gh/ani300/2/head branch November 17, 2023 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source release notes: nested tensor Changes that have a direct impact on nested tensors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants