diff --git a/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py b/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py index cac7a41..6aa1c64 100644 --- a/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py +++ b/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py @@ -17,37 +17,12 @@ from safetensors import safe_open from safetensors.torch import save_file -from utils import convert_pt_statedict_to_safetensors, convert_pt_multifile_index_to_safetensors - -def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]: - # for now, allowlist of recipes we know how to convert and hand convert - # them here - # for a production version, we'll need a more scalable way to do this - - assert isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig), "unsupported" - assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported" - - ct_config = { - "format": "float-quantized", - "input_activations": { - "dynamic": True, - "num_bits": 8, - "strategy": "token", - "symmetric": True, - "type": "float", - }, - "output_activations": None, - "targets": ["Linear"], - "weights": { - "dynamic": False, - "num_bits": 8, - "observer": "minmax", - "strategy": "channel", - "symmetric": True, - "type": "float", - }, - } - return ct_config +from utils import ( + convert_pt_statedict_to_safetensors, + convert_pt_multifile_index_to_safetensors, + ao_config_to_compressed_tensors_config, +) + def run( # original torchao checkpoint diff --git a/hf_torchao_vllm/inspect_torchao_output.py b/hf_torchao_vllm/inspect_torchao_output.py index 16d6f05..282aa12 100644 --- a/hf_torchao_vllm/inspect_torchao_output.py +++ b/hf_torchao_vllm/inspect_torchao_output.py @@ -10,6 +10,13 @@ from utils import inspect_model_state_dict +# ensure NVFP4Tensor can be loaded +import torchao.prototype.mx_formats.inference_workflow + +# TODO: ensure the line below happens in torchao +import torchao +torch.serialization.add_safe_globals([torchao.prototype.mx_formats.nvfp4_tensor.QuantizeTensorToNVFP4Kwargs]) + # not sure why I still need this torch.serialization.add_safe_globals([getattr]) diff --git a/hf_torchao_vllm/quantize_hf_model_with_torchao.py b/hf_torchao_vllm/quantize_hf_model_with_torchao.py index 777084f..a6c6546 100644 --- a/hf_torchao_vllm/quantize_hf_model_with_torchao.py +++ b/hf_torchao_vllm/quantize_hf_model_with_torchao.py @@ -147,6 +147,15 @@ def get_quantization_config(args): single_config = NVFP4InferenceConfig( mm_config=NVFP4MMConfig.WEIGHT_ONLY, use_triton_kernel=False, + # + # weight_only and use_dynamic_per_tensor_scale=True works here + # but garbage output in vLLM, probably because we currently don't have a way + # in torchao to enforce the scales for attention and ffn weights that + # are going to be fused for inference to be the same + # TODO: file a torchao issue about this, and fix in torchao + # + # dynamic and use_dynamic_per_tensor_scale=False not supported in torch._scaled_mm + # use_dynamic_per_tensor_scale=False, ) if args.experts_only_qwen_1_5_moe_a_2_7b: diff --git a/hf_torchao_vllm/run_quantized_model_in_vllm.py b/hf_torchao_vllm/run_quantized_model_in_vllm.py index 2302498..3976c7e 100644 --- a/hf_torchao_vllm/run_quantized_model_in_vllm.py +++ b/hf_torchao_vllm/run_quantized_model_in_vllm.py @@ -45,6 +45,8 @@ def print_vllm_torchao_quant_info(model: torch.nn.Module): for name, mod in model.named_modules(): if "Linear" not in str(type(mod)): continue + if not hasattr(mod, "weight"): + continue mod_and_weight_type = type(mod), type(mod.weight) if mod_and_weight_type in seen_types: continue diff --git a/hf_torchao_vllm/utils.py b/hf_torchao_vllm/utils.py index 365b26f..09eb03a 100644 --- a/hf_torchao_vllm/utils.py +++ b/hf_torchao_vllm/utils.py @@ -1,13 +1,21 @@ import copy import json import os -from typing import List +from typing import List, Dict, Any import pathlib import safetensors from safetensors.torch import save_file import torch +import torchao +from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor + +from torchao.core.config import AOBaseConfig, config_from_dict +from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow +from torchao.prototype.mx_formats.inference_workflow import NVFP4InferenceConfig +from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor +from torchao.prototype.mx_formats.utils import from_blocked from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor @@ -80,12 +88,42 @@ def convert_pt_statedict_to_safetensors( v = copy.deepcopy(v) new_state_dict[k] = v + elif type(v) == Float8Tensor: new_state_dict[k] = v.qdata # for now, manually cast scale to bfloat16 to match current # llm-compressor script # TODO(future): prob needs to be user controllable new_state_dict[k + '_scale'] = v.scale.bfloat16() + + elif type(v) == NVFP4Tensor: + # example checkpoint format: https://www.internalfb.com/phabricator/paste/view/P1981272933 + + # torchao does not support nvfp4 activation calibration yet, + # set activation global scale to 1.0 + new_state_dict[k.replace('weight', 'input_global_scale')] = torch.tensor([1.0]) + # torchao does not support fusion-aware nvfp4 weight global scale yet, + # set weight scale to 1.0 + new_state_dict[k.replace('weight', 'weight_global_scale')] = torch.tensor([1.0]) + new_state_dict[k + '_packed'] = v.qdata.view(torch.uint8) + # compressed-tensors stores the nvfp4 scale in row-major format, + # convert from swizzled to row-major + swizzled_scale = v._scale_e4m3 + original_rows = v.qdata.shape[0] + # multiply by 2 to undo the packing, then divide by nvfp4 block size of 16 + original_cols = v.qdata.shape[1] * 2 // 16 + # TODO(future) also do the padding calculation here and remove the + # assertions + assert original_rows % 128 == 0, "unsupported" + assert original_cols % 4 == 0, "unsupported" + # import pdb; pdb.set_trace() + row_major_scale = from_blocked( + swizzled_scale, + original_rows, + original_cols, + ) + new_state_dict[k + '_scale'] = row_major_scale + else: raise AssertionError(f'unsupported type {type(v)}') save_file(new_state_dict, safetensors_statedict_filename) @@ -145,3 +183,65 @@ def convert_pt_multifile_index_to_safetensors( # print(json.dumps(source_mapping, indent=2)) with open(target_filename, 'w') as f: json.dump(source_mapping, f, indent=2) + + +def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]: + # for now, allowlist of recipes we know how to convert and hand convert + # them here + # for a production version, we'll need a more scalable way to do this + + if isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig): + assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported" + + ct_config = { + "format": "float-quantized", + "input_activations": { + "dynamic": True, + "num_bits": 8, + "strategy": "token", + "symmetric": True, + "type": "float", + }, + "output_activations": None, + "targets": ["Linear"], + "weights": { + "dynamic": False, + "num_bits": 8, + "observer": "minmax", + "strategy": "channel", + "symmetric": True, + "type": "float", + }, + } + + elif isinstance(aobaseconfig, NVFP4InferenceConfig): + + ct_config = { + "format": "nvfp4-pack-quantized", + "input_activations": { + "dynamic": "local", + "group_size": 16, + "num_bits": 4, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "tensor_group", + "symmetric": True, + "type": "float", + }, + "output_activations": None, + "targets": ["Linear"], + "weights": { + "dynamic": False, + "group_size": 16, + "num_bits": 4, + "observer": "minmax", + "observer_kwargs": {}, + "strategy": "tensor_group", + "symmetric": True, + "type": "float" + }, + } + + else: + raise AssertionError(f"unsupported type {type(aobaseconfig)}") + return ct_config