Skip to content

Commit

Permalink
Update on "[dynamo][fsdp] Dont take unspecializedNNModuleVariable pat…
Browse files Browse the repository at this point in the history
…h for FSDP modules"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
  • Loading branch information
anijain2305 committed Jun 18, 2024
2 parents 45699a9 + ac1c7d0 commit 278e37e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
9 changes: 7 additions & 2 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@
GradSource,
LocalSource,
NNModuleSource,
NotNNModuleSource,
NumpyTensorSource,
ODictGetItemSource,
OptimizerSource,
ScriptObjectQualifiedNameSource,
ShapeEnvSource,
TupleIteratorGetItemSource,
TypeSource,
UnspecializedNNModuleSource,
WeakRefCallSource,
)
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
Expand Down Expand Up @@ -830,7 +830,12 @@ def get_guard_manager_from_source(self, source):
)
elif istype(
source,
(OptimizerSource, NNModuleSource, NotNNModuleSource, FSDPNNModuleSource),
(
OptimizerSource,
NNModuleSource,
UnspecializedNNModuleSource,
FSDPNNModuleSource,
),
):
assert base_guard_manager # to make mypy happy
out = base_guard_manager
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def name(self):


@dataclasses.dataclass(frozen=True)
class NotNNModuleSource(NNModuleSource):
class UnspecializedNNModuleSource(NNModuleSource):
def guard_source(self):
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]

Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
FSDPNNModuleSource,
GetItemSource,
NNModuleSource,
NotNNModuleSource,
UnspecializedNNModuleSource,
)
from ..utils import (
get_custom_getattr,
Expand Down Expand Up @@ -983,12 +983,12 @@ def __init__(self, value, **kwargs):

@staticmethod
def _wrap_source(source):
if not isinstance(source, (FSDPNNModuleSource, NotNNModuleSource)):
if not isinstance(source, (FSDPNNModuleSource, UnspecializedNNModuleSource)):
if torch._dynamo.config.skip_fsdp_guards:
return FSDPNNModuleSource(source)
else:
# this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
return NotNNModuleSource(source)
return UnspecializedNNModuleSource(source)
else:
return source

Expand Down

0 comments on commit 278e37e

Please sign in to comment.