Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement narrow from a regular tensor to jagged tensor #112770

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 58 additions & 2 deletions torch/nested/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional
from typing import List, Optional, Union

import torch
from torch import Tensor
from torch import SymInt, Tensor
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved

from torch.types import _device as Device, _dtype as DType
Expand Down Expand Up @@ -186,3 +186,59 @@ def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")


def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.jagged):
r"""
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
similar semantics ot torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(maybe view) shows only the elements in the interval `[start, start+length]`. As nested representations
(maybe view) shows only the elements in the interval `[start, start+length)`. As nested representations

extreme nit: exclusive upper bound for interval

can also be tensors of shape `tensor.shape[0] x 1`.

jbschlosser marked this conversation as resolved.
Show resolved Hide resolved

Args:
tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
for the nested tensor if using the jagged layout or will be copied for the strided layout.
dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
jagged layout, while strided supports all dim
start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op

Keyword arguments:
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the jagged layout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: None default should indicate strided, as is consistent with the behavior in other places (e.g. nested_tensor() / as_nested_tensor()).

also is None being handled? I didn't see it but I may have just missed it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is being handled with a RuntimeError

Example::

>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
>>> nt.is_leaf
True
"""
if not isinstance(start, int) or isinstance(start, Tensor):
raise RuntimeError("start must be an integer or a tensor")
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(length, int) or isinstance(length, Tensor):
raise RuntimeError("length must be an integer or a tensor")

if layout == torch.strided:
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length)
elif layout == torch.jagged:
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
if dim != 1:
raise RuntimeError("jagged layout only supports dim=1")

from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths

if isinstance(start, int):
start = torch.tensor([start], device=tensor.device, dtype=torch.int64)

if isinstance(length, int):
length = torch.tensor([length], device=tensor.device, dtype=torch.int64)

nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
else:
raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}")

return nt
62 changes: 59 additions & 3 deletions torch/nested/_internal/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from torch._C import DispatchKey, DispatchKeySet
from torch._prims_common import is_expandable_to
from torch.fx.experimental.symbolic_shapes import free_symbols
from torch.utils.weak import WeakTensorKeyDictionary
from typing import * # noqa: F403
Expand All @@ -21,6 +22,7 @@ def get_tensor_id(tensor, *, coeff=1):
class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
# NOTE [ Singleton ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
Expand All @@ -46,6 +48,7 @@ def __new__(
values,
offsets,
*,
lengths=None,
ragged_size=None,
**kwargs,
):
Expand All @@ -69,7 +72,7 @@ def __new__(
)
return r

def __init__(self, values, offsets, *, ragged_size=None, **kwargs):
def __init__(self, values, offsets, *, lengths=None, ragged_size=None, **kwargs):
super().__init__()
# Only support jagged for now.
assert offsets is not None
Expand Down Expand Up @@ -97,21 +100,25 @@ def __init__(self, values, offsets, *, ragged_size=None, **kwargs):
)
self._values = values
self._offsets = offsets
self._lengths = lengths

def values(self):
return self._values

def offsets(self):
return self._offsets

def lengths(self):
return self._lengths

def __repr__(self):
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str})"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"

def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
Expand All @@ -131,13 +138,14 @@ def __tensor_flatten__(self):
"requires_grad": self.requires_grad,
"ragged_size": self._size[self._ragged_idx],
}
return ["_values", "_offsets"], ctx
return ["_values", "_offsets", "_lengths"], ctx

@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta):
assert len(inner_tensors) == 2
values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"]
lengths = inner_tensors["_lengths"]

# NOTE [ Storing symbolic values as plain attributes on subclasses ]
#
Expand Down Expand Up @@ -173,6 +181,7 @@ def __tensor_unflatten__(inner_tensors: Dict, meta):
return NestedTensor(
values,
offsets=offsets,
lengths=lengths,
ragged_size=meta["ragged_size"],
requires_grad=meta["requires_grad"],
)
Expand Down Expand Up @@ -232,6 +241,17 @@ def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None


# Not actually a view!
class ViewNonContiguousNestedFromBuffer(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this will go away when we introduce proper dense -> jagged views, which I'm working on. No action needed for this PR

@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)

@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None


# Need to make it obvious that users should be passing in offsets
def jagged_from_list(
tensors: List[torch.Tensor],
Expand Down Expand Up @@ -285,5 +305,41 @@ def jagged_from_list(
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]


def jagged_from_tensor_and_lengths(
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
) -> Tuple[NestedTensor, torch.Tensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
batch_size = tensor.shape[0]
if is_expandable_to(starts.shape, (batch_size, 1)) and is_expandable_to(
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
lengths.shape, (batch_size, 1)
):
start_list = starts.expand(batch_size, 1)
length_list = lengths.expand(batch_size, 1)
else:
raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1"
"start and length must be Tensors that broadcast to input.shape[0]"

(I believe we removed the need for x 1 in the logic)

)

# Calculate jagged offsets
assert (
len(tensor.shape) >= 2
), "tensor must at least be 2D for the nested narrow op to work"
max_seq_len = tensor.shape[1]
ani300 marked this conversation as resolved.
Show resolved Hide resolved
offset_lengths = max_seq_len * torch.arange(
0, batch_size, dtype=torch.int64, device=tensor.device
)
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = start_list + offset_lengths
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add the "final offset" to match the old jagged offsets format with shape B + 1


# Reshape buffer to flatten the 1st and 2nd dimension
if len(tensor.shape) > 2:
values = tensor.reshape(-1, *tensor.shape[2:])
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
else:
values = tensor.reshape(-1)

return ViewNonContiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]


def buffer_from_jagged(jagged):
return ViewBufferFromNested.apply(jagged)
3 changes: 2 additions & 1 deletion torch/nested/_internal/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None:

arg_type_check_fns = {
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor) and x._lengths is None,
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
"jt_nc": lambda x: isinstance(x, NestedTensor),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new type won't be too useful today as you won't actually be able to register two funcs to the same aten op overload even if those ops have different schemas.

For now, what you probably want to do is just to branch inside whatever op you are trying to implement.

Perhaps in the future we want to go the route of writing a general dispatching mechanism, I'm not sure. Or its also entirely possible that we'd revert the t vs jt distinction as well. We don't have many ops implemented, so I think its too early to decide today.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah... I was trying to get out of adding an if statement to every single function we implement, but instead I will change the register_func call to accept a parameter of whether non-contiguous tensors are allowed or not (given most functions won't accept them at all)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then add the two paths for the ones that allow them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on second thought, "jt_nc" might not be the best name, but it is the one that allows both the contiguous and noncont. versions to pass though to the actual kernel code vs "jt" only letting the contiguous ones pass to maintain compatibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I'm okay with your initial approach. I did something similar to that locally when I was playing around with this and it minimized the changes needed to existing op impls

"any": lambda x: True,
}
for i, named_arg_type in enumerate(named_arg_types):
Expand Down