Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add extra punica sizes to support bigger vocabs #4015

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion csrc/punica/bgmv/bgmv_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py

// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
Expand Down
14 changes: 7 additions & 7 deletions csrc/punica/punica_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
}
}

inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
return (uint64_t(a) << 32) | uint64_t(b);
}

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
Expand All @@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint16_t in_features, uint16_t out_features,
uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
switch (pack_u16(in_features, out_features)) {
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \
case pack_u32(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
Expand Down Expand Up @@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
Expand Down Expand Up @@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
Expand Down
78 changes: 44 additions & 34 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def create_random_inputs(
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings(dist_init, num_loras, device) -> None:
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:

torch.set_default_device(device)
max_loras = 8
Expand All @@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)

def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256)
embedding = VocabParallelEmbedding(vocab_size, 256)
embedding.weight.data = torch.rand_like(embedding.weight.data)
embedding.weight.data[512:, :] = 0
embedding.weight.data[vocab_size:, :] = 0
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
lora_embedding.create_lora_weights(max_loras, lora_config)

Expand All @@ -203,12 +204,13 @@ def create_random_embedding_layer():
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info)

lora_result = lora_embedding(torch.cat(inputs))
Expand Down Expand Up @@ -240,12 +242,13 @@ def create_random_embedding_layer():
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )

lora_result = lora_embedding(torch.cat(inputs))
Expand All @@ -263,7 +266,9 @@ def create_random_embedding_layer():
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size) -> None:

torch.set_default_device(device)
max_loras = 8
Expand All @@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)

def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256)
embedding = VocabParallelEmbedding(vocab_size, 256)
embedding_data = torch.rand_like(embedding.weight.data)
embedding.weight.data = embedding_data
embedding.weight.data[512:, :] = 0
embedding.weight.data[vocab_size:, :] = 0
expanded_embedding = VocabParallelEmbedding(
512 + lora_config.lora_extra_vocab_size * max_loras,
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
256,
org_num_embeddings=512)
expanded_embedding.weight.data[:512, :] = embedding_data
org_num_embeddings=vocab_size)
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
# We need to deepcopy the embedding as it will be modified
# in place
lora_embedding = VocabParallelEmbeddingWithLoRA(
Expand All @@ -298,7 +303,7 @@ def create_random_embedding_layer():
id_to_index,
layer=lora_embedding,
layer_weights=torch.zeros(
(256, 512 + lora_config.lora_extra_vocab_size)),
(256, vocab_size + lora_config.lora_extra_vocab_size)),
generate_embeddings_tensor=256,
)

Expand All @@ -316,7 +321,7 @@ def create_random_embedding_layer():
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

Expand All @@ -327,16 +332,18 @@ def create_random_embedding_layer():
for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping):
embedding_id = lora_id - 1
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
original_input_[-1] = 512
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = 512 + embeddings_tensor_len - 1
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
original_input_[-1] = vocab_size
input_[-2] = vocab_size + (
(embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = vocab_size + embeddings_tensor_len - 1

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )

expanded_embedding.weight[512:512 +
expanded_embedding.weight[vocab_size:vocab_size +
(embeddings_tensor_len *
max_loras)] = torch.cat(embeddings_tensors)

Expand Down Expand Up @@ -370,14 +377,15 @@ def create_random_embedding_layer():
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

original_inputs = deepcopy(inputs)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )

lora_result = lora_embedding(torch.cat(original_inputs))
Expand All @@ -393,7 +401,9 @@ def create_random_embedding_layer():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_lm_head_logits_processor(dist_init, num_loras, device,
vocab_size) -> None:

torch.set_default_device(device)
max_loras = 8
Expand All @@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)

def _pretest():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
1024, 32000)
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
1024, vocab_size)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0
linear.weight.data[:, vocab_size:] = 0
logits_processor = LogitsProcessor(
32000 + lora_config.lora_extra_vocab_size, 32000)
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config)
Expand Down Expand Up @@ -444,7 +454,7 @@ def _pretest():
lora_mapping,
id_to_index,
max_loras,
32000,
vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_logits_processor.set_mapping(*mapping_info, )
Expand All @@ -460,19 +470,19 @@ def _pretest():
org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor

logits_processor.org_vocab_size = (32000 +
logits_processor.org_vocab_size = (vocab_size +
lora_config.lora_extra_vocab_size)
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = 32000
logits_processor.org_vocab_size = vocab_size

# Check that resetting the lora weights succeeds

Expand All @@ -489,14 +499,14 @@ def _pretest():
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)

mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000,
vocab_size,
lora_config.lora_extra_vocab_size)
lora_logits_processor.set_mapping(*mapping_info, )

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
embedding_bias=None)[:, :vocab_size]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
Expand Down
49 changes: 45 additions & 4 deletions tests/lora/test_punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,51 @@ def _lora_ref_impl(


H1 = H2 = [
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
32768, 33024
128,
256,
512,
1024,
1152,
1280,
1536,
2048,
2304,
2560,
2752,
3072,
3456,
3584,
4096,
4608,
5120,
5504,
5632,
6144,
6848,
6912,
7168,
8192,
9216,
10240,
11008,
13824,
14336,
22016,
24576,
27392,
32000,
32256,
32512,
32768,
33024,
36864,
49152,
64000,
64256,
102400,
102656,
128000,
128256,
]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,9 +939,9 @@ def create_lora_weights(
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024:
if 32000 < self.base_layer.vocab_size > 128512:
raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 33024")
"32000 >= vocab_size <= 128512")
self.lora_a_stacked = torch.zeros(
(
max_loras,
Expand Down
Loading