Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine kitchen use #10207

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions paddlenlp/transformers/deepseek_v2/fp8_linear.py
Original file line number Diff line number Diff line change
@@ -227,7 +227,7 @@ def forward(ctx, x, weight):
if x_t.shape[-1] % 8 != 0:
x_t = paddle.concat([x_t, paddle.zeros([x_t.shape[0], 8 - (x_t.shape[-1] % 8)], dtype=x_t.dtype)], axis=-1)
x_t_quant, x_t_scale = kitchen_quant(
x_t.contiguous(), backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
x_t.contiguous(), backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
ctx.save_for_backward(
x_t_quant, x_t_scale, weight, paddle.to_tensor(x_t_shape, dtype="int64", place=paddle.CPUPlace())
@@ -267,6 +267,7 @@ def backward(ctx, dout):
dweight = kitchen_fp8_gemm(x_t_quant, x_t_scale, dout_t_quant, dout_t_scale, True, True)
return dx, dweight


class LinearFP8KeepXFunc(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight):
@@ -287,30 +288,25 @@ def forward(ctx, x, weight):
deep_gemm.gemm_fp8_fp8_bf16_nt((x_quant, x_scale), (w_quant, w_scale), out)
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])


ctx.save_for_backward(
x, weight
)
ctx.save_for_backward(x, weight)
return out

@staticmethod
def backward(ctx, dout):
x, weight= ctx.saved_tensor()
x, weight = ctx.saved_tensor()

# padding
x_t = x.T.contiguous()
if x_t.shape[-1] % 8 != 0:
x_t = paddle.concat([x_t, paddle.zeros([x_t.shape[0], 8 - (x_t.shape[-1] % 8)], dtype=x_t.dtype)], axis=-1)
x_t_quant, x_t_scale = kitchen_quant(
x_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
x_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)


x_t_shape = x_t_shape.numpy()
# compute dx = mm(dout, w)
dx = paddle.empty(x.shape, dout.dtype)
dx_orig_shape = x.shape

dout_quant, dout_scale = kitchen_quant(
dout.reshape([-1, dout.shape[-1]]),
backend=kitchen.ops.Backend.CUTLASS,
@@ -337,8 +333,6 @@ def backward(ctx, dout):
return dx, dweight




class FP8Linear(paddle.nn.Layer):
def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None:
super().__init__()
@@ -353,6 +347,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
def forward(self, x):
return LinearFP8Func.apply(x, self.weight)


class FP8KeepXLinear(paddle.nn.Layer):
def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None:
super().__init__()
@@ -365,8 +360,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
)

def forward(self, x):
return LinearFP8KeepXFunc.apply(x, self.weight)

return LinearFP8KeepXFunc.apply(x, self.weight)


class Fuse_FFN_FP8_Func(paddle.autograd.PyLayer):
@@ -418,7 +412,7 @@ def forward(ctx, x, w1, w2):
axis=1,
)
x_t_fp8, x_t_scale = kitchen_quant(
x_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
x_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)

ctx.save_for_backward(
@@ -448,7 +442,7 @@ def backward(ctx, do3):
o2 = swiglu(o1)
o2_t = o2.T.contiguous()
o2_t_fp8, o2_t_scale = kitchen_quant(
o2_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
o2_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)

# ===== do2 = deep_gemm(do3_fp8, w2_fp8)
@@ -472,7 +466,7 @@ def backward(ctx, do3):
axis=-1,
)
o2_t_fp8, o2_t_scale = kitchen_quant(
o2_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
o2_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
do3_t = do3.T.contiguous()
if do3_t.shape[-1] % 128 != 0 or do3_t.shape[-1] % 512 != 0:
@@ -489,7 +483,7 @@ def backward(ctx, do3):
)

do3_t_fp8, do3_t_scale = kitchen_quant(
do3_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
do3_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
dw2 = kitchen_fp8_gemm(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True)

6 changes: 3 additions & 3 deletions paddlenlp/transformers/fp8_utils.py
Original file line number Diff line number Diff line change
@@ -134,7 +134,7 @@ def forward(self, hs_out, hs_scale_out, tokens_per_expert):
axis=1,
)
x_t_fp8, x_t_scale = kitchen_quant(
x_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
x_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
self.x_t_fp8s += [x_t_fp8]
self.x_t_scales += [x_t_scale]
@@ -232,7 +232,7 @@ def bwd_down_weight(self, do3_fp8, do3_scale, o1, dw2=None):
axis=-1,
)
o2_t_fp8, o2_t_scale = kitchen_quant(
o2_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
o2_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)

do3_t = dequantize_fp8_to_fp32(do3_fp8, do3_scale).T.contiguous()
@@ -249,7 +249,7 @@ def bwd_down_weight(self, do3_fp8, do3_scale, o1, dw2=None):
axis=-1,
)
do3_t_fp8, do3_t_scale = kitchen_quant(
do3_t, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
do3_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
)
dw2 = kitchen_fp8_gemm(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, dw2)
return dw2
Loading
Oops, something went wrong.