Skip to content

Commit

Permalink
Support both medusa v1 and v2 (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 18, 2024
1 parent 2d33ee9 commit 0a3c627
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 143 deletions.
86 changes: 81 additions & 5 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple

import torch
import torch.distributed

from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.types import MEDUSA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.utils.layers import FastLinear
from lorax_server.utils.layers import FastLinear, TensorParallelColumnLinear
from lorax_server.utils.weights import AbstractWeights, InMemoryWeights

if TYPE_CHECKING:
Expand All @@ -18,6 +19,10 @@ class MedusaConfig(AdapterConfig):
medusa_num_heads: int
medusa_num_layers: int

@property
def quantize(self) -> Optional[str]:
return None

def map_weights_for_model(
self,
adapter_weights: Dict,
Expand Down Expand Up @@ -62,6 +67,7 @@ def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights):
super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
self.act = torch.nn.SiLU()
self.scaling = 1

def forward(self, x):
return x + self.act(self.linear(x))
Expand All @@ -83,22 +89,92 @@ def forward(self, x):
return x


class MedusaModel(torch.nn.Module):
class MedusaV1(torch.nn.Module):
def __init__(self, config: MedusaConfig, weights: AbstractWeights):
super().__init__()
self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config.medusa_num_heads)]
)

def forward(self, x):
def forward(self, x, lm_head):
logits = lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
return logits, speculative_logits


class MedusaV2(torch.nn.Module):
def __init__(self, config: MedusaConfig, weights: AbstractWeights):
super().__init__()
self.n_medusa_heads = config.medusa_num_heads

assert config.medusa_num_layers == 1
self.linear = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
dim=0,
weights=weights,
bias=True,
)
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()

self.act = torch.nn.SiLU()

def forward(self, x, lm_head):
# If we have too many tokens, we skip speculative logits
if x.shape[0] > 128:
logits = lm_head(x)
return logits, None

size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size

x_block = x[:, start:stop]

# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = self.act(self.linear(x)).reshape(*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1])

# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(-2) + medusa_res

# Gather medusa heads
world_output = [torch.empty_like(output) for _ in range(self.process_group.size())]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)

# Stack x and medusa residual x
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)

# Compute lm head on x + medusa residual x
logits = lm_head(stacked_x)

# Finally, split logits from speculative logits
logits, speculative_logits = torch.split(logits, [1, self.n_medusa_heads], dim=-2)
# Squeeze added dimension
logits = logits.squeeze(-2)

return logits, speculative_logits


class MedusaModel(torch.nn.Module):
def __init__(self, config: MedusaConfig, weights: AbstractWeights):
super().__init__()
if config.medusa_num_layers > 1 or weights.has_tensor(f"0.{config.medusa_num_layers}.weight"):
self.medusa = MedusaV1(config, weights)
else:
self.medusa = MedusaV2(config, weights)

def forward(self, x, lm_head):
return self.medusa(x, lm_head)


class MedusaWeights(AdapterWeights):
def __init__(self, config: MedusaConfig, module_map: ModuleMap, model: "Model"):
self.config = config
self.model = MedusaModel(config, InMemoryWeights(module_map, model.device, model.dtype))
self.model = MedusaModel(config, InMemoryWeights(module_map, model.device, model.dtype, model.process_group))

@classmethod
def get_batch_type(cls) -> BatchAdapterWeights:
Expand Down
10 changes: 5 additions & 5 deletions server/lorax_server/models/flash_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
initialize_torch_distributed,
weight_files,
)
from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ
from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ

tracer = trace.get_tracer(__name__)


# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ]
ROW_PARALLEL = {O_PROJ, DOWN_PROJ}
ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD]
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}


class FlashCohere(FlashCausalLM):
Expand Down Expand Up @@ -117,14 +116,15 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.gate_up_proj", layer.mlp.gate_up_proj)
layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)

layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
return layer_weights

@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS

def get_num_layers_for_type(self, layer_type: str) -> int:
return len(self.model.model.layers)
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)

def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
14 changes: 9 additions & 5 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def __init__(
index=None,
):
super().__init__()
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
assert (
not memory_efficient_backward
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index

Expand Down Expand Up @@ -677,24 +679,26 @@ class MultiAdapterHead(TensorParallelAdapterRowLinear):
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
result = super().forward(input, adapter_data)

# Medusa
data = adapter_data.data.get(self.layer_name)
data: Optional["BatchMedusaWeights"] = data.get(MEDUSA) if data is not None else None

speculative_logits = None
if data is not None and data.default_medusa is not None:
speculative_logits = data.default_medusa.model(input)
forward = super().forward
lm_head = lambda x: forward(x, adapter_data) # noqa: E731
logits, speculative_logits = data.default_medusa.model(input, lm_head)

# TODO(travis): support multiple medusa adapters with masking:
# for adapter_index in adapter_data.meta.adapter_set:
# if data.has_adapter(adapter_index):
# adapter_mask = (adapter_data.meta.adapter_indices == adapter_index).to(input.dtype).view(-1, 1)
# speculative_logits = data.adapter_to_medusa[adapter_index].model(input)
# ...
else:
logits = super().forward(input, adapter_data)

return result, speculative_logits
return logits, speculative_logits


class TensorParallelRowLinear(SuperLayer):
Expand Down
Loading

0 comments on commit 0a3c627

Please sign in to comment.