Skip to content

Commit

Permalink
Fix shared array for all Vulkan versions. (#5721)
Browse files Browse the repository at this point in the history
  • Loading branch information
turbo0628 committed Aug 10, 2022
1 parent f591023 commit 6fac8fd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class TaskCodegen : public IRVisitor {

spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num);
spirv::Value ptr_val = ir_->alloca_workgroup_array(arr_type);
shared_array_binds_.push_back(ptr_val);
ir_->register_value(alloca->raw_name(), ptr_val);
} else {
// Alloca for a single variable
Expand Down Expand Up @@ -1567,6 +1568,7 @@ class TaskCodegen : public IRVisitor {
ir_->set_work_group_size(group_size);
std::vector<spirv::Value> buffers;
if (device_->get_cap(DeviceCapability::spirv_version) > 0x10300) {
buffers = shared_array_binds_;
for (const auto &bb : task_attribs_.buffer_binds) {
for (auto &it : buffer_value_map_) {
if (it.first.first == bb.buffer) {
Expand Down Expand Up @@ -2220,6 +2222,7 @@ class TaskCodegen : public IRVisitor {
BufferInfoTypeTupleHasher>
buffer_binding_map_;
std::vector<TextureBind> texture_binds_;
std::vector<spirv::Value> shared_array_binds_;
spirv::Value kernel_function_;
spirv::Label kernel_return_label_;
bool gen_label_{false};
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_shared_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tests import test_utils


@test_utils.test(arch=[ti.cuda, ti.vulkan], vk_api_version="1.0")
@test_utils.test(arch=[ti.cuda, ti.vulkan])
def test_shared_array_nested_loop():
block_dim = 128
nBlocks = 64
Expand Down

0 comments on commit 6fac8fd

Please sign in to comment.