-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
graph.py
630 lines (528 loc) · 24.5 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
import torch
import contextlib
from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List, Set
from torch.utils.hooks import RemovableHandle
from torch.utils._python_dispatch import TorchDispatchMode
from collections import defaultdict
import weakref
__all__ = [
"saved_tensors_hooks",
"save_on_cpu",
"disable_saved_tensors_hooks",
"register_multi_grad_hook",
"allow_mutation_on_saved_tensors",
]
class saved_tensors_hooks():
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
Use this context-manager to define how intermediary results of an operation
should be packed before saving, and unpacked on retrieval.
In that context, the ``pack_hook`` function will be called everytime an
operation saves a tensor for backward (this includes intermediary results
saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation). The output of
``pack_hook`` is then stored in the computation graph instead of the
original tensor.
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
namely when executing :func:`torch.Tensor.backward()` or
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
returned by ``pack_hook`` and should return a tensor which has the same
content as the original tensor (passed as input to the corresponding
``pack_hook``).
The hooks should have the following signatures:
pack_hook(tensor: Tensor) -> Any
unpack_hook(Any) -> Tensor
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
of value, size, dtype and device.
Example::
>>> def pack_hook(x):
... print("Packing", x)
... return x
>>>
>>> def unpack_hook(x):
... print("Unpacking", x)
... return x
>>>
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
... y = a * b
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
.. warning ::
Performing an inplace operation on the input to either hooks may lead
to undefined behavior.
.. warning ::
Only one pair of hooks is allowed at a time. When recursively nesting this
context-manager, only the inner-most pair of hooks will be applied.
"""
def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]):
self.pack_hook = pack_hook
self.unpack_hook = unpack_hook
def __enter__(self):
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook)
def __exit__(self, *args: Any):
torch._C._autograd._pop_saved_tensors_default_hooks()
class save_on_cpu(saved_tensors_hooks):
"""Context-manager under which tensors saved by the forward pass will be
stored on cpu, then retrieved for backward.
When performing operations within this context manager, intermediary
results saved in the graph during the forward pass will be moved to CPU,
then copied back to the original device when needed for the backward pass.
If the graph was already on CPU, no tensor copy is performed.
Use this context-manager to trade compute for GPU memory usage (e.g.
when your model doesn't fit in GPU memory during training).
Args:
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
during packing and copied to GPU asynchronously during unpacking.
Defaults to ``False``.
Also see :ref:`cuda-memory-pinning`.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda")
>>>
>>> def f(a, b, c):
... prod_1 = a * b # a and b are saved on GPU
... with torch.autograd.graph.save_on_cpu():
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
... y = prod_2 * a # prod_2 and a are saved on GPU
... return y
>>>
>>> y = f(a, b, c)
>>> del a, b, c # for illustration only
>>> # the content of a, b, and prod_2 are still alive on GPU
>>> # the content of prod_1 and c only live on CPU
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
>>> # all intermediary tensors are released (deleted) after the call to backward
"""
def __init__(self, pin_memory=False):
def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())
packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
packed.copy_(tensor)
return (tensor.device, packed)
def unpack_from_cpu(packed):
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
super().__init__(pack_to_cpu, unpack_from_cpu)
# NOTE: [new checkpoint mechanism]
#
# Contents:
# - Definition
# - Mechanism design
# - Example
#
# Definition: Checkpointing
# =========================
#
# We define checkpoint as a context manager such that any variables that
# were saved by forward under this context AND remain saved at the point
# in time when we exit the context are cleared and marked as needed during
# recomputation.
#
# In other words, if saved tensors are manually cleared, e.g. by running
# backward, checkpoint will ignore those tensors because it only cares about
# the set of tensors saved at the point in time when the context exits.
#
# with checkpoint():
# y = x.sin() # saves x
# z = y.exp() # saves z
# torch.autograd.grad(z, z) # clears z, only x remains saved
# # As we exit, clears x only
#
# This may also mean that we cannot simply halt execution early as soon as we've
# saved the right number of buffers.
#
# Special handling of input for the nested case
# ---------------------------------------------
#
# There is some specially handling for the nested case: the inputs to
# are treated as saved variables in the parent context.
#
# with checkpoint0():
# with checkpoint1(): # saves `y` in check0
# y = f(x) # f's saved variables are cleared by check1
# with checkpoint2(): # saves `z` in check0
# z = g(y) # g's saved variables are cleared by check1
# # exiting check0, clears `y` and `z`
# # whatever f and g save are hidden
#
# NB: We never need to recompute function until we have finished running it
#
# TODO: Handling of free variables.
# The current stack was created to recompute values for some call to checkpoint.
# If that checkpointed function calls checkpoint one or more times (possibly
# in a nested way), the inputs to the top-most checkpoints are cleared, so if we
# detect that we are the direct children to the ambient frame, we save its inputs
#
# We can register as many hooks as we want, they all do the same thing
# Backward usually happens outside of the context of any checkpoint anyway
# so we'll at least need to call this once per recomputation stack
#
# We can reuse the checkpoint code because what we are doing is very similar
# to checkpointing, we want to manage the tensors saved and collect the ones
# that remain alive at the very end.
# Creating a new CheckpointStack creates an ambient frame which manages this
#
# Sketch of the mechanism
# =======================
# TODO: Why we need a stack of stacks
#
# Demonstrating the mechanism with an example
# -------------------------------------------
# TODO: reference cycle concerns
#
class CheckpointFrame():
def __init__(self, parent, fn):
self.parent = weakref.ref(parent) if parent is not None else parent
self.fn = fn
self.exited = False
self.needed_counter = 0
# Assume that entries are added/removed determinsitically
# and that removal preserves order
self.saved_tensors: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.child_inputs: List[Tuple[Any, ...]] = []
# Set this when we leave
self.num_needs_recompute = 0
self.idx_of_input_in_parent = None
self.num_child_inputs_need_recompute = 0
# During recomputation
self.recomputed: List[torch.Tensor] = []
self.recomputed_child_inputs: List[Tuple[Any, ...]] = []
class CheckpointStack():
def __init__(self, parent, target_frame=None):
self.parent: CheckpointStack = parent
self.target_frame: CheckpointFrame = target_frame
self.frames = [CheckpointFrame(parent=None, fn=None)]
checkpoint_stacks: List[CheckpointStack] = []
class _checkpoint_hook(saved_tensors_hooks):
def __init__(self):
def pack_hook(x):
# Snapshot the state of the current checkpoint stack
handle = _Handle()
frames = tuple(checkpoint_stacks[-1].frames)
frames[-1].saved_tensors[handle] = x
frames[-1].needed_counter += 1
return handle, frames, frames[-1].needed_counter - 1
def unpack_hook(saved):
handle, frames, idx = saved
# backward was called before leaving checkpoint context
if not frames[-1].exited:
assert handle in frames[-1].saved_tensors
return frames[-1].saved_tensors[handle]
# TODO: give a nice error for when we backwarded during the context
assert len(frames[-1].recomputed) <= frames[-1].num_needs_recompute
if len(frames[-1].recomputed) < frames[-1].num_needs_recompute:
# The first frame is always the ambient frame
for frame in frames[1:]:
if (frame.exited is False
or (len(frame.recomputed) == frame.num_needs_recompute
and len(frame.recomputed_child_inputs) == frame.num_child_inputs_need_recompute)):
continue
assert frame.parent is not None and frame.parent() is not None
parent = frame.parent()
inps = parent.recomputed_child_inputs if parent.exited else parent.child_inputs
args, kwargs = inps[frame.idx_of_input_in_parent]
with torch.autograd.enable_grad():
# Do recomputation in a fresh checkpoint stack
_checkpoint(frame.fn, *args, target_frame=frame, **kwargs)
ret = frames[-1].recomputed[idx]
return ret
super().__init__(pack_hook, unpack_hook)
def get_wrapped_fn(fn):
# Capture the current context, so we can replay it
def wrapped(*args, **kwargs):
return fn(*args, **kwargs)
return wrapped
def _checkpoint(fn, *args, target_frame: CheckpointFrame = None, **kwargs):
needs_to_pop = False
if len(checkpoint_stacks) == 0 or target_frame is not None:
needs_to_pop = True
parent = None if len(checkpoint_stacks) == 0 else checkpoint_stacks[-1]
checkpoint_stacks.append(CheckpointStack(parent=parent, target_frame=target_frame))
curr_stack = checkpoint_stacks[-1]
if target_frame is None:
# Create a proper checkpoint frame and append it to the current stack
wrapped_fn = get_wrapped_fn(fn)
curr_frame = CheckpointFrame(parent=curr_stack.frames[-1], fn=wrapped_fn)
curr_stack.frames.append(curr_frame)
else:
# Don't need a checkpoint frame if we're starting a stack for recomputation
assert len(curr_stack.frames) == 1
curr_frame = curr_stack.frames[0]
if (curr_stack.target_frame is not None
and curr_frame.parent is not None
and curr_frame.parent() is curr_stack.frames[0]):
# Top-level checkpoints save their inputs to the target frame during recomputation
assert len(curr_stack.frames) == 2
curr_stack.target_frame.recomputed_child_inputs.append((args, kwargs))
with _checkpoint_hook():
# We can register this hook as many times as we want, it only reads global state
ret = fn(*args, **kwargs)
curr_frame.num_needs_recompute = len(curr_frame.saved_tensors)
if target_frame is None:
# Children register their inputs to the parent checkpoint to be cleared
# when the parent exits and restored when the parent is recomputed
inputs = (args, kwargs)
assert curr_frame.parent is not None
parent = curr_frame.parent()
assert parent is not None
curr_frame.idx_of_input_in_parent = len(parent.child_inputs)
parent.child_inputs.append(inputs)
parent.num_child_inputs_need_recompute += 1
else:
# Stack for recomputation is getting destroyed, save into the target frame
assert len(target_frame.recomputed) == 0
detached_saved = [t.detach() for t in curr_frame.saved_tensors.values()]
target_frame.recomputed.extend(detached_saved)
curr_frame.child_inputs.clear()
curr_frame.saved_tensors.clear()
curr_frame.exited = True
curr_stack.frames.pop()
if needs_to_pop:
checkpoint_stacks.pop()
return ret
@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message):
"""Context-manager that disables the saved tensors default hooks feature.
Useful for if you are creating a feature that does not work with saved
tensors default hooks.
Args:
error_message (str): When saved tensors default hooks are used when they
have been are disabled, a RuntimeError with this
error message gets raised.
Example::
>>> message = "saved tensors default hooks are disabled"
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
... # Raises RuntimeError: saved tensors default hooks are disabled
... with torch.autograd.graph.save_on_cpu():
... pass
"""
yield
return
try:
maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
torch._C._autograd._saved_tensors_hooks_disable(error_message)
yield
finally:
# See NOTE: [disabled_error_message invariant]
if maybe_prev_message is None:
torch._C._autograd._saved_tensors_hooks_enable()
else:
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
def register_multi_grad_hook(tensors: Sequence[torch.Tensor], fn: Callable[[Sequence[Optional[torch.Tensor]]], None]):
r"""Registers a multi-grad backward hook.
The hook will be called after gradients with respect to every tensor in
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
is not part of the graph, or if a tensor is not needed to compute the gradients
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
this tensor will be ignored and the hook will not wait for its gradient to be
computed.
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
called with those gradients. ``None`` will be passed for tensors that did not
have their gradients computed.
The hook should not modify its arguments.
This function returns a handle with a method ``handle.remove()`` that removes the hook.
Example::
>>> import torch
>>>
>>> a = torch.rand(2, 3, requires_grad=True)
>>> b = torch.rand(2, 3, requires_grad=True)
>>> c = a * b
>>> d = a * b
>>>
>>> def fn(grads):
... print([g is not None for g in grads])
...
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
>>>
>>> c.sum().backward(retain_graph=True)
[True, True, True, False]
>>> c.sum().backward(inputs=(a,), retain_graph=True)
[True, False, True, False]
>>>
"""
count: Dict[int, int] = dict()
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
def get_grad_fn(t):
# or grad accumulator
if t.requires_grad and t.grad_fn is None:
return t.clone().grad_fn.next_functions[0][0]
else:
return t.grad_fn
grad_fns = list(map(get_grad_fn, tensors))
def get_inner_hook(idx):
def inner_hook(grad: torch.Tensor):
nonlocal count, nb_calls, buffer
id = torch._C._current_graph_task_id()
assert id != -1, "expected this hook to be called inside a backward call"
count[id] = count.get(id, 0)
buffer[id] = buffer.get(id, [None] * len(tensors))
if count[id] == 0:
# On the first call, compute the actual nb_calls and buffer
nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined]
buffer[id][idx] = grad
count[id] += 1
if count[id] == nb_calls:
fn(buffer[id])
del count[id]
del buffer[id]
return inner_hook
class Handle(RemovableHandle):
handles: Tuple[RemovableHandle, ...]
def __init__(self, handles: Tuple[RemovableHandle, ...]):
self.handles = handles
def remove(self):
for handle in self.handles:
handle.remove()
def __getstate__(self):
return self.handles
def __setstate__(self, state):
self.handles = state
handles: List[RemovableHandle] = []
for i, t in enumerate(tensors):
handles.append(t.register_hook(get_inner_hook(i)))
return Handle(tuple(handles))
# NOTE [Allow mutation on tensors saved for backward]
#
# 1. Tensor gets saved for backward
# - remember the python object id and the version of the tensor
# - remember aliasing information (data_ptr of base + version)
# - save the original so we control its lifetime
# 2. Any time a tensor gets in-placed
# - for each tensor aliased to it:
# - check using its object id and version to see if it has been saved
# - if it has been saved, clone it
# - delete the reference to the original
# 3. during backward
# - if the clone exists, the tensor must've been modified in-place
_allow_mutation_on_saved_tensors_enabled = False
def _get_tid(t) -> Tuple[int, int, int]:
return (id(t), t.data_ptr(), t._version)
def _get_sid(t) -> Tuple[int, int]:
return (t.data_ptr(), t._version)
class _Handle():
pass
class _swap_with_cloned(saved_tensors_hooks):
def __init__(self, ctx):
def pack_hook(t):
tid = _get_tid(t)
sid = _get_sid(t)
# Tensors saved for backward have an entry in _tid_to_weakhandle
handle: Optional[_Handle] = None
# Save aliasing information
ctx.sid_to_tid[sid].add(tid)
# NB: The same tensor (of the same version) can be saved multiple times
if tid not in ctx.tid_to_weakhandle:
handle = _Handle()
ctx.tid_to_weakhandle[tid] = handle
ctx.original[handle] = t
else:
# Store an additional strong reference to the handle
handle = ctx.tid_to_weakhandle[tid]
return handle
def unpack_hook(tup):
handle = tup
error_msg = (
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
"in which the graph was originally recorded.")
assert _allow_mutation_on_saved_tensors_enabled, error_msg
if handle in ctx.cloned:
res = ctx.cloned[handle]
else:
assert handle in ctx.original, error_msg
res = ctx.original[handle]
return res
super().__init__(pack_hook, unpack_hook)
class _CloneArgBeforeMutateMode(TorchDispatchMode):
def __init__(self, ctx):
self.ctx = ctx
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
for idx, arg in enumerate(func._schema.arguments):
if arg.alias_info is not None and arg.alias_info.is_write:
t = kwargs["out"] if arg.is_out else args[idx]
tid = _get_tid(t)
sid = _get_sid(t)
ctx = self.ctx
if sid in ctx.sid_to_tid:
for tid in ctx.sid_to_tid[sid]:
if tid not in ctx.tid_to_weakhandle:
# We know that if tid is in sid_to_tid, then it must also be in
# tid_to_weakhandle. However, it is possible for the tensor to be
# saved at one point, but cleared by backward before it is modified
# in-place. Consider the following example:
#
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
# >>> out = (a**2).sum()
# >>> out.backward()
# >>> a.sin_()
continue
handle = ctx.tid_to_weakhandle[tid]
if handle in ctx.cloned:
# The same exact tensor has been cloned already
continue
ctx.cloned[handle] = ctx.original[handle].clone()
del ctx.original[handle]
rs = func(*args, **kwargs)
return rs
class _AllowMutationOnSavedContext():
def __init__(self):
self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(set)
def clear(self):
self.cloned.clear()
self.original.clear()
self.tid_to_weakhandle.clear()
self.sid_to_tid.clear()
@contextlib.contextmanager
def allow_mutation_on_saved_tensors():
"""Context manager under which mutating tensors saved for backward is allowed
Under this context manager, tensors saved for backward are cloned on mutation,
so the original version can still be used during backward. Normally, mutating a tensor
saved for backward will result in an error raised when it's used during backward.
To ensure the correct behavior, both the forward and backward should be run under
the same context manager.
returns:
An _AllowMutationOnSavedContext object storing the state managed by this
context manager. This object can be useful for debugging purposes. The state
managed by the context manager is automatically cleared upon exiting.
Example::
>>> import torch
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
... # forward
... a = torch.ones(2, 3, requires_grad=True)
... b = a.clone()
... out = (b**2).sum()
... b.sin_()
... # backward
... out.sum().backward()
...
tensor([[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
"""
global _allow_mutation_on_saved_tensors_enabled
ctx = _AllowMutationOnSavedContext()
with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
try:
if _allow_mutation_on_saved_tensors_enabled:
raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested")
_allow_mutation_on_saved_tensors_enabled = True
yield ctx
finally:
ctx.clear()
_allow_mutation_on_saved_tensors_enabled = False