Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated test_torch.py to use new OptimizerInfo infrastructure #125538

Closed
wants to merge 1 commit into from

Conversation

gambiTarun
Copy link
Contributor

@gambiTarun gambiTarun commented May 4, 2024

Fixes #123451 (only addresses test_torch.py cases)

This PR solves the specific task to update test_grad_scaling_autocast and test_params_invalidated_with_grads_invalidated_between_unscale_and_step in test/test_torch.py to use the new OptimizerInfo infrastructure.

I have combined tests that call _grad_scaling_autocast_test into one called test_grad_scaling_autocast and used _get_optim_inputs_including_global_cliquey_kwargs to avoid hard-coded configurations.

$ lintrunner test/test_cuda.py
ok No lint issues.

cc @janeyx99 @crcrpar @vincentqb @jbschlosser @albanD

Copy link

pytorch-bot bot commented May 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125538

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5499662 with merge base f9a7033 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 4, 2024
Copy link

linux-foundation-easycla bot commented May 4, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: gambiTarun / name: Tarunbir Gambhir (5499662)

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! let's see if ci passes

(a good next step after this is to change line 5884 to just optim_db, to test all optims, but that may run into failures and can be in a different PR.

@jbschlosser jbschlosser added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 7, 2024
@gambiTarun
Copy link
Contributor Author

Hi @janeyx99, it seems there are some failed jobs. How do I fix them?

def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info):
optimizer_ctor = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
device, dtype, optim_info, skip=("differentiable"))
device, dtype, optim_info, skip=("differentiable",))

make it a tuple! right now it’s just a string

self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"fused": True})
optimizer_ctor = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here—try running locally first :)

@gambiTarun
Copy link
Contributor Author

Thank you for the pointer @janeyx99! I couldn't run the test locally due to environment setup issues before. I fixed it now.

This resolves test_params_invalidated_with_grads_invalidated_between_unscale_and_step, but I see test_grad_scaling_autocast failing with the following traceback

(torch-dev) tarunbirgambhir@Tarunbirs-MBP pytorch % pytest test/test_torch.py -k test_grad_scaling_autocast
========================================================================================= test session starts =========================================================================================
platform darwin -- Python 3.12.3, pytest-8.2.0, pluggy-1.5.0
rootdir: /Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch
configfile: pytest.ini
plugins: hypothesis-6.100.2
collected 1012 items / 1009 deselected / 3 selected                                                                                                                                                   
Running 3 items in this shard

test/test_torch.py FFF                                                                                                                                                                          [100%]

============================================================================================== FAILURES ===============================================================================================
_________________________________________________________________ TestTorchDeviceTypeCPU.test_grad_scaling_autocast_AdamW_cpu_float32 _________________________________________________________________
Traceback (most recent call last):
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
    yield
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 634, in run
    self._callTestMethod(testMethod)
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_utils.py", line 2759, in wrapper
    method(*args, **kwargs)
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_optimizers.py", line 211, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_device_type.py", line 1202, in only_fn
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/test/test_torch.py", line 5924, in test_grad_scaling_autocast
    with context():
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 263, in __exit__
    self._raiseFailure("{} not raised".format(exc_name))
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 200, in _raiseFailure
    raise self.test_case.failureException(msg)
AssertionError: AssertionError not raised

To execute this test, run the following from the base repo dir:
     python test/test_torch.py -k test_grad_scaling_autocast_AdamW_cpu_float32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
_________________________________________________________________ TestTorchDeviceTypeCPU.test_grad_scaling_autocast_Adam_cpu_float32 __________________________________________________________________
Traceback (most recent call last):
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
    yield
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 634, in run
    self._callTestMethod(testMethod)
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_utils.py", line 2759, in wrapper
    method(*args, **kwargs)
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_optimizers.py", line 211, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_device_type.py", line 1202, in only_fn
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/test/test_torch.py", line 5924, in test_grad_scaling_autocast
    with context():
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 263, in __exit__
    self._raiseFailure("{} not raised".format(exc_name))
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 200, in _raiseFailure
    raise self.test_case.failureException(msg)
AssertionError: AssertionError not raised

To execute this test, run the following from the base repo dir:
     python test/test_torch.py -k test_grad_scaling_autocast_Adam_cpu_float32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
__________________________________________________________________ TestTorchDeviceTypeCPU.test_grad_scaling_autocast_SGD_cpu_float32 __________________________________________________________________
Traceback (most recent call last):
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
    yield
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 634, in run
    self._callTestMethod(testMethod)
  File "/Users/tarunbirgambhir/miniforge3/envs/torch-dev/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_utils.py", line 2759, in wrapper
    method(*args, **kwargs)
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_optimizers.py", line 211, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_device_type.py", line 1202, in only_fn
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/test/test_torch.py", line 5926, in test_grad_scaling_autocast
    self._run_scaling_case(
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/test/test_torch.py", line 5874, in _run_scaling_case
    self.assertEqual(c.grad, s.grad, atol=atol, rtol=1e-05)
  File "/Users/tarunbirgambhir/Documents/Projects/OpenSourceContributions/pytorch/torch/testing/_internal/common_utils.py", line 3642, in assertEqual
    raise error_metas.pop()[0].to_error(
AssertionError: Tensor-likes are not close!

Mismatched elements: 3 / 64 (4.7%)
Greatest absolute difference: 0.0015461444854736328 at index (5, 3) (up to 0.001 allowed)
Greatest relative difference: 0.0014592085499316454 at index (5, 5) (up to 1e-05 allowed)

To execute this test, run the following from the base repo dir:
     python test/test_torch.py -k test_grad_scaling_autocast_SGD_cpu_float32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
======================================================================================= short test summary info =======================================================================================
FAILED [0.0627s] test/test_torch.py::TestTorchDeviceTypeCPU::test_grad_scaling_autocast_AdamW_cpu_float32 - AssertionError: AssertionError not raised
FAILED [0.0217s] test/test_torch.py::TestTorchDeviceTypeCPU::test_grad_scaling_autocast_Adam_cpu_float32 - AssertionError: AssertionError not raised
FAILED [0.1594s] test/test_torch.py::TestTorchDeviceTypeCPU::test_grad_scaling_autocast_SGD_cpu_float32 - AssertionError: Tensor-likes are not close!
================================================================================= 3 failed, 1009 deselected in 1.74s ==================================================================================
(torch-dev) tarunbirgambhir@Tarunbirs-MBP pytorch % 

The tests run fine with the original hard-coded optimizer_kwargs. Looks like some kwargs from _get_optim_inputs_including_global_cliquey_kwargs do not pass the tests.

context = contextlib.nullcontext
if optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
from functools import partial
context = partial(self.assertRaises, AssertionError)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like one of the errors is no longer getting raised...which sounds like a good thing!

It'd be good to know which configs do not pass in the failing SGD tests.

Copy link
Contributor Author

@gambiTarun gambiTarun May 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I remove the if block, ie: keep context = contextlib.nullcontext for all cases of optimizer_ctor, I get the following errors for Adam and AdamW optims

AssertionError: Tensor-likes are not close!

Mismatched elements: 64 / 64 (100.0%)
Greatest absolute difference: 5.1396894454956055 at index (6, 6) (up to 0.001 allowed)
Greatest relative difference: 10.834877967834473 at index (6, 6) (up to 1e-05 allowed)

To execute this test, run the following from the base repo dir:
     python test/test_torch.py -k test_grad_scaling_autocast_AdamW_cpu_float32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
AssertionError: Tensor-likes are not close!

Mismatched elements: 64 / 64 (100.0%)
Greatest absolute difference: 5.523576259613037 at index (6, 6) (up to 0.001 allowed)
Greatest relative difference: 28.40498161315918 at index (6, 6) (up to 1e-05 allowed)

To execute this test, run the following from the base repo dir:
     python test/test_torch.py -k test_grad_scaling_autocast_Adam_cpu_float32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Copy link
Contributor Author

@gambiTarun gambiTarun May 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Adam and AdamW tests fail when configs (optim_input.kwargs) are {'lr': 0.01, 'fused': False}, {'lr': 0.01, 'fused': True}). This is with the AssertionError context.

The SGD optimizer test fails for optim_input.kwargs == {'weight_decay': 0.1, 'maximize': True, 'fused': True}.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The absolute difference of 5 is pretty major and signifies something seems wrong.

The Adam and AdamW tests fail when configs (optim_input.kwargs) are {'lr': 0.01, 'fused': False}, {'lr': 0.01, 'fused': True}). This is with the AssertionError context.

When you say "with the AssertionError context", do you mean with the context enabled? If so, I am not surprised that these two basic configs now pass + fail the assertion error. Or do you mean that these configs are what causes the major errors?

Ah, is that the only SGD config failing? If you comment it out, does the rest of the test pass?

And are these all for CPU?

Copy link
Contributor Author

@gambiTarun gambiTarun May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wasn't clear about the failing cases.

  1. When the context is partial(self.assertRaises, AssertionError) for Adam and AdamW, the tests fail for configs {'lr': 0.01, 'fused': False}, {'lr': 0.01, 'fused': True} with the error AssertionError: AssertionError not raised.
  2. When the context is contextlib.nullcontext for Adam and AdamW, the tests fail for all the configs with the error AssertionError: Tensor-likes are not close!. In this case, I am confused as to why is the error being thrown even for the configs where case 1 fails, the mismatch elements percentage is around 3.1% for {'lr': 0.01, 'fused': False}, {'lr': 0.01, 'fused': True} but either 39.1% or 100% for other configs.

For SGD, that is the only config that fails. I am able to pass the test if I skip this config.

These all tests are for CPU.
I see that the Native devices are ('cpu', 'cuda', 'meta'). I understand cuda tests are skipped since my machine does not have cuda support. What about meta device type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, debugging these errors will likely be a more involved task. I propose:

  • opening an issue tracking your observations here about this particular test (what fails, what doesn't when the context is on/off).
  • no longer using the _get_optim_inputs_including_global_cliquey_kwargs to get all the kwargs, but just having the kwargs be parametrized like in https://github.com/pytorch/pytorch/blob/main/test/test_optim.py#L689 but with the impl being "forloop", "foreach", "fused".
    • Depending on the impl, you can make the kwargs {"foreach": False}, {"foreach": True}, and {"fused": True} respectively.
  • This way we consolidate into one test but avoid signing up to address all the issues now.

And to answer your q about the meta device, it does input and output shape checking, as the tensors are not actually backed by real storage!

Copy link
Contributor Author

@gambiTarun gambiTarun May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can pass the tests locally by skipping the failing configs from the _get_optim_inputs_including_global_cliquey_kwargs.

Should I try pushing it and checking if the CI passes?

I will raise an issue covering the failing cases with _get_optim_inputs_including_global_cliquey_kwargs configs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am suggesting not using _get_optim_inputs_including_global_cliquey_kwargs at all, but to just use

kwargs = {"foreach": False} if forloop else ...
optim_cls(params, **kwargs)

Once you push that change, the tests should pass since it is not adding more coverage compared to before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yes, please file the issue!

…grad_scaling_autocast for test coverage as some configs fail, raising separate issue for documenting failed configs
@janeyx99
Copy link
Contributor

There may be lint failures--please run lintrunner locally before pushing each commit to save time and resources. Installing lintrunner is:

pip install lintrunner
lintrunner init

and running it is

lintrunner -a

on uncommitted changes. (So after git add but before git commit)

@gambiTarun
Copy link
Contributor Author

Hi @janeyx99, I did run lintrunner on this file before committing. There were no issues.

image

@janeyx99
Copy link
Contributor

wonderful! thank you

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 18, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
…h#125538)

Fixes pytorch#123451 (only addresses test_torch.py cases)

This PR solves the specific task to update `test_grad_scaling_autocast` and `test_params_invalidated_with_grads_invalidated_between_unscale_and_step` in `test/test_torch.py` to use the new OptimizerInfo infrastructure.

I have combined tests that call `_grad_scaling_autocast_test` into one called `test_grad_scaling_autocast` and used `_get_optim_inputs_including_global_cliquey_kwargs` to avoid hard-coded configurations.

```
$ lintrunner test/test_cuda.py
ok No lint issues.
```

Pull Request resolved: pytorch#125538
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: optimizer Related to torch.optim open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update test_cuda.py and test_torch.py optim tests to use OptimizerInfo and optim_db
5 participants