Skip to content

Commit 67b21ae

Browse files
authored
Support fusion moe (#10507)
* support fusion moe * fix * fix * fix
1 parent 5406b5e commit 67b21ae

File tree

5 files changed

+133
-330
lines changed

5 files changed

+133
-330
lines changed

paddlenlp/transformers/deepseek_v2/fp8_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,22 +187,22 @@ def kitchen_quant(x, backend=None, is_1d_scaled=True, return_transpose=False):
187187
return (qresult_ref.data, qresult_ref.scale)
188188

189189

190-
def kitchen_fp8_gemm(x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled):
190+
def kitchen_fp8_gemm(x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, rtn_dtype=paddle.float32):
191191
if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0:
192192
y = kitchen.ops.fp8_gemm_blockwise(
193193
a=x_fp8,
194194
a_decode_scale=x_scale,
195195
b=w_fp8,
196196
b_decode_scale=w_scale,
197-
out_dtype=paddle.bfloat16,
197+
out_dtype=rtn_dtype,
198198
out=None,
199199
accumulate=False,
200200
use_split_accumulator=True,
201201
is_a_1d_scaled=is_a_1d_scaled,
202202
is_b_1d_scaled=is_b_1d_scaled,
203203
)
204204
else:
205-
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.bfloat16)
205+
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype)
206206
return y
207207

208208

0 commit comments

Comments
 (0)