Skip to content

Commit

Permalink
[pipelining] Add grad test for interleaved schedules (#126931)
Browse files Browse the repository at this point in the history
Added `test_grad_with_manual_interleaved`:
- Model: `MultiMLP`
- Tested schedules: Interleaved1F1B, LoopedBFS
- Two stages per rank
```
Rank 0 stages: [0, 2]
Rank 1 stages: [1, 3]
```

Pull Request resolved: #126931
Approved by: https://github.com/wconstab
ghstack dependencies: #126812, #126721, #126735, #126927
  • Loading branch information
kwen2501 authored and pytorchmergebot committed May 23, 2024
1 parent c46b38b commit abf6d4e
Showing 1 changed file with 95 additions and 1 deletion.
96 changes: 95 additions & 1 deletion test/distributed/pipelining/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
PipelineStage,
Schedule1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
Expand All @@ -30,7 +32,6 @@

d_hid = 512
batch_size = 256
chunks = 4

torch.manual_seed(0)

Expand Down Expand Up @@ -63,6 +64,7 @@ def test_kwargs_with_tracer(self, ScheduleClass):
target = torch.randn(batch_size, d_hid, device=self.device)
loss_fn = torch.nn.MSELoss(reduction="sum")

chunks = 4
pipe = pipeline(
mod,
chunks,
Expand Down Expand Up @@ -123,6 +125,7 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass):
ref_loss.backward()

# Create a pipeline
chunks = 4
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
pipe = pipeline(
mod,
Expand Down Expand Up @@ -200,6 +203,7 @@ def test_grad_with_manual(self, ScheduleClass):
# Get a submodule, e.g. `layers.0` or `layers.1`
submod_name = f"layers.{self.rank}"
stage_module = full_mod.get_submodule(submod_name)
chunks = 4
# Create a pipeline stage to wrap that submodule
stage = ManualPipelineStage(
stage_module,
Expand Down Expand Up @@ -247,6 +251,96 @@ def test_grad_with_manual(self, ScheduleClass):
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS])
def test_grad_with_manual_interleaved(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
full_mod = MultiMLP(d_hid, n_layers=n_stages)
full_mod.to(self.device)

ref_mod = copy.deepcopy(full_mod)
x = torch.randn(batch_size, d_hid, device=self.device)
with torch.no_grad():
y = ref_mod(x)
# Add a small perturbation
target = y + torch.randn(batch_size, d_hid, device=self.device)

loss_fn = torch.nn.MSELoss(reduction="sum")

# Run reference
for _ in range(2):
ref_mod.zero_grad()
ref_out = ref_mod(x)
ref_loss = loss_fn(ref_out, target)
ref_loss.backward()

# Get a submodule, e.g. `layers.0` or `layers.1`
stage_indices = [
self.rank + i * self.world_size for i in range(stages_per_rank)
]
print(f"Rank {self.rank} stages: {stage_indices}")
submod_names = [f"layers.{i}" for i in stage_indices]
stage_modules = [
full_mod.get_submodule(submod_name) for submod_name in submod_names
]
# Create a pipeline stage to wrap that submodule
chunks = 8
input_args = x.chunk(chunks)[0]
stages = [
ManualPipelineStage(
stage_module,
stage_idx,
n_stages,
self.device,
chunks,
input_args=input_args,
)
for stage_module, stage_idx in zip(stage_modules, stage_indices)
]

# Attach to a schedule
schedule = ScheduleClass(stages, chunks, loss_fn=loss_fn)

# Run
for _ in range(2):
# Zero gradients
for stage_module in stage_modules:
stage_module.zero_grad()
if self.rank == 0:
schedule.step(x)
elif self.rank == self.world_size - 1:
losses = []
out = schedule.step(target=target, losses=losses)
else:
schedule.step()

dist.barrier()

# Last rank checks result
if self.rank == self.world_size - 1:
# Check output
torch.testing.assert_close(out, ref_out)
# Check loss
# Since the reduction used in the loss function above is "sum", we use
# "sum" here to reduce microbatch losses into a single value too.
pipe_loss = sum(losses)
torch.testing.assert_close(pipe_loss, ref_loss)

# Every rank checks gradients
for stage_module, submod_name in zip(stage_modules, submod_names):
# Get corresponding submodule from reference model
ref_submod = ref_mod.get_submodule(submod_name)
# Check gradients per parameter
for name, p in stage_module.named_parameters():
ref_p = ref_submod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
except AssertionError:
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise


instantiate_parametrized_tests(ScheduleTest)

Expand Down

1 comment on commit abf6d4e

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #126931 on behalf of https://github.com/clee2000 due to newly added test fails distributed/pipelining/test_schedule.py::ScheduleTest::test_grad_with_manual_interleaved_ScheduleClass0 https://hud.pytorch.org/pytorch/pytorch/commit/abf6d4e6bc1a9a0e08bfc2204560ca7858fa90cd https://github.com/pytorch/pytorch/actions/runs/9214413308/job/25352507591, pull workflow failed on startup on PR, so no distributed tests ran at all (comment)

Please sign in to comment.