diff --git a/torchao_hf_vllm/torchao_hf_script.py b/torchao_hf_vllm/torchao_hf_script.py index d94e4f0..ef8f78d 100644 --- a/torchao_hf_vllm/torchao_hf_script.py +++ b/torchao_hf_vllm/torchao_hf_script.py @@ -31,7 +31,11 @@ CutlassInt4PackedLayout, ) from torchao.quantization import ModuleFqnToConfig -from torchao.prototype.mx_formats.inference_workflow import MXFPInferenceConfig +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) from torchao.prototype.mx_formats import MXGemmKernelChoice from jsonargparse import CLI, Namespace from rich import print @@ -134,6 +138,38 @@ def get_quantization_config(args): single_config, modules_to_not_convert=modules_to_not_convert, ) + case "nvfp4": + single_config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.WEIGHT_ONLY, + use_triton_kernel=False, + use_dynamic_per_tensor_scale=False, + ) + if args.experts_only_qwen_1_5_moe_a_2_7b: + expert_fqn_to_config = {} + # TODO(future PR): this is annoying, I should be able to use a regex here + for layer_idx in range(24): + for expert_idx in range(60): + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj"] = single_config + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj"] = single_config + expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj"] = single_config + module_fqn_to_config = ModuleFqnToConfig({ + "_default": None, + **expert_fqn_to_config, + }) + return TorchAoConfig( + quant_type=module_fqn_to_config, + ) + else: + modules_to_not_convert = [] + if args.skip_gate_qwen_1_5_moe_a_2_7b: + for layer_idx in range(24): + modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.gate") + modules_to_not_convert.append(f"model.layers.{layer_idx}.mlp.shared_expert_gate") + modules_to_not_convert.append(f"lm_head") + return TorchAoConfig( + single_config, + modules_to_not_convert=modules_to_not_convert, + ) case _: raise ValueError(f"Unsupported quantization type: {args.quant_type}") @@ -182,6 +218,7 @@ def main( "A8W4", "fp8", "mxfp4", + "nvfp4", ] = "fp8", granularity: Literal["per_row", "per_tensor"] = "per_row", min_sqnr: Optional[float] = None, @@ -238,9 +275,9 @@ def main( print(f"{args=}") if args.experts_only_qwen_1_5_moe_a_2_7b: - assert args.quant_type in ("fp8", "mxfp4"), "unsupported" + assert args.quant_type in ("fp8", "mxfp4", "nvfp4"), "unsupported" - assert not args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b, "unsupported" + assert not (args.skip_gate_qwen_1_5_moe_a_2_7b and args.experts_only_qwen_1_5_moe_a_2_7b), "unsupported" # Create output directory output_dir = Path(args.output_dir)