Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Unify ways to set external array args #5565

Merged
merged 1 commit into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ void autograd() {
auto ctx_backward = kernel_backward->make_launch_context();
auto ctx_ext = kernel_ext->make_launch_context();
std::vector<float> ext_a(n), ext_b(n), ext_c(n);
ctx_ext.set_arg_external_array(0, taichi::uint64(ext_a.data()), n,
/*is_device_allocation=*/false);
ctx_ext.set_arg_external_array(1, taichi::uint64(ext_b.data()), n,
/*is_device_allocation=*/false);
ctx_ext.set_arg_external_array(2, taichi::uint64(ext_c.data()), n,
/*is_device_allocation=*/false);
ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_a.data()), n,
{n});
ctx_ext.set_arg_external_array_with_shape(1, taichi::uint64(ext_b.data()), n,
{n});
ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n,
{n});

(*kernel_init)(ctx_init);
(*kernel_forward)(ctx_forward);
Expand Down
3 changes: 2 additions & 1 deletion cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ void run_snode() {
auto ctx_ret = kernel_ret->make_launch_context();
auto ctx_ext = kernel_ext->make_launch_context();
std::vector<int> ext_arr(n);
ctx_ext.set_arg_external_array(0, taichi::uint64(ext_arr.data()), n, false);
ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()),
n, {n});

(*kernel_init)(ctx_init);
(*kernel_ret)(ctx_ret);
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ FunctionType CUDAModuleToFunctionConverter::convert(
device_buffers[i] = arg_buffers[i];
}
// device_buffers[i] saves a raw ptr on CUDA device.
context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz);
context.set_arg(i, (uint64)device_buffers[i]);

} else if (arr_sz > 0) {
// arg_buffers[i] is a DeviceAllocation*
Expand All @@ -857,7 +857,7 @@ FunctionType CUDAModuleToFunctionConverter::convert(
arg_buffers[i] = device_buffers[i];

// device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i]
context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz);
context.set_arg(i, (uint64)device_buffers[i]);
}
}
}
Expand Down
10 changes: 8 additions & 2 deletions taichi/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,17 @@ struct RuntimeContext {
set_array_device_allocation_type(arg_id, DevAllocType::kRWTexture);
}

void set_arg_external_array(int arg_id, uintptr_t ptr, uint64 size) {
void set_arg_external_array(int arg_id,
uintptr_t ptr,
uint64 size,
ailzhang marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<int64> &shape) {
set_arg(arg_id, ptr);
set_array_runtime_size(arg_id, size);
set_array_device_allocation_type(arg_id,
RuntimeContext::DevAllocType::kNdarray);
RuntimeContext::DevAllocType::kNone);
for (uint64 i = 0; i < shape.size(); ++i) {
extra_args[arg_id][i] = shape[i];
}
}

void set_arg_ndarray(int arg_id,
Expand Down
22 changes: 3 additions & 19 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {
ctx_->extra_args[i][j] = d;
}

void Kernel::LaunchContextBuilder::set_arg_external_array(
void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape(
int arg_id,
uintptr_t ptr,
uint64 size,
bool is_device_allocation) {
const std::vector<int64> &shape) {
TI_ASSERT_INFO(
kernel_->args[arg_id].is_array,
"Assigning external (numpy) array to scalar argument is not allowed.");
Expand All @@ -235,25 +235,9 @@ void Kernel::LaunchContextBuilder::set_arg_external_array(
ActionArg("address", fmt::format("0x{:x}", ptr)),
ActionArg("array_size_in_bytes", (int64)size)});

ctx_->set_arg(arg_id, ptr);
ctx_->set_array_runtime_size(arg_id, size);
ctx_->set_array_device_allocation_type(
arg_id, is_device_allocation ? RuntimeContext::DevAllocType::kNdarray
: RuntimeContext::DevAllocType::kNone);
}

void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape(
int arg_id,
uintptr_t ptr,
uint64 size,
const std::vector<int64> &shape) {
this->set_arg_external_array(arg_id, ptr, size,
/*is_device_allocation=*/false);
TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
for (uint64 i = 0; i < shape.size(); ++i) {
this->set_extra_arg_int(arg_id, i, shape[i]);
}
ctx_->set_arg_external_array(arg_id, ptr, size, shape);
}

void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id,
Expand Down
5 changes: 0 additions & 5 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ class TI_DLL_EXPORT Kernel : public Callable {

void set_extra_arg_int(int i, int j, int32 d);

void set_arg_external_array(int arg_id,
uintptr_t ptr,
uint64 size,
bool is_device_allocation);

void set_arg_external_array_with_shape(int arg_id,
uintptr_t ptr,
uint64 size,
Expand Down
2 changes: 0 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,6 @@ void export_lang(py::module &m) {
py::class_<Kernel::LaunchContextBuilder>(m, "KernelLaunchContext")
.def("set_arg_int", &Kernel::LaunchContextBuilder::set_arg_int)
.def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float)
.def("set_arg_external_array",
&Kernel::LaunchContextBuilder::set_arg_external_array)
.def("set_arg_external_array_with_shape",
&Kernel::LaunchContextBuilder::set_arg_external_array_with_shape)
.def("set_arg_ndarray", &Kernel::LaunchContextBuilder::set_arg_ndarray)
Expand Down
12 changes: 4 additions & 8 deletions tests/cpp/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ TEST(IRBuilder, ExternalPtr) {
auto ker = std::make_unique<Kernel>(*test_prog.prog(), std::move(block));
ker->insert_arg(get_data_type<int>(), /*is_array=*/true);
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_external_array(/*arg_id=*/0, (uint64)array.get(), size,
/*is_device_allocation=*/false);
launch_ctx.set_arg_external_array_with_shape(
/*arg_id=*/0, (uint64)array.get(), size, {size});
(*ker)(launch_ctx);
EXPECT_EQ(array[0], 2);
EXPECT_EQ(array[1], 1);
Expand All @@ -139,19 +139,15 @@ TEST(IRBuilder, Ndarray) {
array.write_int({2}, 40);
auto ker1 = setup_kernel1(test_prog.prog());
auto launch_ctx1 = ker1->make_launch_context();
launch_ctx1.set_arg_external_array(
/*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size,
/*is_device_allocation=*/true);
launch_ctx1.set_arg_ndarray(/*arg_id=*/0, array);
(*ker1)(launch_ctx1);
EXPECT_EQ(array.read_int({0}), 2);
EXPECT_EQ(array.read_int({1}), 1);
EXPECT_EQ(array.read_int({2}), 42);

auto ker2 = setup_kernel2(test_prog.prog());
auto launch_ctx2 = ker2->make_launch_context();
launch_ctx2.set_arg_external_array(
/*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size,
/*is_device_allocation=*/true);
launch_ctx2.set_arg_ndarray(/*arg_id=*/0, array);
launch_ctx2.set_arg_int(/*arg_id=*/1, 3);
(*ker2)(launch_ctx2);
EXPECT_EQ(array.read_int({0}), 2);
Expand Down