Skip to content
Closed
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
738dbf4
[BE] Enabled mypy in `common_fsdp.py`
Jan 31, 2024
86df81c
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Jan 31, 2024
2809517
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Jan 31, 2024
9c11fa0
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Jan 31, 2024
b559540
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Jan 31, 2024
9aa732d
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Jan 31, 2024
d82f962
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Jan 31, 2024
43543ad
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
14260d8
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
993a0f2
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
f962e4c
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
f14905c
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
40b4742
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
c2397c9
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
fd73368
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 1, 2024
1e9562a
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 2, 2024
430cb25
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 2, 2024
ad2b69f
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 2, 2024
710f169
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 2, 2024
1c606f1
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 2, 2024
897a99a
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 3, 2024
8e3e9c3
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 3, 2024
646c567
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 5, 2024
7576de1
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 5, 2024
4886dff
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 5, 2024
8ef1053
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 6, 2024
4238be0
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 6, 2024
0f4bcc7
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 6, 2024
ad6cd3d
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 6, 2024
528715f
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 6, 2024
15c6805
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 6, 2024
13a09a2
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 7, 2024
319a768
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 8, 2024
ccebf28
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 8, 2024
a7b93f3
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 8, 2024
d9edb9f
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 9, 2024
c40cff6
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 13, 2024
d511503
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 13, 2024
fe502d8
Update on "[BE] Enabled mypy in `common_fsdp.py`"
Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions torch/testing/_internal/common_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: ignore-errors

# Owner(s): ["oncall: distributed"]

import contextlib
Expand All @@ -11,7 +9,17 @@
from contextlib import nullcontext
from copy import deepcopy
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
no_type_check,
Optional,
Tuple,
Type,
Union,
)
from unittest import mock

import torch
Expand Down Expand Up @@ -82,15 +90,7 @@ def run_backward(self, loss) -> None:

@staticmethod
@abstractmethod
def init(
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
*init_args: Any,
cuda_init_mode: CUDAInitMode,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
deterministic: bool = False,
**init_kwargs: Any,
) -> nn.Module:
def init(*args: Any, **kwargs: Any) -> nn.Module:
"""Initializes an instance of this model."""
...

Expand Down Expand Up @@ -119,7 +119,9 @@ def _assert_module_states(
olist = [None for _ in range(world_size)]
dist.all_gather_object(olist, named_module_states, group=process_group)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_gather_object fills in the olist destructively. Another approach could be to initialize olist with some dummy object of the expected type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be a good BE item (maybe for myself), to allow [] ?

olist = []
dist.all_gather_object(olist

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me. all_gather_object can destructively modify the olist as it already does today and append world_size many elements.

This sounds like a good BE task!

rank0_states = olist[0]
assert rank0_states is not None # mypy
for state in olist[1:]:
assert state is not None # mypy
for (_, p1), (_, p2) in zip(rank0_states, state):
assert_fn(p1, p2)

Expand Down Expand Up @@ -210,7 +212,7 @@ def allreduce(self, *args, **kwargs):
dist_wait = mock.Mock()

def get_future():
future = torch.futures.Future()
future: torch.futures.Future = torch.futures.Future()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason mypy wanted this type annotation.

future.set_result(1)
return future

Expand Down Expand Up @@ -475,8 +477,9 @@ def init(
wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
policy.
"""
super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule)
model = super_.init(
model = super(
AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule
).init(
group=group,
fsdp_init_mode=FSDPInitMode.NO_FSDP,
cuda_init_mode=cuda_init_mode,
Expand All @@ -486,6 +489,7 @@ def init(
if fsdp_init_mode == FSDPInitMode.NO_FSDP:
return model
elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
fsdp_kwargs = fsdp_kwargs or {}
fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
fsdp_model = fsdp_model.cuda()
Expand Down Expand Up @@ -660,7 +664,7 @@ def init(

class NestedWrappedModuleWithDelay(ModuleWithDelay):
@staticmethod
def init(
def init( # type: ignore[override]
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER,
Expand All @@ -669,7 +673,7 @@ def init(
delay_after_loss_ms: int = 0,
delay_before_reduction_ms: int = 0,
):
return super(NestedWrappedModuleWithDelay, NestedWrappedModuleWithDelay).init(
return ModuleWithDelay.init(
NestedWrappedModule,
group=group,
fsdp_init_mode=fsdp_init_mode,
Expand Down Expand Up @@ -771,8 +775,9 @@ def run_backward(self, loss):
for p in self.parameters():
if hasattr(p, "expert"):
continue # these params don't need grad reduction
p.grad.div_(self.world_size)
torch.distributed.all_reduce(p.grad, group=self.group)
if p.grad is not None:
p.grad.div_(self.world_size)
torch.distributed.all_reduce(p.grad, group=self.group)

@staticmethod
def init(
Expand Down Expand Up @@ -893,6 +898,7 @@ def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
dist.reduce_scatter_tensor = orig_reduce_scatter


@no_type_check
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These patching methods complain about assigning to a method like FSDPParamGroup.unshard = new_unshard. Since we have two of these assignments (to set to new and to restore to old) per patch context, I preferred to just ignore type checking, as it is not too valuable here.

@contextlib.contextmanager
def patch_unshard(new_unshard: Callable):
orig_unshard = FSDPParamGroup.unshard
Expand All @@ -903,6 +909,7 @@ def patch_unshard(new_unshard: Callable):
FSDPParamGroup.unshard = orig_unshard


@no_type_check
@contextlib.contextmanager
def patch_post_backward(new_post_backward: Callable):
orig_post_backward = FSDPParamGroup.post_backward
Expand All @@ -913,6 +920,7 @@ def patch_post_backward(new_post_backward: Callable):
FSDPParamGroup.post_backward = orig_post_backward


@no_type_check
@contextlib.contextmanager
def patch_register_post_backward_hook_backward(new_backward: Callable):
orig_backward = RegisterPostBackwardFunction.backward
Expand All @@ -927,8 +935,8 @@ def reduce_scatter_with_assert(
cls,
orig_reduce_scatter: Callable,
assert_fn: Callable, # `assert_fn(output: Tensor)`
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
*args: Any,
**kwargs: Any,
):
if len(args) > 0:
output = args[0]
Expand Down Expand Up @@ -960,6 +968,7 @@ def check_1d_sharded_parity(
clean_sharded_name = clean_sharded_name.replace(prefix, "")
cls.assertEqual(replicated_name, clean_sharded_name)
cls.assertIsInstance(sharded_param, DTensor)
assert isinstance(sharded_param, DTensor) # mypy
param_chunks = torch.chunk(replicated_param, world_size, dim=0)
cls.assertEqual(sharded_param._local_tensor, param_chunks[rank])
if not check_grads:
Expand All @@ -969,6 +978,8 @@ def check_1d_sharded_parity(
continue
cls.assertIsNotNone(sharded_param.grad)
grad_chunks = torch.chunk(replicated_param.grad, world_size, dim=0)
cls.assertIsInstance(sharded_param.grad, DTensor)
assert isinstance(sharded_param.grad, DTensor) # mypy
cls.assertEqual(sharded_param.grad._local_tensor, grad_chunks[rank])


Expand Down Expand Up @@ -1150,6 +1161,7 @@ def _train_for_several_steps(
self.assertEqual(loss.dtype, torch.float16)
# FSDP loss is fp16, DDP AMP loss is fp32
elif isinstance(model, FSDP):
assert mixed_precision is not None # mypy
self.assertEqual(loss.dtype, mixed_precision.param_dtype)
else:
self.assertEqual(loss.dtype, torch.float32)
Expand All @@ -1173,7 +1185,7 @@ def _train_for_several_steps(

if isinstance(model, FSDP):
model._assert_state(TrainingState.IDLE)
return loss.detach()
return loss.detach() # type: ignore[possibly-undefined]

def _test_fsdp_parity(
self,
Expand Down Expand Up @@ -1316,6 +1328,7 @@ def _test_fsdp_parity(
# Check parameter devices are CPU if offloading to CPU before calling
# `get_full_params()`, which will cast the parameters to FP32
if offload_params:
cpu_device = torch.device("cpu")
for param in fsdp_model.parameters():
self.assertEqual(param.device, cpu_device)
fsdp_loss = fsdp_loss.cuda()
Expand Down