Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fft slow tests #50435

Closed
wants to merge 1 commit into from
Closed

Fix fft slow tests #50435

wants to merge 1 commit into from

Conversation

zasdfgbnm
Copy link
Collaborator

The failure is:

______________________________________________________________________________________________________ TestCommonCUDA.test_variant_consistency_jit_fft_rfft_cuda_float64 _______________________________________________________________________________________________________
../.local/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py:889: in wrapper
    method(*args, **kwargs)
../.local/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py:889: in wrapper
    method(*args, **kwargs)
../.local/lib/python3.9/site-packages/torch/testing/_internal/common_device_type.py:267: in instantiated_test
    if op is not None and op.should_skip(generic_cls.__name__, name,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch.testing._internal.common_methods_invocations.SpectralFuncInfo object at 0x7f7375f9b550>, cls_name = 'TestCommon', test_name = 'test_variant_consistency_jit', device_type = 'cuda', dtype = torch.float64

    def should_skip(self, cls_name, test_name, device_type, dtype):
>       for si in self.skips:
E       TypeError: 'NoneType' object is not iterable

../.local/lib/python3.9/site-packages/torch/testing/_internal/common_methods_invocations.py:186: TypeError

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 12, 2021

💊 CI failures summary and remediations

As of commit cc8a796 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 4 times.

@mruberry
Copy link
Collaborator

Would you elaborate on what's going on here, @zasdfgbnm?

cc @peterbell10

It seems like this line should be preventing skips from becoming None?

@mruberry mruberry added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 12, 2021
@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Jan 12, 2021

When I run with PYTORCH_TEST_WITH_SLOW=1, the test_ops.py is failing with the above error for all these fft ops. It seems that nobody is running slow tests, so nobody found this failure. And this failure is because in the __init__ function, it should replace the skips=None with skips=[], regardless of the TEST_WITH_SLOW. But currently, it is only doing so if not TEST_WITH_SLOW, which is wrong.

To reproduce the failure, you can do

export PYTORCH_TEST_WITH_SLOW=1
python test/test_ops.py -v

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fix, thanks @zasdfgbnm.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 4a3a378.

@malfet
Copy link
Contributor

malfet commented Jan 13, 2021

It seems that nobody is running slow tests, so nobody found this failure.

Hmm, slow test should be executed by slow_test config, see pytorch_linux_xenial_cuda10_2_cudnn7_py3_slow_test for example, though I wonder why failures is not reported back

@malfet malfet deleted the slow-test-fix branch January 13, 2021 15:40
@peterbell10
Copy link
Collaborator

peterbell10 commented Jan 13, 2021

Ah, I think I see this issue. The CI is running with PYTORCH_TEST_SKIP_FAST but the test was never marked as slow, so it gets skipped on every CI run. I think OpInfo needs a way to directly add the @slowTest decorator to a particular test.

@mruberry I'm thinking either extend the SkipInfo class to handle slow tests specifically. Or, introduce a mechanism to add a set of decorators to a particular test. That could either replace SkipInfo, or coexist with it. What do you think?

@mruberry
Copy link
Collaborator

I could see either approach work. If SkipInfo can be extended in a natural way, like there's just a active_if=no_slow_tests option, that sounds good. If there are other scenarios where we need to apply decorators to specific tests then I would create a separate mechanism, although I cannot think of any at the moment.

@peterbell10
Copy link
Collaborator

I'm leaning towards creating a DectorateInfo class that has a decorators argument, like the main OpInfo argument. The decorators get added while instantiating the test, just before where the test is currently wrapped in a functor that queries op.should_skip(...). Then optionally, SkipInfo can be reimplemented as a DecorateInfo that sets decorators = skipIf(true).

Adding the slowTest decorator requires this same logic anyway, so it seems reasonable to provide a generic interface in case it's needed elsewhere.

If that sounds alright, I'll create a PR.

@mruberry
Copy link
Collaborator

Sounds good.

facebook-github-bot pushed a commit that referenced this pull request Jan 25, 2021
…0501)

Summary:
Follow up to #50435

I have confirmed this works by running
```
pytest test_ops.py -k test_fn_gradgrad_fft`
```
with normally and with `PYTORCH_TEST_WITH_SLOW=1 PYTORCH_TEST_SKIP_FAST=1`. In the first case all tests are skipped, in the second they all run as they should.

Pull Request resolved: #50501

Reviewed By: ezyang

Differential Revision: D25956416

Pulled By: mruberry

fbshipit-source-id: c896a8cec5f19b8ffb9b168835f3743b6986dad7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants