Skip to content

Conversation

jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Jul 26, 2024

Stack from ghstack (oldest at bottom):

Background: TestCase.assertEqual() is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes TestCase.assertEqual() stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base NestedTensorTestCase that defines a helper function assertEqualIgnoringNestedInts() so that these tests can explicitly opt in to the looser comparison behavior.

@jbschlosser jbschlosser requested a review from a team as a code owner July 26, 2024 15:11
Copy link

pytorch-bot bot commented Jul 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131898

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 977b972 with merge base 2576dbb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior.

[ghstack-poisoned]
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior.

[ghstack-poisoned]
**Background**: `TestCase.assertEqual()` is commonly used during test case validation. Historically, to support NSTs, the logic was written to compare two nested tensors by unbinding them and comparing their components. This logic applied to NJTs as well, which in practice meant that two NJTs with different nested ints in their shapes could compare equal if their components were equal.

This PR changes the above logic so that NJTs are no longer unbound during comparison, allowing them to receive full shape validation. This makes `TestCase.assertEqual()` stricter for NJTs, requiring them to have the same nested ints in their shapes to compare equal.

Note that some tests rely on the old, looser behavior. To address this, the PR introduces a base `NestedTensorTestCase` that defines a helper function `assertEqualIgnoringNestedInts()` so that these tests can explicitly opt in to the looser comparison behavior.

[ghstack-poisoned]
@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 30, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 30, 2024
This PR utilizes the info from the existing OpInfo database `op_db` to contribute to general NJT testing.
* New tests in `TestNestedTensorOpInfo`
    * `test_forward()` - compares forward output to an unbind-based reference
    * `test_backward()` - compares forward output and grads to an unbind-based reference
    * `test_forward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) to eager
    * `test_backward_compile()` - compares forward compile output (`backend="aot_eager_decomp_partition"`) and grads to eager
* To avoid adding a bunch of NJT-specific stuff to the `OpInfo` structure, this PR translates `op_db` -> a NJT-specific `njt_op_db`.
    * `UnaryUfuncInfo`s utilize a new `sample_inputs_unary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `BinaryUfuncInfo`s utilize a new `sample_inputs_binary_njt_pointwise()` which iterates through a comprehensive list of NJTs: contiguous / non-contiguous, dims 2, 3, and 4, transposed / not, etc.
    * `ReductionOpInfo`s utilize a new `sample_inputs_njt_reduction()` which covers full reductions, reductions over the jagged dim, and reductions over the non-jagged dim
* Several xfails were added to get things passing

TODO (future PRs):
* Pass non-contiguous / non-contiguous with holes NJTs (maybe we should have separate tests for these? most ops don't support NJTs with holes today)
* Mixed (NT, T), (T, NT) inputs for binary ops
* Handle other types of OpInfos (beyond unary pointwise, binary pointwise, and reduction) by manually by writing sample_inputs_funcs
* Address all xfails via fixes
Pull Request resolved: #131704
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898
pytorchmergebot pushed a commit that referenced this pull request Jul 31, 2024
…131937)

**Background:** NJT utilizes a `jagged_unary_pointwise()` fallback that historically has assumed blindly that the first arg is an NJT. This assumption breaks certain ops; for example `pow(scalar, Tensor)` has an NJT as the second arg.

This PR expands `jagged_unary_pointwise()` and the associated schema validation logic to handle an NJT in args other than the first position.
Pull Request resolved: #131937
Approved by: https://github.com/soulitzer
ghstack dependencies: #131898, #131704
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2024
)

It's possible to construct an NJT with "holes" by specifying both `offsets` and `lengths` metadata. When `nt.clone(memory_format=torch.contiguous_format)` is called on such an NJT, the result should be an NJT without holes.

This PR fixes this in simplistic way using `unbind()`, which isn't really supported in `torch.compile`. The longer term solution involves writing a proper kernel to support this.

NB: Another limitation is that the returned NJT does not have the same ragged structure as the input. While we could manually hack the nested int registry (or update the union find when that lands), this is the first instance where a NJT with holes and an NJT without holes could have the same ragged structure, and getting those to play nicely together requires some fairly involved updates. For now, this PR punts on these updates until we can clean this up.
Pull Request resolved: #132776
Approved by: https://github.com/ani300, https://github.com/soulitzer
ghstack dependencies: #131898, #131704, #131937
@github-actions github-actions bot deleted the gh/jbschlosser/166/head branch August 30, 2024 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants