Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference_mode decorator #68617

Closed
wants to merge 1 commit into from
Closed

Fix inference_mode decorator #68617

wants to merge 1 commit into from

Conversation

milesial
Copy link
Contributor

@milesial milesial commented Nov 18, 2021

This fixes the case when torch.inference_mode is called with mode=False (disabled). When used as a decorator, it ignored the argument and enabled inference mode anyway.

_DecoratorContextManager is changed so that a new instance is a copy instead of a new instance with default parameters.

I also added more tests to cover this case.

Current behaviour:

>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> @torch.inference_mode(mode=False)
... def func(x):
...     return x * x
...
>>> out = func(x)
>>> out.requires_grad
False

New behaviour (fixed):

>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> @torch.inference_mode(mode=False)
... def func(x):
...     return x * x
...
>>> out = func(x)
>>> out.requires_grad
True

@pytorch-probot
Copy link

pytorch-probot bot commented Nov 18, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/milesial/pytorch/blob/6c2acd94c4403739240af6a5f218d3bff4f89d5e/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Nov 18, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 6c2acd9 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@milesial
Copy link
Contributor Author

Not sure why the windows build is failing.

@dagitses dagitses added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 19, 2021
@@ -24,7 +24,7 @@ def __call__(self, func: F) -> F:

@functools.wraps(func)
def decorate_context(*args, **kwargs):
with self.__class__():
with deepcopy(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you detail how deepcopy is different from creating a new instance here?
In particular this wrapper works just fine for set_grad_enabled() that has the same API as inference mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepcopy will also copy the internal state. This is only useful for inference_mode, which has the mode attribute (and parameter to __init__). If we create a new instance from there (with no arguments), then inference_mode will use the default mode and not the one from self, which results in the described bug.

For no_grad, enable_grad and set_grad_enabled this does not change the current behaviour.

I guess copy instead of deepcopy could also work if you see any drawback from deepcopy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho set_grad_enabled is not a subclass of this. So it would have the same issue if it did.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I'm not sure the deepcopy will work fine here as you don't call the init.
Maybe we want to have a method that each subclass defined to clone itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work fine for the other current subclasses (no_grad and enable_grad) because they either don't have an internal state, or if they have, it is set during __enter__, which gets called anyway.

If you expect that in the future there will be another subclass wich does important things in the __init__ and not in the __enter__, then adding clone would be a good option too (by default, it would return a new instance with self.__class__(), and inference_mode would override that).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A benefit of the clone method solution would also be that we could move set_grad_enabled to a subclass of _DecoratorContextManager so that it can also be used as a decorator. Let me know if that is preferred.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I was out of the office.
I think the clone could be a nice thing here indeed. Do you think you can update this PR to add the clone method on the main class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Also squashed and rebased on master

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome thanks!

@facebook-github-bot
Copy link
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

desertfire pushed a commit that referenced this pull request Dec 13, 2021
Summary:
This fixes the case when `torch.inference_mode` is called with `mode=False` (disabled). When used as a decorator, it ignored the argument and enabled inference mode anyway.

`_DecoratorContextManager` is changed so that a new instance is a copy instead of a new instance with default parameters.

I also added more tests to cover this case.

Current behaviour:

```python
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> torch.inference_mode(mode=False)
... def func(x):
...     return x * x
...
>>> out = func(x)
>>> out.requires_grad
False
```

New behaviour (fixed):

```python
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> torch.inference_mode(mode=False)
... def func(x):
...     return x * x
...
>>> out = func(x)
>>> out.requires_grad
True
```

Pull Request resolved: #68617

Reviewed By: mrshenli

Differential Revision: D32958434

Pulled By: albanD

fbshipit-source-id: 133c69970ef8bffb9fc9ab5142dedcffc4c32945
desertfire pushed a commit that referenced this pull request Dec 14, 2021
Summary:
This fixes the case when `torch.inference_mode` is called with `mode=False` (disabled). When used as a decorator, it ignored the argument and enabled inference mode anyway.

`_DecoratorContextManager` is changed so that a new instance is a copy instead of a new instance with default parameters.

I also added more tests to cover this case.

Current behaviour:

```python
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> torch.inference_mode(mode=False)
... def func(x):
...     return x * x
...
>>> out = func(x)
>>> out.requires_grad
False
```

New behaviour (fixed):

```python
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> torch.inference_mode(mode=False)
... def func(x):
...     return x * x
...
>>> out = func(x)
>>> out.requires_grad
True
```

Pull Request resolved: #68617

Reviewed By: mrshenli

Differential Revision: D32958434

Pulled By: albanD

fbshipit-source-id: 133c69970ef8bffb9fc9ab5142dedcffc4c32945
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants