-
Notifications
You must be signed in to change notification settings - Fork 38
[Torchax] Add initial support for loading mxfp4 #1080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,192 @@ | ||
| import tempfile | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import pytest | ||
| import torch | ||
| import torchax | ||
| import utils as test_utils | ||
| from jax.sharding import NamedSharding, PartitionSpec | ||
| from torchax.interop import torch_view | ||
| from torchax.ops.mappings import j2t, t2j | ||
| from vllm.config import set_current_vllm_config | ||
| from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, | ||
| init_distributed_environment) | ||
| from vllm.engine.arg_utils import EngineArgs | ||
| from vllm.forward_context import set_forward_context | ||
| from vllm.model_executor.layers.fused_moe.layer import FusedMoE | ||
|
|
||
| from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config | ||
| from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config, | ||
| VllmMxfp4MoEMethod) | ||
|
|
||
| P = PartitionSpec | ||
| MODELS = ["openai/gpt-oss-20b"] | ||
| MXFP4_BLOCK_SIZE = 32 | ||
|
|
||
|
|
||
| def quantize_to_mxfp4(weight: torch.tensor): | ||
| # Utilize JAX because native support for e2m1 makes it easier to work with. | ||
| weight = t2j(weight) | ||
| e2m1_finfo = jnp.finfo(jnp.float4_e2m1fn) | ||
| dtype_min = float(e2m1_finfo.min) | ||
| dtype_max = float(e2m1_finfo.max) | ||
|
|
||
| # Do a subchannel quantization where block size is 32. | ||
| weight_shape = weight.shape | ||
| weight_block = weight.reshape(weight_shape[:-1] + (-1, MXFP4_BLOCK_SIZE)) | ||
| abs_max = jnp.max(jnp.abs(weight_block), axis=-1, keepdims=True) | ||
| scale = abs_max / dtype_max | ||
|
|
||
| weight_q = jnp.clip(weight_block / scale, dtype_min, dtype_max) | ||
| weight_q = weight_q.astype(jnp.float4_e2m1fn).reshape(weight_shape[:-1] + | ||
| (-1, 2)) | ||
| weight_packed = jax.lax.bitcast_convert_type(weight_q, jnp.uint8) | ||
|
|
||
| # We convert scale into e8m0 manually because there is no hardware support. | ||
| e8m0_finfo = jnp.finfo(jnp.float8_e8m0fnu) | ||
| _, scale_exp = jnp.frexp(scale.squeeze(axis=-1)) | ||
| # Subtract by one sinced e8m0 has no decimal | ||
| scale_exp -= 1 | ||
| scale_exp = (scale_exp - e8m0_finfo.minexp).astype(jnp.uint8) | ||
|
|
||
| return j2t(weight_packed), j2t(scale_exp) | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def setup_environment(): | ||
| # This is a fake config used for init dist env. | ||
| # RowParallelLinear needs dist env to be initialized. | ||
| engine_args = EngineArgs( | ||
| model=MODELS[0], | ||
| max_model_len=64, | ||
| max_num_batched_tokens=64, | ||
| max_num_seqs=4, | ||
| load_format='dummy', | ||
| ) | ||
|
|
||
| vllm_config = engine_args.create_engine_config() | ||
|
|
||
| with set_current_vllm_config(vllm_config): | ||
| temp_file = tempfile.mkstemp()[1] | ||
| init_distributed_environment( | ||
| 1, | ||
| 0, | ||
| local_rank=0, | ||
| distributed_init_method=f"file://{temp_file}", | ||
| backend="gloo") | ||
| ensure_model_parallel_initialized(1, 1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", MODELS) | ||
| @pytest.mark.parametrize("mesh", [ | ||
| test_utils.get_spmd_mesh(1), | ||
| test_utils.get_spmd_mesh(jax.local_device_count()) | ||
| ]) | ||
| def test_quant_override(model, mesh): | ||
|
|
||
| engine_args = EngineArgs( | ||
| model=model, | ||
| max_model_len=64, | ||
| max_num_batched_tokens=64, | ||
| max_num_seqs=4, | ||
| load_format='dummy', | ||
| ) | ||
| vllm_config = engine_args.create_engine_config() | ||
| vllm_config.model_config.dtype = torch.bfloat16 | ||
|
|
||
| quant_config = get_tpu_quantization_config(vllm_config, mesh) | ||
| assert isinstance(quant_config, VllmMxfp4Config) | ||
| assert quant_config.vllm_config == vllm_config | ||
| assert quant_config.mesh == mesh | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("mesh", [ | ||
| test_utils.get_spmd_mesh(1), | ||
| test_utils.get_spmd_mesh(jax.local_device_count()) | ||
| ]) | ||
| @pytest.mark.parametrize("num_tokens", [8]) | ||
| @pytest.mark.parametrize("intermediate_size", [1024]) | ||
| @pytest.mark.parametrize("hidden_size", [128]) | ||
| @pytest.mark.parametrize("num_experts", [8]) | ||
| @pytest.mark.parametrize("topk", [2]) | ||
| def test_fused_moe_bias(mesh, num_tokens, intermediate_size, hidden_size, | ||
| num_experts, topk): | ||
| torch.manual_seed(42) | ||
| dtype = torch.bfloat16 | ||
|
|
||
| a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10 | ||
| w1 = torch.randn( | ||
| (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10 | ||
| w2 = torch.randn( | ||
| (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10 | ||
| w1_weight, w1_weight_scale = quantize_to_mxfp4(w1) | ||
| w2_weight, w2_weight_scale = quantize_to_mxfp4(w2) | ||
|
|
||
| print(f'kky {w1_weight.shape=} {w1_weight_scale.shape=}') | ||
|
|
||
| w1_bias = torch.randn( | ||
| (num_experts, 2 * intermediate_size), dtype=dtype) / 10 | ||
| w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10 | ||
| score = torch.randn((num_tokens, num_experts), dtype=dtype) | ||
|
|
||
| engine_args = EngineArgs( | ||
| model=MODELS[0], | ||
| max_model_len=64, | ||
| max_num_batched_tokens=64, | ||
| max_num_seqs=4, | ||
| load_format='dummy', | ||
| ) | ||
| vllm_config = engine_args.create_engine_config() | ||
| vllm_config.model_config.dtype = dtype | ||
|
|
||
| quant_config = get_tpu_quantization_config(vllm_config, mesh) | ||
| with set_current_vllm_config(vllm_config): | ||
| vllm_fused_moe = FusedMoE( | ||
| num_experts=num_experts, | ||
| top_k=topk, | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| reduce_results=False, | ||
| renormalize=False, | ||
| tp_size=1, | ||
| dp_size=1, | ||
| quant_config=quant_config, | ||
| has_bias=True, | ||
| ) | ||
| vllm_fused_moe.w13_weight.data = w1_weight | ||
| vllm_fused_moe.w2_weight.data = w2_weight | ||
| vllm_fused_moe.w13_weight_scale.data = w1_weight_scale | ||
| vllm_fused_moe.w2_weight_scale.data = w2_weight_scale | ||
| vllm_fused_moe.w13_bias.data = w1_bias | ||
| vllm_fused_moe.w2_bias.data = w2_bias | ||
|
|
||
| with torchax.default_env(), set_forward_context(None, vllm_config): | ||
| assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod) | ||
|
|
||
| jax_a = a.to('jax') | ||
| jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) | ||
| score = torch_view(t2j(score)) | ||
| score.apply_jax_(jax.device_put, NamedSharding(mesh, P(None, None))) | ||
|
|
||
| vllm_fused_moe.quant_method.process_weights_after_loading( | ||
| vllm_fused_moe) | ||
|
|
||
| # Because we are dequantizing mxfp4 weights for now, we verify if | ||
| # dequantized weights matches with the original weights. | ||
| # Due to NaN, comparing two values are difficult. Therefore, we utilize | ||
| # nanmean instead. | ||
| torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w13_weight), | ||
| torch.nanmean(w1), | ||
| check_device=False, | ||
| equal_nan=True, | ||
| rtol=0.2, | ||
| atol=0.1) | ||
| torch.testing.assert_close(torch.nanmean(vllm_fused_moe.w2_weight), | ||
| torch.nanmean(w2), | ||
| check_device=False, | ||
| equal_nan=True, | ||
| rtol=0.2, | ||
| atol=0.1) | ||
|
|
||
| vllm_fused_moe(jax_a, score) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.