Skip to content

autobucketing does not support deepseekv3 #2037

@BoyuanFeng

Description

@BoyuanFeng

Repro:

NGPU=4 CONFIG_FILE=/data/users/boyuan/torchtitan/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering

Error:

    File "/data/users/boyuan/torchtitan/torchtitan/experiments/compiler_toolkit/passes.py", line 28, in autobucketing_reordering_pass
      schedule_overlap_bucketing(gm, collective_bucketing=True, collective_estimator = "benchmark",)
    File "/data/users/boyuan/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 1032, in schedule_overlap_bucketing
      return OverlapScheduler(
             ^^^^^^^^^^^^^^^^^
    File "/data/users/boyuan/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 281, in __init__
      self._identify_collectives()
    File "/data/users/boyuan/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 323, in _identify_collectives
      coll_time_ms = estimate_collective_time(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/boyuan/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 65, in estimate_collective_time
      return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/boyuan/pytorch/torch/_inductor/comm_analysis.py", line 397, in estimate_nccl_collective_runtime_from_fx_node
      coll = get_collective_type_from_kernel_name(fx_node.target.name())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/data/users/boyuan/pytorch/torch/_inductor/comm_analysis.py", line 59, in get_collective_type_from_kernel_name
      raise ValueError(f"Unsupported collective kernel: {kernel_name}")
  ValueError: Unsupported collective kernel: _c10d_functional::all_to_all_single

I tried analytical mode for autobucketing, which errors with all_to_all not supported.

I also tried benchmark mode. It shows the same error all_to_all not supported. I checked a bit. No matter values of collective_estimator, OverlapScheduler calls _identify_collectives() anyway, which finally uses analytic model.

cc @eellison @ruisizhang123 @tianyu-l

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions