-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
_runtime_utils.py
1638 lines (1483 loc) · 64.7 KB
/
_runtime_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-defs
import functools
import logging
from enum import auto, Enum
from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd.graph import register_multi_grad_hook
from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
from torch.distributed.fsdp._common_utils import (
_assert_in_training_states,
_FSDPState,
_get_module_fsdp_state,
_is_composable,
_log_post_backward_hook,
_no_dispatch_record_stream,
clean_tensor_name,
TrainingState,
)
from torch.distributed.fsdp._flat_param import (
FlatParameter,
FlatParamHandle,
HandleShardingStrategy,
HandleTrainingState,
RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
)
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
from torch.distributed.fsdp.api import BackwardPrefetch
from torch.distributed.utils import (
_apply_to_tensors,
_cast_forward_inputs,
_p_assert,
_to_kwargs,
)
from torch.utils import _pytree as pytree
logger = logging.getLogger(__name__)
# Do not include "process_group" to enable hybrid shard and MoE cases
HOMOGENEOUS_ATTR_NAMES = (
"_use_orig_params",
"limit_all_gathers",
"_use_full_prec_in_eval",
)
class _PrefetchMode(Enum):
BACKWARD = auto()
FORWARD = auto()
def _get_fsdp_root_states_with_modules(
module: nn.Module,
) -> Tuple[List[_FSDPState], List[nn.Module]]:
"""
Returns a tuple containing:
1. A list of the root ``_FSDPState`` instances in the module tree rooted at
``module`` without any duplicates and following the ``module.modules()``
traversal order (which is assumed to be depth-first).
2. A corresponding list of the root modules owning the states in the first
list.
This is similar to :func:`_get_fsdp_states_with_modules` except that we
must call :func:`_is_fsdp_root` to force a lazy initialization to determine
the FSDP root in case lazy initialization has not yet happened.
"""
fsdp_root_states: List[_FSDPState] = []
fsdp_root_modules: List[nn.Module] = []
visited_fsdp_states: Set[_FSDPState] = set()
# NOTE: This function assumes that `module.modules()` proceeds top-down.
for submodule in module.modules():
optional_state = _get_module_fsdp_state(submodule)
if (
optional_state is not None
and optional_state not in visited_fsdp_states
and _is_fsdp_root(optional_state, submodule)
):
visited_fsdp_states.add(optional_state)
fsdp_root_states.append(optional_state)
fsdp_root_modules.append(submodule)
return fsdp_root_states, fsdp_root_modules
def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
"""See :func:`_get_fsdp_root_states_with_modules`."""
fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
return fsdp_root_states
def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
"""
Returns if ``state`` corresponds to that of an FSDP root.
For the wrapper code path, ``state`` and ``module`` should be the same. For
the non-wrapper code path, ``state`` should be ``module`` 's state.
"""
# Force a lazy initialization to determine the FSDP root
_lazy_init(state, module)
assert state._is_root is not None # mypy
return state._is_root
@no_type_check
def _lazy_init(
state: _FSDPState,
root_module: nn.Module,
) -> _FSDPState:
"""
Performs initialization lazily, typically right before the first forward
pass. The laziness is needed to ensure that the parameter device/dtype and
the FSDP hierarchy have finalized. This method's actual logic only runs on
the root FSDP instance, which performs initialization for all non-root FSDP
instances to avoid partial initialization.
For the non-composable code path, ``state`` and ``root_module`` should be
the same, namely the FSDP instance itself.
"""
if state._is_root is not None:
return # no-op: already lazily initialized
if not state._device_handle.is_available():
# Allow the FSDP constructor to run even without CUDA but check this
# once we start real execution
raise RuntimeError("FSDP does not support CPU only execution")
# The following logic is only run on the root FSDP instance since it will
# set `_is_root=False` for the non-root instances
state._is_root = True
_assert_in_training_states(state, [TrainingState.IDLE])
_check_flat_params_on_expected_device(state, root_module)
state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module)
_init_streams(state)
buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
_cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
state._exec_order_data.init(state, root_module, state.process_group)
_share_state_and_init_handle_attrs(state, root_module)
return state
def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
"""
Checks that all ``FlatParameter``s in ``module`` 's tree managed by
``state`` are on the expected device for *lazy initialization*.
"""
cpu_device = torch.device("cpu")
for handle in traversal_utils._get_fsdp_handles(module):
if (
not handle._offload_params
and handle.flat_param.device != state.compute_device
):
raise RuntimeError(
"An FSDP-managed module unexpectedly has parameters on "
f"{handle.flat_param.device}. Make sure to move the module to "
f"{state.compute_device} before training."
)
elif handle._offload_params and handle.flat_param.device != cpu_device:
raise RuntimeError(
"An FSDP-managed module with parameter CPU offloading enabled "
f"has parameters on {handle.flat_param.device}. Make sure to "
f"not move the module from CPU when offloading parameters."
)
@no_type_check
def _share_state_and_init_handle_attrs(
root_state: _FSDPState,
root_module: nn.Module,
) -> None:
"""
Shares data structure state from the ``root_state`` to all FSDP states in
``root_module`` 's module tree, and initializes handle attributes. These
are done together to require a single loop over the states.
"""
handle = root_state._handle
if handle:
handle.init_flat_param_attributes()
attr_name_to_values: Dict[str, Set[Any]] = {}
for attr_name in HOMOGENEOUS_ATTR_NAMES:
attr_name_to_values[attr_name] = set()
root_state._all_handles = root_state._exec_order_data.all_handles # share reference
# Update _has_optim_in_backward for each handle.
for handle in root_state._all_handles:
flat_param = handle.flat_param
if hasattr(flat_param, "_in_backward_optimizers"):
raise RuntimeError(
"FSDP optimizer in backward only supported with use_orig_params=True!"
)
handle._has_optim_in_backward = flat_param._params is not None and any(
hasattr(param, "_in_backward_optimizers") for param in flat_param._params
)
if handle._has_optim_in_backward:
torch._C._log_api_usage_once("fsdp.optimizer_in_backward")
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
hasattr(fsdp_state, attr_name),
f"FSDP state missing attribute {attr_name}",
)
attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
if fsdp_state is root_state:
continue
# Relax the assert for non-root FSDP instances in case the nested
# initialized module is wrapped again in FSDP later (e.g. after
# training to run inference)
_p_assert(
fsdp_state._is_root is None or not fsdp_state._is_root,
"Non-root FSDP instance's `_is_root` should not have been "
"set yet or should have been set to `False`",
)
fsdp_state._is_root = False
fsdp_state._unshard_stream = root_state._unshard_stream
fsdp_state._post_backward_stream = root_state._post_backward_stream
fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
fsdp_state._all_reduce_stream = root_state._all_reduce_stream
fsdp_state._default_stream = root_state._default_stream
fsdp_state._exec_order_data = root_state._exec_order_data
fsdp_state._free_event_queue = root_state._free_event_queue
if fsdp_state._fsdp_extension is not None:
fsdp_state._fsdp_extension.compute_stream = root_state._default_stream
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(
f"Expects one homogeneous value for {attr_name} but got {attr_values}"
)
@no_type_check
def _init_streams(
state: _FSDPState,
) -> None:
"""
Initializes CUDA streams for overlapping communication, computation, and
data transfers. The streams should be shared across FSDP instances.
"""
assert state._is_root
assert state._device_handle.is_available()
uses_hybrid_sharding = any(
fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES
for fsdp_state in state._all_fsdp_states
)
# Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and
# preserve the default priority of 0 otherwise
high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0
# Default stream for computation
state._default_stream = state._device_handle.current_stream()
if state._fsdp_extension is not None:
# set the compute stream to the FSDP extension
state._fsdp_extension.compute_stream = state._default_stream
# Stream for unshard logic, including allocating the all-gather destination
# tensors and the all-gathers themselves
state._unshard_stream = state._device_handle.Stream(priority=high_priority)
# Stream for overlapping gradient reduction with the backward pass gradient
# computation
state._post_backward_stream = state._device_handle.Stream(priority=high_priority)
# Stream for pre-unshard logic, namely allocations and writes for CPU
# offloading (H2D copy) and mixed precision (low precision cast)
state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority)
# Stream to run HSDP's all-reduce as async (if using HSDP)
state._all_reduce_stream = (
state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream
)
@no_type_check
def _unshard(
state: _FSDPState,
handle: FlatParamHandle,
unshard_stream: torch.Stream,
pre_unshard_stream: torch.Stream,
) -> None:
"""
Unshards the handles in ``handles``. If the handles are in
:meth:`summon_full_params` and are using mixed precision, then they are
forced to full precision.
Postcondition: handle's ``FlatParameter`` 's data is the padded
unsharded flat parameter on the compute device.
"""
if not handle:
return
with state._device_handle.stream(pre_unshard_stream):
ran_pre_unshard = handle.pre_unshard()
if ran_pre_unshard:
unshard_stream.wait_stream(pre_unshard_stream)
if state.limit_all_gathers:
event = state._free_event_queue.dequeue_if_needed()
if event:
with torch.profiler.record_function(
"FullyShardedDataParallel.rate_limiter"
):
event.synchronize()
with state._device_handle.stream(unshard_stream):
handle.unshard()
handle.post_unshard()
@no_type_check
def _reshard(
state: _FSDPState,
handle: FlatParamHandle,
free_unsharded_flat_param: bool,
):
"""
Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
free the handle's padded unsharded flat parameter.
"""
handle.reshard(free_unsharded_flat_param)
if state.limit_all_gathers and free_unsharded_flat_param:
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
# We don't run a even queue for freeing under torch compile atm
# But maybe we need to? TODO(voz): Look into this
free_event = state._device_handle.Event()
free_event.record()
state._free_event_queue.enqueue(free_event)
handle.post_reshard()
# Flat parameter freed or not, we always have to "unshard" the parameter
# upon next access to get its shape correct.
handle._prefetched = False
def _unshard_grads(
handle: Optional[FlatParamHandle],
) -> None:
if handle:
handle.unshard_grad()
def _reshard_grads(
handle: Optional[FlatParamHandle],
) -> None:
if handle:
handle.reshard_grad()
@no_type_check
def _pre_forward(
state: _FSDPState,
handle: Optional[FlatParamHandle],
unshard_fn: Callable,
module: nn.Module,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Runs the pre-forward logic. This includes an opportunity to unshard
currently sharded parameters such as those for the current forward and
registering post-backward hooks for these current parameters. This function
also converts forward ``args`` and ``kwargs`` to the given precision.
Args:
handles (List[FlatParamHandle]): Handles giving the parameters used in
the current forward.
unshard_fn (Optional[Callable]): A callable to unshard any currently
sharded parameters or ``None`` to not do any unsharding.
module (nn.Module): Module whose forward this method runs right before;
expected by the hook signature.
args (Tuple[Any, ...]): Module forward ``args``.
kwargs (Dict[str, Any]): Module forward ``kwargs``.
"""
with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):
# For `fully_shard` + `checkpoint`, skip pre-forward logic in the
# recomputed forward
if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
# For both checkpoint implementations, we do not need to re-cast
# inputs here since they will be checkpointed in the low precision
# either by AC or normally by autograd as long as the AC region is
# nested within FSDP
return args, kwargs
state.training_state = TrainingState.FORWARD_BACKWARD
state._exec_order_data.record_pre_forward(handle, module.training)
if handle:
handle._training_state = HandleTrainingState.FORWARD
if unshard_fn is not None:
unshard_fn(state, handle)
# Register post-backward hooks to reshard the parameters and reduce-scatter
# their gradients. They must be re-registered every forward pass in case
# the `grad_fn` is mutated.
_register_post_backward_hook(state, handle)
# We have to reallocate the _cpu_grad if optimizer overlap
# set the grad to None in the backward pass.
if handle and handle._offload_params and handle.flat_param._cpu_grad is None:
handle.flat_param._cpu_grad = torch.zeros_like(
handle.flat_param._local_shard, device=torch.device("cpu")
).pin_memory(device=state.compute_device)
should_cast_forward_inputs = (
state._handle and not state._handle._force_full_precision
)
if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:
# Recursively convert args and kwargs to specified precision.
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
_register_post_backward_reshard_only_hook(state, handle, args, kwargs)
return args, kwargs
@no_type_check
def _pre_forward_unshard(
state: _FSDPState,
handle: Optional[FlatParamHandle],
) -> None:
"""Unshards parameters in the pre-forward."""
if not handle:
return
# If the handles have been prefetched, then there is no need to call
# `_unshard()` again
if not handle._prefetched:
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
handle._needs_pre_forward_unshard = False
# Don't wait during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
current_stream = state._device_handle.current_stream()
if state._unshard_event is not None:
current_stream.wait_event(state._unshard_event)
state._unshard_event = None
else:
current_stream.wait_stream(state._unshard_stream)
with torch.profiler.record_function(
"FullyShardedDataParallel._pre_forward_prefetch"
):
_prefetch_handle(state, handle, _PrefetchMode.FORWARD)
@no_type_check
def _post_forward(
state: _FSDPState,
handle: Optional[FlatParamHandle],
reshard_fn: Callable,
module: nn.Module,
input: Any,
output: Any,
) -> Any:
"""
Runs the post-forward logic. This includes an opportunity to reshard
currently unsharded parameters such as those used in the current forward
and registering pre-backward hooks on the forward outputs.
Args:
handles (List[FlatParamHandle]): Handles giving the parameters used in
the current forward.
reshard_fn (Optional[Callable]): A callable to reshard any currently
unsharded parameters (e.g. from the current forward) or ``None`` to
not do any resharding.
module (nn.Module): Module whose forward just ran, which should be a
fully sharded module (see [Note: Fully Sharded Module]); expected
by the hook signature.
input (Any): Unused; expected by the hook signature.
output (Any): Forward pass output; pre-backward hooks are registered on
the tensors that require gradients in this output.
Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
parameter.
"""
with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):
# For `fully_shard` + `checkpoint`, skip post-forward logic in the
# recomputed forward
if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
return output
state._exec_order_data.record_post_forward(handle)
if reshard_fn is not None:
reshard_fn(state, handle)
# Register pre-backward hooks to unshard the flat parameters for the
# gradient computation (if needed)
output = _register_pre_backward_hooks(state, module, output, handle)
state.training_state = TrainingState.IDLE
if handle:
handle._training_state = HandleTrainingState.IDLE
return output
@no_type_check
def _post_forward_reshard(
state: _FSDPState,
handle: FlatParamHandle,
) -> None:
"""Reshards parameters in the post-forward."""
if not handle:
return
# Do not free the root's parameters in the post-forward for `FULL_SHARD`
# with the intention that they are immediately used for backward
# computation (though this may not be true)
free_unsharded_flat_param = (
not state._is_root
and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
_reshard(state, handle, free_unsharded_flat_param)
@no_type_check
def _root_pre_forward(
state: _FSDPState,
module: nn.Module,
args,
kwargs,
) -> None:
"""
Runs pre-forward logic specific to the root FSDP instance, which should run
before any individual module's pre-forward. This starts with an attempt at
lazy initialization (which only runs non-vacuously once). Otherwise, if
this is called on a non-root FSDP instance, then it returns directly.
Args:
module (nn.Module): Module for which this logic tries to run. It may or
may not be the root. If not, then this method does not do anything.
"""
with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"):
_lazy_init(state, module)
_p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
if not state._is_root:
# Always cast forward inputs in the root of this local FSDP unit for mixed
# precision, as this is where mixed precision could be configed.
# This is more useful for auto wrapping that is recommended in composable path.
# For manual wrapping, cast forward inputs on each local FSDP unit root will
# increase some overhead, so not turned on for model wrapper path right now where
# manual wrapping is more broadly used.
if _is_composable(state):
return _root_cast_forward_input(state, module, args, kwargs)
return args, kwargs
# We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
# are in full precision and if we should cast them back to lower precision, which happens when
# exiting eval() mode.
handle = state._handle
if handle:
should_cast_buffers_to_full_prec = handle._force_full_precision
else:
should_cast_buffers_to_full_prec = True
if should_cast_buffers_to_full_prec:
_cast_buffers_to_dtype_and_device(
buffers=dict(module.named_buffers()).values(),
buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
device=state.compute_device,
)
# This flag is only set when we cast buffers to full precision, to avoid the
# CPU overhead that can stem from retrieving all buffers and their types in the
# following else branch.
state._needs_buffer_dtype_restore_check = True
elif getattr(state, "_needs_buffer_dtype_restore_check", False):
# Check if buffers are in full precision and we need to cast them
# back down.
(
buffers,
buffer_dtypes_for_computation,
) = _get_buffers_and_dtypes_for_computation(state, module)
if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
if any(
buffer.dtype != buffer_dtype_for_computation
for buffer, buffer_dtype_for_computation in zip(
buffers, buffer_dtypes_for_computation
)
):
# Assume we have to cast everything if there is one mismatch
_cast_buffers_to_dtype_and_device(
buffers, buffer_dtypes_for_computation, state.compute_device
)
# We don't have to check this again until we cast buffers to full precision again.
state._needs_buffer_dtype_restore_check = False
if state.forward_prefetch:
handles = []
for fsdp_state in state._all_fsdp_states:
if fsdp_state._handle:
handles.append(fsdp_state._handle)
for handle in handles:
handle._needs_pre_forward_unshard = True
handle._prefetched = False
_wait_for_computation_stream(
state._device_handle.current_stream(),
state._unshard_stream,
state._pre_unshard_stream,
)
_reset_flat_param_grad_info_if_needed(state._all_handles)
# Prepares the forward inputs by moving them to ``compute_device``
# TODO: Do not use the side stream for tensor copies for now; investigate
# the perf with/without it.
with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"):
args_tuple, kwargs_tuple = _to_kwargs(
args, kwargs, state.compute_device, False
)
args = args_tuple[0]
kwargs = kwargs_tuple[0]
return _root_cast_forward_input(state, module, args, kwargs)
@no_type_check
def _root_cast_forward_input(
state: _FSDPState, module: torch.nn.Module, args, kwargs
) -> Tuple[Any, Any]:
if state._handle:
force_full_precision = not state._handle._force_full_precision
else:
force_full_precision = True
should_cast_forward_inputs = (
(module.training or not state._use_full_prec_in_eval) and force_full_precision
) and state.mixed_precision.cast_root_forward_inputs
if should_cast_forward_inputs:
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
return args, kwargs
@no_type_check
def _pre_backward_hook(
state: _FSDPState,
module: nn.Module,
handle: FlatParamHandle,
grad,
*unused: Any,
) -> Any:
"""
Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation.
Args:
module (nn.Module): Fully sharded module (see [Note: Fully Sharded
Module]).
"""
# Only run the pre-backward hook once per group of handles involved in the
# same module forward computation
if (
handle
and hasattr(handle, "_ran_pre_backward_hook")
and handle._ran_pre_backward_hook
):
logger.debug("%s %s", id(state), "Not Running pre backward! Already Ran!")
return grad
with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
# Queue the post-backward callback once for the root FSDP instance to
# attach it to the outermost backward graph task so that it is called
# after all backward calls complete
if state._is_root and not state._post_backward_callback_queued:
_register_post_backward_final_callback(state, module)
_reset_flat_param_grad_info_if_needed(state._all_handles)
elif handle:
allowed_states = [TrainingState.IDLE]
if _is_composable(state):
allowed_states.append(TrainingState.FORWARD_BACKWARD)
_assert_in_training_states(state, allowed_states)
state.training_state = TrainingState.FORWARD_BACKWARD
# Queueing the post-backward callback is the only logic that is not
# per-handle in the pre-backward hook, so we can return early here if
# there are no handles.
if not handle:
return grad
handle._training_state = HandleTrainingState.BACKWARD_PRE
if handle._needs_pre_backward_unshard:
# If the handles have been prefetched, then there is no need to
# call `_unshard()` again
if not handle._prefetched:
_unshard(
state,
handle,
state._unshard_stream,
state._pre_unshard_stream,
)
# Don't wait during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
state._device_handle.current_stream().wait_stream(state._unshard_stream)
# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
handle._needs_pre_backward_unshard = False
with torch.profiler.record_function(
"FullyShardedDataParallel._pre_backward_prefetch"
):
_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
handle.prepare_gradient_for_backward()
handle._ran_pre_backward_hook = True
return grad
@no_type_check
@torch.no_grad()
def _post_backward_hook(
state: _FSDPState,
handle: FlatParamHandle,
flat_param,
*unused: Any,
):
"""
Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
unsharded gradient for the local batch.
Postcondition:
- If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
unsharded gradient.
- Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
gradient (accumulating with any existing gradient).
"""
_log_post_backward_hook(state, handle, logger)
flat_param = handle.flat_param
flat_param._post_backward_called = True
with torch.autograd.profiler.record_function(
"FullyShardedDataParallel._post_backward_hook"
):
_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
# For multiple applications of reentrant AC across submodules sharing
# the same `FlatParameter`, the post-backward hook may run multiple
# times in one backward, in which case we permit the state to already
# be in `BACKWARD_POST`.
_p_assert(
handle._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
)
handle._training_state = HandleTrainingState.BACKWARD_POST
if flat_param.grad is None:
return
if flat_param.grad.requires_grad:
raise RuntimeError("FSDP does not support gradients of gradients")
_post_backward_reshard(state, handle)
if not state._sync_gradients:
if handle._use_orig_params:
handle._use_unsharded_grad_views()
return
# Wait for all ops in the current stream (e.g. gradient computation) to
# finish before reduce-scattering the gradient
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
state._post_backward_stream.wait_stream(
state._device_handle.current_stream()
)
with state._device_handle.stream(state._post_backward_stream):
autograd_computed_grad = flat_param.grad.data
if (
not _low_precision_hook_enabled(state)
and flat_param.grad.dtype != handle._reduce_dtype
# If we are forcing full precision but communicating grads
# (i.e. model.eval() + full precision in eval was configured), don't downcast gradient.
and not handle._force_full_precision
):
flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)
if handle.uses_sharded_strategy:
_reduce_grad(state, handle)
else:
_reduce_grad_no_shard(state, handle)
# Since the unsharded gradient is produced in the computation
# stream and consumed in the post-backward stream, inform the
# caching allocator (before it goes out of scope)
_no_dispatch_record_stream(
autograd_computed_grad, state._post_backward_stream
)
def _post_backward_reshard_only_hook(
state: _FSDPState,
handle: FlatParamHandle,
*unused: Any,
) -> None:
with torch.profiler.record_function(
"FullyShardedDataParallel._post_backward_hook_reshard_only"
):
# `_pre_backward_hook` may not get executed
# if forward output does not require grad
# overwrite IDLE state for post-backward prefetching
state.training_state = TrainingState.FORWARD_BACKWARD
handle._training_state = HandleTrainingState.BACKWARD_POST
_post_backward_reshard(state, handle)
def _post_backward_reshard(
state: _FSDPState,
handle: FlatParamHandle,
*unused: Any,
) -> None:
free_unsharded_flat_param = _should_free_in_backward(state, handle)
_reshard(state, handle, free_unsharded_flat_param)
# TODO: Post-backward prefetching does not support the multiple handles
# per module case since the post-backward hook runs per handle, not per
# group of handles.
with torch.profiler.record_function(
"FullyShardedDataParallel._post_backward_prefetch"
):
_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
@no_type_check
def _should_free_in_backward(
state: _FSDPState,
handle: FlatParamHandle,
) -> bool:
"""
Returns whether FSDP should free the unsharded flat parameter in the
post-backward or not.
"""
if not handle.uses_sharded_strategy:
return False
# If not syncing gradients, then we do not free for strategies that do not
# reshard after forward as a *heuristic* to tradeoff higher memory for
# higher throughput.
return (
state._sync_gradients
or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
@no_type_check
def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:
"""
For sharded strategies, this runs gradient reduction, sharded gradient
accumulation if needed, and the post-reduction callback.
"""
flat_param = handle.flat_param
uses_hybrid_sharded_strategy = handle._sharding_strategy in (
HandleShardingStrategy.HYBRID_SHARD,
HandleShardingStrategy._HYBRID_SHARD_ZERO2,
)
# We clear `.grad` to permit multiple backwards. This avoids a race where
# the second backward pass computation precedes ahead of the first backward
# pass reduction, which is possible since the reduction is issued in a
# separate stream and is async and would result in reducing the wrong
# gradient.
unsharded_grad = flat_param.grad.data
flat_param.grad = None
padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors(
state, unsharded_grad
)
if state._comm_hook is None: # default path
_div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor)
pg = (
handle._fake_process_group
if handle._use_fake_reduce
else state.process_group
)
dist.reduce_scatter_tensor(
new_sharded_grad,
padded_unsharded_grad,
group=pg,
)
if uses_hybrid_sharded_strategy:
# Don't wait during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
state._all_reduce_stream.wait_stream(state._post_backward_stream)
with state._device_handle.stream(state._all_reduce_stream):
# Since the new sharded gradient is produced in the post-
# backward stream and consumed in the all-reduce stream,
# inform the caching allocator
_no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream)
dist.all_reduce(new_sharded_grad, group=state._inter_node_pg)
_div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
grad_to_offload = _accumulate_sharded_grad(
state, handle, new_sharded_grad
)
_post_reduce_grad_callback(state, handle, grad_to_offload)
return
_div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
else:
state._comm_hook(
state._comm_hook_state, padded_unsharded_grad, new_sharded_grad
)
# NOTE: HSDP variants do not support communication hook.
grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad)
_post_reduce_grad_callback(state, handle, grad_to_offload)
@no_type_check
def _get_reduce_scatter_tensors(
state: _FSDPState, unsharded_grad: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns the input and output tensors to reduce-scatter, respectively.
"""
chunks = list(unsharded_grad.chunk(state.world_size))
numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel()
padded_unsharded_grad = (
F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad
)
new_sharded_grad = torch.empty_like(chunks[0]) # padded
return padded_unsharded_grad, new_sharded_grad
@no_type_check
def _accumulate_sharded_grad(
state: _FSDPState,
handle: FlatParamHandle,
sharded_grad: torch.Tensor,
) -> torch.Tensor:
"""
Accumulates the reduce-scattered sharded gradient with any existing sharded
gradient if needed, returning the gradient to offload (if CPU offloading is
enabled).
"""
flat_param = handle.flat_param
_cast_grad_to_param_dtype(state, sharded_grad, flat_param)
# Save the sharded gradient in `_saved_grad_shard` to support gradient
# accumulation -- for multiple backwards, the gradient reductions may
# happen in arbitrary order
accumulate_grad = hasattr(flat_param, "_saved_grad_shard")
if accumulate_grad:
_check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard)
flat_param._saved_grad_shard += sharded_grad
else:
flat_param._saved_grad_shard = sharded_grad
grad_to_offload = flat_param._saved_grad_shard
return grad_to_offload
@no_type_check
def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None:
"""
For no-shard, this runs gradient reduction (which directly covers any
gradient accumulation implicitly) and the post-reduction callback.
"""
flat_param = handle.flat_param
if state._comm_hook is None: # default path
_div_if_needed(flat_param.grad, state._gradient_predivide_factor)
dist.all_reduce(flat_param.grad, group=state.process_group)
_div_if_needed(flat_param.grad, state._gradient_postdivide_factor)
else:
state._comm_hook(state._comm_hook_state, flat_param.grad)
# For `NO_SHARD`, we can keep the low precision gradients by simply
# omitting the cast altogether
if not handle._keep_low_precision_grads:
_cast_grad_to_param_dtype(state, flat_param.grad, flat_param)
grad_to_offload = flat_param.grad.data
_post_reduce_grad_callback(state, handle, grad_to_offload)
@no_type_check
def _post_reduce_grad_callback(
state: _FSDPState,
handle: FlatParamHandle,
# Additional arguments needed for the callback logic
grad_to_offload: torch.Tensor,
):
"""
This callback captures any logic to run after the gradient reduction
finishes. Currently, this offloads the gradient to CPU if CPU offloading is
enabled and uses sharded gradient views if ``use_orig_params=True``.
"""
_offload_grad(state, handle, grad_to_offload)
_post_backward_use_sharded_grad_views(handle)
@no_type_check
def _offload_grad(
state: _FSDPState,
handle: FlatParamHandle,
grad_to_offload: torch.Tensor,
):
if not handle._offload_params:
return
# Offload the gradient to CPU to ensure parameters and gradients are on the
# same device as required by the optimizer
# TODO: Investigate why `NO_SHARD` breaks correctness when using
# `non_blocking=True` here.
# TODO (rohan-varma): When CPU offload and optimizer overlap,
# non_blocking=True won't work since the copy may have not finished before
# the optimizer step executes on CPU. If we want to use non-blocking=True
# here, we'll have to synchronize before using result on CPU.
non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward
handle.flat_param._cpu_grad.copy_(
grad_to_offload.detach(), non_blocking=non_blocking
) # synchronized in the post-backward callback
# Since the gradient being offloaded may have been produced in the
# computation stream and is being consumed here in the post-backward
# stream, inform the caching allocator
_no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream)
@no_type_check
def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):
if not handle._use_orig_params:
return
# Since the handle's `FlatParameter` completed its gradient computation, we
# should reset the gradient noneness mask
handle._reset_is_grad_none()
# Delay using sharded gradient views until after the reduce-scatter instead
# of immediately after resharding
handle._use_sharded_grad_views()
if handle._has_optim_in_backward:
handle.prepare_gradient_for_optim()
for orig_param in handle.flat_param._params:
# Check for `None` gradient to filter parameters not in the rank
if orig_param.grad is not None and hasattr(
orig_param, "_in_backward_optimizers"
):
# TODO (rohan-varma): For CPU offload, this unfortunately