Skip to content

Commit

Permalink
Implement sparse SGMV (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 28, 2023
1 parent 90b2362 commit 54fafb9
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 112 deletions.
5 changes: 4 additions & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,12 @@ def load_batched_adapter_weights(

lora_a = module_map[weight_name]["lora_A"].to(base_device, base_weight.dtype)
lora_b = module_map[weight_name]["lora_B"].to(base_device, base_weight.dtype)
scale = adapter_config.lora_alpha / adapter_config.r

# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

q_lora_merged = MergedLoraWeights(lora_a_list, lora_b_list, adapter_config, layer_type, self.process_group)
q_lora_weights = self.batched_lora_weights[layer_type]
Expand Down
35 changes: 18 additions & 17 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,23 +352,26 @@ def forward_layer_type(
if (
has_sgmv() and
self.process_group.size() == 1 and
data is not None and data.can_vectorize
data is not None
):
proj = torch.zeros_like(result[:, start_idx:end_idx])

lora_a_ptr = data.lora_a_ptr
lora_b_ptr = data.lora_b_ptr
if lora_a_ptr is not None and lora_b_ptr is not None:
add_lora_sgmv_cutlass(
proj,
input,
lora_a_ptr,
lora_b_ptr,
adapter_data.meta.adapter_segments,
self.layer_id,
data.rank,
)
result[:, start_idx:end_idx] += proj * data.scaling
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is not None and lora_b_ptr is not None:
add_lora_sgmv_cutlass(
proj,
input,
lora_a_ptr,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)

result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
Expand All @@ -385,8 +388,6 @@ def forward_lora(
adapter_index: int,
adapter_mask: torch.Tensor,
) -> torch.Tensor:
scaling = data.scaling_for_adapter(adapter_index)

lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, data.adapter_index_configs[adapter_index].r)
a_out = input @ lora_a
Expand All @@ -395,7 +396,7 @@ def forward_lora(
a_out = self.collect_lora_a(a_out)

lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
result = (a_out @ lora_b) * scaling * adapter_mask
result = (a_out @ lora_b) * adapter_mask
return result

def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
Expand Down
69 changes: 27 additions & 42 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Set

Expand Down Expand Up @@ -29,45 +30,24 @@


@dataclass
class AdapterWeightData:
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor


@dataclass
class AdapterWeightData:
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]

r: Set[int]
alpha: Set[int]
adapter_index_configs: Dict[int, LoraConfig]

@property
def can_vectorize(self) -> bool:
# Currently we can only use the SGMV kernel when the following criteria are met:
# 1. All adapters have the same r
# 2. All adapters have the same alpha
# 3. The base model (no adapter) is not contained in the batch
#
# TODO(travis): we should remove 3 as a constraint as quickly as possible,
# as many requests will likely come in for the base model in parallel with
# adapters. One solution is to create a zeroed out tensor with the same shape,
# the other is to rework the kernel to handle this case as a missing segment.
return len(self.r) == 1 and len(self.alpha) == 1 and None not in self.r
rank_data: Dict[int, RankSegments]

def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs

@property
def rank(self) -> int:
return next(iter(self.r))

@property
def scaling(self) -> float:
alpha = next(iter(self.alpha))
return alpha / self.rank

def scaling_for_adapter(self, adapter_idx: int) -> float:
cfg = self.adapter_index_configs[adapter_idx]
return cfg.lora_alpha / cfg.r


@dataclass
class AdapterBatchMetadata:
Expand Down Expand Up @@ -172,26 +152,31 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData:
device=device,
)

r = set([
(self.lora_weights[idx].adapter_config.r if idx in self.lora_weights else None)
for idx in segment_indices
])
alpha = set([
(self.lora_weights[idx].adapter_config.lora_alpha if idx in self.lora_weights else None)
for idx in segment_indices
])
adapter_index_configs = {
idx: self.lora_weights[idx].adapter_config
for idx in segment_indices
if idx in self.lora_weights
}

rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in self.lora_weights:
continue
rank_indices[self.lora_weights[adapter_idx].adapter_config.r].append(segment_idx)

rank_data = {}
for rank, indices in rank_indices.items():
rank_data[rank] = RankSegments(
rank=rank,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=meta.adapter_segments[indices],
segment_ends=meta.adapter_segments[[i+1 for i in indices]],
)

return AdapterWeightData(
lora_a_ptr=lora_a_ptr,
lora_b_ptr=lora_b_ptr,
lora_a=lora_a,
lora_b=lora_b,
r=r,
alpha=alpha,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
)
42 changes: 22 additions & 20 deletions server/lorax_server/utils/sgmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,49 +31,51 @@ def add_lora_sgmv_cutlass(
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s: torch.Tensor,
s_start: torch.Tensor,
s_end: torch.Tensor,
layer_idx: int,
lora_rank: int,
):
"""
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H1]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s: Shape: `[S+1]`, DType: torch.int32. Indptr of the weight matrices.\
`s[0] == 0`, `s[-1] == B`.
layer_idx: Layer index of the weight matrices.
"""
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H1]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
layer_idx: Layer index of the weight matrices.
"""
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
# Custom SGMV shrink only supports rank 16, 32, 64, 128
_add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s, layer_idx, lora_rank)
_add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
return

tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s, tmp1, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s, tmp2, layer_idx)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)


def _add_lora_sgmv_cutlass_legacy(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s: torch.IntTensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
):
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s, tmp, layer_idx)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
30 changes: 17 additions & 13 deletions server/punica_kernels/punica_kernels/punica_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,49 +322,53 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,

//====== sgmv ======

void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x,
torch::Tensor w_ptr, torch::Tensor s,
void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
torch::Tensor s_start, torch::Tensor s_end,
torch::Tensor tmp, int layer_idx) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w_ptr);
CHECK_INPUT(s);
CHECK_INPUT(s_start);
CHECK_INPUT(s_end);
CHECK_INPUT(tmp);

CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(1, w_ptr);
CHECK_DIM(1, s);
CHECK_DIM(1, s_start);
CHECK_DIM(1, s_end);
CHECK_DIM(1, tmp);

int num_problems = s.size(0) - 1;
int num_problems = s_start.size(0);
int d_in = x.size(1);
int d_out = y.size(1);
CHECK_EQ(tmp.size(0), static_cast<int64_t>(sgmv_tmp_size(num_problems)));
bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] {
return sgmv<c_type>((c_type*)y.data_ptr(), (c_type*)x.data_ptr(),
(c_type**)w_ptr.data_ptr(), s.data_ptr<int32_t>(),
return sgmv<c_type>((c_type*)y.data_ptr(), (c_type*)x.data_ptr(), (c_type**)w_ptr.data_ptr(),
s_start.data_ptr<int32_t>(), s_end.data_ptr<int32_t>(),
tmp.data_ptr<uint8_t>(), num_problems, d_in, d_out,
layer_idx);
});
TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type());
}

void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
torch::Tensor s, torch::Tensor tmp, int layer_idx) {
torch::Tensor s_start, torch::Tensor s_end, torch::Tensor tmp, int layer_idx) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w_ptr);
CHECK_INPUT(s);
CHECK_INPUT(s_start);
CHECK_INPUT(s_end);
CHECK_INPUT(tmp);

CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(1, w_ptr);
CHECK_DIM(1, s);
CHECK_DIM(1, s_start);
CHECK_DIM(1, s_end);
CHECK_DIM(1, tmp);

uint32_t num_problems = s.size(0) - 1;
uint32_t num_problems = s_start.size(0);
uint32_t d_in = x.size(1);
uint32_t d_out = y.size(1);
CHECK_EQ(tmp.scalar_type(), at::ScalarType::Byte);
Expand All @@ -374,7 +378,7 @@ void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
case D_OUT: \
return sgmv_shrink<c_type, D_OUT>( \
(c_type*)y.data_ptr(), (c_type*)x.data_ptr(), \
(c_type**)w_ptr.data_ptr(), s.data_ptr<int32_t>(), \
(c_type**)w_ptr.data_ptr(), s_start.data_ptr<int32_t>(), s_end.data_ptr<int32_t>(), \
tmp.data_ptr<uint8_t>(), num_problems, d_in, layer_idx);

bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&] {
Expand Down Expand Up @@ -435,4 +439,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sgmv_cutlass_tmp_size", &sgmv_tmp_size, "");
m.def("sgmv_shrink", &dispatch_sgmv_shrink, "");
m.def("rms_norm", &dispatch_rms_norm, "");
}
}
4 changes: 2 additions & 2 deletions server/punica_kernels/punica_kernels/sgmv/sgmv.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
template <typename DType>
bool sgmv(DType *y, DType *x, DType **w, int32_t *s, void *tmp_d,
int num_problems, int d_in, int d_out, int layer_idx);
bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx);

size_t sgmv_tmp_size(int num_problems);
8 changes: 5 additions & 3 deletions server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

#include "sgmv_cutlass.cuh"

template bool sgmv<nv_half>(nv_half *y, nv_half *x, nv_half **w, int32_t *s,
template bool sgmv<nv_half>(nv_half *y, nv_half *x, nv_half **w,
int32_t *s_start, int32_t *s_end,
void *tmp_d, int num_problems, int d_in, int d_out,
int layer_idx);

template bool sgmv<nv_bfloat16>(nv_bfloat16 *y, nv_bfloat16 *x, nv_bfloat16 **w,
int32_t *s, void *tmp_d, int num_problems,
int d_in, int d_out, int layer_idx);
int32_t *s_start, int32_t *s_end,
void *tmp_d, int num_problems, int d_in, int d_out,
int layer_idx);
Loading

0 comments on commit 54fafb9

Please sign in to comment.