Skip to content

[INTEL_HPU] Enable FusedBlockMultiTransformerHPU #10514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
@@ -833,7 +833,10 @@ def __init__(
BasePredictor.__init__(self, config, tokenizer, model)

self.num_layers = len(self.cache_k_shapes)
self.num_key_value_heads = self.cache_k_shapes[0][-3]
if paddle.is_compiled_with_custom_device("intel_hpu"):
self.num_key_value_heads = self.cache_k_shapes[0][-2]
else:
self.num_key_value_heads = self.cache_k_shapes[0][-3]
self.head_dim = self.cache_k_shapes[0][-1]
self.max_block_nums = self.cache_k_shapes[0][0]
self.batch_size = config.batch_size
@@ -1215,7 +1218,7 @@ def predict_via_mq(self, input_texts: list[str], return_tokens=False):
outputs = []
output_tokens = []
while len(outputs) < len(input_texts):
result = result_queue.get(timeout=1)
result = result_queue.get(timeout=10)
outputs.append(result[-1])
output_tokens.append(result[-2])

174 changes: 174 additions & 0 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
@@ -6765,3 +6765,177 @@ def forward(
out = src[:, -1:, :]
out = out.squeeze(axis=1)
return out, caches


class FusedBlockMultiTransformerHPU(FusedBlockMultiTransformer):
def __init__(self, config: FusedMultiTransformerConfig):
super().__init__(config)

self.config = config

def forward(
self,
input_ids,
src,
cum_offsets=None,
padding_offset=None,
attn_mask=None,
caches=None,
pre_caches=None,
pre_caches_length=0,
rotary_embs=None,
rotary_emb_dims=0,
seq_lens=None,
time_step=None,
**kwargs,
):
if caches is not None:
assert len(caches) == len(self.qkv_weights) or len(caches) == 2 * len(self.qkv_weights)

assert self.num_layers == len(self.qkv_weights)

seq_lens_encoder = kwargs.get("seq_lens_encoder", None)
seq_lens_decoder = kwargs.get("seq_lens_decoder", None)
block_size = kwargs.get("block_size", None)
block_indices = kwargs.get("block_indices", None)
block_groups = kwargs.get("block_groups", None)
block_list = kwargs.get("block_list", None)
block_offsets = kwargs.get("block_offsets", None)
block_mapping = kwargs.get("block_mapping", None)
block_bias = kwargs.get("block_bias", None)

max_enc_len = paddle.max(seq_lens_encoder, axis=0).item()
max_dec_len = paddle.max(seq_lens_decoder, axis=0).item()

if len(src.shape) == 2:
src = src.unsqueeze(axis=1)

import paddlenlp_ops

if max_enc_len > 0: # context
for i in range(self.num_layers):
residual_input = src
query_states, key_value_states = paddlenlp_ops.fused_rms_qkv_rope_t(
src,
self.ln_scales[i],
self.qkv_weights[i],
rotary_embs,
self._epsilon,
self.head_dim,
self.num_heads,
)
# Fused-OP-1 end

# Fused-OP-2 start
# write cache kv (inplace)
# [2, 8, 384, 32, 128]
# [2, 8, 6, 64, 32, 128]
# [2, 48, 64, 32, 128]
# --> [64][512, 64, 32, 128]
kv, B, BP_BS, M, H = key_value_states.shape
key_value_states_reshape = key_value_states.reshape([kv, -1, block_size, M, H])
key_states = key_value_states_reshape[0]
value_states = key_value_states_reshape[1]
k_cache = caches[2 * i]
v_cache = caches[2 * i + 1]
paddlenlp_ops.index_copy_(k_cache, block_indices, key_states, 0)
paddlenlp_ops.index_copy_(v_cache, block_indices, value_states, 0)

out_linear_out = paddlenlp_ops.fused_sdpa_proj_t(
query_states,
key_value_states,
attn_mask,
None,
self.linear_weights[i],
scaling_factor=self.head_dim**-0.5,
causal=True,
)
# Fused-OP-2 end

# all_reduce
if self.tp_degree > 1:
dist.all_reduce(out_linear_out)
out_linear_out = residual_input + out_linear_out
residual_input = out_linear_out

# Fused-OP-4 start
ffn2_out = paddlenlp_ops.fused_rms_mlp(
out_linear_out,
self.ffn_ln_scales[i],
self.ffn1_weights[i],
self.ffn2_weights[i],
self._epsilon,
)
# Fused-OP-4 end

# all_reduce
if self.tp_degree > 1:
dist.all_reduce(ffn2_out)
src = residual_input + ffn2_out
# end LlamaDecoderLayer
elif max_dec_len > 0:
for i in range(self.num_layers):
residual_input = src
query_states, key_value_states = paddlenlp_ops.fused_rms_qkv_rope_t(
src,
self.ln_scales[i],
self.qkv_weights[i],
rotary_embs,
self._epsilon,
self.head_dim,
self.num_heads,
)
# Fused-OP-1 end

# Fused-OP-2 start
# write cache kv (inplace)
# [2, B, 1, 32, 128] --> [64][max_block_num, blk_size, 32, 128]
key_states = key_value_states[0].squeeze(1)
value_states = key_value_states[1].squeeze(1)
k_cache = caches[2 * i]
v_cache = caches[2 * i + 1]
k_cache.index_put_((block_indices, block_offsets), key_states)
v_cache.index_put_((block_indices, block_offsets), value_states)

out_linear_out = paddlenlp_ops.fused_flatpa_proj(
query_states,
caches[2 * i],
caches[2 * i + 1],
block_groups,
block_list,
block_mapping,
block_bias,
self.linear_weights[i],
scaling_factor=self.head_dim**-0.5,
)
# Fused-OP-2 end

# all_reduce
if self.tp_degree > 1:
dist.all_reduce(out_linear_out)
out_linear_out = residual_input + out_linear_out
residual_input = out_linear_out

# Fused-OP-4 start
ffn2_out = paddlenlp_ops.fused_rms_mlp(
out_linear_out,
self.ffn_ln_scales[i],
self.ffn1_weights[i],
self.ffn2_weights[i],
self._epsilon,
)
# Fused-OP-4 end

# all_reduce
if self.tp_degree > 1:
dist.all_reduce(ffn2_out)
src = residual_input + ffn2_out
# end LlamaDecoderLayer

kwargs["time_step"] = time_step
kwargs["multi_block_output"] = src
kwargs["seq_lens"] = seq_lens
kwargs["input_ids"] = input_ids

out = self.post_process(**kwargs)
return out, caches
86 changes: 66 additions & 20 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
FusedBlockMultiTransformer,
FusedBlockMultiTransformerA8W8,
FusedBlockMultiTransformerFP8,
FusedBlockMultiTransformerHPU,
FusedBlockMultiTransformerWeightOnly,
FusedMultiTransformerA8W8,
FusedMultiTransformerAvx,
@@ -1406,6 +1407,8 @@ def set_transformer_block(self, transformer_config):
self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config)
elif "fp8" in self.quant_type:
self.transformer_block = FusedBlockMultiTransformerFP8(transformer_config)
elif paddle.is_compiled_with_custom_device("intel_hpu"):
self.transformer_block = FusedBlockMultiTransformerHPU(transformer_config)
else:
self.transformer_block = FusedBlockMultiTransformer(transformer_config)

@@ -1433,24 +1436,62 @@ def forward(
):
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
rope_emb = kwargs.get("rope_emb", None)
draft_tokens = kwargs.get("draft_tokens", None)
seq_lens_encoder = kwargs.get("seq_lens_encoder", None)

# whether speculative decoding or not
if draft_tokens is None:
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
input_ids, seq_lens_this_time
if paddle.is_compiled_with_custom_device("intel_hpu"):
from paddlenlp_ops import prepare_input_hpu

block_tables = kwargs.get("block_tables", None).to("CPU")
seq_lens_encoder = kwargs.get("seq_lens_encoder", None).to("CPU")
seq_lens_decoder = kwargs.get("seq_lens_decoder", None).to("CPU")
input_ids = input_ids.to("CPU")

(
ids_remove_padding,
rope_emb,
block_groups,
block_list,
block_indices,
block_offsets,
block_mapping,
attention_mask,
valid_seq_len,
) = prepare_input_hpu(
input_ids,
rope_emb,
block_tables,
self.block_size,
seq_lens_encoder,
seq_lens_decoder,
paddle.get_default_dtype(),
)
cum_offsets = None
kwargs["block_groups"] = block_groups
kwargs["block_list"] = block_list
kwargs["block_indices"] = block_indices
kwargs["block_offsets"] = block_offsets
kwargs["block_mapping"] = block_mapping
kwargs["block_bias"] = attention_mask
kwargs["block_size"] = self.block_size
kwargs["valid_seq_len"] = valid_seq_len
else:
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder
)
draft_tokens = kwargs.get("draft_tokens", None)
seq_lens_encoder = kwargs.get("seq_lens_encoder", None)

kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_k"] = cu_seqlens_k
kwargs["padding_offsets"] = padding_offset
kwargs["max_input_length"] = self.max_seq_len
kwargs["block_size"] = self.block_size
# whether speculative decoding or not
if draft_tokens is None:
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
input_ids, seq_lens_this_time
)
else:
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder
)

kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_k"] = cu_seqlens_k
kwargs["padding_offsets"] = padding_offset
kwargs["max_input_length"] = self.max_seq_len
kwargs["block_size"] = self.block_size

inputs_embeds = self.embed_tokens(ids_remove_padding)

@@ -1958,15 +1999,20 @@ def get_cache_kvs_shape(
else:
max_block_nums = max_batch_size * max_block_per_seq

cache_kv_shape = [
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.hidden_size // config.num_attention_heads,
]
if paddle.is_compiled_with_custom_device("intel_hpu"):
# HPU block multi-transformer
# Use KV Cache shape [max_block_nums, seq_len, num_head, head_dim]
cache_kv_shape = [cache_kv_shape[i] for i in [0, 2, 1, 3]]

cache_k_shapes = []
cache_v_shapes = []
for _ in range(config.num_hidden_layers):
cache_kv_shape = [
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.hidden_size // config.num_attention_heads,
]
cache_k_shapes.append(cache_kv_shape)
cache_v_shapes.append(cache_kv_shape)
return cache_k_shapes, cache_v_shapes
Loading
Oops, something went wrong.