Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapter support for all linear layers in Llama and Mistral #75

Merged
merged 9 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions server/lorax_server/models/custom_modeling/flash_llama_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Set, Tuple
from typing import Optional, List, Tuple

# Flash attention imports
import dropout_layer_norm
Expand All @@ -41,7 +41,7 @@
TensorParallelHead,
get_linear,
)
from lorax_server.utils.lora import AdapterBatchData
from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData


class LlamaConfig(PretrainedConfig):
Expand Down Expand Up @@ -152,11 +152,13 @@ def forward(self, hidden_states, residual=None):
def load_attention(config, prefix, weights, layer_id):
base_layer = load_attention_multi(config, prefix, weights)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(base_layer, layer_id, sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
], process_group=weights.process_group)
return TensorParallelMultiAdapterLinear.load(
base_layer, layer_id, [Q_PROJ, K_PROJ, V_PROJ], sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
], process_group=weights.process_group
)


def load_attention_multi(config, prefix, weights):
Expand Down Expand Up @@ -237,7 +239,7 @@ def __init__(
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
), layer_id, process_group=weights.process_group)
), layer_id, O_PROJ, process_group=weights.process_group)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
Expand Down Expand Up @@ -329,7 +331,7 @@ def forward(


class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_act
self.act = (
Expand All @@ -343,27 +345,34 @@ def __init__(self, prefix, config, weights):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj, layer_id, [GATE_PROJ, UP_PROJ], sizes=[
config.intermediate_size,
config.intermediate_size,
], process_group=weights.process_group
)

self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
), layer_id, DOWN_PROJ, process_group=weights.process_group)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)

def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data)


class FlashLlamaLayer(nn.Module):
Expand All @@ -373,7 +382,7 @@ def __init__(self, layer_id, config, weights):
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id,
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id)

self.input_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
Expand Down Expand Up @@ -419,7 +428,7 @@ def forward(
attn_output, res
)

mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)

return mlp_output, attn_res

Expand Down Expand Up @@ -500,11 +509,11 @@ def __init__(self, config, weights):
super().__init__()

self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load(
self.lm_head = TensorParallelAdapterRowLinear.load(TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
), 0, LM_HEAD, process_group=weights.process_group)

def forward(
self,
Expand Down Expand Up @@ -532,5 +541,5 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states, adapter_data)
return logits
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
TensorParallelHead,
get_linear,
)
from lorax_server.utils.lora import AdapterBatchData
from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData

if not HAS_FLASH_ATTN_V2:
raise ImportError("Mistral model requires flash attn v2")
Expand Down Expand Up @@ -158,11 +158,13 @@ def forward(self, hidden_states, residual=None):
def load_attention(config, prefix, weights, layer_id):
base_layer = load_attention_multi(config, prefix, weights)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(base_layer, layer_id, sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
], process_group=weights.process_group)
return TensorParallelMultiAdapterLinear.load(
base_layer, layer_id, [Q_PROJ, K_PROJ, V_PROJ], sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
], process_group=weights.process_group
)


def load_attention_multi(config, prefix, weights):
Expand Down Expand Up @@ -246,7 +248,7 @@ def __init__(
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
), layer_id, process_group=weights.process_group)
), layer_id, O_PROJ, process_group=weights.process_group)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
Expand Down Expand Up @@ -345,7 +347,7 @@ def forward(


class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_act
self.act = (
Expand All @@ -359,27 +361,34 @@ def __init__(self, prefix, config, weights):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj, layer_id, [GATE_PROJ, UP_PROJ], sizes=[
config.intermediate_size,
config.intermediate_size,
], process_group=weights.process_group
)

self.down_proj = TensorParallelAdapterRowLinear.load(TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
), layer_id, DOWN_PROJ, process_group=weights.process_group)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)

def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data)


class MistralLayer(nn.Module):
Expand All @@ -389,7 +398,7 @@ def __init__(self, layer_id, config, weights):
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id,
)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id)

self.input_layernorm = MistralRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
Expand Down Expand Up @@ -437,7 +446,7 @@ def forward(
attn_output, res
)

mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)

return mlp_output, attn_res

Expand Down Expand Up @@ -520,11 +529,11 @@ def __init__(self, config, weights):
super().__init__()

self.model = MistralModel(config, weights)
self.lm_head = TensorParallelHead.load(
self.lm_head = TensorParallelAdapterRowLinear.load(TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
), 0, LM_HEAD, process_group=weights.process_group)
self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")
Expand Down Expand Up @@ -566,5 +575,5 @@ def forward(
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states, adapter_data)
return logits
64 changes: 30 additions & 34 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from lorax_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map
from lorax_server.utils.dist import MEMORY_FRACTION
from lorax_server.utils.lora import K_PROJ, O_PROJ, Q_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.lora import ADAPTER_LAYERS, DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments

tracer = trace.get_tracer(__name__)
Expand Down Expand Up @@ -696,14 +696,23 @@ def __init__(
sliding_window=sliding_window,
)

weight_names = []
layer_weights = {}

# TODO(travis): generalize this
prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
weight_names.append(f"{prefix}.{i}.self_attn.{Q_PROJ}")
weight_names.append(f"{prefix}.{i}.self_attn.{K_PROJ}")
weight_names.append(f"{prefix}.{i}.self_attn.{V_PROJ}")
weight_names.append(f"{prefix}.{i}.self_attn.{O_PROJ}")
self.weight_names = tuple(weight_names)
layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value)
layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value)
layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value)
layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)

layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj)
layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj)
layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)

layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)

self.layer_weights = layer_weights

@property
def supports_adapter_loading(self) -> bool:
Expand Down Expand Up @@ -733,12 +742,10 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index):
return
elif adapter_id != BASE_MODEL_ADAPTER_ID:
logger.info(f"Loading adapter weights into model: {adapter_id}")
module_map, adapter_config = load_module_map(self.model_id, adapter_id, adapter_source, self.weight_names)

self.load_batched_adapter_weights(module_map, adapter_config, adapter_index, Q_PROJ)
self.load_batched_adapter_weights(module_map, adapter_config, adapter_index, V_PROJ)
self.load_batched_adapter_weights(module_map, adapter_config, adapter_index, K_PROJ)
self.load_batched_adapter_weights(module_map, adapter_config, adapter_index, O_PROJ)
weight_names = tuple([v[0] for v in self.layer_weights.values()])
module_map, adapter_config = load_module_map(self.model_id, adapter_id, adapter_source, weight_names)
for layer_name in ADAPTER_LAYERS:
self.load_batched_adapter_weights(module_map, adapter_config, adapter_index, layer_name)

self.adapter_id = adapter_id

Expand All @@ -749,29 +756,26 @@ def load_batched_adapter_weights(
adapter_index: int,
layer_type: str,
):
nlayers = len(self.model.model.layers)
nlayers = len(self.model.model.layers) if layer_type != LM_HEAD else 1
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers

prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
# TODO(travis): generalize this beyond qkv for accessing the layer_id
# This works for o_proj because they share the same id sequence, but may not
# extend to other layers.
layer = layer.self_attn.query_key_value
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = self.layer_weights[key]

base_weight = layer.base_layer.linear.weight
base_device = base_weight.device

weight_name = f"{prefix}.{i}.self_attn.{layer_type}"
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return

lora_a = module_map[weight_name]["lora_A"].to(base_device, base_weight.dtype)
lora_b = module_map[weight_name]["lora_B"].to(base_device, base_weight.dtype)

lora_a_list[layer.layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer.layer_id] = lora_b.transpose(0, 1)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1)

q_lora_merged = MergedLoraWeights(lora_a_list, lora_b_list, adapter_config, layer_type, self.process_group)
q_lora_weights = self.batched_lora_weights[layer_type]
Expand All @@ -794,17 +798,9 @@ def offload_adapter(self, adapter_id, adapter_source, adapter_index):
if adapter_id == BASE_MODEL_ADAPTER_ID:
return
else:
if Q_PROJ in self.batched_lora_weights:
self.batched_lora_weights[Q_PROJ].remove_adapter(adapter_index)

if V_PROJ in self.batched_lora_weights:
self.batched_lora_weights[V_PROJ].remove_adapter(adapter_index)

if K_PROJ in self.batched_lora_weights:
self.batched_lora_weights[K_PROJ].remove_adapter(adapter_index)

if O_PROJ in self.batched_lora_weights:
self.batched_lora_weights[O_PROJ].remove_adapter(adapter_index)
for layer_name in ADAPTER_LAYERS:
if layer_name in self.batched_lora_weights:
self.batched_lora_weights[layer_name].remove_adapter(adapter_index)

self.adapter_id = BASE_MODEL_ADAPTER_ID

Expand Down
Loading
Loading