Skip to content

Commit

Permalink
Implement narrow from a regular tensor to jagged tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: ec6d15238467d273092f99f675298e4aa8c42daa
Pull Request resolved: #112770
  • Loading branch information
ani300 committed Nov 2, 2023
1 parent 09df6b7 commit 71e7f48
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 6 deletions.
35 changes: 33 additions & 2 deletions torch/nested/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional
from typing import List, Optional, Union

import torch
from torch import Tensor
from torch import SymInt, Tensor
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]

from torch.types import _device as Device, _dtype as DType
Expand Down Expand Up @@ -186,3 +186,34 @@ def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")




def narrow(tensor: Tensor, dim: int, start: Union[SymInt, Tensor], length: Union[SymInt, Tensor], layout=torch.jagged):

if not isinstance(start, SymInt) or isinstance(start, Tensor):
raise RuntimeError("start must be an integer or a tensor")

if not isinstance(length, SymInt) or isinstance(length, Tensor):
raise RuntimeError("length must be an integer or a tensor")

if layout == torch.jagged and dim != 1:
raise RuntimeError("jagged layout only supports dim=1")

# Create the offset and lengths tensors from start and length
if layout == torch.jagged:
from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths

if isinstance(start, SymInt):
start = torch.tensor([int(start)])

if isinstance(length, SymInt):
length = torch.tensor([int(length)])

nt = jagged_from_tensor_and_lengths(tensor, start, length)

else:
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(start, length)

return nt
56 changes: 53 additions & 3 deletions torch/nested/_internal/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def get_tensor_id(tensor, *, coeff=1):
class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
# NOTE [ Singleton ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
Expand All @@ -46,6 +47,7 @@ def __new__(
values,
offsets,
*,
lengths=None,
ragged_size=None,
**kwargs,
):
Expand All @@ -69,7 +71,7 @@ def __new__(
)
return r

def __init__(self, values, offsets, *, ragged_size=None, **kwargs):
def __init__(self, values, offsets, *, lengths=None, ragged_size=None, **kwargs):
super().__init__()
# Only support jagged for now.
assert offsets is not None
Expand Down Expand Up @@ -97,21 +99,25 @@ def __init__(self, values, offsets, *, ragged_size=None, **kwargs):
)
self._values = values
self._offsets = offsets
self._lengths = lengths

def values(self):
return self._values

def offsets(self):
return self._offsets

def lengths(self):
return self._lengths

def __repr__(self):
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str})"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"

def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
Expand All @@ -131,13 +137,14 @@ def __tensor_flatten__(self):
"requires_grad": self.requires_grad,
"ragged_size": self._size[self._ragged_idx],
}
return ["_values", "_offsets"], ctx
return ["_values", "_offsets", "_lengths"], ctx

@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta):
assert len(inner_tensors) == 2
values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"]
lengths = inner_tensors["_lengths"]

# NOTE [ Storing symbolic values as plain attributes on subclasses ]
#
Expand Down Expand Up @@ -173,6 +180,7 @@ def __tensor_unflatten__(inner_tensors: Dict, meta):
return NestedTensor(
values,
offsets=offsets,
lengths=lengths,
ragged_size=meta["ragged_size"],
requires_grad=meta["requires_grad"],
)
Expand Down Expand Up @@ -232,6 +240,17 @@ def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None


# Not actually a view!
class ViewNoncontiguousNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor): # type: ignore[override]
return NestedTensor(values.detach(), offsets=offsets, lengths=lengths)

@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None


# Need to make it obvious that users should be passing in offsets
def jagged_from_list(
tensors: List[torch.Tensor],
Expand Down Expand Up @@ -285,5 +304,36 @@ def jagged_from_list(
return ViewNestedFromBuffer.apply(values, offsets), offsets # type: ignore[call-overload]


def jagged_from_tensor_and_lengths(
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
) -> Tuple[NestedTensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
batch_size = tensor.shape[0]
try:
start_list = starts.expand(batch_size, 1)
length_list = lengths.expand(batch_size, 1)
except RuntimeError as e:
raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be a Tensor that broadcasts to input.shape[0] x 1"
) from e

# Calculate jagged offsets
max_seq_len = tensor.shape[1]
offset_lengths = max_seq_len * torch.arange(
0, batch_size, dtype=torch.int64, device=tensor.device
)
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = start_list + offset_lengths

# Reshape buffer to flatten the 1st and 2nd dimension
if len(tensor.shape) > 2:
values = tensor.reshape(-1, *tensor.shape[2:])
else:
values = tensor.reshape(-1)

return ViewNoncontiguousNestedFromBuffer.apply(values, offsets, length_list), offsets, length_list # type: ignore[call-overload]


def buffer_from_jagged(jagged):
return ViewBufferFromNested.apply(jagged)
3 changes: 2 additions & 1 deletion torch/nested/_internal/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None:

arg_type_check_fns = {
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor) and x._lengths is None,
"jt_nc": lambda x: isinstance(x, NestedTensor),
"any": lambda x: True,
}
for i, named_arg_type in enumerate(named_arg_types):
Expand Down

0 comments on commit 71e7f48

Please sign in to comment.