Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion test/dynamo/test_guard_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
import sys
import types
import unittest
from collections.abc import Iterator
from unittest.mock import patch

Expand All @@ -23,6 +24,8 @@
)
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext, tracing
from torch.overrides import TorchFunctionMode
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils import _pytree as pytree


Expand All @@ -43,6 +46,13 @@ def global_func(x):
return x + 1


class GlobalTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)


class SubclassWithMeta(torch.Tensor):
@staticmethod
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
Expand Down Expand Up @@ -292,7 +302,7 @@ def transform(instructions: list, code_options: dict[str, object]):
self._frame_state.f_globals,
self._frame_state.f_builtins,
fn.__closure__ or (),
[], # TODO tf_mode_stack,
torch.overrides._get_current_function_mode_stack(),
code_options,
torch._dynamo.lookup_backend("eager"),
one_graph=False,
Expand Down Expand Up @@ -1100,6 +1110,92 @@ def fn(x):
self._test_check_fn(ref, loaded, {"m": m, "x": x}, False)
h.remove()

def test_grad_mode(self):
def fn(x):
return x + 1

x = torch.randn(3, 2)
with torch.enable_grad():
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad():
self._test_check_fn(ref, loaded, {"x": x}, True)

def test_deterministic_algorithms(self):
def fn(x):
return x + 1

deterministic_restore = torch.are_deterministic_algorithms_enabled()
try:
x = torch.randn(3, 2)
torch.use_deterministic_algorithms(True)
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
torch.use_deterministic_algorithms(False)
self._test_check_fn(ref, loaded, {"x": x}, False)
torch.use_deterministic_algorithms(True)
self._test_check_fn(ref, loaded, {"x": x}, True)
finally:
torch.use_deterministic_algorithms(deterministic_restore)

def test_torch_function_state(self):
def fn(x):
return x + 1

x = torch.randn(3, 2)

class LocalTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)

with GlobalTorchFunctionMode():
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
self._test_check_fn(ref, loaded, {"x": x}, False)
with GlobalTorchFunctionMode():
with torch._C.DisableTorchFunction():
self._test_check_fn(ref, loaded, {"x": x}, False)
with self.assertRaisesRegex(
RuntimeError,
"defined in local scope. Please define the class at global scope",
):
with LocalTorchFunctionMode():
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_training_state(self):
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup

param_group = FSDPParamGroup(
[], # params: List[nn.Parameter],
(torch.nn.Linear(1, 1),), # module: nn.Module,
None, # mesh_info: FSDPMeshInfo,
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
torch.device("cpu"), # device: torch.device,
None, # shard_placement_fn: Optional[Callable],
None, # mp_policy: MixedPrecisionPolicy,
None, # offload_policy: OffloadPolicy,
)

def fn(x):
with param_group.use_training_state(TrainingState.FORWARD):
if param_group._training_state == TrainingState.FORWARD:
return x + 1
else:
return x - 1

x = torch.randn(3, 2)

with torch.enable_grad():
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
with torch.no_grad():
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch.enable_grad():
self._test_check_fn(ref, loaded, {"x": x}, True)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
Loading