-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Strict shape checking for NJTs with TestCase.assertEqual() #131898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131898
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 977b972 with merge base 2576dbb ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal. This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal. Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior. [ghstack-poisoned]
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal. This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal. Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior. [ghstack-poisoned]
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal. This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal. Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR utilizes the info from the existing OpInfo database `op_db` to contribute to general NJT testing. * New tests in `TestNestedTensorOpInfo` * `test_forward()` - compares forward output to an unbind-based reference * `test_backward()` - compares forward output and grads to an unbind-based reference * `test_forward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) to eager * `test_backward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) and grads to eager * To avoid adding a bunch of NJT-specific stuff to the `OpInfo` structure, this PR translates `op_db` -> a NJT-specific `njt_op_db`. * `UnaryUfuncInfo`s utilize a new `sample_inputs_unary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc. * `BinaryUfuncInfo`s utilize a new `sample_inputs_binary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc. * `ReductionOpInfo`s utilize a new `sample_inputs_njt_reduction()` which covers full reductions, reductions over the jagged dim, and reductions over the non-jagged dim * Several xfails were added to get things passing TODO (future PRs): * Pass non-contiguous / non-contiguous with holes NJTs (maybe we should have separate tests for these? most ops don't support NJTs with holes today) * Mixed (NT, T), (T, NT) inputs for binary ops * Handle other types of OpInfos (beyond unary pointwise, binary pointwise, and reduction) by manually by writing sample_inputs_funcs * Address all xfails via fixes Pull Request resolved: #131704 Approved by: https://github.com/soulitzer ghstack dependencies: #131898
…131937) **Background:** NJT utilizes a `jagged_unary_pointwise()` fallback that historically has assumed blindly that the first arg is an NJT. This assumption breaks certain ops; for example `pow(scalar, Tensor)` has an NJT as the second arg. This PR expands `jagged_unary_pointwise()` and the associated schema validation logic to handle an NJT in args other than the first position. Pull Request resolved: #131937 Approved by: https://github.com/soulitzer ghstack dependencies: #131898, #131704
) It's possible to construct an NJT with "holes" by specifying both `offsets` and `lengths` metadata. When `nt.clone(memory_format=torch.contiguous_format)` is called on such an NJT, the result should be an NJT without holes. This PR fixes this in simplistic way using `unbind()`, which isn't really supported in `torch.compile`. The longer term solution involves writing a proper kernel to support this. NB: Another limitation is that the returned NJT does not have the same ragged structure as the input. While we could manually hack the nested int registry (or update the union find when that lands), this is the first instance where a NJT with holes and an NJT without holes could have the same ragged structure, and getting those to play nicely together requires some fairly involved updates. For now, this PR punts on these updates until we can clean this up. Pull Request resolved: #132776 Approved by: https://github.com/ani300, https://github.com/soulitzer ghstack dependencies: #131898, #131704, #131937
Stack from ghstack (oldest at bottom):
Background:
TestCase.assertEqual()
is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes
TestCase.assertEqual()
stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base
NestedTensorTestCase
that defines a helper functionassertEqualIgnoringNestedInts()
so that these tests can explicitly opt in to the looser comparison behavior.