Skip to content

Commit ffaa1ce

Browse files
committed
[RFC] Separate CPU offload activation to its own wrapper
Pull Request resolved: #85459 Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept. Now, offload to CPU + checkpoint can be composed together, such as ``` apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer)) model = offload_wrapper(model) ``` Will polish / add tests if this proposal sounds good. ghstack-source-id: f5c0100 Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
1 parent 3b1ec75 commit ffaa1ce

File tree

3 files changed

+133
-92
lines changed

3 files changed

+133
-92
lines changed

test/distributed/fsdp/test_checkpoint_wrapper.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
99
checkpoint_wrapper,
10+
offload_wrapper,
1011
apply_activation_checkpointing,
1112
CheckpointWrapper,
1213
CheckpointImpl
@@ -21,6 +22,9 @@
2122

2223
import unittest
2324

25+
_SAVED_PREFIX = '_saved_'
26+
GRAD_FN_NEXT_FUNCTIONS = 'next_functions'
27+
2428
class CheckpointWrapperTest(TestCase):
2529
def setUp(self):
2630
super().setUp()
@@ -72,7 +76,7 @@ def forward(self, a, b, c=None, d=None, **kwargs):
7276
for wrapper in [
7377
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
7478
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
75-
partial(checkpoint_wrapper, offload_to_cpu=True),
79+
offload_wrapper,
7680
]:
7781
with self.subTest(wrapper=wrapper):
7882
model = wrapper(MyModel())
@@ -211,6 +215,7 @@ def check_fn(l):
211215
for wrapper in [
212216
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
213217
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
218+
offload_wrapper,
214219
]:
215220
model = MyModel()
216221
if n_linear is None:
@@ -276,7 +281,7 @@ def testing_cpu_offload_unpack_hook(packed):
276281
orig_init = torch.autograd.graph.saved_tensors_hooks.__init__
277282
torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init
278283

279-
model = checkpoint_wrapper(model, offload_to_cpu=True)
284+
model = offload_wrapper(model)
280285

281286
inp = torch.randn(3, 10, device='cuda')
282287
loss = model(inp).sum()
@@ -286,7 +291,7 @@ def testing_cpu_offload_unpack_hook(packed):
286291

287292
def dfs(grad_fn):
288293
for e in dir(grad_fn):
289-
if not e.startswith('_saved_'):
294+
if not e.startswith(_SAVED_PREFIX):
290295
continue
291296

292297
saved = getattr(grad_fn, e)
@@ -295,7 +300,7 @@ def dfs(grad_fn):
295300
nonlocal offload_verified
296301
offload_verified = True
297302

298-
if hasattr(grad_fn, 'next_functions'):
303+
if hasattr(grad_fn, GRAD_FN_NEXT_FUNCTIONS):
299304
for next_grad_fn, _ in grad_fn.next_functions:
300305
dfs(next_grad_fn)
301306

test/distributed/fsdp/test_fsdp_checkpoint.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1414
checkpoint_wrapper,
15+
offload_wrapper,
1516
)
1617
from torch.testing._internal.common_distributed import (
1718
skip_if_lt_x_gpu,
@@ -65,9 +66,11 @@ def __init__(
6566
l3 = nn.Linear(3, 3).cuda()
6667

6768
if checkpoint_layer:
68-
ckpt_wrapper = partial(
69-
checkpoint_wrapper, offload_to_cpu=offload_activations
70-
)
69+
if offload_activations:
70+
ckpt_wrapper = offload_wrapper
71+
else:
72+
ckpt_wrapper = checkpoint_wrapper
73+
pass
7174

7275
l1 = ckpt_wrapper(l1)
7376
l2 = ckpt_wrapper(l2)
@@ -110,11 +113,13 @@ def _verify_parity(self, losses, outputs, models):
110113
@parametrize("offload_activations", [True, False])
111114
def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
112115
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
116+
if offload_activations:
117+
checkpoint_wrapper = offload_wrapper
118+
113119
ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
114120
TestFSDPCheckpoint.SequentialModule(
115121
wrap_fsdp=True, cpu_offload=cpu_offload
116122
),
117-
offload_to_cpu=offload_activations,
118123
)
119124
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
120125
inner_ckpt = TestFSDPCheckpoint.SequentialModule(
@@ -166,13 +171,15 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
166171
# Runs FSDP with no checkpointing
167172
fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
168173
# Runs checkpoint-wrapped FSDP
174+
if offload_activations:
175+
checkpoint_wrapper = offload_wrapper
176+
169177
checkpointed_fsdp = checkpoint_wrapper(
170178
FSDP(deepcopy(seq), cpu_offload=cpu_offload),
171-
offload_to_cpu=offload_activations,
172179
)
173180
# Runs FSDP-wrapped checkpointed module
174181
fsdp_wrapped_checkpoint = FSDP(
175-
checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
182+
checkpoint_wrapper(deepcopy(seq)),
176183
cpu_offload=cpu_offload,
177184
)
178185
# Runs FSDP with manual calls to checkpoint.

torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py

Lines changed: 111 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
from torch.autograd.graph import save_on_cpu
88
from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
9-
from torch.utils.checkpoint import checkpoint
9+
from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
1010

1111
_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
1212

@@ -15,42 +15,14 @@ class CheckpointImpl(Enum):
1515
NO_REENTRANT = auto()
1616

1717

18-
class CheckpointWrapper(torch.nn.Module):
18+
class ActivationWrapper(torch.nn.Module):
1919
"""
20-
An nn.Module that wraps another nn.Module with checkpointing. Note that this
21-
module is not meant to be used directly, but instead it is to be used
22-
through the ``checkpoint_wrapper`` function.
20+
Base class for Activation Checkpoint and Activation Offload.
21+
Not meant to be instantiated directly.
2322
"""
24-
def __init__(
25-
self,
26-
mod: torch.nn.Module,
27-
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT,
28-
offload_to_cpu: bool = False,
29-
checkpoint_fn=None,
30-
*checkpoint_fn_args,
31-
**checkpoint_fn_kwargs,
32-
):
23+
def __init__(self, mod):
3324
super().__init__()
3425
self._checkpoint_wrapped_module = mod
35-
self.checkpoint_impl = checkpoint_impl
36-
self.offload_to_cpu = offload_to_cpu
37-
if self.offload_to_cpu:
38-
self.checkpoint_fn = None
39-
else:
40-
if checkpoint_fn is None:
41-
# use torch.utils.checkpoint
42-
self.checkpoint_fn = partial(
43-
checkpoint,
44-
use_reentrant=(
45-
self.checkpoint_impl == CheckpointImpl.REENTRANT
46-
),
47-
)
48-
else:
49-
self.checkpoint_fn = partial(
50-
checkpoint_fn,
51-
*checkpoint_fn_args,
52-
**checkpoint_fn_kwargs,
53-
)
5426
# state_dict post hook to remove prefix to allow loading into a
5527
# non-checkpoint wrapped module.
5628
self._register_state_dict_hook(self._post_state_dict_hook)
@@ -60,6 +32,9 @@ def __init__(
6032
self._pre_load_state_dict_hook, with_module=True
6133
)
6234

35+
def forward(self, *args, **kwargs):
36+
raise ValueError("Subclasses should implement forward().")
37+
6338
def __getattr__(self, name: str) -> Any:
6439
"""Forward missing attributes to wrapped module."""
6540
try:
@@ -71,44 +46,6 @@ def __getitem__(self, key: int) -> Any:
7146
"""Forward indexing calls in case the module is a nn.Sequential."""
7247
return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator]
7348

74-
def forward(self, *args, **kwargs):
75-
if self.offload_to_cpu:
76-
with save_on_cpu(pin_memory=True):
77-
return self._checkpoint_wrapped_module(*args, **kwargs)
78-
else:
79-
# Support keyword arguments for reentrant checkpoint. Note that this
80-
# only works if user has specified self.checkpoint_impl and is not
81-
# using their own custom checkpoint_fn.
82-
if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}:
83-
# Pack the args and kwargs
84-
flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs)
85-
86-
# Function that only takes (packed) args, but can unpack them
87-
# into the original args and kwargs for the checkpointed
88-
# function, and runs that function.
89-
def my_function(*inputs):
90-
# unpack back into args and kwargs
91-
unpacked_args, unpacked_kwargs = _unpack_kwargs(
92-
inputs, kwarg_keys
93-
)
94-
# run original module
95-
return self._checkpoint_wrapped_module(
96-
*unpacked_args, **unpacked_kwargs
97-
)
98-
99-
# Pass the function that only takes packed args into reentrant
100-
# checkpoint API.
101-
return self.checkpoint_fn( # type: ignore[misc]
102-
my_function,
103-
*flat_args,
104-
)
105-
else:
106-
return self.checkpoint_fn( # type: ignore[misc]
107-
self._checkpoint_wrapped_module,
108-
*args,
109-
**kwargs
110-
)
111-
11249
def named_parameters(
11350
self,
11451
*args,
@@ -155,10 +92,107 @@ def _pre_load_state_dict_hook(
15592
_replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}.")
15693

15794

95+
class OffloadWrapper(ActivationWrapper):
96+
def __init__(self, mod):
97+
super().__init__(mod)
98+
99+
def forward(self, *args, **kwargs):
100+
with save_on_cpu(pin_memory=True):
101+
return self._checkpoint_wrapped_module(*args, **kwargs)
102+
103+
104+
class CheckpointWrapper(ActivationWrapper):
105+
"""
106+
An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing. Note that this
107+
module is not meant to be used directly, but instead it is to be used
108+
through the ``checkpoint_wrapper`` function.
109+
"""
110+
def __init__(
111+
self,
112+
mod: torch.nn.Module,
113+
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT,
114+
checkpoint_fn=None,
115+
*checkpoint_fn_args,
116+
**checkpoint_fn_kwargs,
117+
):
118+
super().__init__(mod)
119+
self.checkpoint_impl = checkpoint_impl
120+
if checkpoint_fn is None:
121+
# use torch.utils.checkpoint
122+
self.checkpoint_fn = partial(
123+
torch_utils_checkpoint,
124+
use_reentrant=(
125+
self.checkpoint_impl == CheckpointImpl.REENTRANT
126+
),
127+
)
128+
else:
129+
# Construct user-specified checkpoint function.
130+
self.checkpoint_fn = partial(
131+
checkpoint_fn,
132+
*checkpoint_fn_args,
133+
**checkpoint_fn_kwargs,
134+
)
135+
136+
def forward(self, *args, **kwargs):
137+
# Support keyword arguments for reentrant checkpoint. Note that this
138+
# only works if user has specified self.checkpoint_impl and is not
139+
# using their own custom checkpoint_fn.
140+
if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}:
141+
# Pack the args and kwargs
142+
flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs)
143+
144+
# Function that only takes (packed) args, but can unpack them
145+
# into the original args and kwargs for the checkpointed
146+
# function, and runs that function.
147+
def my_function(*inputs):
148+
# unpack back into args and kwargs
149+
unpacked_args, unpacked_kwargs = _unpack_kwargs(
150+
inputs, kwarg_keys
151+
)
152+
# run original module
153+
return self._checkpoint_wrapped_module(
154+
*unpacked_args, **unpacked_kwargs
155+
)
156+
157+
# Pass the function that only takes packed args into reentrant
158+
# checkpoint API.
159+
return self.checkpoint_fn( # type: ignore[misc]
160+
my_function,
161+
*flat_args,
162+
)
163+
else:
164+
return self.checkpoint_fn( # type: ignore[misc]
165+
self._checkpoint_wrapped_module,
166+
*args,
167+
**kwargs
168+
)
169+
170+
def offload_wrapper(
171+
module: torch.nn.Module
172+
) -> torch.nn.Module:
173+
"""
174+
A convenience wrapper for activation offloading to CPU. If the module is wrapped
175+
with this function, all subsequent calls to the module will automatically
176+
offload intermediate activations to the CPU. Wrappers with activation
177+
offload can be composed with ones that do recomputation-based
178+
checkpoint to trade off increased compute versus increased CPU
179+
memory usage and additional H2D transfers.
180+
Usage::
181+
offloaded_module = offload_wrapper(module)
182+
outputs = checkpointed_module(inputs)
183+
Args:
184+
module (nn.Module):
185+
The module to be wrapped
186+
Returns:
187+
(nn.Module):
188+
Wrapped module
189+
"""
190+
return OffloadWrapper(module)
191+
192+
158193
def checkpoint_wrapper(
159194
module: torch.nn.Module,
160195
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT,
161-
offload_to_cpu: bool = False,
162196
checkpoint_fn=None,
163197
*checkpoint_fn_args,
164198
**checkpoint_fn_kwargs,
@@ -181,14 +215,6 @@ def checkpoint_wrapper(
181215
specified. Note that for implementations using reentrant checkpoint
182216
from ``torch.utils.checkpoint``, keyword arguments will only be
183217
supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`.
184-
offload_to_cpu (Optional[bool]):
185-
Whether to offload activations of this wrapped module to CPU. Note
186-
that if this is specified, ``checkpoint_impl`` and ``checkpoint_fn``
187-
arguments will be ignored in favor of the activations being
188-
offloaded to CPU. Default is ``False``. Wrappers with activation
189-
offload can be composed with ones that do recomputation-based
190-
checkpoint to trade off increased compute versus increased CPU
191-
memory usage and additional H2D transfers.
192218
checkpoint_fn (Optional[Callable]):
193219
Functional checkpoint implementation to use. If this is specified,
194220
it will be used over the default ``torch.utils.checkpoint.checkpoint``
@@ -202,7 +228,7 @@ def checkpoint_wrapper(
202228
"""
203229

204230
return CheckpointWrapper(
205-
module, checkpoint_impl, offload_to_cpu, checkpoint_fn, checkpoint_fn_args, checkpoint_fn_kwargs
231+
module, checkpoint_impl, checkpoint_fn, checkpoint_fn_args, checkpoint_fn_kwargs
206232
)
207233

208234

@@ -219,13 +245,16 @@ def apply_activation_checkpointing(
219245
their checkpoint-wrapped modules.
220246
Note::
221247
This function will not wrap the overall root module. If this is needed, please directly use
222-
:class:`CheckpointWrapper`.
248+
:func:`checkpoint_wrapper` or :func:`offload_wrapper`.
223249
Usage::
224250
model = nn.Sequential(
225251
nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)
226252
)
227253
check_fn = lambda l: isinstance(l, nn.Linear)
254+
# Checkpoint activations
228255
apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
256+
# Or Offload activations to CPU
257+
apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn)
229258
Args:
230259
model (nn.Module):
231260
The model whose submodules should be wrapped with activation checkpointing.

0 commit comments

Comments
 (0)