Skip to content

Commit 07519c8

Browse files
committed
add hot model start
1 parent a3b7a01 commit 07519c8

File tree

4 files changed

+242
-113
lines changed

4 files changed

+242
-113
lines changed

csrc/gpu/cpp_extensions.cu

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,11 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
236236
const paddle::optional<paddle::Tensor>& draft_tokens,
237237
const paddle::optional<paddle::Tensor>& seq_lens_encoder);
238238

239-
void SaveOutMmsg(const paddle::Tensor& x,
239+
void SaveOutMmsgStatic(const paddle::Tensor& x,
240240
const paddle::Tensor& not_need_stop, // cpu
241-
const paddle::Tensor& msg_queue_id, // cpu
242241
int64_t rank_id);
243242

244-
void GetOutput(const paddle::Tensor& x,
245-
const paddle::Tensor& msg_queue_id, // cpu
243+
void GetOutputStatic(const paddle::Tensor& x,
246244
int64_t rank_id,
247245
bool wait_flag);
248246

@@ -301,8 +299,8 @@ PYBIND11_MODULE(paddlenlp_ops, m) {
301299
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
302300
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
303301
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
304-
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
305-
m.def("f_get_output", &GetOutput, "GetOutput");
302+
m.def("f_save_output", &SaveOutMmsgStatic, "SaveOutMmsgStatic");
303+
m.def("f_get_output", &GetOutputStatic, "GetOutputStatic");
306304
m.def("f_step_paddle", &StepPaddle, "StepPaddle");
307305
m.def("f_save_output_dygraph", &SaveOutputDygraph, "SaveOutputDygraph");
308306
// m.def("f_cutlass_fp8_fp8_half_block_gemm_fused", &cutlass_fp8_fp8_half_block_gemm_fused_func, "cutlass_fp8_fp8_half_block_gemm_fused_func");
@@ -331,8 +329,8 @@ PYBIND11_MODULE(paddlenlp_ops_80, m) {
331329
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
332330
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
333331
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
334-
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
335-
m.def("f_get_output", &GetOutput, "GetOutput");
332+
m.def("f_save_output", &SaveOutMmsgStatic, "SaveOutMmsgStatic");
333+
m.def("f_get_output", &GetOutputStatic, "GetOutputStatic");
336334
m.def("f_step_paddle", &StepPaddle, "StepPaddle");
337335
m.def("f_save_output_dygraph", &SaveOutputDygraph, "SaveOutputDygraph");
338336
}
@@ -360,8 +358,8 @@ PYBIND11_MODULE(paddlenlp_ops_90, m) {
360358
m.def("f_per_token_group_quant", &PerTokenGroupQuant, "PerTokenGroupQuant");
361359
m.def("f_per_tensor_quant_fp8", &PerTensorQuantFp8, "PerTensorQuantFp8");
362360
m.def("f_get_padding_offset_v2", &GetPaddingOffsetV2, "GetPaddingOffsetV2");
363-
m.def("f_save_output", &SaveOutMmsg, "SaveOutMmsg");
364-
m.def("f_get_output", &GetOutput, "GetOutput");
361+
m.def("f_save_output", &SaveOutMmsgStatic, "SaveOutMmsgStatic");
362+
m.def("f_get_output", &GetOutputStatic, "GetOutputStatic");
365363
m.def("f_step_paddle", &StepPaddle, "StepPaddle");
366364
m.def("f_save_output_dygraph", &SaveOutputDygraph, "SaveOutputDygraph");
367365
}

csrc/setup_cuda.py

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,14 @@ def get_gencode_flags():
130130
"./gpu/speculate_decoding_kernels/speculate_save_output.cc",
131131
"./gpu/speculate_decoding_kernels/speculate_get_output.cc",
132132
"./gpu/save_output_dygraph.cu",
133-
"./gpu/cpp_extensions.cu",
133+
# "./gpu/cpp_extensions.cu",
134134
"./gpu/all_reduce.cu",
135135
"./gpu/quantization/per_token_group_quant.cu",
136136
"./gpu/quantization/per_tensor_quant_fp8.cu",
137137
]
138138
sources += find_end_files("./gpu/speculate_decoding_kernels", ".cu")
139-
sources += find_end_files("./gpu/moe/fused_moe/cutlass_kernels/moe_gemm/", ".cu")
140-
sources += find_end_files("./gpu/moe/fused_moe/", ".cu")
139+
# sources += find_end_files("./gpu/moe/fused_moe/cutlass_kernels/moe_gemm/", ".cu")
140+
# sources += find_end_files("./gpu/moe/fused_moe/", ".cu")
141141

142142
nvcc_compile_args = gencode_flags
143143
update_git_submodule()
@@ -167,66 +167,66 @@ def get_gencode_flags():
167167
cuda_version = float(paddle.version.cuda())
168168
nvcc_version = get_nvcc_cuda_version(os.environ.get("CUDA_HOME", "/usr/local/cuda"))
169169

170-
if cc >= 80:
171-
sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"]
172-
173-
sources += ["./gpu/append_attention.cu", "./gpu/multi_head_latent_attention.cu"]
174-
175-
sources += find_end_files("./gpu/append_attn", ".cu")
176-
sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu")
177-
178-
179-
fp8_auto_gen_directory = "gpu/cutlass_kernels/fp8_gemm_fused/autogen"
180-
if os.path.isdir(fp8_auto_gen_directory):
181-
shutil.rmtree(fp8_auto_gen_directory)
182-
183-
184-
if cc == 89 and cuda_version >= 12.4:
185-
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py --cuda_arch 89")
186-
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py --cuda_arch 89")
187-
sources += find_end_files(fp8_auto_gen_directory, ".cu")
188-
sources += [
189-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
190-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
191-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
192-
]
193-
194-
if cc >= 80 and nvcc_version >= Version("12.4"):
195-
os.environ.pop('PADDLE_CUDA_ARCH_LIST', None)
196-
nvcc_compile_args += [
197-
"-std=c++17",
198-
"--use_fast_math",
199-
"--threads=8",
200-
"-D_GLIBCXX_USE_CXX11_ABI=1",
201-
]
202-
sources += ["./gpu/sage_attn_kernels/sageattn_fused.cu"]
203-
if cc >= 80 and cc < 89:
204-
sources += ["./gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu"]
205-
nvcc_compile_args += ["-gencode", "arch=compute_80,code=compute_80"]
206-
elif cc >= 89 and cc < 90:
207-
sources += ["./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu"]
208-
nvcc_compile_args += ["-gencode", "arch=compute_89,code=compute_89"]
209-
elif cc >= 90:
210-
sources += [
211-
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu",
212-
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu",
213-
]
214-
nvcc_compile_args += ["-gencode", "arch=compute_90a,code=compute_90a"]
215-
216-
if cc >= 90 and cuda_version >= 12.0:
217-
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py --cuda_arch 90")
218-
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_ptr_scale_sm90.py --cuda_arch 90")
219-
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py --cuda_arch 90")
220-
os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py --cuda_arch 90")
221-
sources += find_end_files(fp8_auto_gen_directory, ".cu")
222-
sources += [
223-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
224-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
225-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
226-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
227-
"gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm_ptr_scale.cu",
228-
]
229-
sources += find_end_files("./gpu/mla_attn", ".cu")
170+
# if cc >= 80:
171+
# sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"]
172+
173+
# sources += ["./gpu/append_attention.cu", "./gpu/multi_head_latent_attention.cu"]
174+
175+
# sources += find_end_files("./gpu/append_attn", ".cu")
176+
# sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu")
177+
178+
179+
# fp8_auto_gen_directory = "gpu/cutlass_kernels/fp8_gemm_fused/autogen"
180+
# if os.path.isdir(fp8_auto_gen_directory):
181+
# shutil.rmtree(fp8_auto_gen_directory)
182+
183+
184+
# if cc == 89 and cuda_version >= 12.4:
185+
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py --cuda_arch 89")
186+
# os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py --cuda_arch 89")
187+
# sources += find_end_files(fp8_auto_gen_directory, ".cu")
188+
# sources += [
189+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
190+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
191+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
192+
# ]
193+
194+
# if cc >= 80 and nvcc_version >= Version("12.4"):
195+
# os.environ.pop('PADDLE_CUDA_ARCH_LIST', None)
196+
# nvcc_compile_args += [
197+
# "-std=c++17",
198+
# "--use_fast_math",
199+
# "--threads=8",
200+
# "-D_GLIBCXX_USE_CXX11_ABI=1",
201+
# ]
202+
# sources += ["./gpu/sage_attn_kernels/sageattn_fused.cu"]
203+
# if cc >= 80 and cc < 89:
204+
# sources += ["./gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu"]
205+
# nvcc_compile_args += ["-gencode", "arch=compute_80,code=compute_80"]
206+
# elif cc >= 89 and cc < 90:
207+
# sources += ["./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu"]
208+
# nvcc_compile_args += ["-gencode", "arch=compute_89,code=compute_89"]
209+
# elif cc >= 90:
210+
# sources += [
211+
# "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu",
212+
# "./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu",
213+
# ]
214+
# nvcc_compile_args += ["-gencode", "arch=compute_90a,code=compute_90a"]
215+
216+
# if cc >= 90 and cuda_version >= 12.0:
217+
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py --cuda_arch 90")
218+
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_ptr_scale_sm90.py --cuda_arch 90")
219+
# os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py --cuda_arch 90")
220+
# os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py --cuda_arch 90")
221+
# sources += find_end_files(fp8_auto_gen_directory, ".cu")
222+
# sources += [
223+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
224+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
225+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
226+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
227+
# "gpu/fp8_gemm_with_cutlass/fp8_fp8_half_gemm_ptr_scale.cu",
228+
# ]
229+
# sources += find_end_files("./gpu/mla_attn", ".cu")
230230

231231
ops_name = f"paddlenlp_ops_{sm_version}" if sm_version != 0 else "paddlenlp_ops"
232232

0 commit comments

Comments
 (0)