Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bfab8cc
Save an initial version. Not capable of running at all.
vanbasten23 Oct 27, 2025
4921386
now the test can run to completion.
vanbasten23 Oct 29, 2025
8044fb7
ok. The case without lora passed.
vanbasten23 Oct 29, 2025
11c7ea2
ok, the test passed. Need to make it simpler next.
vanbasten23 Oct 29, 2025
2eae008
also check if the correct and the sharding is correct.
vanbasten23 Oct 29, 2025
2d87418
cleaned up
vanbasten23 Oct 29, 2025
c480658
ok, fixed the torchax.view.item() issue.
vanbasten23 Oct 30, 2025
cc47bab
add multi-chip test case
vanbasten23 Oct 30, 2025
c2082ff
fix the format
vanbasten23 Oct 30, 2025
d409a35
clean up
vanbasten23 Oct 31, 2025
7ea2939
Add lora unit tests to the CI
vanbasten23 Nov 5, 2025
ceaf1b2
refactored. The test still passed.
vanbasten23 Nov 5, 2025
2af4dae
added test for MergedQKVParallelLinearWithLoRA
vanbasten23 Nov 6, 2025
540e9db
Added the test for QKVParallelLinearWithLoRA
vanbasten23 Nov 6, 2025
26794f2
The ColumnParallelLinear test passed
vanbasten23 Nov 6, 2025
2be10fd
Finally fix the test.
vanbasten23 Nov 6, 2025
d0fca2b
merge conflicts
vanbasten23 Nov 10, 2025
83dd99b
fixed the test
vanbasten23 Nov 10, 2025
01ddba7
Start using bf16
vanbasten23 Nov 10, 2025
7560758
merged with main
vanbasten23 Nov 10, 2025
edb82e4
fix merge conflict
vanbasten23 Nov 11, 2025
735e1e4
also added the test for ReplicatedLinearWithLoRA
vanbasten23 Nov 11, 2025
1dbebe4
Added some comments
vanbasten23 Nov 12, 2025
e30dfc7
Merge branch 'main' into xiowei/add_qkv_parallel_linear_unit_tests
vanbasten23 Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 259 additions & 35 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@
from vllm.config import LoRAConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA)
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LoRAMapping, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA)
# yapf: enable
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform

Expand Down Expand Up @@ -199,7 +207,7 @@ def create_random_inputs(

@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 4, 9])
@pytest.mark.parametrize("repeats", [2])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("stage", [True, False])
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
set_random_seed(6)
Expand All @@ -210,7 +218,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
max_loras=max_loras,
max_lora_rank=max_lora_rank,
fully_sharded_loras=False,
lora_dtype=torch.float16,
lora_dtype=torch.bfloat16,
)
vllm_config = dist_init
vllm_config.lora_config = lora_config
Expand All @@ -220,6 +228,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
repeats, vllm_config, mesh)
_verify_lora_linear_layer(linear, lora_linear)

# After we create the lora_config, the linear layer and the lora layer,
# here are the steps to do next:
# - create a punica wrapper.
# - associate the punica wrapper with the lora layer.
# - populate the lora matrices in the lora layer: use non-zero values for testing lora and zero values for testing the case where the layer doesn't have lora.
# - create inputs and lora_mapping.
# - update the metadata of the punica wrapper.
# - convert the inputs to be torchax tensors.
# - then run a forward on the lora layer to get the actual output.
# - then run a reference implementation as the expected output.

# Create a punica wrapper and associate it with the lora linear layer.
max_num_batched_tokens = 8192
max_batches = 256
Expand Down Expand Up @@ -250,7 +269,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.float16,
input_type=torch.bfloat16,
device='cpu')

_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
Expand Down Expand Up @@ -297,7 +316,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.float16,
input_type=torch.bfloat16,
device='cpu')

_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
Expand All @@ -318,6 +337,173 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 4, 9])
@pytest.mark.parametrize("layer_type", ["row", "column", "replicated"])
@pytest.mark.parametrize("stage", [True, False])
def test_linear_parallel(dist_init, num_loras, layer_type, stage) -> None:
set_random_seed(6)

max_loras = 9
max_lora_rank = 8
lora_config = LoRAConfig(
max_loras=max_loras,
max_lora_rank=max_lora_rank,
fully_sharded_loras=False,
lora_dtype=torch.bfloat16,
)
vllm_config = dist_init
vllm_config.lora_config = lora_config

mesh = _create_mesh()
linear, lora_linear = _create_random_linear_parallel_layer(
layer_type, vllm_config, mesh)
_verify_lora_linear_layer(linear, lora_linear)

max_num_batched_tokens = 8192
max_batches = 256
with torchax.default_env():
punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
max_batches,
'jax',
max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_linear.set_mapping(punica_wrapper)

# Populate lora matrices (lora_a and lora_b) in the lora layer.
index_to_id = get_random_index_to_id(num_loras, max_loras)
# lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
lora_dict, sublora_dict = populate_loras(
index_to_id,
lora_layer=lora_linear,
baselayer_weights=linear.weight,
)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.bfloat16,
device='cpu')

_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
prompt_mapping, stage, index_to_id,
lora_config)

with torchax.default_env():
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
actual_result = lora_linear(torchax_inputs)[0]

expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
result = linear(input_)[0]
lora = lora_dict[lora_id]
lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
result += lora_result
expected_results.append(result)
expected_result = torch.cat(expected_results)

rtol, atol = TOLERANCES[actual_result.dtype]
with torchax.default_env():
actual_result_cpu = actual_result.to('cpu')
torch.testing.assert_close(actual_result_cpu,
expected_result,
rtol=rtol,
atol=atol)

# Check that resetting the lora weights succeeds
# Here we set all lora weight to be empty.
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0], # different from the above create_random_inputs
num_inputs=32,
input_size=(1, 64),
input_range=(0, 1),
input_type=torch.bfloat16,
device='cpu')
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
prompt_mapping, stage, index_to_id,
lora_config)

with torchax.default_env():
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
actual_result = lora_linear(torchax_inputs)[0]
expected_result = linear(torch.cat(inputs))[0]

rtol, atol = TOLERANCES[actual_result.dtype]
with torchax.default_env():
actual_result_cpu = actual_result.to('cpu')
torch.testing.assert_close(actual_result_cpu,
expected_result,
rtol=rtol,
atol=atol)


def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
# We first create a base linear layer, then a lora layer to wrap it.
if layer_type == "row":

def _create_row_linear():
return RowParallelLinear(
64, # input_size
64, # output_size
bias=False,
params_dtype=torch.bfloat16)

linear = _create_row_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_row_linear()
lora_linear = _create_lora_wrapper(linear,
base_linear,
RowParallelLinearWithLoRA,
vllm_config=vllm_config,
mesh=mesh)
elif layer_type == "column":

def _create_column_linear():
return ColumnParallelLinear(64,
64,
bias=False,
params_dtype=torch.bfloat16)

linear = _create_column_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_column_linear()
lora_linear = _create_lora_wrapper(linear,
base_linear,
ColumnParallelLinearWithLoRA,
vllm_config=vllm_config,
mesh=mesh)

elif layer_type == "replicated":

def _create_replicated_linear():
return ReplicatedLinear(64,
64,
bias=False,
params_dtype=torch.bfloat16)

linear = _create_replicated_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_replicated_linear()
lora_linear = _create_lora_wrapper(linear,
base_linear,
ReplicatedLinearWithLoRA,
vllm_config=vllm_config,
mesh=mesh)

else:
raise NotImplementedError("Unknown layer type: {}".format(layer_type))

return linear, lora_linear


def _create_mesh():
axis_names = ("data", "model")
devices = jax.devices()
Expand Down Expand Up @@ -374,37 +560,75 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
# We first create a base linear layer, then a lora layer to wrap it.
if repeats == 2:
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
linear = MergedColumnParallelLinear(
64, # input_size
[64] * repeats, # output_size
bias=False,
params_dtype=torch.float16)
def _create_merged_column_linear():
return MergedColumnParallelLinear(
64, # input_size
[64] * repeats, # output_size
bias=False,
params_dtype=torch.bfloat16)

linear = _create_merged_column_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = MergedColumnParallelLinear(
64, # input_size
[64] * repeats, # output_size
bias=False,
params_dtype=torch.float16)
base_linear.weight.data = linear.weight.data
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
linear_method = VllmUnquantizedLinearMethod(jax_config)
base_linear.quant_method = linear_method
linear_method.process_weights_after_loading(
base_linear
) # here base_linear.weight is moved to TPU and sharded.
assert jax_view(base_linear.weight).platform(
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
assert not isinstance(
jax_view(
base_linear.weight).sharding, jax.sharding.SingleDeviceSharding
), 'base_linear.weight should have been sharded.'

lora_linear = MergedColumnParallelLinearWithLoRA(base_linear)
base_linear = _create_merged_column_linear()
lora_linear = _create_lora_wrapper(linear, base_linear,
MergedColumnParallelLinearWithLoRA,
vllm_config, mesh, repeats)
elif repeats == 3:
raise NotImplementedError("NYI: for MergedQKVParallelLinear case")

def _create_qkv_linear():
return QKVParallelLinear(64,
64,
32,
bias=False,
params_dtype=torch.bfloat16)

linear = _create_qkv_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_qkv_linear()
lora_linear = _create_lora_wrapper(linear, base_linear,
MergedQKVParallelLinearWithLoRA,
vllm_config, mesh, repeats)
else:
raise NotImplementedError("NYI: for QKVParallelLinear case")

def _create_qkv_linear():
return QKVParallelLinear(64,
64,
32,
bias=False,
params_dtype=torch.bfloat16)

linear = _create_qkv_linear()
linear.weight.data = torch.rand_like(linear.weight.data)

base_linear = _create_qkv_linear()
lora_linear = _create_lora_wrapper(linear, base_linear,
QKVParallelLinearWithLoRA,
vllm_config, mesh, repeats)

return linear, lora_linear


def _create_lora_wrapper(linear,
base_linear,
lora_cls,
vllm_config,
mesh,
repeats=1):
base_linear.weight.data = linear.weight.data
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
linear_method = VllmUnquantizedLinearMethod(jax_config)
base_linear.quant_method = linear_method
linear_method.process_weights_after_loading(
base_linear) # here base_linear.weight is moved to TPU and sharded.
assert jax_view(base_linear.weight).platform(
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
assert not isinstance(
jax_view(base_linear.weight).sharding, jax.sharding.
SingleDeviceSharding), 'base_linear.weight should have been sharded.'

lora_linear = lora_cls(base_linear)

lora_config = vllm_config.lora_config
max_loras = lora_config.max_loras
Expand All @@ -427,4 +651,4 @@ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
lora_linear.lora_b_stacked) == n_slices)

return linear, lora_linear
return lora_linear
Loading