Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST/MAINT: cluster: use new array API assertions #19251

Merged
merged 8 commits into from Sep 28, 2023

Conversation

lucascolley
Copy link
Member

Reference issue

Follow-up to gh-19186 for gh-18668.

What does this implement/fix?

The new assertions xp_assert_close and xp_assert_equal are used in cluster, such that all of our array-API-converted code demonstrates the preferred methods of testing.

There are a few other changes in here, like changing assert_equal(correspond(Z, y2), False) to assert not correspond(Z, y2), and some minor PEP-8 bits. I hope that that's okay, but please let me know if you'd like those removed.

Additional information

@mdhaber how does this look? python dev.py test -s cluster -b all passes for me, but there are a few tricky bits in here with dtypes, scalars and tolerances which could do with a look over. I think that I've covered everything, but there's a chance that I've missed something, or included an unwanted change.

cc: @tupui

Copy link
Member

@tupui tupui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Lucas 🚀

I will let Matt have a look

@tupui tupui added scipy.cluster array types Items related to array API support and input array validation (see gh-18286) labels Sep 16, 2023
@tupui tupui added this to the 1.12.0 milestone Sep 16, 2023
@tupui
Copy link
Member

tupui commented Sep 16, 2023

ark CI seems to complain on Windows:

FAILED cluster/tests/test_hierarchy.py::test_cut_tree[numpy] - AssertionError: dtypes do not match.
Actual: int64
Desired: int32
FAILED cluster/tests/test_vq.py::TestVq::test_py_vq[numpy] - AssertionError: dtypes do not match.
Actual: int64
Desired: int32

@lucascolley
Copy link
Member Author

lucascolley commented Sep 16, 2023

ark CI seems to complain on Windows:

Okay, looks like we should make sure we have int64 for all platforms then.

Edit: let's see if CI is happy after 9af46c8.

Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance, I'd think something is not right whenever the first argument is an xp_assert is explicitly converted to something else. Could you comment on those cases?

Using xp.asarray on the second argument is fine. However, there are some places where there is explicit dtype conversion to float64 that doesn't look like it would be needed. For instance, if the calculation of the reference value starts with Python floats, it would be surprising if the dtype of the result were other than float64, and I'd want to know why it were not.

If float32 is supported, are there any tests?
I see some places with integer output. Should these functions ever produce int32, and if so, are any of these cases tested?

scipy/cluster/tests/test_hierarchy.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_vq.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_vq.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_hierarchy.py Outdated Show resolved Hide resolved
@lucascolley
Copy link
Member Author

A lot of these tests were giving different dtypes (32-bit or 64-bit) for different namespaces. In particular, I remember torch returning some float32s where other namespaces return float64s. Also, as CI has caught, there can be variation in return types between platforms.

If float32 is supported, are there any tests?

All this complexity means that it would probably be a good idea to test every supported dtype, although this seems like quite a bit of work.

something is not right whenever the first argument is an xp_assert is explicitly converted to something else

This is mostly due to the 32 vs 64-bit stuff, but also compounded by the fact that lots of our tests are called as desired, actual rather than actual, desired.

@mdhaber
Copy link
Contributor

mdhaber commented Sep 17, 2023

All this complexity means that it would probably be a good idea to test every supported dtype

I wouldn't ask you to add that here. But yes, we should test whatever we claim to support.

compounded by the fact that lots of our tests are called as desired, actual rather than actual, desired.

Would you mind swapping these?

I don't mean to ask you to do more than you meant to sign up fo. But if we're going to update these to use xp_asserts, I think they should conform to all modern standards. Otherwise, I'd rather leave them alone. Of course, that's only because I was asked to review. Another maintainer is welcome to do what they think is right.

@lucascolley
Copy link
Member Author

if we're going to update these to use xp_asserts, I think they should conform to all modern standards.

I completely agree 👍. No worries, I am happy to take this on. I might not be able to get round to this immediately, but hopefully in the next week or two.

There is going to be a lot of work to upgrade the tests for larger submodules to meet these standards. I will add it on to the list of follow-ups for fft, but I can't guarantee that I'll be able to get that done any time soon.

This work is definitely worthwhile in my opinion, since future changes to the namespaces we work with are likely to cause dtype issues. Having these tests in place should make the DX a lot better down the line.

Perhaps the solution for now will include a lot of check_dtype=False until work is done to document the expected dtypes returned by each function. A lot of our tests just take python lists as inputs, which is where I suspect a lot of the deviation in return types between namespaces is coming from.

Copy link
Member Author

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdhaber I've remade this PR with the changes to the assertions, namespaces and dtypes split into separate commits now - hopefully that makes it easier to review.

I've fixed the zeros issue you pointed out above. For the rest of the dtype problems, I've added check_dtype=False for now, and commented inline with exactly what the mismatches are. Please could you point me in the right direction for these? Quite a few seem to be torch defaulting to float32. I'm hesitant to make any more changes since I'm not sure what the desired way to fix the issues is.

Alternatively, I would be happy for this to merge and to open an issue for the dtype discrepancies to be sorted out separately.

scipy/cluster/tests/test_hierarchy.py Show resolved Hide resolved
scipy/cluster/tests/test_hierarchy.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_hierarchy.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_hierarchy.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_hierarchy.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_vq.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_vq.py Outdated Show resolved Hide resolved
scipy/cluster/tests/test_vq.py Outdated Show resolved Hide resolved

@skip_if_array_api_gpu
@array_api_compatible
def test_kmeans_diff_convergence(self, xp):
# Regression test for gh-8727
obs = xp.asarray([-3, -1, 0, 1, 1, 8], dtype=xp.float64)
res = kmeans(obs, xp.asarray([-3., 0.99]))
assert_allclose(res[0], xp.asarray([-0.4, 8.]))
assert_allclose(res[1], 1.0666666666666667)
xp_assert_close(res[0], xp.asarray([-0.4, 8.]), check_dtype=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

res[0] is float64, but xp.asarray([-0.4, 8.]) is float32 for torch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, adding an explicit dtype=xp.float64 for the "expected" value seems more robust, and less likely to let the future reader wonder about why dtypes can't be checked here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I thought the solution was (seems pretty simple), but Matt seemed unsure:

One of the main dtype issues is just that the default dtype of torch, even when a double provided, is single precision. Not sure what the best way to handle that is.

I didn't want to make any tests too strict by requiring a specific dtype, where we may be okay with a different one being returned as well.

If you are happy to just say 'expect what we have currently to test for regressions' then I can include that here 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test requiring res to have float64 dtype when the input to kmeans is also float64 seems like a perfectly reasonable requirement and good to test. I'd probably consider anything else a bug.

Copy link
Contributor

@mdhaber mdhaber Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. The main thing meant to disagree with was changing the dtype of the actual result.
Before, it looked like dtypes were being specified wherever it was needed to force the test to pass. See #19251 (comment).

When the torch default being float32 came up (#19251 (comment)), the suggestion ended up being to force the dtype of the expected result to be float64 - but we did it where the array was created rather than each place it was used. In retrospect, suppose the original code would have been fine it that case. There was some added confusion because I didn't know that torch would default to float32 even when a double was provided.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. It's also not just PyTorch - all deep learning focused libraries will default to float32, because that's much better supported on GPU/TPU, and more than enough precision for deep learning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I knew it was useful, I just didn't know it was the default.

Here was the overall comment btw. #19251 (review)

scipy/cluster/tests/test_vq.py Outdated Show resolved Hide resolved
Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much better. It is fine with me to merge as-is (maybe with the one change discussed with that really loose tolerance) and open an issue about dtypes.

One of the main dtype issues is just that the default dtype of torch, even when a double provided, is single precision. Not sure what the best way to handle that is. But it doesn't need to be done here.

@lucascolley
Copy link
Member Author

@tupui see Matt's comment above. I can open the issue for dtypes once this merges (assuming the rest still looks okay). We should leave my inline comments related to the dtypes unresolved for reference from the issue.

@rgommers
Copy link
Member

This looks quite good to me, only two minor comments.

[skip cirrus] [skip circle]
@lucascolley
Copy link
Member Author

lucascolley commented Sep 27, 2023

@rgommers e09ce5e has removed most of the check_dtype=False uses by specifying that we expect whatever dtype we currently output.

They remain in one test, where CI was catching a variation between int32 and int64 between platforms, which I have noted in a code comment. Let's see if CI catches anything else.

Edit: CI has caught a few more int32 vs int64 places. I'll revert to check_dtype=False unless this is actually a bug with a quick fix.

@lucascolley
Copy link
Member Author

Only the Linux Meson tests / Linux - 32 bit (pull_request) machine is failing on CI, due to outputting int32s. Turning off the dtype checks now then hopefully this will be good to go.

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most tests were tightened and are passing, and comments were added for the few remaining check_dtype=False instances. So this LGTM now - let's give this a go. Thanks @lucascolley and @mdhaber, @tupui!

@rgommers rgommers merged commit 45e875d into scipy:main Sep 28, 2023
22 of 23 checks passed
@lucascolley lucascolley deleted the cluster-assertions branch September 28, 2023 10:57
@lucascolley
Copy link
Member Author

Opened gh-19319 to document that checks which are turned off as requested,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) scipy.cluster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants