Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions test/distributed/fsdp/test_checkpoint_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Owner(s): ["oncall: distributed"]

from copy import deepcopy

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)

from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)

class CheckpointWrapperTest(TestCase):
def setUp(self):
super().setUp()

def test_load_activation_checkpointed_module(self):
lin = nn.Linear(10, 10, bias=False)
lin = checkpoint_wrapper(lin)
state_dict = deepcopy(lin.state_dict())
# Load into non-checkpoint wrapped linear module
lin_new = nn.Linear(10, 10, bias=False)
lin_new.load_state_dict(state_dict)
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)
self.assertTrue(torch.allclose(p1, p2))

# Load non-checkpoint wrapped module into checkpoint wrapped one
# Make params different
for p in lin_new.parameters():
with torch.no_grad():
p.add_(0.5)

state_dict = deepcopy(lin_new.state_dict())
# Verify checkpoint wrapped linear can load unwrapped linear
lin.load_state_dict(state_dict)
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)


if __name__ == "__main__":
run_tests()
26 changes: 0 additions & 26 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,32 +187,6 @@ def _validate_state_dict_contents(
if isinstance(model, FSDP):
self.assertEqual(fsdp_state_dict, {})

@skip_if_lt_x_gpu(2)
def test_load_activation_checkpointed_module(self):
# TODO: move this tests to checkpoint_wrapper tests once there is a dedicated
# test suite for them: https://github.com/pytorch/pytorch/issues/77478.
lin = nn.Linear(10, 10, bias=False).cuda()
lin = checkpoint_wrapper(lin)
state_dict = deepcopy(lin.state_dict())
# Load into non-checkpoint wrapped linear module
lin_new = nn.Linear(10, 10, bias=False).cuda()
lin_new.load_state_dict(state_dict)
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)

# Load non-checkpoint wrapped module into checkpoint wrapped one
# Make params different
for p in lin_new.parameters():
with torch.no_grad():
p.add_(0.5)

state_dict = deepcopy(lin_new.state_dict())
# Verify checkpoint wrapped linear can load unwrapped linear
lin.load_state_dict(state_dict)
print(type(lin))
for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
self.assertEqual(p1, p2)

@skip_if_lt_x_gpu(2)
@parametrize("checkpoint_wrap", ["first", "second", "both"])
def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap):
Expand Down