66import torch .nn as nn
77from torch .autograd .graph import save_on_cpu
88from torch .distributed .utils import _pack_kwargs , _replace_by_prefix , _unpack_kwargs
9- from torch .utils .checkpoint import checkpoint
9+ from torch .utils .checkpoint import checkpoint as torch_utils_checkpoint
1010
1111_CHECKPOINT_PREFIX = "_checkpoint_wrapped_module"
1212
@@ -15,42 +15,14 @@ class CheckpointImpl(Enum):
1515 NO_REENTRANT = auto ()
1616
1717
18- class CheckpointWrapper (torch .nn .Module ):
18+ class ActivationWrapper (torch .nn .Module ):
1919 """
20- An nn.Module that wraps another nn.Module with checkpointing. Note that this
21- module is not meant to be used directly, but instead it is to be used
22- through the ``checkpoint_wrapper`` function.
20+ Base class for Activation Checkpoint and Activation Offload.
21+ Not meant to be instantiated directly.
2322 """
24- def __init__ (
25- self ,
26- mod : torch .nn .Module ,
27- checkpoint_impl : CheckpointImpl = CheckpointImpl .REENTRANT ,
28- offload_to_cpu : bool = False ,
29- checkpoint_fn = None ,
30- * checkpoint_fn_args ,
31- ** checkpoint_fn_kwargs ,
32- ):
23+ def __init__ (self , mod ):
3324 super ().__init__ ()
3425 self ._checkpoint_wrapped_module = mod
35- self .checkpoint_impl = checkpoint_impl
36- self .offload_to_cpu = offload_to_cpu
37- if self .offload_to_cpu :
38- self .checkpoint_fn = None
39- else :
40- if checkpoint_fn is None :
41- # use torch.utils.checkpoint
42- self .checkpoint_fn = partial (
43- checkpoint ,
44- use_reentrant = (
45- self .checkpoint_impl == CheckpointImpl .REENTRANT
46- ),
47- )
48- else :
49- self .checkpoint_fn = partial (
50- checkpoint_fn ,
51- * checkpoint_fn_args ,
52- ** checkpoint_fn_kwargs ,
53- )
5426 # state_dict post hook to remove prefix to allow loading into a
5527 # non-checkpoint wrapped module.
5628 self ._register_state_dict_hook (self ._post_state_dict_hook )
@@ -60,6 +32,9 @@ def __init__(
6032 self ._pre_load_state_dict_hook , with_module = True
6133 )
6234
35+ def forward (self , * args , ** kwargs ):
36+ raise ValueError ("Subclasses should implement forward()." )
37+
6338 def __getattr__ (self , name : str ) -> Any :
6439 """Forward missing attributes to wrapped module."""
6540 try :
@@ -71,44 +46,6 @@ def __getitem__(self, key: int) -> Any:
7146 """Forward indexing calls in case the module is a nn.Sequential."""
7247 return self ._checkpoint_wrapped_module .__getitem__ (key ) # type: ignore[operator]
7348
74- def forward (self , * args , ** kwargs ):
75- if self .offload_to_cpu :
76- with save_on_cpu (pin_memory = True ):
77- return self ._checkpoint_wrapped_module (* args , ** kwargs )
78- else :
79- # Support keyword arguments for reentrant checkpoint. Note that this
80- # only works if user has specified self.checkpoint_impl and is not
81- # using their own custom checkpoint_fn.
82- if self .checkpoint_impl == CheckpointImpl .REENTRANT and kwargs != {}:
83- # Pack the args and kwargs
84- flat_args , kwarg_keys = _pack_kwargs (* args , ** kwargs )
85-
86- # Function that only takes (packed) args, but can unpack them
87- # into the original args and kwargs for the checkpointed
88- # function, and runs that function.
89- def my_function (* inputs ):
90- # unpack back into args and kwargs
91- unpacked_args , unpacked_kwargs = _unpack_kwargs (
92- inputs , kwarg_keys
93- )
94- # run original module
95- return self ._checkpoint_wrapped_module (
96- * unpacked_args , ** unpacked_kwargs
97- )
98-
99- # Pass the function that only takes packed args into reentrant
100- # checkpoint API.
101- return self .checkpoint_fn ( # type: ignore[misc]
102- my_function ,
103- * flat_args ,
104- )
105- else :
106- return self .checkpoint_fn ( # type: ignore[misc]
107- self ._checkpoint_wrapped_module ,
108- * args ,
109- ** kwargs
110- )
111-
11249 def named_parameters (
11350 self ,
11451 * args ,
@@ -155,10 +92,107 @@ def _pre_load_state_dict_hook(
15592 _replace_by_prefix (state_dict , prefix , prefix + f"{ _CHECKPOINT_PREFIX } ." )
15693
15794
95+ class OffloadWrapper (ActivationWrapper ):
96+ def __init__ (self , mod ):
97+ super ().__init__ (mod )
98+
99+ def forward (self , * args , ** kwargs ):
100+ with save_on_cpu (pin_memory = True ):
101+ return self ._checkpoint_wrapped_module (* args , ** kwargs )
102+
103+
104+ class CheckpointWrapper (ActivationWrapper ):
105+ """
106+ An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing. Note that this
107+ module is not meant to be used directly, but instead it is to be used
108+ through the ``checkpoint_wrapper`` function.
109+ """
110+ def __init__ (
111+ self ,
112+ mod : torch .nn .Module ,
113+ checkpoint_impl : CheckpointImpl = CheckpointImpl .REENTRANT ,
114+ checkpoint_fn = None ,
115+ * checkpoint_fn_args ,
116+ ** checkpoint_fn_kwargs ,
117+ ):
118+ super ().__init__ (mod )
119+ self .checkpoint_impl = checkpoint_impl
120+ if checkpoint_fn is None :
121+ # use torch.utils.checkpoint
122+ self .checkpoint_fn = partial (
123+ torch_utils_checkpoint ,
124+ use_reentrant = (
125+ self .checkpoint_impl == CheckpointImpl .REENTRANT
126+ ),
127+ )
128+ else :
129+ # Construct user-specified checkpoint function.
130+ self .checkpoint_fn = partial (
131+ checkpoint_fn ,
132+ * checkpoint_fn_args ,
133+ ** checkpoint_fn_kwargs ,
134+ )
135+
136+ def forward (self , * args , ** kwargs ):
137+ # Support keyword arguments for reentrant checkpoint. Note that this
138+ # only works if user has specified self.checkpoint_impl and is not
139+ # using their own custom checkpoint_fn.
140+ if self .checkpoint_impl == CheckpointImpl .REENTRANT and kwargs != {}:
141+ # Pack the args and kwargs
142+ flat_args , kwarg_keys = _pack_kwargs (* args , ** kwargs )
143+
144+ # Function that only takes (packed) args, but can unpack them
145+ # into the original args and kwargs for the checkpointed
146+ # function, and runs that function.
147+ def my_function (* inputs ):
148+ # unpack back into args and kwargs
149+ unpacked_args , unpacked_kwargs = _unpack_kwargs (
150+ inputs , kwarg_keys
151+ )
152+ # run original module
153+ return self ._checkpoint_wrapped_module (
154+ * unpacked_args , ** unpacked_kwargs
155+ )
156+
157+ # Pass the function that only takes packed args into reentrant
158+ # checkpoint API.
159+ return self .checkpoint_fn ( # type: ignore[misc]
160+ my_function ,
161+ * flat_args ,
162+ )
163+ else :
164+ return self .checkpoint_fn ( # type: ignore[misc]
165+ self ._checkpoint_wrapped_module ,
166+ * args ,
167+ ** kwargs
168+ )
169+
170+ def offload_wrapper (
171+ module : torch .nn .Module
172+ ) -> torch .nn .Module :
173+ """
174+ A convenience wrapper for activation offloading to CPU. If the module is wrapped
175+ with this function, all subsequent calls to the module will automatically
176+ offload intermediate activations to the CPU. Wrappers with activation
177+ offload can be composed with ones that do recomputation-based
178+ checkpoint to trade off increased compute versus increased CPU
179+ memory usage and additional H2D transfers.
180+ Usage::
181+ offloaded_module = offload_wrapper(module)
182+ outputs = checkpointed_module(inputs)
183+ Args:
184+ module (nn.Module):
185+ The module to be wrapped
186+ Returns:
187+ (nn.Module):
188+ Wrapped module
189+ """
190+ return OffloadWrapper (module )
191+
192+
158193def checkpoint_wrapper (
159194 module : torch .nn .Module ,
160195 checkpoint_impl : CheckpointImpl = CheckpointImpl .REENTRANT ,
161- offload_to_cpu : bool = False ,
162196 checkpoint_fn = None ,
163197 * checkpoint_fn_args ,
164198 ** checkpoint_fn_kwargs ,
@@ -181,14 +215,6 @@ def checkpoint_wrapper(
181215 specified. Note that for implementations using reentrant checkpoint
182216 from ``torch.utils.checkpoint``, keyword arguments will only be
183217 supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`.
184- offload_to_cpu (Optional[bool]):
185- Whether to offload activations of this wrapped module to CPU. Note
186- that if this is specified, ``checkpoint_impl`` and ``checkpoint_fn``
187- arguments will be ignored in favor of the activations being
188- offloaded to CPU. Default is ``False``. Wrappers with activation
189- offload can be composed with ones that do recomputation-based
190- checkpoint to trade off increased compute versus increased CPU
191- memory usage and additional H2D transfers.
192218 checkpoint_fn (Optional[Callable]):
193219 Functional checkpoint implementation to use. If this is specified,
194220 it will be used over the default ``torch.utils.checkpoint.checkpoint``
@@ -202,7 +228,7 @@ def checkpoint_wrapper(
202228 """
203229
204230 return CheckpointWrapper (
205- module , checkpoint_impl , offload_to_cpu , checkpoint_fn , checkpoint_fn_args , checkpoint_fn_kwargs
231+ module , checkpoint_impl , checkpoint_fn , checkpoint_fn_args , checkpoint_fn_kwargs
206232 )
207233
208234
@@ -219,13 +245,16 @@ def apply_activation_checkpointing(
219245 their checkpoint-wrapped modules.
220246 Note::
221247 This function will not wrap the overall root module. If this is needed, please directly use
222- :class:`CheckpointWrapper `.
248+ :func:`checkpoint_wrapper` or :func:`offload_wrapper `.
223249 Usage::
224250 model = nn.Sequential(
225251 nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)
226252 )
227253 check_fn = lambda l: isinstance(l, nn.Linear)
254+ # Checkpoint activations
228255 apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
256+ # Or Offload activations to CPU
257+ apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn)
229258 Args:
230259 model (nn.Module):
231260 The model whose submodules should be wrapped with activation checkpointing.
0 commit comments