Skip to content

Commit

Permalink
[FSDP] fix: fix for fsdp zero2 validation error (pytorch#110139)
Browse files Browse the repository at this point in the history
# Problem
When sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch is turned on, the validation after the train has an incorrect weight shape.
<img width="1508" alt="image" src="https://github.com/pytorch/pytorch/assets/41232043/57a9c3bb-cb5c-46df-ac26-922740686f9e">

# Analyze
When using `SHARD_GRAD_OP`, the `free_unsharded_flat_param` in `_post_forward_reshard` is often False, so it does not set the handle's `_prefetched` flag to False after the forward.

The normal train phase sets this flag to False in the `_post_backward_final_callback`, and the validation phase doesn't execute the hook, so after the first iter of the validation is done, the flag of the handle of the prefetched will remain True.

This will cause the handle to skip the `_unshard` in the next `_pre_forward_unshard`, and the `_prefetch_handle` will not do a prefetch, which will result in an incorrect weight shape.
Pull Request resolved: pytorch#110139
Approved by: https://github.com/awgu
  • Loading branch information
Edwiv authored and yeounoh committed Oct 16, 2023
1 parent 57bfbed commit 70a05f5
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
103 changes: 103 additions & 0 deletions test/distributed/fsdp/test_fsdp_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,109 @@ def _check_device_matches(module, device_id):
nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
)

@skip_if_lt_x_gpu(2)
def test_fsdp_zero2_eval_with_prefetch(self):
# Test FSDP validation with SHARD_GRAD_OP and forward_prefetch

class Mnist(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.ln = nn.LayerNorm(9216)

def forward(self, x, y):
x = self.conv1(x)
x = torch.nn.functional.relu(x)
x = self.conv2(x)
x = torch.nn.functional.relu(x)
x = torch.nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.ln(x)
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = torch.nn.functional.log_softmax(x, dim=1)
loss = torch.nn.functional.cross_entropy(output, y)
return loss

model = Mnist().cuda()
model1 = Mnist().cuda()
model1.load_state_dict(model.state_dict())
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
forward_prefetch=True,
use_orig_params=True,
auto_wrap_policy=ModuleWrapPolicy([nn.Linear, nn.Conv2d]),
)
ddp_model = torch.nn.parallel.DistributedDataParallel(
model1,
)

fsdp_opt = torch.optim.SGD(fsdp_model.parameters(), lr=1e-4)
ddp_opt = torch.optim.SGD(ddp_model.parameters(), lr=1e-4)

seed = self.rank + 20231010
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

losses = []
grads = []
for i in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
loss = model(x, y).sum()
losses.append(loss)
loss.backward()
opt.step()
grads.append(x.grad)
opt.zero_grad()
assert torch.allclose(losses[0], losses[1])
assert torch.allclose(grads[0], grads[1])
losses.clear()
grads.clear()

with torch.no_grad():
fsdp_model.eval()
ddp_model.eval()
for _ in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
fsdp_loss = fsdp_model(x, y)
ddp_loss = ddp_model(x, y)
assert torch.allclose(fsdp_loss, ddp_loss)

fsdp_model.train()
ddp_model.train()
for i in range(5):
x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
y = torch.randint(low=0, high=9, size=(8,), device="cuda")
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
loss = model(x, y).sum()
losses.append(loss)
loss.backward()
opt.step()
grads.append(x.grad)
opt.zero_grad()
assert torch.allclose(losses[0], losses[1])
assert torch.allclose(grads[0], grads[1])
losses.clear()
grads.clear()

@skip_if_lt_x_gpu(2)
@parametrize("use_second_layer", [True, False])
@parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None])
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def _root_pre_forward(
handles.append(fsdp_state._handle)
for handle in handles:
handle._needs_pre_forward_unshard = True
handle._prefetched = False
_wait_for_computation_stream(
state._device_handle.current_stream(),
state._unshard_stream,
Expand Down

0 comments on commit 70a05f5

Please sign in to comment.