-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
grad_mode.py
208 lines (157 loc) · 6.62 KB
/
grad_mode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import sys
import torch
import functools
import inspect
from typing import Any, Callable, TypeVar, cast
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled']
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func: F) -> F:
if inspect.isgeneratorfunction(func):
return self._wrap_generator(func)
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with self.__class__():
return func(*args, **kwargs)
return cast(F, decorate_context)
def _wrap_generator(self, func):
"""Wrap each generator invocation with the context manager"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
cls = type(self)
try:
# Issuing `None` to a generator fires it up
with cls():
response = gen.send(None)
while True:
try:
# Forward the response to our caller and get its next request
request = yield response
except GeneratorExit:
# Inform the still active generator about its imminent closure
with cls():
gen.close()
raise
except BaseException:
# Propagate the exception thrown at us by the caller
with cls():
response = gen.throw(*sys.exc_info())
else:
# Pass the last request to the generator and get its response
with cls():
response = gen.send(request)
# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value
return generator_context
def __enter__(self) -> None:
raise NotImplementedError
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError
class no_grad(_DecoratorContextManager):
r"""Context-manager that disabled gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
"""
def __init__(self):
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
def __enter__(self):
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
class enable_grad(_DecoratorContextManager):
r"""Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
or :class:`~set_grad_enabled`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
True
"""
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(True)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
class set_grad_enabled(object):
r"""Context-manager that sets gradient calculation to on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable grad (``True``), or disable
(``False``). This can be used to conditionally enable
gradients.
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
"""
def __init__(self, mode: bool) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(mode)
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)