Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 78 additions & 47 deletions benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class ExperimentConfig:

@dataclass(frozen=True)
class ExperimentResult:
bf16_ms: float
mxfp8_ms: float
fwd_bf16_ms: float
fwd_mxfp8_ms: float
bwd_bf16_ms: float
bwd_mxfp8_ms: float


@dataclass(frozen=True)
Expand All @@ -55,6 +57,10 @@ class Experiment:
def get_configs() -> List[ExperimentConfig]:
# (batch_size, seq_len, dim)
input_shapes = [
(1, 8192, 5120),
(2, 8192, 5120),
(4, 8192, 5120),
(8, 8192, 5120),
(16, 8192, 5120),
]
configs = []
Expand All @@ -67,9 +73,8 @@ def get_configs() -> List[ExperimentConfig]:
return configs


def default_a2a_fwd_bwd(
def default_a2a_fwd(
routed_input: torch.Tensor,
labels: torch.Tensor,
output_splits_list: list[int],
input_splits_list: list[int],
device_mesh: DeviceMesh,
Expand All @@ -81,17 +86,12 @@ def default_a2a_fwd_bwd(
device_mesh.get_group(),
)
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)

loss = F.mse_loss(routed_input, labels)
loss.backward()

torch.cuda.synchronize()
return routed_input


def mxfp8_a2a_fwd_bwd(
def mxfp8_a2a_fwd(
routed_input: torch.Tensor,
labels: torch.Tensor,
output_splits_list: list[int],
input_splits_list: list[int],
device_mesh: DeviceMesh,
Expand All @@ -102,16 +102,22 @@ def mxfp8_a2a_fwd_bwd(
input_splits_list,
device_mesh.get_group(),
)
torch.cuda.synchronize()
return routed_input


def mse_loss_and_bwd(
routed_input: torch.Tensor,
labels: torch.Tensor,
):
loss = F.mse_loss(routed_input, labels)
loss.backward()
torch.cuda.synchronize()
return routed_input


# Compile target funcs
default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd)
mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd)
mse_loss_and_bwd_compiled = torch.compile(mse_loss_and_bwd)


def run_experiment(
Expand All @@ -129,82 +135,105 @@ def run_experiment(
# Set up device mesh
mesh = init_device_mesh("cuda", (dist.get_world_size(),))

# Max output tokens per rank is worst case where one rank receives all tokens
input_tokens_per_rank = batch_size * seq_len

def warmup(func_no_args):
for _ in range(2):
func_no_args()

input_tokens_per_rank = batch_size * seq_len
num_experts_per_rank = 2
num_splits = dist.get_world_size() * num_experts_per_rank
input_splits = generate_split_sizes(
num_splits, input_tokens_per_rank, device=device
)
num_experts = dist.get_world_size() * num_experts_per_rank
input_tokens_per_expert = input_tokens_per_rank // num_experts
input_splits = torch.tensor(
input_tokens_per_expert, dtype=torch.int32, device=device
).repeat(num_experts)
input_splits_list, output_splits_list = get_split_lists(input_splits, mesh)

# Generate labels
labels_shape = (sum(output_splits_list), dim)
labels = x.new_ones(*labels_shape)

# Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(
lambda: default_a2a_sync_compiled(
ref_x, labels, output_splits_list, input_splits_list, mesh
)
)
# Bench default a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(lambda: default_a2a_fwd(ref_x, output_splits_list, input_splits_list, mesh))
start_sec = time.perf_counter()
default_a2a_sync_compiled(
ref_x, labels, output_splits_list, input_splits_list, mesh
bf16_routed_input = default_a2a_fwd(
ref_x, output_splits_list, input_splits_list, mesh
)
end_sec = time.perf_counter()
bf16_ms = (end_sec - start_sec) * 1e3
fwd_bf16_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
default_a2a_sync_compiled,
default_a2a_fwd,
ref_x,
labels,
output_splits_list,
input_splits_list,
mesh,
distributed=True,
profile_name="all_to_all_single_autograd",
profile_name="default_a2a_fwd",
)

# Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(
lambda: mxfp8_a2a_sync_compiled(
x, labels, output_splits_list, input_splits_list, mesh
# Bench default a2a backward
warmup(lambda: mse_loss_and_bwd_compiled(bf16_routed_input, labels))
start_sec = time.perf_counter()
mse_loss_and_bwd_compiled(bf16_routed_input, labels)
end_sec = time.perf_counter()
bwd_bf16_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
mse_loss_and_bwd_compiled,
bf16_routed_input,
labels,
distributed=True,
profile_name="bf16_a2a_bwd",
)
)

# Bench mxfp8 sync a2a fwd (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(lambda: mxfp8_a2a_fwd(x, output_splits_list, input_splits_list, mesh))
start_sec = time.perf_counter()
mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh)
mxfp8_routed_input = mxfp8_a2a_fwd(x, output_splits_list, input_splits_list, mesh)
end_sec = time.perf_counter()
mxfp8_ms = (end_sec - start_sec) * 1e3
fwd_mxfp8_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
mxfp8_a2a_sync_compiled,
mxfp8_a2a_fwd,
x,
labels,
output_splits_list,
input_splits_list,
mesh,
distributed=True,
profile_name="to_mxfp8_a2a_dequant",
profile_name="mxfp8_a2a_fwd",
)

# Bench mxfp8 sync a2a backward
warmup(lambda: mse_loss_and_bwd_compiled(mxfp8_routed_input, labels))
start_sec = time.perf_counter()
mse_loss_and_bwd_compiled(mxfp8_routed_input, labels)
end_sec = time.perf_counter()
bwd_mxfp8_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
mse_loss_and_bwd_compiled,
mxfp8_routed_input,
labels,
distributed=True,
profile_name="mxfp8_a2a_bwd",
)

return ExperimentResult(
bf16_ms=bf16_ms,
mxfp8_ms=mxfp8_ms,
fwd_bf16_ms=fwd_bf16_ms,
fwd_mxfp8_ms=fwd_mxfp8_ms,
bwd_bf16_ms=bwd_bf16_ms,
bwd_mxfp8_ms=bwd_mxfp8_ms,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"num_splits",
"bf16_ms",
"mxfp8_ms",
"fwd_bf16_ms",
"fwd_mxfp8_ms",
"bwd_bf16_ms",
"bwd_mxfp8_ms",
]
rows = []
num_splits = dist.get_world_size()
Expand All @@ -213,8 +242,10 @@ def print_results(experiments: List[Experiment]):
[
str(experiment.config.input_shape),
num_splits,
experiment.result.bf16_ms,
experiment.result.mxfp8_ms,
experiment.result.fwd_bf16_ms,
experiment.result.fwd_mxfp8_ms,
experiment.result.bwd_bf16_ms,
experiment.result.bwd_mxfp8_ms,
]
)
print(tabulate(rows, headers=headers))
Expand Down
Loading