Skip to content

Commit 5c482b6

Browse files
zongwaveyanfeich
andauthored
[INTEL_HPU] Enable FusedBlockMultiTransformerHPU (#10514)
add FusedBlockMultiTransformerHPU and prepare_input_hpu atten output post-process FusedBlockMultiTransformerHPU opti for-if-else kv_head update for HPU shape fix benchmark can't run with bigger bs issue Co-authored-by: yanfeich <yanfei.cheng@intel.com>
1 parent 6bdb716 commit 5c482b6

File tree

3 files changed

+245
-22
lines changed

3 files changed

+245
-22
lines changed

llm/predict/predictor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,10 @@ def __init__(
833833
BasePredictor.__init__(self, config, tokenizer, model)
834834

835835
self.num_layers = len(self.cache_k_shapes)
836-
self.num_key_value_heads = self.cache_k_shapes[0][-3]
836+
if paddle.is_compiled_with_custom_device("intel_hpu"):
837+
self.num_key_value_heads = self.cache_k_shapes[0][-2]
838+
else:
839+
self.num_key_value_heads = self.cache_k_shapes[0][-3]
837840
self.head_dim = self.cache_k_shapes[0][-1]
838841
self.max_block_nums = self.cache_k_shapes[0][0]
839842
self.batch_size = config.batch_size
@@ -1215,7 +1218,7 @@ def predict_via_mq(self, input_texts: list[str], return_tokens=False):
12151218
outputs = []
12161219
output_tokens = []
12171220
while len(outputs) < len(input_texts):
1218-
result = result_queue.get(timeout=1)
1221+
result = result_queue.get(timeout=10)
12191222
outputs.append(result[-1])
12201223
output_tokens.append(result[-2])
12211224

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6765,3 +6765,177 @@ def forward(
67656765
out = src[:, -1:, :]
67666766
out = out.squeeze(axis=1)
67676767
return out, caches
6768+
6769+
6770+
class FusedBlockMultiTransformerHPU(FusedBlockMultiTransformer):
6771+
def __init__(self, config: FusedMultiTransformerConfig):
6772+
super().__init__(config)
6773+
6774+
self.config = config
6775+
6776+
def forward(
6777+
self,
6778+
input_ids,
6779+
src,
6780+
cum_offsets=None,
6781+
padding_offset=None,
6782+
attn_mask=None,
6783+
caches=None,
6784+
pre_caches=None,
6785+
pre_caches_length=0,
6786+
rotary_embs=None,
6787+
rotary_emb_dims=0,
6788+
seq_lens=None,
6789+
time_step=None,
6790+
**kwargs,
6791+
):
6792+
if caches is not None:
6793+
assert len(caches) == len(self.qkv_weights) or len(caches) == 2 * len(self.qkv_weights)
6794+
6795+
assert self.num_layers == len(self.qkv_weights)
6796+
6797+
seq_lens_encoder = kwargs.get("seq_lens_encoder", None)
6798+
seq_lens_decoder = kwargs.get("seq_lens_decoder", None)
6799+
block_size = kwargs.get("block_size", None)
6800+
block_indices = kwargs.get("block_indices", None)
6801+
block_groups = kwargs.get("block_groups", None)
6802+
block_list = kwargs.get("block_list", None)
6803+
block_offsets = kwargs.get("block_offsets", None)
6804+
block_mapping = kwargs.get("block_mapping", None)
6805+
block_bias = kwargs.get("block_bias", None)
6806+
6807+
max_enc_len = paddle.max(seq_lens_encoder, axis=0).item()
6808+
max_dec_len = paddle.max(seq_lens_decoder, axis=0).item()
6809+
6810+
if len(src.shape) == 2:
6811+
src = src.unsqueeze(axis=1)
6812+
6813+
import paddlenlp_ops
6814+
6815+
if max_enc_len > 0: # context
6816+
for i in range(self.num_layers):
6817+
residual_input = src
6818+
query_states, key_value_states = paddlenlp_ops.fused_rms_qkv_rope_t(
6819+
src,
6820+
self.ln_scales[i],
6821+
self.qkv_weights[i],
6822+
rotary_embs,
6823+
self._epsilon,
6824+
self.head_dim,
6825+
self.num_heads,
6826+
)
6827+
# Fused-OP-1 end
6828+
6829+
# Fused-OP-2 start
6830+
# write cache kv (inplace)
6831+
# [2, 8, 384, 32, 128]
6832+
# [2, 8, 6, 64, 32, 128]
6833+
# [2, 48, 64, 32, 128]
6834+
# --> [64][512, 64, 32, 128]
6835+
kv, B, BP_BS, M, H = key_value_states.shape
6836+
key_value_states_reshape = key_value_states.reshape([kv, -1, block_size, M, H])
6837+
key_states = key_value_states_reshape[0]
6838+
value_states = key_value_states_reshape[1]
6839+
k_cache = caches[2 * i]
6840+
v_cache = caches[2 * i + 1]
6841+
paddlenlp_ops.index_copy_(k_cache, block_indices, key_states, 0)
6842+
paddlenlp_ops.index_copy_(v_cache, block_indices, value_states, 0)
6843+
6844+
out_linear_out = paddlenlp_ops.fused_sdpa_proj_t(
6845+
query_states,
6846+
key_value_states,
6847+
attn_mask,
6848+
None,
6849+
self.linear_weights[i],
6850+
scaling_factor=self.head_dim**-0.5,
6851+
causal=True,
6852+
)
6853+
# Fused-OP-2 end
6854+
6855+
# all_reduce
6856+
if self.tp_degree > 1:
6857+
dist.all_reduce(out_linear_out)
6858+
out_linear_out = residual_input + out_linear_out
6859+
residual_input = out_linear_out
6860+
6861+
# Fused-OP-4 start
6862+
ffn2_out = paddlenlp_ops.fused_rms_mlp(
6863+
out_linear_out,
6864+
self.ffn_ln_scales[i],
6865+
self.ffn1_weights[i],
6866+
self.ffn2_weights[i],
6867+
self._epsilon,
6868+
)
6869+
# Fused-OP-4 end
6870+
6871+
# all_reduce
6872+
if self.tp_degree > 1:
6873+
dist.all_reduce(ffn2_out)
6874+
src = residual_input + ffn2_out
6875+
# end LlamaDecoderLayer
6876+
elif max_dec_len > 0:
6877+
for i in range(self.num_layers):
6878+
residual_input = src
6879+
query_states, key_value_states = paddlenlp_ops.fused_rms_qkv_rope_t(
6880+
src,
6881+
self.ln_scales[i],
6882+
self.qkv_weights[i],
6883+
rotary_embs,
6884+
self._epsilon,
6885+
self.head_dim,
6886+
self.num_heads,
6887+
)
6888+
# Fused-OP-1 end
6889+
6890+
# Fused-OP-2 start
6891+
# write cache kv (inplace)
6892+
# [2, B, 1, 32, 128] --> [64][max_block_num, blk_size, 32, 128]
6893+
key_states = key_value_states[0].squeeze(1)
6894+
value_states = key_value_states[1].squeeze(1)
6895+
k_cache = caches[2 * i]
6896+
v_cache = caches[2 * i + 1]
6897+
k_cache.index_put_((block_indices, block_offsets), key_states)
6898+
v_cache.index_put_((block_indices, block_offsets), value_states)
6899+
6900+
out_linear_out = paddlenlp_ops.fused_flatpa_proj(
6901+
query_states,
6902+
caches[2 * i],
6903+
caches[2 * i + 1],
6904+
block_groups,
6905+
block_list,
6906+
block_mapping,
6907+
block_bias,
6908+
self.linear_weights[i],
6909+
scaling_factor=self.head_dim**-0.5,
6910+
)
6911+
# Fused-OP-2 end
6912+
6913+
# all_reduce
6914+
if self.tp_degree > 1:
6915+
dist.all_reduce(out_linear_out)
6916+
out_linear_out = residual_input + out_linear_out
6917+
residual_input = out_linear_out
6918+
6919+
# Fused-OP-4 start
6920+
ffn2_out = paddlenlp_ops.fused_rms_mlp(
6921+
out_linear_out,
6922+
self.ffn_ln_scales[i],
6923+
self.ffn1_weights[i],
6924+
self.ffn2_weights[i],
6925+
self._epsilon,
6926+
)
6927+
# Fused-OP-4 end
6928+
6929+
# all_reduce
6930+
if self.tp_degree > 1:
6931+
dist.all_reduce(ffn2_out)
6932+
src = residual_input + ffn2_out
6933+
# end LlamaDecoderLayer
6934+
6935+
kwargs["time_step"] = time_step
6936+
kwargs["multi_block_output"] = src
6937+
kwargs["seq_lens"] = seq_lens
6938+
kwargs["input_ids"] = input_ids
6939+
6940+
out = self.post_process(**kwargs)
6941+
return out, caches

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
FusedBlockMultiTransformer,
3535
FusedBlockMultiTransformerA8W8,
3636
FusedBlockMultiTransformerFP8,
37+
FusedBlockMultiTransformerHPU,
3738
FusedBlockMultiTransformerWeightOnly,
3839
FusedMultiTransformerA8W8,
3940
FusedMultiTransformerAvx,
@@ -1406,6 +1407,8 @@ def set_transformer_block(self, transformer_config):
14061407
self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config)
14071408
elif "fp8" in self.quant_type:
14081409
self.transformer_block = FusedBlockMultiTransformerFP8(transformer_config)
1410+
elif paddle.is_compiled_with_custom_device("intel_hpu"):
1411+
self.transformer_block = FusedBlockMultiTransformerHPU(transformer_config)
14091412
else:
14101413
self.transformer_block = FusedBlockMultiTransformer(transformer_config)
14111414

@@ -1433,24 +1436,62 @@ def forward(
14331436
):
14341437
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
14351438
rope_emb = kwargs.get("rope_emb", None)
1436-
draft_tokens = kwargs.get("draft_tokens", None)
1437-
seq_lens_encoder = kwargs.get("seq_lens_encoder", None)
14381439

1439-
# whether speculative decoding or not
1440-
if draft_tokens is None:
1441-
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
1442-
input_ids, seq_lens_this_time
1440+
if paddle.is_compiled_with_custom_device("intel_hpu"):
1441+
from paddlenlp_ops import prepare_input_hpu
1442+
1443+
block_tables = kwargs.get("block_tables", None).to("CPU")
1444+
seq_lens_encoder = kwargs.get("seq_lens_encoder", None).to("CPU")
1445+
seq_lens_decoder = kwargs.get("seq_lens_decoder", None).to("CPU")
1446+
input_ids = input_ids.to("CPU")
1447+
1448+
(
1449+
ids_remove_padding,
1450+
rope_emb,
1451+
block_groups,
1452+
block_list,
1453+
block_indices,
1454+
block_offsets,
1455+
block_mapping,
1456+
attention_mask,
1457+
valid_seq_len,
1458+
) = prepare_input_hpu(
1459+
input_ids,
1460+
rope_emb,
1461+
block_tables,
1462+
self.block_size,
1463+
seq_lens_encoder,
1464+
seq_lens_decoder,
1465+
paddle.get_default_dtype(),
14431466
)
1467+
cum_offsets = None
1468+
kwargs["block_groups"] = block_groups
1469+
kwargs["block_list"] = block_list
1470+
kwargs["block_indices"] = block_indices
1471+
kwargs["block_offsets"] = block_offsets
1472+
kwargs["block_mapping"] = block_mapping
1473+
kwargs["block_bias"] = attention_mask
1474+
kwargs["block_size"] = self.block_size
1475+
kwargs["valid_seq_len"] = valid_seq_len
14441476
else:
1445-
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
1446-
input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder
1447-
)
1477+
draft_tokens = kwargs.get("draft_tokens", None)
1478+
seq_lens_encoder = kwargs.get("seq_lens_encoder", None)
14481479

1449-
kwargs["cu_seqlens_q"] = cu_seqlens_q
1450-
kwargs["cu_seqlens_k"] = cu_seqlens_k
1451-
kwargs["padding_offsets"] = padding_offset
1452-
kwargs["max_input_length"] = self.max_seq_len
1453-
kwargs["block_size"] = self.block_size
1480+
# whether speculative decoding or not
1481+
if draft_tokens is None:
1482+
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
1483+
input_ids, seq_lens_this_time
1484+
)
1485+
else:
1486+
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
1487+
input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder
1488+
)
1489+
1490+
kwargs["cu_seqlens_q"] = cu_seqlens_q
1491+
kwargs["cu_seqlens_k"] = cu_seqlens_k
1492+
kwargs["padding_offsets"] = padding_offset
1493+
kwargs["max_input_length"] = self.max_seq_len
1494+
kwargs["block_size"] = self.block_size
14541495

14551496
inputs_embeds = self.embed_tokens(ids_remove_padding)
14561497

@@ -1958,15 +1999,20 @@ def get_cache_kvs_shape(
19581999
else:
19592000
max_block_nums = max_batch_size * max_block_per_seq
19602001

2002+
cache_kv_shape = [
2003+
max_block_nums,
2004+
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
2005+
config.block_size,
2006+
config.hidden_size // config.num_attention_heads,
2007+
]
2008+
if paddle.is_compiled_with_custom_device("intel_hpu"):
2009+
# HPU block multi-transformer
2010+
# Use KV Cache shape [max_block_nums, seq_len, num_head, head_dim]
2011+
cache_kv_shape = [cache_kv_shape[i] for i in [0, 2, 1, 3]]
2012+
19612013
cache_k_shapes = []
19622014
cache_v_shapes = []
19632015
for _ in range(config.num_hidden_layers):
1964-
cache_kv_shape = [
1965-
max_block_nums,
1966-
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
1967-
config.block_size,
1968-
config.hidden_size // config.num_attention_heads,
1969-
]
19702016
cache_k_shapes.append(cache_kv_shape)
19712017
cache_v_shapes.append(cache_kv_shape)
19722018
return cache_k_shapes, cache_v_shapes

0 commit comments

Comments
 (0)