Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support lazy loading the lora module for reducing the loading p… #434

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
pad_rank,
use_cutlass_shrink,
)
from lorax_server.utils.weights import load_module_weight

if TYPE_CHECKING:
from lorax_server.models.model import Model
Expand Down Expand Up @@ -166,10 +167,10 @@ def load(
return None

lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_a = load_module_weight(lora_a_name, lora_a, base_device, model.dtype)

lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
lora_b = load_module_weight(lora_b_name, lora_b, base_device, model.dtype)

scale = get_scaling_factor(
config.lora_alpha,
Expand Down
13 changes: 12 additions & 1 deletion server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Set, Tuple

from loguru import logger
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer

Expand Down Expand Up @@ -78,6 +79,7 @@ def _load_and_merge(
weight_names,
api_token,
trust_remote_code,
False,
)

adapters_to_merge.append((module_map, adapter_config))
Expand Down Expand Up @@ -136,6 +138,7 @@ def load_module_map(
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
lazy_load_weights: bool = True,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
# TODO(geoffrey): refactor this and merge parts of this function with
# lorax_server/utils/adapter.py::create_merged_weight_files
Expand All @@ -157,7 +160,15 @@ def load_module_map(
adapter_filenames = source.weight_files()
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
if lazy_load_weights:
result = {}
# just fetching the layer names of the module
with safe_open(filename, framework="pt") as f:
for k in f.keys():
result[k] = filename
adapter_weights.update(result)
else:
adapter_weights.update(load_file(filename))

# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names)
Expand Down
19 changes: 16 additions & 3 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,7 @@ def get_slice(self, tensor_name: str) -> torch.Tensor:

def get_tensor(self, tensor_name: str) -> torch.Tensor:
tensor = self.weights[tensor_name]
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
return load_module_weight(tensor_name, tensor, self.device, self.dtype)

def get_slice_shape(self, slice) -> torch.Size:
return slice.shape
Expand Down Expand Up @@ -542,3 +540,18 @@ def download_weights(
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)


def load_module_weight(name: str, module: Union[torch.Tensor, str], device, dtype):
if isinstance(module, torch.Tensor):
return module.to(device, dtype)

if isinstance(device, torch.device):
if device.type == "cuda":
device = device.index
elif device.type == "cpu":
device = "cpu"

# module would be just the filename if lazy loading happened before
with safe_open(module, framework="pt", device=device) as f:
return f.get_tensor(name).to(dtype)
Loading