diff --git a/tests/lora/__init__.py b/tests/lora/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 000000000..d573070de --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,32 @@ +import tempfile + +import pytest +from vllm.config import set_current_vllm_config +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.engine.arg_utils import EngineArgs + + +@pytest.fixture +def dist_init(): + engine_args = EngineArgs( + model="Qwen/Qwen2-1.5B-Instruct", + max_model_len=64, + max_num_batched_tokens=64, + max_num_seqs=4, + ) + + vllm_config = engine_args.create_engine_config() + + with set_current_vllm_config(vllm_config): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + 1, + 0, + local_rank=0, + distributed_init_method=f"file://{temp_file}", + backend="gloo") + ensure_model_parallel_initialized(1, 1) + yield vllm_config + cleanup_dist_env_and_memory(shutdown_ray=True) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 000000000..98e984899 --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,400 @@ +import random +from typing import Optional + +import jax +import pytest +import torch +import torchax +from jax.sharding import NamedSharding, PartitionSpec +from torchax.interop import torch_view +from torchax.ops.mappings import j2t, t2j +from vllm.config import LoRAConfig +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import LoRAMapping, MergedColumnParallelLinearWithLoRA +# yapf: enable +from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.model_executor.layers.linear import MergedColumnParallelLinear +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform + +from tpu_commons.lora.layers import (TorchaxBaseLayerWithLoRA, + TorchaxMergedColumnParallelLinearWithLoRA) +from tpu_commons.models.vllm.sharding import shard_parallel_layers_to_tpu + +from .utils import DummyLoRAManager + +# TODO(xiowei): +# - add test for multi-chip. +# - add equivalent test for ColumnParallelLinearWithShardedLoRA. + +P = PartitionSpec + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + +pytestmark = pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test is only for TPU platform.") + +# prefill stage(True) or decode stage(False) +STAGES = [True, False] + + +def check_punica_wrapper(punica_wrapper) -> bool: + from tpu_commons.lora.torch_punica_tpu import PunicaWrapperTPU + return type(punica_wrapper) is PunicaWrapperTPU + + +def get_random_index_to_id(num_loras: int, + num_slots: int, + log: bool = True) -> list[Optional[int]]: + """Creates a random index_to_lora_id mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: list[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + index_to_id: list[Optional[int]], + layer: TorchaxBaseLayerWithLoRA, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, + bias_enabled: bool = False, +) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: + """This method populates the lora layers (TorchaxBaseLayerWithLoRA) with lora weights. + + Args: + index_to_id: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. + sublora_dict: dict[int, list[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(index_to_id): + if lora_id is not None: + subloras: list[LoRALayerWeights] = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager( + layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + if bias_enabled: + sublora.bias = sublora.bias[(sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env. + with torchax.default_env(), jax.default_device( + jax.devices("tpu")[0]): + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + lora_bias=lora.bias if bias_enabled else None, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: list[int], + num_inputs: int, + input_size: tuple[int, ...], + input_range: tuple[float, float], + input_type: torch.dtype = torch.int, + device: torch.device = "cpu", +) -> tuple[list[torch.Tensor], list[int], list[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. Or the number of requests. + input_size: the size of each individual input. Or the number of tokens. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + + returns: + inputs: a list of torch tensors of size num_inputs. Each input has shape `input_size`. + index_mapping: maps each input token to a lora ID. + prompt_mapping: maps each request to a lora ID. + """ + + low, high = input_range + + inputs: list[torch.Tensor] = [] + index_mapping: list[int] = [] + prompt_mapping: list[int] = [] + + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 9]) +@pytest.mark.parametrize("repeats", [2]) +@pytest.mark.parametrize("fully_shard", [False]) # TODO(xiowei): add "True". +@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("stage", [True, False]) +@pytest.mark.parametrize("bias_enabled", [True, False]) +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device, stage, bias_enabled) -> None: + max_loras = 9 + max_lora_rank = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=max_lora_rank, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + bias_enabled=bias_enabled) + + axis_names = ("data", "model") + mesh_shape = ( + 1, 1 + ) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices())) + mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices()) + + def create_column_parallel_packed_layer(): + # Step 1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper. + if repeats == 2: + linear = MergedColumnParallelLinear( + 256, + [256] * repeats, # input_size, output_size + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = MergedColumnParallelLinearWithLoRA( + linear + ) # TODO(xiowei): add test for MergedColumnParallelLinearWithShardedLoRA (fully_shard == True) + elif repeats == 3: + # TODO(xiowei): add test for this case. + raise NotImplementedError("NYI: for MergedQKVParallelLinear case") + else: + # TODO(xiowei): add test for this case. + raise NotImplementedError("NYI: for QKVParallelLinear case") + + n_slices = repeats + # create_lora_weights creates global shape weight. + lora_linear.create_lora_weights(max_loras, lora_config) + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == n_slices) + if bias_enabled: + assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices + else: + assert lora_linear.lora_bias_stacked is None + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + # Then we replace the base layer (e.g. MergedColumnParallelLinear) with the torchax one (e.g. JaxMergedColumnParallelLinear). + vllm_config = dist_init + shard_parallel_layers_to_tpu(lora_linear, mesh, vllm_config) + + # replace the LoRA wrapper with our own wrapper (e.g. TorchaxMergedColumnParallelLinearWithLoRA) + torchax_lora_linear = TorchaxMergedColumnParallelLinearWithLoRA( + lora_linear, mesh) + + return linear, torchax_lora_linear + + set_random_seed(6) + + linear, torchax_lora_linear = create_column_parallel_packed_layer() + # linear.weight has type torch.nn.Parameter, lora_linear.weight has type torchax.tensor.Tensor + # BaseLinearLayerWithLoRA.weight property guarantees this. + with torchax.default_env(): + assert torch.equal(linear.weight.data, j2t(torchax_lora_linear.weight)) + + max_num_batched_tokens = 8192 + max_batches = 256 + punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches, + device, + max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) + torchax_lora_linear.set_mapping(punica_wrapper) + + # load the lora weight, shard it, and send it to TPU. + # create a lora slot index to lora id mapping. + index_to_id = get_random_index_to_id(num_loras, max_loras) + # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights + lora_dict, sublora_dict = populate_loras( + index_to_id, + layer=torchax_lora_linear, + layer_weights=linear.weight, + repeats=repeats, + bias_enabled=bias_enabled, + ) + + # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 256]. + # index_mapping: list[int] + # prompt_mapping: list[int] + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32, + input_size=(1, 256), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + punica_wrapper.move_to_device(mesh) + + jax_inputs = [] + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + for input in inputs: + # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'` + # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'` + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = torchax_lora_linear(torch.cat(jax_inputs))[0] + lora_result = j2t(lora_result) + + expected_results: list[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + # linear(input_) returns (output, output_bias) so we only need the first one. + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * + (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * + sublora.scaling) + if bias_enabled: + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * + (i + 1)] += sublora.bias + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' + ) + + # Check that resetting the lora weights succeeds + # Here we set all lora weight to be empty. + for slot_idx in range(max_loras): + torchax_lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], # different from the above create_random_inputs + num_inputs=32, + input_size=(1, 256), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + punica_wrapper.update_metadata( + lora_mapping, + index_to_id, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + punica_wrapper.move_to_device(mesh) + + jax_inputs = [] + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + for input in inputs: + jax_input = torch_view(t2j(input)) + jax_input.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None, None))) + jax_inputs.append(jax_input) + with torchax.default_env(): + lora_result = torchax_lora_linear(torch.cat(jax_inputs))[0] + lora_result = j2t(lora_result) + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + print( + f'Output max diff: {torch.max(torch.abs(expected_result - lora_result))}' + ) + print( + f'Output mean diff: {torch.mean(torch.abs(expected_result - lora_result))}' + ) diff --git a/tests/lora/test_torch_lora_ops.py b/tests/lora/test_torch_lora_ops.py new file mode 100644 index 000000000..350079d9d --- /dev/null +++ b/tests/lora/test_torch_lora_ops.py @@ -0,0 +1,32 @@ +import jax +import torch +import torchax + +from tpu_commons.lora.torch_lora_ops import bgmv_torch + + +def test_bgmv_torch(): + num_tokens = 16 + hidden_size = 128 + max_loras = 9 + max_lora_rank = 8 + + with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): + inputs = torch.rand(num_tokens, hidden_size, device='jax') + loras = torch.rand(max_loras, + 1, + max_lora_rank, + hidden_size, + device='jax') + idxs = torch.randint(0, max_loras, (num_tokens, ), device='jax') + + actual = bgmv_torch(inputs, loras, idxs) + expected = _ref_bgmv_torch(inputs, loras, idxs) + torch.testing.assert_close(actual, expected, atol=3e-2, rtol=1e-3) + + +def _ref_bgmv_torch(inputs, loras, idxs): + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + selected_loras = loras[idxs] + return torch.einsum('td,tld->tl', inputs, selected_loras) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 000000000..74e0feb5b --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights + + +# https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py +class DummyLoRAManager: + + def __init__(self, device: torch.device = "cuda:0"): + super().__init__() + self._loras: dict[str, LoRALayerWeights] = {} + self._device = device + + def set_module_lora(self, module_name: str, lora: LoRALayerWeights): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> LoRALayerWeights: + return self._loras[module_name] + + def init_random_lora( + self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0, + ): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([weight.shape[1], rank], + dtype=weight.dtype, + device=self._device), + lora_b=torch.rand([rank, weight.shape[0]], + dtype=weight.dtype, + device=self._device), + bias=torch.rand([weight.shape[0]], + dtype=weight.dtype, + device=self._device), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand( + 5, + generate_embeddings_tensor, + dtype=weight.dtype, + device=self._device, + ) + self.set_module_lora(module_name, lora) + + return lora + + def init_lora( + self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None, + ): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([input_dim, rank], device="cuda"), + lora_b=torch.rand([rank, output_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: list[int], + noop_lora_index: Optional[list[int]] = None, + rank: int = 8, + ): + base_loras: list[LoRALayerWeights] = [] + noop_lora_index_set = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index_set, + ) + base_loras.append(base_lora) + packed_lora = PackedLoRALayerWeights.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora diff --git a/tpu_commons/lora/__init__.py b/tpu_commons/lora/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpu_commons/lora/layers.py b/tpu_commons/lora/layers.py new file mode 100644 index 000000000..b10be77f6 --- /dev/null +++ b/tpu_commons/lora/layers.py @@ -0,0 +1,233 @@ +from typing import TYPE_CHECKING, Optional, Union, cast + +import jax +import torch +import torch.nn as nn +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from torchax.interop import torch_view +from torchax.ops.mappings import t2j +from transformers import PretrainedConfig +from vllm.config import LoRAConfig +# yapf: enable +from vllm.lora.layers import BaseLayerWithLoRA +# yapf: disable +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + +P = PartitionSpec + +""" +Step1: create a base layer (e.g. MergedColumnParallelLinear) and a vLLM LoRA wrapper (via load_lora_model()) +Here, we add a LoRA wrapper so the model becomes: +LinearWithLoRA { + base_layer: Linear +} +The lora weight and linear weight should have been initialized. +Step2: shard_parallel_layers_to_tpu() +Here we replace the linear layer with Torchax layer +LinearWithLora { + base_layer: JaxLinear +} +then we load the linear weight, shard it, but keep it on CPU. +Step3: move the linear weight to TPU. +Step4: write a function to replace the LoRA wrapper with our own wrapper JaxLinearWithLoRA: +JaxLinearWithLoRA { + base_layer: JaxLinear +} +Step5: load the lora weight, shard it, and send it to TPU. +""" + +class TorchaxBaseLayerWithLoRA(nn.Module): + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError + + # create_lora_weights is not needed because we get the initalized weight from base lora layer. + +class TorchaxBaseLinearLayerWithLoRA(TorchaxBaseLayerWithLoRA): + + def __init__(self, base_lora_layer: BaseLayerWithLoRA, mesh: Mesh): + super().__init__() + self.base_lora_layer = base_lora_layer + self.base_layer = base_lora_layer.base_layer + self.mesh = mesh + + # self.lora_a_stacked, self.lora_b_stacked, self.lora_bias_stacked, self.lora_config are initialized in original LoRA wrapper's create_lora_weight(). + self.lora_config = base_lora_layer.lora_config + self.lora_a_stacked = tuple(torch_view(t2j(lora_a)) for lora_a in base_lora_layer.lora_a_stacked) + self.lora_b_stacked = tuple(torch_view(t2j(lora_b)) for lora_b in base_lora_layer.lora_b_stacked) + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None + if self.lora_config.bias_enabled: + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = tuple(torch_view(t2j(lora_bias)) for lora_bias in base_lora_layer.lora_bias_stacked) + + self.output_slices: tuple[int, ...] + self.n_slices: int + + def reset_lora(self, index: int): + # lora_a_stacked: tuple(torch.Tensor: [max_loras, 1, num_out_features, num_in_features]) + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # self.base_layer (JaxMergedColumnParallelLinear) returns (output, output_bias), we only need the first one. + # x: [bs, in_features] + output = self.base_layer(x)[0] + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + + return output + + @property + def weight(self) -> torch.Tensor: + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None + + +class TorchaxMergedColumnParallelLinearWithLoRA(TorchaxBaseLinearLayerWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_lora_layer: BaseLayerWithLoRA, mesh: Mesh) -> None: + # TODO(xiowei): add mesh to the __init__. + super().__init__(base_lora_layer, mesh) + output_sizes = self.base_layer.output_sizes + self.output_slices = output_sizes + self.n_slices = len(output_sizes) + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + # print(f"{self.base_layer.gather_output=}, {self.base_layer.return_bias=}") print False and False + if self.base_layer.gather_output: + # All-gather across the partitions. + # output = tensor_model_parallel_all_gather(output_parallel) + raise NotImplementedError("NYI: TorchaxMergedColumnParallelLinearWithLoRA.forward when self.base_layer.gather_output is true.") + else: + output = output_parallel + + if not self.base_layer.return_bias: + return output + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, # tuple of (max_loras, 1, max_lora_rank, in_features) + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + for i in range(self.n_slices): + self.lora_a_stacked[i].apply_jax_(jax.device_put, NamedSharding(self.mesh, P())) + self.lora_b_stacked[i].apply_jax_(jax.device_put, NamedSharding(self.mesh, P())) + + lora_a_i = torch_view(t2j(lora_a[i])).apply_jax_(jax.device_put, NamedSharding(self.mesh, P())) + if lora_a_i is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + + lora_b_i = torch_view(t2j(lora_b[i])).apply_jax_(jax.device_put, NamedSharding(self.mesh, P())) + if lora_b_i is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + + if not self.lora_config.bias_enabled: + assert lora_bias is None, "lora_bias is not None but the lora bias is disabled." + if lora_bias is not None: + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + self.lora_bias_stacked[i].apply_jax_(jax.device_put, NamedSharding(self.mesh, P())) + lora_bias_i = torch_view(t2j(lora_bias[i])).apply_jax_(jax.device_put, NamedSharding(self.mesh, P())) + if lora_bias_i is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) diff --git a/tpu_commons/lora/torch_lora_ops.py b/tpu_commons/lora/torch_lora_ops.py new file mode 100644 index 000000000..23a9e541a --- /dev/null +++ b/tpu_commons/lora/torch_lora_ops.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import jax +import jax.numpy as jnp +import torch +import torch.nn.functional as F +from torchax.interop import call_jax + + +@jax.jit +def bgmv_jax( + inputs, # [num_tokens, hidden_size] + loras, # [num_loras, lora_rank, hidden_size] + idxs, # [num_tokens] +): + return jnp.einsum( + "td,tX,Xld->tl", + inputs, + jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype), + loras, + ) + + +def bgmv_torch( + inputs, # [num_tokens, hidden_size] + loras, # [num_loras, 1, lora_rank, hidden_size] + idxs, # [num_tokens] +): # [num_tokens, lora_rank] + # TODO(xiowei): use the below one_hot impl (added in https://github.com/pytorch/xla/pull/9523) after we upgrade torchax version. + # if len(loras.shape) == 4: + # loras = loras.squeeze(axis=1) + # return torch.einsum( + # "td,tX,Xld->tl", + # inputs, + # torch.nn.functional.one_hot(idxs.long(), loras.shape[0]), + # loras, + # ) # [num_tokens, lora_rank] + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + return call_jax(bgmv_jax, inputs, loras, idxs) + + +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + lora_b_weights (torch.Tensor): LoRA weights of shape + [max_loras, 1, max_lora_rank, hidden_size]. + output_tensor (torch.Tensor): (Unused) output tensor (placeholder). + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + scaling (float, optional): Scalar multiplier applied to the output. + """ + return scaling * bgmv_torch(inputs, lora_b_weights, lora_indices_tensor) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, lora_rank]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, 1, out_features, lora_rank]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, out_features * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + outputs = bgmv_torch(inputs, lora_b_weights, lora_indices_tensor) + + outputs = F.pad( + outputs, + ( + slice_offset, + output_tensor.shape[1] - (slice_offset + slice_size), + 0, + 0, + ), + ) + + if add_inputs: + return output_tensor + outputs + else: + return outputs diff --git a/tpu_commons/lora/torch_punica_tpu.py b/tpu_commons/lora/torch_punica_tpu.py new file mode 100644 index 000000000..2b0f0a7cb --- /dev/null +++ b/tpu_commons/lora/torch_punica_tpu.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Optional, Union + +import jax +import torch +import torch.nn.functional as F +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from vllm.lora.punica_wrapper.utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase + +from tpu_commons.lora.torch_lora_ops import bgmv_expand_slice, bgmv_shrink + + +class PunicaWrapperTPU(PunicaWrapperBase): + """ + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which + # isn't supported by the TPU. So convert those tensors to int32. + # Not all of them are used by the TPU so only convert the useful ones. + self._token_lora_indices = self._token_lora_indices.to( + dtype=torch.int32) # map from token to LoRA index. + self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) + self._sampler_indices_padded = self._sampler_indices_padded.to( + dtype=torch.int32) + + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + raise NotImplementedError( + "NYI: torch_punica_tpu.PunicaWrapperTPU.embeddings_indices.") + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + raise NotImplementedError( + "NYI: torch_punica_tpu.PunicaWrapperTPU.sampler_indices_padded.") + + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, **kwargs) -> Optional[torch.Tensor]: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors. (n_slices, num_tokens, r) + x (torch.Tensor): Input tensor. (num_tokens, in_features) + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features) + scale (float): Scaling factor for the operation + """ + x = x.view(-1, x.shape[-1]) + + for slice_idx in range(len(lora_a_stacked)): + lora_s = lora_a_stacked[slice_idx] + y_s = bgmv_shrink(x, lora_s, self._get_token_lora_indices(x), + scale) + y[slice_idx, :, :] = y_s # type: ignore[index] + return y + + def add_expand(self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> torch.Tensor: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. (num_tokens, out_features) + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors. (n_slices, num_tokens, r) + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + bias's weight + output_slices (tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_orig = y + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + for slice_idx in range(len(lora_b_stacked)): + y = bgmv_expand_slice(x[slice_idx], lora_b_stacked[slice_idx], y, + self._get_token_lora_indices(x[slice_idx]), + offset_left, output_slices[slice_idx], + add_inputs) + offset_left += output_slices[slice_idx] + return y.view(y_orig.shape) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + raise NotImplementedError( + "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_embedding.") + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs) -> torch.Tensor: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor (bs, out_features). Will not be changed in-place. + x (torch.Tensor): Input tensor (bs, in_features) + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight of length n_slices. lora_a_stacked[i]: (max_loras, 1, max_lora_rank, in_features) + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight of length n_slices. lora_b_stacked[i]: (max_loras, 1, out_features, max_lora_rank) + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. lora_bias_stacked[i]: (max_loras, 1, out_features) + scale (float): Scaling factor. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + + if buffer is None: + max_lora_rank = lora_b_stacked[0].size(-1) + num_tokens = x.size(0) + buffer = torch.zeros( + (len(output_slices), num_tokens, max_lora_rank), + dtype=x.dtype, + device=x.device, + ) + buffer = self.add_shrink( + buffer, x, lora_a_stacked, scale, + **kwargs) # (n_slices, num_tokens, max_lora_rank) + return self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + raise NotImplementedError( + "NYI: torch_punica_tpu.PunicaWrapperTPU.add_lora_logits.") + + def _apply_bias( + self, + token_lora_indices: torch.Tensor, + output: torch.Tensor, + output_slices: tuple[int, ...], + lora_bias_stacked: tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + indices: (batch_size) + output: (batch_size, output_size) e.g. (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + lora_bias_stacked: 3 element tuple of (max_loras, 1, output_dim). Length of the tuple is n_slices. + """ + orig_output = output + output = output.view(-1, output.shape[-1]) + token_lora_indices = token_lora_indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[token_lora_indices] # [num_tokens, out_features] + bias = torch.where(token_lora_indices[:, None] == -1, 0, bias) + + bias = F.pad(bias, (offset_left, output.shape[1] - + (offset_left + slice), 0, 0)) + + output += bias + offset_left += slice + + return output.view(orig_output.shape) + + # This performs the same tensor ops as the base method, except it does them + # on the CPU then transfers the results to the TPU + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + ): + # Pad the prompt mapping to avoid running into recompiles on the TPU + # TODO: Should this happen inside mapping internally? If so how can we + # avoid having backend specific LoRAMapping classes? + mapping.prompt_mapping = self._pad_prompt_mapping( + mapping.prompt_mapping) + + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + "cpu", + ) + self._token_lora_indices = self._pad_to_shape( + base_indices, self._token_lora_indices.shape, + dims=1).to(self.device) + self._sampler_indices = self._pad_to_shape(sampler_indices, + self._sampler_indices.shape, + dims=1).to(self.device) + self._sampler_indices_padded = self._pad_to_shape( + sampler_indices_padded, self._sampler_indices_padded.shape, + dims=1).to(self.device) + self._embeddings_indices = self._pad_to_shape( + embeddings_indices, self._embeddings_indices.shape, + dims=2).to(self.device) + self.indices_len[:] = indices_len + + def move_to_device(self, mesh: Mesh): + self._token_lora_indices = self._token_lora_indices.to('jax') + self._token_lora_indices.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None))) + + self._sampler_indices = self._sampler_indices.to('jax') + self._sampler_indices.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None))) + + self._sampler_indices_padded = self._sampler_indices_padded.to('jax') + self._sampler_indices_padded.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None))) + + self._embeddings_indices = self._embeddings_indices.to('jax') + self._embeddings_indices.apply_jax_(jax.device_put, + NamedSharding(mesh, P(None))) + + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: + self.batch_size = 1 + self._lora_indices_per_batch[:self. + batch_size] = token_lora_tensor[:self. + batch_size] + + def _pad_prompt_mapping( + self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: + num_reqs = len(prompt_mapping) + + # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular + # import + MIN_NUM_SEQS = 8 + + padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) + pad_len = padded_num_reqs - num_reqs + + padding = [-1] * pad_len + return tuple(list(prompt_mapping) + padding) + + def _pad_to_shape(self, src, target_shape, dims=1): + if dims == 1: + pad_len = target_shape[0] - src.shape[0] + return F.pad(src, (0, pad_len), value=0).to(torch.int32) + else: + pad_rows = target_shape[0] - src.shape[0] + pad_cols = target_shape[1] - src.shape[1] + return F.pad(src, (0, pad_cols, 0, pad_rows), + value=0).to(torch.int32) diff --git a/tpu_commons/models/vllm/sharding.py b/tpu_commons/models/vllm/sharding.py index 176b1cf18..410bd5c9b 100644 --- a/tpu_commons/models/vllm/sharding.py +++ b/tpu_commons/models/vllm/sharding.py @@ -55,11 +55,16 @@ def shard_qkv_parallel_linear(layer: torch.nn.Module, mesh: Mesh, def shard_merged_column_parallel_linear(layer: torch.nn.Module, mesh: Mesh, vllm_config: VllmConfig): + fuse_matmuls = get_model_matmul_fusion_assignment( + vllm_config.model_config.model, + vllm_config.scheduler_config.max_num_batched_tokens, + vllm_config.parallel_config.tensor_parallel_size, + "MergedColumnParallelLinear") assert isinstance(layer, MergedColumnParallelLinear) jax_layer = JaxMergedColumnParallelLinear( layer, mesh, - shard_merged_column_parallel_linear.fuse_matmuls, + fuse_matmuls, enable_sequence_parallelism=vllm_config.compilation_config.pass_config. enable_sequence_parallelism) return jax_layer @@ -159,10 +164,6 @@ def _move_to_tpu_replicated(x): vllm_config.model_config.model, vllm_config.scheduler_config.max_num_batched_tokens, tp_size, "QKVParallelLinear") - shard_merged_column_parallel_linear.fuse_matmuls = get_model_matmul_fusion_assignment( - vllm_config.model_config.model, - vllm_config.scheduler_config.max_num_batched_tokens, tp_size, - "MergedColumnParallelLinear") with jax.default_device(jax.devices("cpu")[0]), torchax.default_env(): shard_parallel_layers_to_tpu(model, mesh, vllm_config) diff --git a/tpu_commons/platforms/tpu_jax.py b/tpu_commons/platforms/tpu_jax.py index 8d923ed31..58619d5e3 100644 --- a/tpu_commons/platforms/tpu_jax.py +++ b/tpu_commons/platforms/tpu_jax.py @@ -88,7 +88,7 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: @classmethod def get_punica_wrapper(cls) -> str: - return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + return "tpu_commons.lora.torch_punica_tpu.PunicaWrapperTPU" @classmethod def get_infinity_values(cls, dtype: jnp.dtype) -> Tuple[float, float]: