Skip to content

Commit

Permalink
[refactor] Unify ways to set external array args
Browse files Browse the repository at this point in the history
Also I introduced a typo in #5559 (`DevAllocType::kNdarray ->
DevAllocType::kNone`) and it's also fixed in this PR.
  • Loading branch information
Ailing Zhang committed Jul 29, 2022
1 parent db71d0c commit abec6f6
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 45 deletions.
12 changes: 6 additions & 6 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ void autograd() {
auto ctx_backward = kernel_backward->make_launch_context();
auto ctx_ext = kernel_ext->make_launch_context();
std::vector<float> ext_a(n), ext_b(n), ext_c(n);
ctx_ext.set_arg_external_array(0, taichi::uint64(ext_a.data()), n,
/*is_device_allocation=*/false);
ctx_ext.set_arg_external_array(1, taichi::uint64(ext_b.data()), n,
/*is_device_allocation=*/false);
ctx_ext.set_arg_external_array(2, taichi::uint64(ext_c.data()), n,
/*is_device_allocation=*/false);
ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_a.data()), n,
{n});
ctx_ext.set_arg_external_array_with_shape(1, taichi::uint64(ext_b.data()), n,
{n});
ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n,
{n});

(*kernel_init)(ctx_init);
(*kernel_forward)(ctx_forward);
Expand Down
3 changes: 2 additions & 1 deletion cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ void run_snode() {
auto ctx_ret = kernel_ret->make_launch_context();
auto ctx_ext = kernel_ext->make_launch_context();
std::vector<int> ext_arr(n);
ctx_ext.set_arg_external_array(0, taichi::uint64(ext_arr.data()), n, false);
ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()),
n, {n});

(*kernel_init)(ctx_init);
(*kernel_ret)(ctx_ret);
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ FunctionType CUDAModuleToFunctionConverter::convert(
device_buffers[i] = arg_buffers[i];
}
// device_buffers[i] saves a raw ptr on CUDA device.
context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz);
context.set_arg(i, (uint64)device_buffers[i]);

} else if (arr_sz > 0) {
// arg_buffers[i] is a DeviceAllocation*
Expand All @@ -857,7 +857,7 @@ FunctionType CUDAModuleToFunctionConverter::convert(
arg_buffers[i] = device_buffers[i];

// device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i]
context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz);
context.set_arg(i, (uint64)device_buffers[i]);
}
}
}
Expand Down
10 changes: 8 additions & 2 deletions taichi/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,17 @@ struct RuntimeContext {
set_array_device_allocation_type(arg_id, DevAllocType::kRWTexture);
}

void set_arg_external_array(int arg_id, uintptr_t ptr, uint64 size) {
void set_arg_external_array(int arg_id,
uintptr_t ptr,
uint64 size,
const std::vector<int64> &shape) {
set_arg(arg_id, ptr);
set_array_runtime_size(arg_id, size);
set_array_device_allocation_type(arg_id,
RuntimeContext::DevAllocType::kNdarray);
RuntimeContext::DevAllocType::kNone);
for (uint64 i = 0; i < shape.size(); ++i) {
extra_args[arg_id][i] = shape[i];
}
}

void set_arg_ndarray(int arg_id,
Expand Down
22 changes: 3 additions & 19 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {
ctx_->extra_args[i][j] = d;
}

void Kernel::LaunchContextBuilder::set_arg_external_array(
void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape(
int arg_id,
uintptr_t ptr,
uint64 size,
bool is_device_allocation) {
const std::vector<int64> &shape) {
TI_ASSERT_INFO(
kernel_->args[arg_id].is_array,
"Assigning external (numpy) array to scalar argument is not allowed.");
Expand All @@ -235,25 +235,9 @@ void Kernel::LaunchContextBuilder::set_arg_external_array(
ActionArg("address", fmt::format("0x{:x}", ptr)),
ActionArg("array_size_in_bytes", (int64)size)});

ctx_->set_arg(arg_id, ptr);
ctx_->set_array_runtime_size(arg_id, size);
ctx_->set_array_device_allocation_type(
arg_id, is_device_allocation ? RuntimeContext::DevAllocType::kNdarray
: RuntimeContext::DevAllocType::kNone);
}

void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape(
int arg_id,
uintptr_t ptr,
uint64 size,
const std::vector<int64> &shape) {
this->set_arg_external_array(arg_id, ptr, size,
/*is_device_allocation=*/false);
TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
for (uint64 i = 0; i < shape.size(); ++i) {
this->set_extra_arg_int(arg_id, i, shape[i]);
}
ctx_->set_arg_external_array(arg_id, ptr, size, shape);
}

void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id,
Expand Down
5 changes: 0 additions & 5 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ class TI_DLL_EXPORT Kernel : public Callable {

void set_extra_arg_int(int i, int j, int32 d);

void set_arg_external_array(int arg_id,
uintptr_t ptr,
uint64 size,
bool is_device_allocation);

void set_arg_external_array_with_shape(int arg_id,
uintptr_t ptr,
uint64 size,
Expand Down
2 changes: 0 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,6 @@ void export_lang(py::module &m) {
py::class_<Kernel::LaunchContextBuilder>(m, "KernelLaunchContext")
.def("set_arg_int", &Kernel::LaunchContextBuilder::set_arg_int)
.def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float)
.def("set_arg_external_array",
&Kernel::LaunchContextBuilder::set_arg_external_array)
.def("set_arg_external_array_with_shape",
&Kernel::LaunchContextBuilder::set_arg_external_array_with_shape)
.def("set_arg_ndarray", &Kernel::LaunchContextBuilder::set_arg_ndarray)
Expand Down
12 changes: 4 additions & 8 deletions tests/cpp/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ TEST(IRBuilder, ExternalPtr) {
auto ker = std::make_unique<Kernel>(*test_prog.prog(), std::move(block));
ker->insert_arg(get_data_type<int>(), /*is_array=*/true);
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_external_array(/*arg_id=*/0, (uint64)array.get(), size,
/*is_device_allocation=*/false);
launch_ctx.set_arg_external_array_with_shape(
/*arg_id=*/0, (uint64)array.get(), size, {size});
(*ker)(launch_ctx);
EXPECT_EQ(array[0], 2);
EXPECT_EQ(array[1], 1);
Expand All @@ -139,19 +139,15 @@ TEST(IRBuilder, Ndarray) {
array.write_int({2}, 40);
auto ker1 = setup_kernel1(test_prog.prog());
auto launch_ctx1 = ker1->make_launch_context();
launch_ctx1.set_arg_external_array(
/*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size,
/*is_device_allocation=*/true);
launch_ctx1.set_arg_ndarray(/*arg_id=*/0, array);
(*ker1)(launch_ctx1);
EXPECT_EQ(array.read_int({0}), 2);
EXPECT_EQ(array.read_int({1}), 1);
EXPECT_EQ(array.read_int({2}), 42);

auto ker2 = setup_kernel2(test_prog.prog());
auto launch_ctx2 = ker2->make_launch_context();
launch_ctx2.set_arg_external_array(
/*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size,
/*is_device_allocation=*/true);
launch_ctx2.set_arg_ndarray(/*arg_id=*/0, array);
launch_ctx2.set_arg_int(/*arg_id=*/1, 3);
(*ker2)(launch_ctx2);
EXPECT_EQ(array.read_int({0}), 2);
Expand Down

0 comments on commit abec6f6

Please sign in to comment.