-
Notifications
You must be signed in to change notification settings - Fork 909
/
adjoint.py
288 lines (235 loc) · 13.2 KB
/
adjoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import warnings
import torch
import torch.nn as nn
from .odeint import SOLVERS, odeint
from .misc import _check_inputs, _flat_to_shape, _mixed_norm, _all_callback_names, _all_adjoint_callback_names
class OdeintAdjointMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
adjoint_options, t_requires_grad, *adjoint_params):
ctx.shapes = shapes
ctx.func = func
ctx.adjoint_rtol = adjoint_rtol
ctx.adjoint_atol = adjoint_atol
ctx.adjoint_method = adjoint_method
ctx.adjoint_options = adjoint_options
ctx.t_requires_grad = t_requires_grad
ctx.event_mode = event_fn is not None
with torch.no_grad():
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)
if event_fn is None:
y = ans
ctx.save_for_backward(t, y, *adjoint_params)
else:
event_t, y = ans
ctx.save_for_backward(t, y, event_t, *adjoint_params)
return ans
@staticmethod
def backward(ctx, *grad_y):
with torch.no_grad():
func = ctx.func
adjoint_rtol = ctx.adjoint_rtol
adjoint_atol = ctx.adjoint_atol
adjoint_method = ctx.adjoint_method
adjoint_options = ctx.adjoint_options
t_requires_grad = ctx.t_requires_grad
# Backprop as if integrating up to event time.
# Does NOT backpropagate through the event time.
event_mode = ctx.event_mode
if event_mode:
t, y, event_t, *adjoint_params = ctx.saved_tensors
_t = t
t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
grad_y = grad_y[1]
else:
t, y, *adjoint_params = ctx.saved_tensors
grad_y = grad_y[0]
adjoint_params = tuple(adjoint_params)
##################################
# Set up initial state #
##################################
# [-1] because y and grad_y are both of shape (len(t), *y0.shape)
aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y
aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params
##################################
# Set up backward ODE func #
##################################
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
def augmented_dynamics(t, y_aug):
# Dynamics of the original system augmented with
# the adjoint wrt y, and an integrator wrt t and args.
y = y_aug[1]
adj_y = y_aug[2]
# ignore gradients wrt time and parameters
with torch.enable_grad():
t_ = t.detach()
t = t_.requires_grad_(True)
y = y.detach().requires_grad_(True)
# If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
# doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
# wrt t here means we won't compute that if we don't need it.
func_eval = func(t if t_requires_grad else t_, y)
# Workaround for PyTorch bug #39784
_t = torch.as_strided(t, (), ()) # noqa
_y = torch.as_strided(y, (), ()) # noqa
_params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa
vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
func_eval, (t, y) + adjoint_params, -adj_y,
allow_unused=True, retain_graph=True
)
# autograd.grad returns None if no gradient, set to zero.
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
for param, vjp_param in zip(adjoint_params, vjp_params)]
return (vjp_t, func_eval, vjp_y, *vjp_params)
# Add adjoint callbacks
for callback_name, adjoint_callback_name in zip(_all_callback_names, _all_adjoint_callback_names):
try:
callback = getattr(func, adjoint_callback_name)
except AttributeError:
pass
else:
setattr(augmented_dynamics, callback_name, callback)
##################################
# Solve adjoint ODE #
##################################
if t_requires_grad:
time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
else:
time_vjps = None
for i in range(len(t) - 1, 0, -1):
if t_requires_grad:
# Compute the effect of moving the current time measurement point.
# We don't compute this unless we need to, to save some computation.
func_eval = func(t[i], y[i])
dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
aug_state[0] -= dLd_cur_t
time_vjps[i] = dLd_cur_t
# Run the augmented system backwards in time.
aug_state = odeint(
augmented_dynamics, tuple(aug_state),
t[i - 1:i + 1].flip(0),
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
)
aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state
aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point
if t_requires_grad:
time_vjps[0] = aug_state[0]
# Only compute gradient wrt initial time when in event handling mode.
if event_mode and t_requires_grad:
time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])
adj_y = aug_state[2]
adj_params = aug_state[3:]
return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)
def odeint_adjoint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None,
adjoint_rtol=None, adjoint_atol=None, adjoint_method=None, adjoint_options=None, adjoint_params=None):
# We need this in order to access the variables inside this module,
# since we have no other way of getting variables along the execution path.
if adjoint_params is None and not isinstance(func, nn.Module):
raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they '
'can be specified explicitly via the `adjoint_params` argument. If there are no parameters '
'then it is allowable to set `adjoint_params=()`.')
# Must come before _check_inputs as we don't want to use normalised input (in particular any changes to options)
if adjoint_rtol is None:
adjoint_rtol = rtol
if adjoint_atol is None:
adjoint_atol = atol
if adjoint_method is None:
adjoint_method = method
if adjoint_method != method and options is not None and adjoint_options is None:
raise ValueError("If `adjoint_method != method` then we cannot infer `adjoint_options` from `options`. So as "
"`options` has been passed then `adjoint_options` must be passed as well.")
if adjoint_options is None:
adjoint_options = {k: v for k, v in options.items() if k != "norm"} if options is not None else {}
else:
# Avoid in-place modifying a user-specified dict.
adjoint_options = adjoint_options.copy()
if adjoint_params is None:
adjoint_params = tuple(find_parameters(func))
else:
adjoint_params = tuple(adjoint_params) # in case adjoint_params is a generator.
# Filter params that don't require gradients.
oldlen_ = len(adjoint_params)
adjoint_params = tuple(p for p in adjoint_params if p.requires_grad)
if len(adjoint_params) != oldlen_:
# Some params were excluded.
# Issue a warning if a user-specified norm is specified.
if 'norm' in adjoint_options and callable(adjoint_options['norm']):
warnings.warn("An adjoint parameter was passed without requiring gradient. For efficiency this will be "
"excluded from the adjoint pass, and will not appear as a tensor in the adjoint norm.")
# Convert to flattened state.
shapes, func, y0, t, rtol, atol, method, options, event_fn, decreasing_time = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
# Handle the adjoint norm function.
state_norm = options["norm"]
handle_adjoint_norm_(adjoint_options, shapes, state_norm)
ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol,
adjoint_method, adjoint_options, t.requires_grad, *adjoint_params)
if event_fn is None:
solution = ans
else:
event_t, solution = ans
event_t = event_t.to(t)
if decreasing_time:
event_t = -event_t
if shapes is not None:
solution = _flat_to_shape(solution, (len(t),), shapes)
if event_fn is None:
return solution
else:
return event_t, solution
def find_parameters(module):
assert isinstance(module, nn.Module)
# If called within DataParallel, parameters won't appear in module.parameters().
if getattr(module, '_is_replica', False):
def find_tensor_attributes(module):
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) and v.requires_grad]
return tuples
gen = module._named_members(get_members_fn=find_tensor_attributes)
return [param for _, param in gen]
else:
return list(module.parameters())
def handle_adjoint_norm_(adjoint_options, shapes, state_norm):
"""In-place modifies the adjoint options to choose or wrap the norm function."""
# This is the default adjoint norm on the backward pass: a mixed norm over the tuple of inputs.
def default_adjoint_norm(tensor_tuple):
t, y, adj_y, *adj_params = tensor_tuple
# (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
return max(t.abs(), state_norm(y), state_norm(adj_y), _mixed_norm(adj_params))
if "norm" not in adjoint_options:
# `adjoint_options` was not explicitly specified by the user. Use the default norm.
adjoint_options["norm"] = default_adjoint_norm
else:
# `adjoint_options` was explicitly specified by the user...
try:
adjoint_norm = adjoint_options['norm']
except KeyError:
# ...but they did not specify the norm argument. Back to plan A: use the default norm.
adjoint_options['norm'] = default_adjoint_norm
else:
# ...and they did specify the norm argument.
if adjoint_norm == 'seminorm':
# They told us they want to use seminorms. Slight modification to plan A: use the default norm,
# but ignore the parameter state
def adjoint_seminorm(tensor_tuple):
t, y, adj_y, *adj_params = tensor_tuple
# (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
return max(t.abs(), state_norm(y), state_norm(adj_y))
adjoint_options['norm'] = adjoint_seminorm
else:
# And they're using their own custom norm.
if shapes is None:
# The state on the forward pass was a tensor, not a tuple. We don't need to do anything, they're
# already going to get given the full adjoint state as (t, y, adj_y, adj_params)
pass # this branch included for clarity
else:
# This is the bit that is tuple/tensor abstraction-breaking, because the odeint machinery
# doesn't know about the tupled nature of the forward state. We need to tell the user's adjoint
# norm about that ourselves.
def _adjoint_norm(tensor_tuple):
t, y, adj_y, *adj_params = tensor_tuple
y = _flat_to_shape(y, (), shapes)
adj_y = _flat_to_shape(adj_y, (), shapes)
return adjoint_norm((t, *y, *adj_y, *adj_params))
adjoint_options['norm'] = _adjoint_norm