Skip to content

Commit

Permalink
[pipelining] Consolidate test models into a registry
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed May 13, 2024
1 parent 71d4ab7 commit 2b8cbca
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 66 deletions.
Empty file.
61 changes: 61 additions & 0 deletions test/distributed/pipelining/model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
# This file is a model zoo for testing torch.distributed.pipelining.
import torch
from torch.distributed.pipelining import pipe_split


class ExampleCode(torch.nn.Module):
default_dhid = 512
default_batch_size = 256

def __init__(self, d_hid: int = default_dhid):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.register_buffer("cval", torch.randn((d_hid,), requires_grad=False))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x, y=torch.zeros(default_batch_size, default_dhid)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = torch.relu(x)
# try passing a value that doesn't require_grad across skip boundaries
a_constant = self.cval.clone()
x = self.lin0(x)
pipe_split()
x = torch.relu(x) + a_constant
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
return x


# MLP Layer
class MLPModule(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x


# Multi-MLP model
class MultiMLP(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.mlp0 = MLPModule(d_hid)
self.mlp1 = MLPModule(d_hid)

def forward(self, x):
x = self.mlp0(x)
pipe_split()
x = self.mlp1(x)
return x
35 changes: 12 additions & 23 deletions test/distributed/pipelining/test_pipe.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch

from model_registry import MLPModule
from torch.distributed.pipelining import pipe_split, pipeline
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)


d_hid = 512
Expand Down Expand Up @@ -39,21 +46,6 @@ def forward(self, x, y):
return x


# MLP example
class MLPModule(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x


class MultiMLP(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -74,8 +66,9 @@ def forward(self, x, y):


class PipeTests(TestCase):
def _test_model_split(self, model_class):
mod = model_class()
@parametrize("ModelClass", [ExampleCode, MultiMLP])
def test_model_split(self, ModelClass):
mod = ModelClass()
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)

Expand Down Expand Up @@ -108,12 +101,8 @@ def _test_model_split(self, model_class):
"""
print("Qualname check passed")

def test_example_code(self):
self._test_model_split(ExampleCode)

def test_multi_mlp(self):
self._test_model_split(MultiMLP)

instantiate_parametrized_tests(PipeTests)

if __name__ == "__main__":
run_tests()
105 changes: 78 additions & 27 deletions test/distributed/pipelining/test_schedule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import copy
import os
import sys
import tempfile

import torch
import torch.distributed as dist

from model_registry import ExampleCode, MultiMLP
from torch.distributed.pipelining import (
pipe_split,
pipeline,
PipelineStage,
Schedule1F1B,
Expand All @@ -32,30 +34,6 @@
torch.manual_seed(0)


class ExampleCode(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.register_buffer("cval", torch.randn((d_hid,), requires_grad=False))
self.lin0 = torch.nn.Linear(d_hid, d_hid)
self.lin1 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x, y=torch.zeros(batch_size, d_hid)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = torch.relu(x)
# try passing a value that doesn't require_grad across skip boundaries
a_constant = self.cval.clone()
x = self.lin0(x)
pipe_split()
x = torch.relu(x) + a_constant
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
x = torch.relu(x)
return x


class ScheduleTest(MultiProcContinousTest):
@classmethod
def backend_str(cls) -> str:
Expand All @@ -78,7 +56,7 @@ def test_ec_forward(self):
# Setting this flag for numerical stability
torch.distributed.pipelining.microbatch._debug_mask_minibatches = True

mod = ExampleCode()
mod = ExampleCode(d_hid)
mod.to(self.device)

x = torch.randn(batch_size, d_hid, device=self.device)
Expand Down Expand Up @@ -125,7 +103,7 @@ def test_ec_forward(self):
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_ec_backward(self, ScheduleClass):
mod = ExampleCode()
mod = ExampleCode(d_hid)
mod.to(self.device)

x = torch.randn(batch_size, d_hid, device=self.device)
Expand Down Expand Up @@ -168,6 +146,79 @@ def test_ec_backward(self, ScheduleClass):
torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3)
torch.testing.assert_close(pipe_loss, ref_loss)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
def test_grad(self, ScheduleClass):
mod = MultiMLP(d_hid)
mod.to(self.device)

ref_mod = copy.deepcopy(mod)
x = torch.randn(batch_size, d_hid, device=self.device)
with torch.no_grad():
y = ref_mod(x)
# Add a small perturbation
target = y + torch.randn(batch_size, d_hid, device=self.device)

loss_fn = torch.nn.MSELoss(reduction="sum")

# Run reference
for _ in range(2):
ref_mod.zero_grad()
ref_out = ref_mod(x)
ref_loss = loss_fn(ref_out, target)
ref_loss.backward()

# Create a pipeline
pipe = pipeline(
mod,
chunks,
example_args=(x,),
)

stage = PipelineStage(
pipe,
self.rank,
device=self.device,
)

# Attach to a schedule
schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn)

# Run
stage_module = pipe.get_stage_module(self.rank)
for _ in range(2):
# Zero gradients
stage_module.zero_grad()
if self.rank == 0:
schedule.step(x)
elif self.rank == self.world_size - 1:
losses = []
out = schedule.step(target=target, losses=losses)
else:
schedule.step()

dist.barrier()

# Last rank checks result
if self.rank == self.world_size - 1:
# Check output
torch.testing.assert_close(out, ref_out)
# Check loss
# Since the reduction used in the loss function above is "sum", we use
# "sum" here to reduce microbatch losses into a single value too.
pipe_loss = sum(losses)
torch.testing.assert_close(pipe_loss, ref_loss)

# Every rank checks gradients
for name, p in stage_module.named_parameters():
ref_p = ref_mod.get_parameter(name)
try:
torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5)
except AssertionError:
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise


instantiate_parametrized_tests(ScheduleTest)

Expand Down
18 changes: 2 additions & 16 deletions test/distributed/pipelining/test_stage_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import copy

import torch

from model_registry import MLPModule
from torch.distributed.pipelining._backward import stage_backward
from torch.testing._internal.common_utils import run_tests, TestCase

Expand All @@ -11,20 +13,6 @@
batch_size = 256


class MLPModule(torch.nn.Module):
def __init__(self, d_hid):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid, d_hid)

def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x


class StageBackwardTests(TestCase):
def test_stage_backward(self):
# MLP as a stage module
Expand Down Expand Up @@ -65,8 +53,6 @@ def test_stage_backward(self):
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
raise

print("Stage backward test passed")


if __name__ == "__main__":
run_tests()

0 comments on commit 2b8cbca

Please sign in to comment.