-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[BE] Enabled mypy in common_fsdp.py
#118755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
738dbf4
86df81c
2809517
9c11fa0
b559540
9aa732d
d82f962
43543ad
14260d8
993a0f2
f962e4c
f14905c
40b4742
c2397c9
fd73368
1e9562a
430cb25
ad2b69f
710f169
1c606f1
897a99a
8e3e9c3
646c567
7576de1
4886dff
8ef1053
4238be0
0f4bcc7
ad6cd3d
528715f
15c6805
13a09a2
319a768
ccebf28
a7b93f3
d9edb9f
c40cff6
d511503
fe502d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,3 @@ | ||
| # mypy: ignore-errors | ||
|
|
||
| # Owner(s): ["oncall: distributed"] | ||
|
|
||
| import contextlib | ||
|
|
@@ -11,7 +9,17 @@ | |
| from contextlib import nullcontext | ||
| from copy import deepcopy | ||
| from enum import auto, Enum | ||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union | ||
| from typing import ( | ||
| Any, | ||
| Callable, | ||
| Dict, | ||
| List, | ||
| no_type_check, | ||
| Optional, | ||
| Tuple, | ||
| Type, | ||
| Union, | ||
| ) | ||
| from unittest import mock | ||
|
|
||
| import torch | ||
|
|
@@ -82,15 +90,7 @@ def run_backward(self, loss) -> None: | |
|
|
||
| @staticmethod | ||
| @abstractmethod | ||
| def init( | ||
| group: dist.ProcessGroup, | ||
| fsdp_init_mode: FSDPInitMode, | ||
| *init_args: Any, | ||
| cuda_init_mode: CUDAInitMode, | ||
| fsdp_kwargs: Optional[Dict[str, Any]] = None, | ||
| deterministic: bool = False, | ||
| **init_kwargs: Any, | ||
| ) -> nn.Module: | ||
| def init(*args: Any, **kwargs: Any) -> nn.Module: | ||
| """Initializes an instance of this model.""" | ||
| ... | ||
|
|
||
|
|
@@ -119,7 +119,9 @@ def _assert_module_states( | |
| olist = [None for _ in range(world_size)] | ||
| dist.all_gather_object(olist, named_module_states, group=process_group) | ||
| rank0_states = olist[0] | ||
| assert rank0_states is not None # mypy | ||
| for state in olist[1:]: | ||
| assert state is not None # mypy | ||
| for (_, p1), (_, p2) in zip(rank0_states, state): | ||
| assert_fn(p1, p2) | ||
|
|
||
|
|
@@ -210,7 +212,7 @@ def allreduce(self, *args, **kwargs): | |
| dist_wait = mock.Mock() | ||
|
|
||
| def get_future(): | ||
| future = torch.futures.Future() | ||
| future: torch.futures.Future = torch.futures.Future() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason mypy wanted this type annotation. |
||
| future.set_result(1) | ||
| return future | ||
|
|
||
|
|
@@ -475,8 +477,9 @@ def init( | |
| wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap | ||
| policy. | ||
| """ | ||
| super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule) | ||
| model = super_.init( | ||
| model = super( | ||
| AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule | ||
| ).init( | ||
| group=group, | ||
| fsdp_init_mode=FSDPInitMode.NO_FSDP, | ||
| cuda_init_mode=cuda_init_mode, | ||
|
|
@@ -486,6 +489,7 @@ def init( | |
| if fsdp_init_mode == FSDPInitMode.NO_FSDP: | ||
| return model | ||
| elif fsdp_init_mode == FSDPInitMode.RECURSIVE: | ||
| fsdp_kwargs = fsdp_kwargs or {} | ||
| fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs) | ||
| if cuda_init_mode == CUDAInitMode.CUDA_AFTER: | ||
| fsdp_model = fsdp_model.cuda() | ||
|
|
@@ -660,7 +664,7 @@ def init( | |
|
|
||
| class NestedWrappedModuleWithDelay(ModuleWithDelay): | ||
| @staticmethod | ||
| def init( | ||
| def init( # type: ignore[override] | ||
| group: dist.ProcessGroup, | ||
| fsdp_init_mode: FSDPInitMode, | ||
| cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER, | ||
|
|
@@ -669,7 +673,7 @@ def init( | |
| delay_after_loss_ms: int = 0, | ||
| delay_before_reduction_ms: int = 0, | ||
| ): | ||
| return super(NestedWrappedModuleWithDelay, NestedWrappedModuleWithDelay).init( | ||
| return ModuleWithDelay.init( | ||
| NestedWrappedModule, | ||
| group=group, | ||
| fsdp_init_mode=fsdp_init_mode, | ||
|
|
@@ -771,8 +775,9 @@ def run_backward(self, loss): | |
| for p in self.parameters(): | ||
| if hasattr(p, "expert"): | ||
| continue # these params don't need grad reduction | ||
| p.grad.div_(self.world_size) | ||
| torch.distributed.all_reduce(p.grad, group=self.group) | ||
| if p.grad is not None: | ||
| p.grad.div_(self.world_size) | ||
| torch.distributed.all_reduce(p.grad, group=self.group) | ||
|
|
||
| @staticmethod | ||
| def init( | ||
|
|
@@ -893,6 +898,7 @@ def patch_reduce_scatter(new_reduce_scatter_tensor: Callable): | |
| dist.reduce_scatter_tensor = orig_reduce_scatter | ||
|
|
||
|
|
||
| @no_type_check | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These patching methods complain about assigning to a method like |
||
| @contextlib.contextmanager | ||
| def patch_unshard(new_unshard: Callable): | ||
| orig_unshard = FSDPParamGroup.unshard | ||
|
|
@@ -903,6 +909,7 @@ def patch_unshard(new_unshard: Callable): | |
| FSDPParamGroup.unshard = orig_unshard | ||
|
|
||
|
|
||
| @no_type_check | ||
| @contextlib.contextmanager | ||
| def patch_post_backward(new_post_backward: Callable): | ||
| orig_post_backward = FSDPParamGroup.post_backward | ||
|
|
@@ -913,6 +920,7 @@ def patch_post_backward(new_post_backward: Callable): | |
| FSDPParamGroup.post_backward = orig_post_backward | ||
|
|
||
|
|
||
| @no_type_check | ||
| @contextlib.contextmanager | ||
| def patch_register_post_backward_hook_backward(new_backward: Callable): | ||
| orig_backward = RegisterPostBackwardFunction.backward | ||
|
|
@@ -927,8 +935,8 @@ def reduce_scatter_with_assert( | |
| cls, | ||
| orig_reduce_scatter: Callable, | ||
| assert_fn: Callable, # `assert_fn(output: Tensor)` | ||
| *args: Tuple[Any, ...], | ||
| **kwargs: Dict[str, Any], | ||
| *args: Any, | ||
| **kwargs: Any, | ||
| ): | ||
| if len(args) > 0: | ||
| output = args[0] | ||
|
|
@@ -960,6 +968,7 @@ def check_1d_sharded_parity( | |
| clean_sharded_name = clean_sharded_name.replace(prefix, "") | ||
| cls.assertEqual(replicated_name, clean_sharded_name) | ||
| cls.assertIsInstance(sharded_param, DTensor) | ||
| assert isinstance(sharded_param, DTensor) # mypy | ||
| param_chunks = torch.chunk(replicated_param, world_size, dim=0) | ||
| cls.assertEqual(sharded_param._local_tensor, param_chunks[rank]) | ||
| if not check_grads: | ||
|
|
@@ -969,6 +978,8 @@ def check_1d_sharded_parity( | |
| continue | ||
| cls.assertIsNotNone(sharded_param.grad) | ||
| grad_chunks = torch.chunk(replicated_param.grad, world_size, dim=0) | ||
| cls.assertIsInstance(sharded_param.grad, DTensor) | ||
| assert isinstance(sharded_param.grad, DTensor) # mypy | ||
| cls.assertEqual(sharded_param.grad._local_tensor, grad_chunks[rank]) | ||
|
|
||
|
|
||
|
|
@@ -1150,6 +1161,7 @@ def _train_for_several_steps( | |
| self.assertEqual(loss.dtype, torch.float16) | ||
| # FSDP loss is fp16, DDP AMP loss is fp32 | ||
| elif isinstance(model, FSDP): | ||
| assert mixed_precision is not None # mypy | ||
| self.assertEqual(loss.dtype, mixed_precision.param_dtype) | ||
| else: | ||
| self.assertEqual(loss.dtype, torch.float32) | ||
|
|
@@ -1173,7 +1185,7 @@ def _train_for_several_steps( | |
|
|
||
| if isinstance(model, FSDP): | ||
| model._assert_state(TrainingState.IDLE) | ||
| return loss.detach() | ||
| return loss.detach() # type: ignore[possibly-undefined] | ||
|
|
||
| def _test_fsdp_parity( | ||
| self, | ||
|
|
@@ -1316,6 +1328,7 @@ def _test_fsdp_parity( | |
| # Check parameter devices are CPU if offloading to CPU before calling | ||
| # `get_full_params()`, which will cast the parameters to FP32 | ||
| if offload_params: | ||
| cpu_device = torch.device("cpu") | ||
| for param in fsdp_model.parameters(): | ||
| self.assertEqual(param.device, cpu_device) | ||
| fsdp_loss = fsdp_loss.cuda() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all_gather_objectfills in theolistdestructively. Another approach could be to initializeolistwith some dummy object of the expected type.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be a good BE item (maybe for myself), to allow
[]?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable to me.
all_gather_objectcan destructively modify theolistas it already does today and appendworld_sizemany elements.This sounds like a good BE task!