Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NJT <-> padded dense conversions #125947

Draft
wants to merge 39 commits into
base: gh/jbschlosser/140/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
560c873
NJT <-> padded dense conversions
jbschlosser May 10, 2024
693979b
Update on "NJT <-> padded dense conversions"
jbschlosser May 10, 2024
9d86762
Update on "NJT <-> padded dense conversions"
jbschlosser May 14, 2024
75a274c
Update on "NJT <-> padded dense conversions"
jbschlosser May 14, 2024
84a6f5f
Update on "NJT <-> padded dense conversions"
jbschlosser May 14, 2024
afc43d1
Update on "NJT <-> padded dense conversions"
jbschlosser May 15, 2024
8e77035
Update on "NJT <-> padded dense conversions"
jbschlosser May 15, 2024
ca8b676
Update on "NJT <-> padded dense conversions"
jbschlosser May 17, 2024
e33eecd
Update on "NJT <-> padded dense conversions"
jbschlosser May 17, 2024
b2ad249
Update on "NJT <-> padded dense conversions"
jbschlosser May 17, 2024
29dc557
Update on "NJT <-> padded dense conversions"
jbschlosser May 17, 2024
0a19ede
Update on "NJT <-> padded dense conversions"
jbschlosser May 20, 2024
a24cea6
Update on "NJT <-> padded dense conversions"
jbschlosser May 20, 2024
8b2bd58
Update on "NJT <-> padded dense conversions"
jbschlosser May 20, 2024
29921ff
Update on "NJT <-> padded dense conversions"
jbschlosser May 21, 2024
85a0586
Update on "NJT <-> padded dense conversions"
jbschlosser May 21, 2024
974c8e9
Update on "NJT <-> padded dense conversions"
jbschlosser May 21, 2024
bfade9b
Update on "NJT <-> padded dense conversions"
jbschlosser May 22, 2024
29828b5
Update on "NJT <-> padded dense conversions"
jbschlosser May 22, 2024
100d48f
Update on "NJT <-> padded dense conversions"
jbschlosser May 22, 2024
f7459d7
Update on "NJT <-> padded dense conversions"
jbschlosser May 23, 2024
be32c1d
Update on "NJT <-> padded dense conversions"
jbschlosser May 23, 2024
3c103f1
Update on "NJT <-> padded dense conversions"
jbschlosser May 24, 2024
434d898
Update on "NJT <-> padded dense conversions"
jbschlosser May 24, 2024
02f1243
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 3, 2024
fa9fae2
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 4, 2024
e1567d6
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 4, 2024
cb275b2
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 5, 2024
0dab59f
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 7, 2024
2def51c
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 11, 2024
69bc5d2
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 11, 2024
0bcfcf8
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 12, 2024
2c99461
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 12, 2024
1779701
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 13, 2024
33c514e
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 14, 2024
484fb1b
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 17, 2024
6033447
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 18, 2024
4d11b18
Update on "NJT <-> padded dense conversions"
jbschlosser Jun 24, 2024
df387ed
Update on "NJT <-> padded dense conversions"
jbschlosser Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14686,6 +14686,11 @@
CUDA: _fbgemm_dense_to_jagged_forward_symint
CPU: _padded_dense_to_jagged_forward_cpu

- func: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int dim=1, SymInt? sum_S=None) -> Tensor
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor
dispatch:
NestedTensorCPU: NestedTensor_softmax_dropout
Expand Down
1 change: 1 addition & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ aten::_nested_from_padded
aten::_nested_from_padded.out
aten::_nested_from_padded_and_nested_example
aten::_nested_from_padded_and_nested_example.out
aten::_nested_from_padded_tensor
aten::_nested_get_jagged_dummy
aten::_nested_get_lengths
aten::_nested_get_max_seqlen
Expand Down
94 changes: 94 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5651,6 +5651,100 @@ def test_unbind_backward(self, device, dtype):
expected_grad.unbind()[1].add_(1.0)
torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)

@dtypes(torch.float32, torch.double, torch.half)
@parametrize("nt_dim", [2, 3, 4])
@parametrize("requires_grad", [False, True])
def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad):
if nt_dim == 2:
post_seq_len_shape = ()
elif nt_dim == 3:
post_seq_len_shape = (10,)
elif nt_dim == 4:
post_seq_len_shape = (9, 10)

nt = torch.nested.nested_tensor(
[
torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
for n in range(2, 9)
],
layout=torch.jagged,
requires_grad=requires_grad,
)

PADDING_VAL = 4.2
expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL)
for i, component in enumerate(nt.unbind()):
expected_padded[i, : component.shape[0]].copy_(component)

padded = nt.to_padded_tensor(PADDING_VAL)
self.assertEqual(expected_padded, padded)

# convert padded dense -> NJT
nt2 = torch.nested.nested_tensor_from_padded(padded, nt.offsets())
self.assertEqual(nt, nt2)

if requires_grad:
# ensure gradients flow through conversions
nt2.backward(torch.ones_like(nt2))
self.assertEqual(nt.grad, torch.ones_like(nt))

# blows up due to test parametrization otherwise
@torch._dynamo.utils.disable_cache_limit()
@skipIfTorchDynamo("SDPA test compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
@dtypes(torch.float32, torch.double, torch.half)
@parametrize("nt_dim", [2, 3, 4])
@parametrize("requires_grad", [False, True])
def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad):
if nt_dim == 2:
post_seq_len_shape = ()
elif nt_dim == 3:
post_seq_len_shape = (10,)
elif nt_dim == 4:
post_seq_len_shape = (9, 10)

nt = torch.nested.nested_tensor(
[
torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
for n in range(2, 9)
],
layout=torch.jagged,
requires_grad=requires_grad,
)

def f(x):
return x.sin() + 1

PADDING_VAL = 4.2

@torch.compile(fullgraph=True)
def g(nt):
padded = nt.to_padded_tensor(PADDING_VAL)
padded = f(padded)
# NB: sum_S must be specified to use the lowering for dense -> jagged
# and get full fusion
return torch.nested.nested_tensor_from_padded(
padded, nt.offsets(), sum_S=nt.values().shape[0]
)

expected_output = f(nt)
if requires_grad:
expected_output.backward(torch.ones_like(expected_output))
expected_grad = nt.grad.clone().detach()
nt.grad = None

compiled_output = g(nt)
if requires_grad:
compiled_output.backward(torch.ones_like(compiled_output))
compiled_grad = nt.grad.clone().detach()
self.assertEqual(compiled_grad, expected_grad, rtol=1e-3, atol=1e-3)

self.assertEqual(compiled_output, expected_output, rtol=1e-3, atol=1e-3)

# TODO: Verify that computation fusion happens


instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
Expand Down
7 changes: 6 additions & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2802,9 +2802,14 @@
cpu_nested_shape_example: non_differentiable

- name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
self: at::_nested_from_padded(grad, self._nested_tensor_size())
self: "self.layout() == c10::kJagged ? at::_nested_from_padded_tensor_symint(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_ragged_idx(self), c10::optional<c10::SymInt>(at::_nested_get_values(self).sym_size(0))) : at::_nested_from_padded(grad, self._nested_tensor_size())"
padding: non_differentiable

- name: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int dim=1, SymInt? sum_S=None) -> Tensor
padded: grad.to_padded_tensor_symint(0.0, at::OptionalArrayRef<c10::SymInt>(padded.sym_sizes()))
offsets: non_differentiable
dummy: non_differentiable

- name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a)
self: grad.values()
nested_size: non_differentiable
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"torch._nested_tensor_from_mask": SkipFunctionVariable,
"torch._nested_from_padded": SkipFunctionVariable,
"torch.nested.nested_tensor_from_jagged": UserFunctionVariable,
"torch.nested.nested_tensor_from_padded": UserFunctionVariable,
# symbol operators implemented in Python
"torch.sym_not": TorchInGraphFunctionVariable,
"torch.sym_float": TorchInGraphFunctionVariable,
Expand Down
8 changes: 8 additions & 0 deletions torch/nested/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"as_nested_tensor",
"nested_tensor",
"nested_tensor_from_jagged",
"nested_tensor_from_padded",
"narrow",
]

Expand Down Expand Up @@ -389,3 +390,10 @@ def nested_tensor_from_jagged(
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets_lengths

return nested_view_from_values_offsets_lengths(values, offsets, lengths, ragged_idx=jagged_dim)


def nested_tensor_from_padded(padded: Tensor, offsets: Tensor, ragged_idx=1, sum_S=None):
from torch.nested._internal.nested_tensor import nested_from_padded

# TODO: implement this for the strided layout?
return nested_from_padded(padded, offsets, ragged_idx, sum_S)
10 changes: 10 additions & 0 deletions torch/nested/_internal/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,13 @@ def nested_view_from_values_offsets_lengths(
min_seqlen_tensor,
max_seqlen_tensor,
) # type: ignore[return-value]


@torch._dynamo.allow_in_graph
def nested_from_padded(padded, offsets, ragged_idx=1, sum_S=None):
if ragged_idx != 1:
raise RuntimeError("nested_from_padded(): only ragged_idx=1 supported for now")

return torch._nested_from_padded_tensor(
padded, offsets, _nt_view_dummy(), ragged_idx, sum_S
)
81 changes: 81 additions & 0 deletions torch/nested/_internal/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,87 @@ def values_default(func, *args, **kwargs):
return inp._values.detach()


@register_jagged_func(
torch.ops.aten.to_padded_tensor.default, "self: jt, padding: any, output_size: any?"
)
def to_padded_tensor_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)

inp = new_kwargs.pop("input")

# TODO: Handle the rest of output_size
output_size = new_kwargs["output_size"]
if output_size is not None:
max_seq_len = output_size[inp._ragged_idx]
else:
max_seq_len = inp._max_seqlen

if inp.is_cuda:
# > 2D values is not supported by the underlying FBGEMM kernel so do shape gymnastics
values = inp.values()
values_shape = values.shape
if values.dim() > 2:
values = values.flatten(start_dim=1)

padded_out = torch.ops.aten._jagged_to_padded_dense_forward(
values,
[inp._offsets],
[max_seq_len],
new_kwargs["padding"],
)

# shape gymnastics part 2
if len(values_shape) > 2:
padded_out = padded_out.unflatten(-1, values_shape[1:])

return padded_out
else:
# TODO: backup non-FBGEMM impl
return None


@register_jagged_func(
torch.ops.aten._nested_from_padded_tensor.default,
"padded: t, offsets: t, dummy: jt, dim: any?, sum_S: any?",
)
def _nested_from_padded_tensor_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)

if new_kwargs["dim"] != 1:
raise RuntimeError(
"_nested_from_padded_tensor(): only dim=1 supported for jagged layout"
)

padded, offsets = new_kwargs["padded"], new_kwargs["offsets"]

# non-3D padded is not supported by the underlying FBGEMM kernel so do shape gymnastics
padded_shape = padded.shape
if padded.dim() > 3:
padded = padded.flatten(start_dim=2)
elif padded.dim() < 3:
padded = padded.unsqueeze(-1)

if padded.is_cuda:
values = torch.ops.aten._padded_dense_to_jagged_forward(
padded, [offsets], new_kwargs["sum_S"]
)

# shape gymnastics part 2
if len(padded_shape) > 3:
values = values.unflatten(-1, padded_shape[2:])
elif len(padded_shape) < 3:
values = values.squeeze(-1)

return NestedTensor(values, offsets)
else:
# TODO: backup non-FBGEMM impl
return None


@register_jagged_func(
torch.ops.aten._nested_view_from_jagged.default,
"values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
Expand Down
Loading