From d3b9ceb0c52e4a5b2bf5535389a47f3cdaad82a1 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 8 Jul 2025 06:22:26 -0700 Subject: [PATCH] [ET-VK][ez] Fix to codegen caching mechanism ## Changes * Fixed a bug with caching the generated GLSL file too early in `gen_vulkan_spv.py` ## Context Currently, the `gen_vulkan_spv.py` script saves the generated GLSL file to the cache immediately after generation. Then, when compiling the GLSL to SPIR-V, it checks the current generated GLSL file against the one in the cache. However, because of the early caching, this check will always pass, even when the GLSL template was updated and a SPIR-V recompilation is needed. The fix is to only store the generated GLSL file after the SPIR-V compilation succeeds. Differential Revision: [D77933112](https://our.internmc.facebook.com/intern/diff/D77933112/) ghstack-source-id: 294852284 Pull Request resolved: https://github.com/pytorch/executorch/pull/12272 --- backends/vulkan/runtime/gen_vulkan_spv.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index dc8275bc099..cc17445005d 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -861,10 +861,6 @@ def generate_src_file(shader_paths_pair): # Construct generated file name gen_out_path = os.path.join(output_dir, f"{src_file_name}.{out_file_ext}") - # Construct path of cached generated file - cached_gen_out_path = os.path.join( - cache_dir, f"{src_file_name}.{out_file_ext}" - ) # Execute codegen to generate the output file with codecs.open(src_file_fullpath, "r", encoding="utf-8") as input_file: @@ -875,10 +871,6 @@ def generate_src_file(shader_paths_pair): with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file: output_file.write(output_text) - if cache_dir is not None: - # Store the generated file in the cache for SPIR-V compilation - shutil.copyfile(gen_out_path, cached_gen_out_path) - def compile_spirv(shader_paths_pair): # Extract components from the input tuple # name of generated .glsl, .glslh, or .h @@ -959,7 +951,10 @@ def compile_spirv(shader_paths_pair): else: raise RuntimeError(f"{err_msg_base} {e.stderr}") from e + # If compilation was successful, store the source GLSL file and the + # compiled SPIR-V file in the cache for future comparison. if cache_dir is not None: + shutil.copyfile(gen_out_path, cached_gen_out_path) shutil.copyfile(spv_out_path, cached_spv_out_path) return (spv_out_path, gen_out_path)