Skip to content

Commit 33f6064

Browse files
committed
Merge branch 'develop' of https://github.com/gzy19990617/PaddleNLP into develop
2 parents 131bb5b + d3ee14f commit 33f6064

File tree

8 files changed

+429
-273
lines changed

8 files changed

+429
-273
lines changed

csrc/gpu/set_preids_token_penalty_multi_scores.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ __global__ void set_preids_token_penalty_multi_scores_kernel(const bool *stop_fl
4141
T *logits_now = logits + bi * length;
4242
int tid = threadIdx.x;
4343

44-
if (tid < bs && !stop_flags[tid]) {
45-
int64_t *pre_ids_now = pre_ids + tid * length_id;
46-
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
47-
const int seq_len_dec = seq_lens_decoder[tid];
48-
const int seq_len_enc = seq_lens_encoder[tid];
44+
if (bi < bs && !stop_flags[bi]) {
45+
int64_t *pre_ids_now = pre_ids + bi * length_id;
46+
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
47+
const int seq_len_dec = seq_lens_decoder[bi];
48+
const int seq_len_enc = seq_lens_encoder[bi];
4949
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
5050

5151
const int step_idx_now = step_idx[bi];

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def main():
539539
config.tensor_parallel_degree = training_args.tensor_parallel_degree
540540
config.tensor_parallel_rank = training_args.tensor_parallel_rank
541541
config.sharding_parallel_degree = training_args.sharding_parallel_degree
542+
config.to_static = training_args.to_static
542543

543544
if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1:
544545
pipeline = training_args.strategy.pipeline
@@ -556,6 +557,11 @@ def main():
556557

557558
print("Final pre-training config:", config)
558559

560+
if "replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config and config.tensor_parallel_degree > 1 and config.to_static is False:
561+
from llm.utils.replace_ops import replace_cross_entropy
562+
563+
replace_cross_entropy()
564+
559565
# # Set the dtype for loading model
560566
# dtype = "float32"
561567
# if training_args.fp16_opt_level == "O2":

llm/predict/predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,8 @@ def insert_task(self, pos, task_id, repeat_num):
13111311
self.model_inputs["stop_flags"][pos] = False
13121312
self.model_inputs["result_id"][pos][0] = task_id
13131313
self.model_inputs["step_idx"][pos, 0] = 1
1314+
self.model_inputs["pre_ids"][pos][0] = self.input_ids[query_id][-1]
1315+
self.model_inputs["pre_ids"][pos][1:] = -1
13141316
self.model_inputs["not_need_stop"][0] = True
13151317

13161318
num_prefill_blocks = length // self.block_size

llm/utils/replace_ops.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle import nn
17+
import functools
18+
import math
19+
import operator
20+
from typing import Literal, TypeAlias
21+
import paddle.distributed as dist
22+
23+
from paddle import Tensor
24+
from paddle import _C_ops, base, in_dynamic_mode
25+
from paddle.distributed.fleet.base import topology as tp
26+
from paddle.distributed import collective
27+
from paddle.tensor.manipulation import reshape
28+
from paddle.nn.layer.layers import Layer
29+
_ReduceMode: TypeAlias = Literal['mean', 'sum', 'none']
30+
31+
32+
# TODO: this function is rewrited from paddle.nn.functional.cross_entropy,
33+
# but better to merge into only one.
34+
def parallel_cross_entropy(
35+
input: Tensor,
36+
label: Tensor,
37+
weight: Tensor | None = None,
38+
ignore_index: int = -100,
39+
reduction: _ReduceMode = 'mean',
40+
soft_label: bool = False,
41+
axis: int = -1,
42+
use_softmax: bool = True,
43+
label_smoothing: float = 0.0,
44+
name: str | None = None,
45+
) -> Tensor:
46+
47+
if reduction not in ['sum', 'mean', 'none']:
48+
raise ValueError(
49+
"The value of 'reduction' in softmax_cross_entropy"
50+
f"should be 'sum', 'mean' or 'none', but received {reduction}, which is not allowed."
51+
)
52+
if ignore_index > 0 and soft_label:
53+
raise ValueError(
54+
"When soft_label == True, the value of 'ignore_index' in softmax_cross_entropy"
55+
f"should be '-100', but received {ignore_index}, which is not allowed."
56+
)
57+
58+
input_dims = len(list(input.shape))
59+
if input_dims == 0:
60+
raise ValueError('The dimension of input should be larger than zero!')
61+
62+
label_dims = len(list(label.shape))
63+
if input_dims - 1 == label_dims:
64+
label = paddle.unsqueeze(label, axis=axis)
65+
66+
if input_dims - 1 != label_dims and input_dims != label_dims:
67+
raise ValueError(
68+
f'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
69+
(got nput_dims{input_dims}, label_dims{label_dims})'
70+
)
71+
72+
if label_smoothing > 0.0:
73+
soft_label = True
74+
# converting the label to one-hot encoding
75+
# for 1d case, converting label's shape from [N] to [N, C]
76+
# for 2d case, converting label's shape from [N, d_1, ..., d_k] to [N, d_1, ..., d_k, C]
77+
if input_dims - 1 == label_dims:
78+
label = paddle.squeeze(label, axis=axis)
79+
label = paddle.nn.functional.one_hot(label, input.shape[-1])
80+
81+
label = paddle.nn.functional.label_smooth(
82+
label, epsilon=label_smoothing
83+
)
84+
label = label.astype(input.dtype)
85+
label_dims = len(list(label.shape))
86+
87+
if not soft_label:
88+
valid_label = (
89+
paddle.cast(label != ignore_index, dtype=label.dtype) * label
90+
)
91+
92+
if soft_label == False and is_tensor_sharded(input):
93+
group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
94+
ring_id = group.id
95+
nranks = group.nranks
96+
global_rank = collective._get_global_env().rank
97+
rank = group.get_group_rank(global_rank)
98+
_, out = _C_ops.c_softmax_with_cross_entropy(
99+
input, label, ignore_index, ring_id, rank, nranks
100+
)
101+
else:
102+
from paddlenlp.utils.log import logger
103+
104+
logger.warning(
105+
"Failed to replace CrossEntropyLoss with ParallelCrossEntropyLoss. Please ensure: \n"
106+
"1. soft_label=False is set for parallel computation (current value: {}) \n"
107+
"2. Input tensor is properly sharded (current sharding status: {}) \n".format(
108+
soft_label,
109+
input_placement,
110+
)
111+
)
112+
113+
_, out = _C_ops.cross_entropy_with_softmax(
114+
input, label, soft_label, use_softmax, True, ignore_index, axis
115+
)
116+
117+
if weight is not None:
118+
# trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
119+
if soft_label:
120+
# chajchaj:
121+
# weight's shape is C, where C is class num.
122+
# for 1d case: label's shape is [N,C], weight_gather's shape is N.
123+
# for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
124+
weight_gather = paddle.matmul(
125+
x=paddle.cast(label, weight.dtype),
126+
y=weight,
127+
transpose_x=False,
128+
transpose_y=True,
129+
)
130+
out_shape = list(out.shape)
131+
weight_gather_reshape = reshape(weight_gather, shape=out_shape)
132+
out = paddle.cast(out, weight_gather_reshape.dtype)
133+
134+
out = _C_ops.multiply(out, weight_gather_reshape)
135+
else:
136+
if input.shape[axis] != weight.shape[-1]:
137+
raise ValueError(
138+
f"input's class_dimension({input.shape[axis]}) must equal to "
139+
f"weight's class_dimension({weight.shape[-1]}) "
140+
"when weight is provided"
141+
)
142+
143+
ignore_weight_mask = paddle.cast(
144+
(label != ignore_index), out.dtype
145+
)
146+
if (
147+
ignore_weight_mask.ndim > 1
148+
and ignore_weight_mask.shape[axis] == 1
149+
):
150+
# TODO: Temporarily use squeeze instead of squeeze_
151+
ignore_weight_mask = paddle.squeeze(
152+
ignore_weight_mask, axis
153+
)
154+
if axis != -1 and axis != valid_label.ndim - 1:
155+
temp_perm = (
156+
list(range(axis % valid_label.ndim))
157+
+ list(
158+
range(
159+
(axis % valid_label.ndim + 1), valid_label.ndim
160+
)
161+
)
162+
+ [axis % valid_label.ndim]
163+
)
164+
weight_gather = _C_ops.gather_nd(
165+
weight, valid_label.transpose(temp_perm)
166+
)
167+
else:
168+
weight_gather = _C_ops.gather_nd(weight, valid_label)
169+
weight_gather = _C_ops.multiply(
170+
weight_gather, ignore_weight_mask
171+
)
172+
input_shape = list(label.shape)
173+
weight_gather_reshape = reshape(
174+
weight_gather, shape=input_shape
175+
)
176+
out = paddle.cast(out, weight_gather_reshape.dtype)
177+
out = _C_ops.multiply(out, weight_gather_reshape)
178+
179+
if reduction == "sum":
180+
# because of base_softmax_with_cross_entropy op's inner logic,
181+
# in the out tensor of this op, the loss of sample with class_index==ignore_index is 0
182+
# so, reduce_sum all directly is ok
183+
return _C_ops.sum(out, [], None, False)
184+
elif reduction == "mean":
185+
# 1. if weight==none,
186+
# numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic
187+
# denominator: count sample num with class_index!=ignore_index
188+
# 2. else
189+
# numerator: loss's weighted sum
190+
# denominator: cal the sum of weight where the sample's class_index!=ignore_index
191+
if ignore_index >= 0: # ignore label
192+
out_sum = _C_ops.sum(out, [], None, False)
193+
# for each label[i],set 1 or 0, according to ignore_index
194+
# mask[i]=0, if label[i]==ignore_index
195+
# mask[i]=1, otherwise
196+
mask = label != ignore_index
197+
if weight is None:
198+
mask = paddle.cast(mask, dtype=out_sum.dtype)
199+
count = _C_ops.sum(mask, [], None, False)
200+
ret = out_sum / (count + (count == 0.0).astype(count.dtype))
201+
else:
202+
mask = paddle.cast(mask, weight_gather_reshape.dtype)
203+
weight_ignored = _C_ops.multiply(
204+
mask, weight_gather_reshape
205+
)
206+
weight_sum = _C_ops.sum(weight_ignored, [], None, False)
207+
ret = out_sum / (
208+
weight_sum
209+
+ (weight_sum == 0.0).astype(weight_sum.dtype)
210+
)
211+
return ret
212+
elif weight is not None:
213+
out_sum = _C_ops.sum(out, [], None, False)
214+
total_weight = _C_ops.sum(
215+
weight_gather_reshape, [], None, False
216+
)
217+
return out_sum / (
218+
total_weight
219+
+ (total_weight == 0.0).astype(total_weight.dtype)
220+
)
221+
else:
222+
return _C_ops.mean_all(out)
223+
224+
else:
225+
if input_dims - 1 == label_dims:
226+
out = paddle.squeeze(out, axis=axis)
227+
return out
228+
229+
230+
# TODO: placement[1] may not be mp axis.
231+
def is_tensor_sharded(tensor):
232+
if not tensor.is_dist():
233+
return False
234+
235+
placement = tensor.placements
236+
return placement[1].is_shard()
237+
238+
239+
def replace_cross_entropy():
240+
paddle.nn.functional.cross_entropy = parallel_cross_entropy

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,13 +1131,13 @@ def compute_qkv_linear(self, ln_out, i, latent_cache=None, **kwargs):
11311131
qkv_out = paddle.add(qkv_out, self.qkv_biases[i])
11321132
return qkv_out
11331133

1134-
def compute_qkv(self, src, residual_input, i):
1134+
def compute_qkv(self, src, residual_input, i, **kwargs):
11351135
ln_out = self.compute_layernorm_before_qkv(src, i)
11361136

11371137
if self.config.mla_config.use_absorb():
11381138
qkv_out = ln_out
11391139
else:
1140-
qkv_out = self.compute_qkv_linear(ln_out, i)
1140+
qkv_out = self.compute_qkv_linear(ln_out, i, **kwargs)
11411141

11421142
return qkv_out, residual_input
11431143

@@ -1523,7 +1523,7 @@ def forward(
15231523

15241524
residual_input = src
15251525
for i in range(self.num_layers):
1526-
qkv_out, residual_input = self.compute_qkv(src, residual_input, i)
1526+
qkv_out, residual_input = self.compute_qkv(src, residual_input, i, **kwargs)
15271527
fmha_out = self.compute_attn(
15281528
time_step,
15291529
qkv_out,
@@ -1596,7 +1596,7 @@ class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase):
15961596
def __init__(self, config: FusedMultiTransformerConfig):
15971597
super().__init__(config)
15981598

1599-
def compute_qkv(self, src, residual_input, i):
1599+
def compute_qkv(self, src, residual_input, i, **kwargs):
16001600
qkv_out = self.compute_qkv_linear(src, i)
16011601
return qkv_out, src
16021602

@@ -2055,9 +2055,7 @@ def compute_qkv_linear(self, ln_out, i, latent_cache=None, **kwargs):
20552055
epsilon=self._epsilon,
20562056
begin_norm_axis=1,
20572057
)[0]
2058-
query_pe, key_pe = self.config.rotary_emb(
2059-
self.position_ids[0 : kwargs.get("seq_lens_encoder", None).sum()], query_pe, key_pe
2060-
)
2058+
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)
20612059

20622060
if self.config.mla_config.use_absorb():
20632061
from paddlenlp_ops import prefill_mla_write_cache
@@ -2689,7 +2687,7 @@ def compute_layernorm_before_qkv(self, src, i):
26892687

26902688
return ln_out
26912689

2692-
def compute_qkv_linear(self, ln_out, i):
2690+
def compute_qkv_linear(self, ln_out, i, **kwargs):
26932691
if self.config.mla_config.use_mla():
26942692
raise NotImplementedError("Not support MLA yet.")
26952693
else:
@@ -5140,7 +5138,7 @@ def compute_layernorm_before_qkv(self, src, i):
51405138

51415139
return ln_out
51425140

5143-
def compute_qkv_linear(self, ln_out, i):
5141+
def compute_qkv_linear(self, ln_out, i, **kwargs):
51445142
if self.config.mla_config.use_mla():
51455143
raise NotImplementedError("Not support MLA yet.")
51465144
else:

0 commit comments

Comments
 (0)