Skip to content

Commit

Permalink
[Bug] [vulkan] Fix data type alignment for arguments and return values (
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier authored and sjwsl committed Nov 21, 2021
1 parent 813ddae commit 787271b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 5 additions & 0 deletions taichi/backends/vulkan/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
// Put scalar args in the memory first
for (int i : scalar_indices) {
auto &attribs = (*vec)[i];
const size_t dt_bytes = vk_data_type_size(attribs.dt);
// Align bytes to the nearest multiple of dt_bytes
bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes;
attribs.offset_in_mem = bytes;
bytes += attribs.stride;
TI_TRACE(" at={} scalar offset_in_mem={} stride={}", i,
Expand All @@ -103,6 +106,8 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
// Then the array args
for (int i : array_indices) {
auto &attribs = (*vec)[i];
const size_t dt_bytes = vk_data_type_size(attribs.dt);
bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes;
attribs.offset_in_mem = bytes;
bytes += attribs.stride;
TI_TRACE(" at={} array offset_in_mem={} stride={}", i,
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_arg_alignment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import taichi as ti


@ti.test(exclude=[ti.opengl, ti.vulkan])
@ti.test(exclude=[ti.opengl])
def test_ret_write():
@ti.kernel
def func(a: ti.i16) -> ti.f32:
Expand All @@ -10,7 +10,7 @@ def func(a: ti.i16) -> ti.f32:
assert func(255) == 3.0


@ti.test(exclude=[ti.opengl, ti.vulkan])
@ti.test(exclude=[ti.opengl])
def test_arg_read():
x = ti.field(ti.i32, shape=())

Expand Down

0 comments on commit 787271b

Please sign in to comment.