diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 3a15f8d560e6c..4e28cb981c901 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -208,6 +208,7 @@ class TaskCodegen : public IRVisitor { spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num); spirv::Value ptr_val = ir_->alloca_workgroup_array(arr_type); + shared_array_binds_.push_back(ptr_val); ir_->register_value(alloca->raw_name(), ptr_val); } else { // Alloca for a single variable @@ -1567,6 +1568,7 @@ class TaskCodegen : public IRVisitor { ir_->set_work_group_size(group_size); std::vector buffers; if (device_->get_cap(DeviceCapability::spirv_version) > 0x10300) { + buffers = shared_array_binds_; for (const auto &bb : task_attribs_.buffer_binds) { for (auto &it : buffer_value_map_) { if (it.first.first == bb.buffer) { @@ -2220,6 +2222,7 @@ class TaskCodegen : public IRVisitor { BufferInfoTypeTupleHasher> buffer_binding_map_; std::vector texture_binds_; + std::vector shared_array_binds_; spirv::Value kernel_function_; spirv::Label kernel_return_label_; bool gen_label_{false}; diff --git a/tests/python/test_shared_array.py b/tests/python/test_shared_array.py index 53b129d6e7b45..64a9a60f7383b 100644 --- a/tests/python/test_shared_array.py +++ b/tests/python/test_shared_array.py @@ -4,7 +4,7 @@ from tests import test_utils -@test_utils.test(arch=[ti.cuda, ti.vulkan], vk_api_version="1.0") +@test_utils.test(arch=[ti.cuda, ti.vulkan]) def test_shared_array_nested_loop(): block_dim = 128 nBlocks = 64