Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,12 @@
from safetensors import safe_open
from safetensors.torch import save_file

from utils import convert_pt_statedict_to_safetensors, convert_pt_multifile_index_to_safetensors

def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]:
# for now, allowlist of recipes we know how to convert and hand convert
# them here
# for a production version, we'll need a more scalable way to do this

assert isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig), "unsupported"
assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported"

ct_config = {
"format": "float-quantized",
"input_activations": {
"dynamic": True,
"num_bits": 8,
"strategy": "token",
"symmetric": True,
"type": "float",
},
"output_activations": None,
"targets": ["Linear"],
"weights": {
"dynamic": False,
"num_bits": 8,
"observer": "minmax",
"strategy": "channel",
"symmetric": True,
"type": "float",
},
}
return ct_config
from utils import (
convert_pt_statedict_to_safetensors,
convert_pt_multifile_index_to_safetensors,
ao_config_to_compressed_tensors_config,
)


def run(
# original torchao checkpoint
Expand Down
7 changes: 7 additions & 0 deletions hf_torchao_vllm/inspect_torchao_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

from utils import inspect_model_state_dict

# ensure NVFP4Tensor can be loaded
import torchao.prototype.mx_formats.inference_workflow

# TODO: ensure the line below happens in torchao
import torchao
torch.serialization.add_safe_globals([torchao.prototype.mx_formats.nvfp4_tensor.QuantizeTensorToNVFP4Kwargs])

# not sure why I still need this
torch.serialization.add_safe_globals([getattr])

Expand Down
9 changes: 9 additions & 0 deletions hf_torchao_vllm/quantize_hf_model_with_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ def get_quantization_config(args):
single_config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
use_triton_kernel=False,
#
# weight_only and use_dynamic_per_tensor_scale=True works here
# but garbage output in vLLM, probably because we currently don't have a way
# in torchao to enforce the scales for attention and ffn weights that
# are going to be fused for inference to be the same
# TODO: file a torchao issue about this, and fix in torchao
#
# dynamic and use_dynamic_per_tensor_scale=False not supported in torch._scaled_mm
#
use_dynamic_per_tensor_scale=False,
)
if args.experts_only_qwen_1_5_moe_a_2_7b:
Expand Down
2 changes: 2 additions & 0 deletions hf_torchao_vllm/run_quantized_model_in_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def print_vllm_torchao_quant_info(model: torch.nn.Module):
for name, mod in model.named_modules():
if "Linear" not in str(type(mod)):
continue
if not hasattr(mod, "weight"):
continue
mod_and_weight_type = type(mod), type(mod.weight)
if mod_and_weight_type in seen_types:
continue
Expand Down
102 changes: 101 additions & 1 deletion hf_torchao_vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import copy
import json
import os
from typing import List
from typing import List, Dict, Any
import pathlib

import safetensors
from safetensors.torch import save_file

import torch
import torchao
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor

from torchao.core.config import AOBaseConfig, config_from_dict
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
from torchao.prototype.mx_formats.inference_workflow import NVFP4InferenceConfig
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
from torchao.prototype.mx_formats.utils import from_blocked
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor


Expand Down Expand Up @@ -80,12 +88,42 @@ def convert_pt_statedict_to_safetensors(
v = copy.deepcopy(v)

new_state_dict[k] = v

elif type(v) == Float8Tensor:
new_state_dict[k] = v.qdata
# for now, manually cast scale to bfloat16 to match current
# llm-compressor script
# TODO(future): prob needs to be user controllable
new_state_dict[k + '_scale'] = v.scale.bfloat16()

elif type(v) == NVFP4Tensor:
# example checkpoint format: https://www.internalfb.com/phabricator/paste/view/P1981272933

# torchao does not support nvfp4 activation calibration yet,
# set activation global scale to 1.0
new_state_dict[k.replace('weight', 'input_global_scale')] = torch.tensor([1.0])
# torchao does not support fusion-aware nvfp4 weight global scale yet,
# set weight scale to 1.0
new_state_dict[k.replace('weight', 'weight_global_scale')] = torch.tensor([1.0])
new_state_dict[k + '_packed'] = v.qdata.view(torch.uint8)
# compressed-tensors stores the nvfp4 scale in row-major format,
# convert from swizzled to row-major
swizzled_scale = v._scale_e4m3
original_rows = v.qdata.shape[0]
# multiply by 2 to undo the packing, then divide by nvfp4 block size of 16
original_cols = v.qdata.shape[1] * 2 // 16
# TODO(future) also do the padding calculation here and remove the
# assertions
assert original_rows % 128 == 0, "unsupported"
assert original_cols % 4 == 0, "unsupported"
# import pdb; pdb.set_trace()
row_major_scale = from_blocked(
swizzled_scale,
original_rows,
original_cols,
)
new_state_dict[k + '_scale'] = row_major_scale

else:
raise AssertionError(f'unsupported type {type(v)}')
save_file(new_state_dict, safetensors_statedict_filename)
Expand Down Expand Up @@ -145,3 +183,65 @@ def convert_pt_multifile_index_to_safetensors(
# print(json.dumps(source_mapping, indent=2))
with open(target_filename, 'w') as f:
json.dump(source_mapping, f, indent=2)


def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]:
# for now, allowlist of recipes we know how to convert and hand convert
# them here
# for a production version, we'll need a more scalable way to do this

if isinstance(aobaseconfig, Float8DynamicActivationFloat8WeightConfig):
assert aobaseconfig.granularity == [PerRow(), PerRow()], "unsupported"

ct_config = {
"format": "float-quantized",
"input_activations": {
"dynamic": True,
"num_bits": 8,
"strategy": "token",
"symmetric": True,
"type": "float",
},
"output_activations": None,
"targets": ["Linear"],
"weights": {
"dynamic": False,
"num_bits": 8,
"observer": "minmax",
"strategy": "channel",
"symmetric": True,
"type": "float",
},
}

elif isinstance(aobaseconfig, NVFP4InferenceConfig):

ct_config = {
"format": "nvfp4-pack-quantized",
"input_activations": {
"dynamic": "local",
"group_size": 16,
"num_bits": 4,
"observer": "minmax",
"observer_kwargs": {},
"strategy": "tensor_group",
"symmetric": True,
"type": "float",
},
"output_activations": None,
"targets": ["Linear"],
"weights": {
"dynamic": False,
"group_size": 16,
"num_bits": 4,
"observer": "minmax",
"observer_kwargs": {},
"strategy": "tensor_group",
"symmetric": True,
"type": "float"
},
}

else:
raise AssertionError(f"unsupported type {type(aobaseconfig)}")
return ct_config