Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions tests/layers/vllm/test_mxfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import tempfile

import jax
import jax.numpy as jnp
import pytest
import torch
import torchax
import utils as test_utils
from jax.sharding import NamedSharding, PartitionSpec
from torchax.interop import torch_view
from torchax.ops.mappings import j2t, t2j
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE

from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config,
VllmMxfp4MoEMethod)

P = PartitionSpec
MODELS = ["openai/gpt-oss-20b"]
MXFP4_BLOCK_SIZE = 32


def quantize_to_mxfp4(weight: torch.tensor):
# Utilize JAX because native support for e2m1 makes it easier to work with.
weight = t2j(weight)
e2m1_finfo = jnp.finfo(jnp.float4_e2m1fn)
dtype_min = float(e2m1_finfo.min)
dtype_max = float(e2m1_finfo.max)

# Do a subchannel quantization where block size is 32.
weight_shape = weight.shape
weight_block = weight.reshape(weight_shape[:-1] + (-1, MXFP4_BLOCK_SIZE))
abs_max = jnp.max(jnp.abs(weight_block), axis=-1, keepdims=True)
scale = abs_max / dtype_max

weight_q = jnp.clip(weight_block / scale, dtype_min, dtype_max)
weight_q = weight_q.astype(jnp.float4_e2m1fn).reshape(weight_shape[:-1] +
(-1, 2))
weight_packed = jax.lax.bitcast_convert_type(weight_q, jnp.uint8)

# We convert scale into e8m0 manually because there is no hardware support.
e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu)
_, scale_exp = jnp.frexp(scale.squeeze(axis=-1))
# Subtract by one sinced e8m0 has no decimal
scale_exp -= 1
scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8)

return j2t(weight_packed), j2t(scale_exp)


@pytest.fixture(autouse=True)
def setup_environment():
# This is a fake config used for init dist env.
# RowParallelLinear needs dist env to be initialized.
engine_args = EngineArgs(
model=MODELS[0],
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
load_format='dummy',
)

vllm_config = engine_args.create_engine_config()

with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
1,
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
ensure_model_parallel_initialized(1, 1)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("mesh", [
test_utils.get_spmd_mesh(1),
test_utils.get_spmd_mesh(jax.local_device_count())
])
def test_quant_override(model, mesh):

engine_args = EngineArgs(
model=model,
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
load_format='dummy',
)
vllm_config = engine_args.create_engine_config()
vllm_config.model_config.dtype = torch.bfloat16

quant_config = get_tpu_quantization_config(vllm_config, mesh)
assert isinstance(quant_config, VllmMxfp4Config)
assert quant_config.vllm_config == vllm_config
assert quant_config.mesh == mesh


@pytest.mark.parametrize("mesh", [
test_utils.get_spmd_mesh(1),
test_utils.get_spmd_mesh(jax.local_device_count())
])
@pytest.mark.parametrize("num_tokens", [8])
@pytest.mark.parametrize("intermediate_size", [1024])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("topk", [2])
def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size,
num_experts, topk):
torch.manual_seed(42)
dtype = torch.bfloat16

a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
w1 = torch.randn(
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
w2 = torch.randn(
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
w1_weight, w1_weight_scale = quantize_to_mxfp4(w1)
w2_weight, w2_weight_scale = quantize_to_mxfp4(w2)

print(f'kky {w1_weight.shape=} {w1_weight_scale.shape=}')

w1_bias = torch.randn(
(num_experts, 2 * intermediate_size), dtype=dtype) / 10
w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
score = torch.randn((num_tokens, num_experts), dtype=dtype)

engine_args = EngineArgs(
model=MODELS[0],
max_model_len=64,
max_num_batched_tokens=64,
max_num_seqs=4,
load_format='dummy',
)
vllm_config = engine_args.create_engine_config()
vllm_config.model_config.dtype = dtype

quant_config = get_tpu_quantization_config(vllm_config, mesh)
with set_current_vllm_config(vllm_config):
vllm_fused_moe = FusedMoE(
num_experts=num_experts,
top_k=topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=False,
renormalize=False,
tp_size=1,
dp_size=1,
quant_config=quant_config,
has_bias=True,
)
vllm_fused_moe.w13_weight.data = w1_weight
vllm_fused_moe.w2_weight.data = w2_weight
vllm_fused_moe.w13_weight_scale.data = w1_weight_scale
vllm_fused_moe.w2_weight_scale.data = w2_weight_scale
vllm_fused_moe.w13_bias.data = w1_bias
vllm_fused_moe.w2_bias.data = w2_bias

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)

jax_a = a.to('jax')
jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))
score = torch_view(t2j(score))
score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None)))

vllm_fused_moe.quant_method.process_weights_after_loading(
vllm_fused_moe)

# Because we are dequantizing mxfp4 weights for now, we verify if
# dequantized weights matches with the original weights.
# Due to NaN, comparing two values are difficult. Therefore, we utilize
# nanmean instead.
torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w13_weight),
torch.nanmean(w1),
check_device=False,
equal_nan=True,
rtol=0.2,
atol=0.1)
torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w2_weight),
torch.nanmean(w2),
check_device=False,
equal_nan=True,
rtol=0.2,
atol=0.1)

vllm_fused_moe(jax_a, score)
5 changes: 4 additions & 1 deletion tpu_inference/layers/vllm/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
VllmCompressedTensorsConfig # noqa: E501
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
from tpu_inference.layers.vllm.quantization.unquantized import \
VllmUnquantizedConfig

Expand All @@ -21,6 +22,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
None: VllmUnquantizedConfig,
"compressed-tensors": VllmCompressedTensorsConfig,
"awq": VllmAWQConfig,
"mxfp4": VllmMxfp4Config,
}
if model_config.quantization not in method_to_config:
raise NotImplementedError(
Expand All @@ -30,6 +32,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
assert issubclass(quant_config, JaxCommonConfig)
quant_config.set_configs(vllm_config, mesh)

model_config.quantization = quant_config.get_name()
# TODO(kyuyeunk): Create more programmatic way to handle this.
model_config.quantization = "tpu-" + quant_config.get_name()
return VllmConfig.get_quantization_config(model_config,
vllm_config.load_config)
6 changes: 1 addition & 5 deletions tpu_inference/layers/vllm/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,9 @@
logger = init_logger(__name__)


@register_quantization_config("jax-awq")
@register_quantization_config("tpu-awq")
class VllmAWQConfig(AWQConfig, JaxCommonConfig):

@classmethod
def get_name(cls) -> str:
return "jax-awq"

def get_supported_act_dtypes(self) -> list[torch.dtype]:
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
# bfloat16 is signifcantly preferred over foat16. This might lead to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,9 @@
logger = init_logger(__name__)


@register_quantization_config("jax-compressed-tensors")
@register_quantization_config("tpu-compressed-tensors")
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):

@classmethod
def get_name(cls) -> str:
return "jax-compressed-tensors"

def get_scheme(self,
layer: torch.nn.Module,
layer_name: Optional[str] = None
Expand Down
Loading