diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 6d11dc891bef..a8bb4c301058 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2021,13 +2021,17 @@ def generate_load_kernel_once( self, name: str, mangled_name: str, cubin_path: str, shared_mem: int ): if V.graph.aot_mode: + self.writeline(f"if (kernels.{name} == nullptr) {{") self.writeline( - f"""kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);""" + f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);""" ) + self.writeline("}") else: + self.writeline(f"if ({name} == nullptr) {{") self.writeline( - f"""{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});""" + f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});""" ) + self.writeline("}") def generate_args_decl(self, call_args): dynamic_symbols = V.graph.sizevars.free_symbols()