Skip to content

Conversation

CaoE
Copy link
Collaborator

@CaoE CaoE commented Jul 27, 2023

Add backward check for test_memory_format.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 27, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106104

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit b2924b1 with merge base a0cfaf0 (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@CaoE CaoE force-pushed the add_backward_check_cl branch from b98b481 to ba4e739 Compare July 27, 2023 04:54
@CaoE CaoE added topic: not user facing topic category ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Jul 27, 2023
@CaoE CaoE force-pushed the add_backward_check_cl branch 9 times, most recently from 32a7f8b to ba6de4b Compare July 31, 2023 02:59
@CaoE CaoE requested review from jgong5 and mingfeima August 1, 2023 01:20
@mikaylagawarecki mikaylagawarecki self-requested a review August 1, 2023 02:47
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the cases where the backward memory format checks are disabled, are they issues that need to be fixed?

@CaoE
Copy link
Collaborator Author

CaoE commented Aug 1, 2023

For the cases where the backward memory format checks are disabled, are they issues that need to be fixed?

Will collect such issues. If there are no related issues, I will create corresponding github issues.

@CaoE
Copy link
Collaborator Author

CaoE commented Aug 15, 2023

@mikaylagawarecki For test breakings on MPS and CUDA, I created corresponding issues #107214, #107199, and #107201.

I have no mac machine to reproduce test breaking on MPS, so I'm waiting replies if it is real issue #107214. Then this PR should be ready.
From my perspective, the code is ready to be reviewed. Do you mind review this draft PR first ?

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CaoE Thank you very much for your hard work with this PR as well as filing issues for the failing tests!!

Regarding the failing mac test, I think the issue is sufficient and it is ok to skip this test when landing this PR.

Separately, what do you think of removing the corresponding tests in test_nn.py?

if isinstance(desired_outputs, torch.Tensor):
desired_outputs = (desired_outputs,)
# === Do backward pass. ===
ref_diff_outputs = tuple(t for t in desired_outputs if _req_grad(t))
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is an edge case but this line might not be modular if something in desired_outputs is a TensorList (I'm not sure whether that is ever the case though), we could use _traverse_object pytree.tree_flatten instead. Similarly below for diff_outputs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used pytree.tree_flatten instead.


if (
input_mem_format != torch.contiguous_format
or module_mem_formats != torch.contiguous_format
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second check will always be True here because we are checking the list instead of the current item

Suggested change
or module_mem_formats != torch.contiguous_format
or module_mem_format != torch.contiguous_format

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 743 to 749
grad_outputs = tuple(
torch.rand_like(t)
for t in diff_outputs
)
grad_outputs = tuple(
t1.copy_(t2)
for (t1, t2) in zip(grad_outputs, ref_grad_outputs)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe this would be a bit cleaner but up to you!

Suggested change
grad_outputs = tuple(
torch.rand_like(t)
for t in diff_outputs
)
grad_outputs = tuple(
t1.copy_(t2)
for (t1, t2) in zip(grad_outputs, ref_grad_outputs)
)
grad_outputs = tuple(
torch.empty_like(t1).copy_(t2)
for (t1, t2) in zip(diff_outputs, ref_grad_outputs)
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is cleaner. Fixed as suggested.

ref_diff_outputs,
ref_diff_inputs,
grad_outputs=ref_grad_outputs,
allow_unused=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, why do we set allow_unused=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. I removed allow_unused=True.

ModuleInfo(torch.nn.AdaptiveAvgPool2d,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
# Fails on backward check if output size is 1x1
gradcheck_memformat=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate the detailed comments here regarding why each of these is set to False!

nit: Separately, is there any way we could make this an xfail? I'm hoping that if these are fixed, whoever sends the PR fixing will get signal to un-xfail these tests. I worry that if we make this a flag, re-enabling the tests might fall through the cracks.

I'm thinking maybe something to the effect of

DecorateInfo(
     unittest.expectedFailure, 'TestModule',
     'test_memory_format',
     active_if=lambda p: p['training']
 )

This seems reasonable to me because the non-backward version will still be tested if training=False and we can change the check here to if training and len(ref_diff_outputs) > 0

Let me know your thoughts on this!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! It‘s a good idea. I just found there is active_if by which DecorateInfo can be activated according to input parameters of the test.
Modified as suggested.

@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Aug 15, 2023

One more thing -- do you think it makes sense to extend this to test that the gradient for the channels_last params/buffers of the model have gradients of the correct memory format as well. This could be a followup PR (and if you would prefer I could send the PR instead but just curious to get your thoughts as a developer working on channels_last_3d)

I was looking into #107199 and the complexity of the code paths/amount of branching made me wonder whether we might have silent correctness issues for memory format of gradients of params/buffers as well?

@CaoE CaoE force-pushed the add_backward_check_cl branch from 499aaf6 to d2e84d4 Compare August 16, 2023 03:55
DecorateInfo(skipMPS),)
DecorateInfo(skipMPS),
# Fails on backward check if output size is 1x1
DecorateInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the issue is that this gives an unexpected success when run using inductor. Do you know what the issue is here for 1x1 outputs for eager, would it be possible to fix it?

Otherwise I am okay with skipping the test and filing an issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 1x1 outputs, mean is applied on CPU instead of AdaptiveAvgPool2d so the grad will be channels first . Firstly, I added expectedFailure for cpu but it is still failed on CUDA. I will check this further.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unexpected success when run using inductor, please see #107861

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm how come this doesn't give unexpected success in CI on this PR anymore?

Copy link
Collaborator Author

@CaoE CaoE Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because I added:

# See https://github.com/pytorch/pytorch/issues/107861
# When inductor tests are turned on, the setting of requires_grad will be lost
for t1, t2 in zip(
    torch.utils._pytree.tree_flatten(d_args)[0],
    torch.utils._pytree.tree_flatten(module_input.forward_input.args)[0],
):
    t1.requires_grad_(t2.requires_grad)
for t1, t2 in zip(
    torch.utils._pytree.tree_flatten(d_kwargs)[0],
    torch.utils._pytree.tree_flatten(module_input.forward_input.kwargs)[0],
):

When inductor is turned on, this will success as backwards are not executed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thank you! I couldn't see the diff because it was force pushed, merging now

@CaoE CaoE force-pushed the add_backward_check_cl branch 6 times, most recently from edb86f1 to e088221 Compare August 20, 2023 13:43
@huydhn
Copy link
Contributor

huydhn commented Aug 22, 2023

@pytorchbot drci

@CaoE CaoE force-pushed the add_backward_check_cl branch from e088221 to 20ab3e0 Compare August 24, 2023 10:29
@mikaylagawarecki
Copy link
Contributor

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased add_backward_check_cl onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout add_backward_check_cl && git pull --rebase)

@CaoE
Copy link
Collaborator Author

CaoE commented Aug 25, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: periodic / linux-focal-rocm5.6-py3.8 / test (distributed, 1, 2, linux.rocm.gpu)

Details for Dev Infra team Raised by workflow job

@CaoE CaoE force-pushed the add_backward_check_cl branch from 0beb7a6 to b2924b1 Compare August 25, 2023 05:11
@CaoE
Copy link
Collaborator Author

CaoE commented Aug 25, 2023

@CaoE thanks again for adding this test. Following up,

  1. I will do an in depth review of add channel last 3d support for maxpool3d on CPU #97775 (Could you update that to make sure this test is running on that PR)
  2. Would you be interested in removing the related test_nn.py tests?

@mikaylagawarecki Sorry for slow reply. I'm occupied by some urgent tasks recently.
I may not have much time to do this in short time, but I can gradually remove these tests in later PRs.

@CaoE
Copy link
Collaborator Author

CaoE commented Aug 25, 2023

@mikaylagawarecki
Copy link
Contributor

@pytorchbot merge -f "macos test_multilayer_var failures are unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
Add backward check for test_memory_format.

Pull Request resolved: #106104
Approved by: https://github.com/mikaylagawarecki
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged open source Reverted topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants