diff --git a/torchao_hf_vllm/torchao_hf_script.py b/torchao_hf_vllm/torchao_hf_script.py index 3f449a1..d94e4f0 100644 --- a/torchao_hf_vllm/torchao_hf_script.py +++ b/torchao_hf_vllm/torchao_hf_script.py @@ -30,6 +30,7 @@ Int8DynamicActivationInt4WeightConfig, CutlassInt4PackedLayout, ) +from torchao.quantization import ModuleFqnToConfig from torchao.prototype.mx_formats.inference_workflow import MXFPInferenceConfig from torchao.prototype.mx_formats import MXGemmKernelChoice from jsonargparse import CLI, Namespace @@ -67,7 +68,6 @@ def get_quantization_config(args): case "fp8": single_config = Float8DynamicActivationFloat8WeightConfig(granularity=gran) if args.experts_only_qwen_1_5_moe_a_2_7b: - from torchao.quantization import ModuleFqnToConfig expert_fqn_to_config = {} # TODO(future PR): this is annoying, I should be able to use a regex here for layer_idx in range(24): @@ -101,14 +101,39 @@ def get_quantization_config(args): case "mxfp8": return TorchAoConfig(MXFPInferenceConfig()) case "mxfp4": - return TorchAoConfig( - MXFPInferenceConfig( - activation_dtype=torch.float4_e2m1fn_x2, - weight_dtype=torch.float4_e2m1fn_x2, - block_size=32, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, - ) + single_config = MXFPInferenceConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + block_size=32, + # gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + gemm_kernel_choice=MXGemmKernelChoice.EMULATED, ) + if args.experts_only_qwen_1_5_moe_a_2_7b: + expert_fqn_to_config = {} + # TODO(future PR): this is annoying, I should be able to use a regex here + for layer_idx in range(24): + for expert_idx in range(60): + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj"] = single_config + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj"] = single_config + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj"] = single_config + module_fqn_to_config = ModuleFqnToConfig({ + "_default": None, + **expert_fqn_to_config, + }) + return TorchAoConfig( + quant_type=module_fqn_to_config, + ) + else: + modules_to_not_convert = [] + if args.skip_gate_qwen_1_5_moe_a_2_7b: + for layer_idx in range(24): + modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.gate") + modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.shared_expert_gate") + modules_to_not_convert.append(f"lm_head") + return TorchAoConfig( + single_config, + modules_to_not_convert=modules_to_not_convert, + ) case _: raise ValueError(f"Unsupported quantization type: {args.quant_type}") @@ -165,6 +190,7 @@ def main( bench_tokens: int = 100, device_map: str = "cuda", experts_only_qwen_1_5_moe_a_2_7b: bool = False, + skip_gate_qwen_1_5_moe_a_2_7b: bool = False, save_model_to_disk: bool = True, ): """ @@ -182,6 +208,7 @@ def main( bench_tokens: Number of tokens to generate for benchmarking device_map: Device mapping strategy experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model + skip_gate_qwen_1_5_moe_a_2_7b: if True, skips gate quantization for Qwen1.5-MoE-A2.7B model save_model_to_disk: if True, saves quantized model to local disk """ # Set seed before creating the model @@ -206,11 +233,14 @@ def main( device_map=device_map, experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b, save_model_to_disk=save_model_to_disk, + skip_gate_qwen_1_5_moe_a_2_7b=skip_gate_qwen_1_5_moe_a_2_7b, ) print(f"{args=}") if args.experts_only_qwen_1_5_moe_a_2_7b: - assert args.quant_type == "fp8", "unsupported" + assert args.quant_type in ("fp8", "mxfp4"), "unsupported" + + assert not args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b, "unsupported" # Create output directory output_dir = Path(args.output_dir)