Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve memory footprint of torch.testing.assert_close #96131

Closed
wants to merge 11 commits into from

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Mar 6, 2023

Stack from ghstack (oldest at bottom):

Redo of #90172 out of stack.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96131

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 8e1f3aa:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pmeier pmeier added module: testing Issues related to the torch.testing module (not tests) topic: not user facing topic category labels Mar 6, 2023
@@ -1008,28 +1013,10 @@ def _compare_regular_values_close(
)
else:
msg = make_tensor_mismatch_msg(
actual, expected, ~matches, rtol=rtol, atol=atol, identifier=identifier
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the error case, we created the mismatches = ~matches tensor here and turned it back into a matches inside make_tensor_mismatch_msg. With a minor refactoring, we no longer need to invert matches and can use it directly.

@@ -991,7 +997,6 @@ def _compare_regular_values_close(
identifier: Optional[Union[str, Callable[[str], str]]] = None,
) -> None:
"""Checks if the values of two tensors are close up to a desired tolerance."""
actual, expected = self._promote_for_comparison(actual, expected)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We unconditionally upcasted here in the past, since that was needed for isclose. This is no longer the case and so we can just drop that.

Comment on lines +285 to +289
if not actual.dtype.is_floating_point and not actual.dtype.is_complex:
# TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid
# overflow
actual_flat = actual_flat.to(torch.int64)
expected_flat = expected_flat.to(torch.int64)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, we still need to upcast in the error case, since we want to display the absolute diff and that is not supported for torch.bool and might overflow for other integer dtypes.

Comment on lines +281 to +282
actual_flat = actual.flatten()
expected_flat = expected.flatten()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Driveby renaming. a and b were only used in the beginning and should be actual and expected now.

torch/testing/_comparison.py Outdated Show resolved Hide resolved
@pmeier pmeier marked this pull request as ready for review March 7, 2023 08:32
@pmeier pmeier requested review from mruberry and pearu March 7, 2023 08:32
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, @pmeier!

I have an OT feature request, so feel free to ignore it.

# Ensure that only mismatches are used for the max_abs_diff computation
abs_diff[matches_flat] = 0
max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)

rel_diff = abs_diff / torch.abs(b_flat)
rel_diff = abs_diff / torch.abs(expected_flat)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A slight OT suggestion: could we have a better normalization factor here (say, (torch.abs(actual) + torch.abs(expected)) / 2) for the case where expected contains zeros (having zeros is typical, say when comparing the indices of sparse tensors)? Atm, the mismatch messages from assert_close depends on the order of inputs, for example:

>>> torch.testing.assert_close(torch.tensor([1, 0]), torch.tensor([1, 1]))
<snip>
Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 1 at index (1,)
Greatest relative difference: 1.0 at index (1,)
>>> torch.testing.assert_close(torch.tensor([1, 1]), torch.tensor([1, 0]))
<snip>
Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 1 at index (1,)
Greatest relative difference: inf at index (1,)

(btw, reporting relative differences for non-float tensors is often pointless as well).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Atm, the mismatch messages from assert_close depends on the order of inputs, for example:

It's not just the messages, it is the actual op. Internally, we rely on torch.isclose and that is already asymmetric. It defines closeness as abs(actual - expected) <= atol + rtol * abs(expected). Believe me when I say, we (torch.testing team) wanted to change that, but there is just too much inertia. PyTorch is not an outlier here; numpy (and virtually every other array library) is doing the same.

Pythons math module is doing the more sensible thing in defining closeness as abs(actual - expected) <= max(atol, rtol * max(abs(actual), abs(expected))). You can read more about this whole issue in PEP485.

At some point we tried to get this behavior specified by the Array API, but couldn't gain enough traction. See data-apis/array-api#170.

(btw, reporting relative differences for non-float tensors is often pointless as well).

Doesn't that somewhat contradict the use case you gave earlier?

for the case where expected contains zeros (having zeros is typical, say when comparing the indices of sparse tensors)

@pytorchmergebot
Copy link
Collaborator

Rebased gh/pmeier/55/orig onto refs/remotes/origin/viable/strict because #96132 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/96131)

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier
Copy link
Collaborator Author

pmeier commented Mar 17, 2023

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 17, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/pmeier/54/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/96131)

pytorchmergebot pushed a commit that referenced this pull request Mar 17, 2023
ghstack-source-id: c42b920568ee9fe210a5f8ec42961ca1272c65ca
Pull Request resolved: #96131
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64-mps / test (default, 1, 1)

Details for Dev Infra team Raised by workflow job

test/test_mps.py Outdated
@@ -10070,7 +10070,7 @@ def test_mps_compat(self):
# If this test is successful, that means that all operations in the comparison logic are supported natively on
# the MPS backend. Please remove this test as well as the compatibility logic in
# torch.testing._comparison.TensorLikePair._equalize_attributes
actual = torch.tensor(1.0, device="mps")
actual = torch.zeros(2, 3, 4, 5, device="mps")
Copy link
Collaborator Author

@pmeier pmeier Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kulinseth I've increased the shape to 4 dimensions here, because otherwise this test would pass although torch.testing.assert_close is not ready. See #95538 for details.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@pmeier pmeier requested a review from kulinseth March 17, 2023 09:03
@pmeier
Copy link
Collaborator Author

pmeier commented Mar 20, 2023

@kulinseth It seems the test is still passing: https://hud.pytorch.org/pr/96131#12075596190. Does that mean the behavior was fixed? Otherwise, could you send me a patch that consistently makes this test fail? I don't have access to a MPS machine and don't want to waste CI resources by pushing multiple times just for this one test.

@pmeier
Copy link
Collaborator Author

pmeier commented Mar 28, 2023

@kulinseth any update on this?

@kulinseth
Copy link
Collaborator

@kulinseth any update on this?

Sorry for delay @pmeier . I think we have support till 4 dims, if we increase the dimensions to 5 , then test starts failing .

@pmeier
Copy link
Collaborator Author

pmeier commented Mar 28, 2023

Argh, my bad. Let me fix that.

@pmeier
Copy link
Collaborator Author

pmeier commented Mar 29, 2023

@kulinseth Test fails now, but unfortunately, the error is not recoverable

test_mps.py::TestNoRegression::test_mps_compat Assertion failed: (0 <= mpsAxis && mpsAxis < 4 && "Runtime canonicalization must simplify reduction axes to minor 4 dimensions."), function encodeNDArrayOp, file GPUReductionOps.mm, line 76.
Fatal Python error: Aborted

and thus we never hit the xfail. Not sure what to do with this. I'll remove the test to unblock and leave a comment in #95538. LMK if you want to handle it differently.

@pytorch-bot pytorch-bot bot added the ciflow/mps Run MPS tests (subset of trunk) label Mar 29, 2023
pmeier added a commit that referenced this pull request Mar 29, 2023
ghstack-source-id: 1b796fd7695e8ba2673eb05cccf2f7d9174b21bd
Pull Request resolved: #96131
@pmeier
Copy link
Collaborator Author

pmeier commented Mar 29, 2023

@pytorchbot merge -r viable/strict

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/pmeier/54/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/96131)

pytorchmergebot pushed a commit that referenced this pull request Mar 29, 2023
ghstack-source-id: ec7cd022806cea09dfd1cd4e1e91477d4d5dedf4
Pull Request resolved: #96131
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: GraphQL query
fragment PRReviews on PullRequestReviewConnection {
nodes {
author {
login
}
state
}
pageInfo {
startCursor
hasPreviousPage
}
}

fragment PRCheckSuites on CheckSuiteConnection {
edges {
node {
app {
name
databaseId
}
workflowRun {
workflow {
name
}
url
}
checkRuns(first: 50) {
nodes {
name
conclusion
detailsUrl
databaseId
}
pageInfo {
endCursor
hasNextPage
}
}
conclusion
}
cursor
}
pageInfo {
hasNextPage
}
}

fragment CommitAuthors on PullRequestCommitConnection {
nodes {
commit {
author {
user {
login
}
email
name
}
oid
}
}
pageInfo {
endCursor
hasNextPage
}
}

query ($owner: String!, $name: String!, $number: Int!) {
repository(owner: $owner, name: $name) {
pullRequest(number: $number) {
closed
isCrossRepository
author {
login
}
title
body
headRefName
headRepository {
nameWithOwner
}
baseRefName
baseRepository {
nameWithOwner
isPrivate
defaultBranchRef {
name
}
}
mergeCommit {
oid
}
commits_with_authors: commits(first: 100) {
...CommitAuthors
totalCount
}
commits(last: 1) {
nodes {
commit {
checkSuites(first: 10) {
...PRCheckSuites
}
status {
contexts {
context
state
targetUrl
}
}
pushedDate
oid
}
}
}
changedFiles
files(first: 100) {
nodes {
path
}
pageInfo {
endCursor
hasNextPage
}
}
reviews(last: 100) {
...PRReviews
}
comments(last: 5) {
nodes {
bodyText
createdAt
author {
login
}
authorAssociation
editor {
login
}
databaseId
}
pageInfo {
startCursor
hasPreviousPage
}
}
labels(first: 100) {
edges {
node {
name
}
}
}
}
}
}
, args {'name': 'pytorch', 'owner': 'pytorch', 'number': 96131} failed: [{'message': 'Something went wrong while executing your query. Please include 0402:1535:7C69A:100353:6424630D when reporting this issue.'}]

Details for Dev Infra team Raised by workflow job

pmeier added a commit that referenced this pull request Mar 29, 2023
ghstack-source-id: 11844b06eccc59a5eca1d577c2d6538427e74461
Pull Request resolved: #96131
@pmeier
Copy link
Collaborator Author

pmeier commented Mar 29, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged module: testing Issues related to the torch.testing module (not tests) open source topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants