diff --git a/cpp_examples/autograd.cpp b/cpp_examples/autograd.cpp index dcf5324b14ecb..0965c1603e149 100644 --- a/cpp_examples/autograd.cpp +++ b/cpp_examples/autograd.cpp @@ -174,12 +174,12 @@ void autograd() { auto ctx_backward = kernel_backward->make_launch_context(); auto ctx_ext = kernel_ext->make_launch_context(); std::vector ext_a(n), ext_b(n), ext_c(n); - ctx_ext.set_arg_external_array(0, taichi::uint64(ext_a.data()), n, - /*is_device_allocation=*/false); - ctx_ext.set_arg_external_array(1, taichi::uint64(ext_b.data()), n, - /*is_device_allocation=*/false); - ctx_ext.set_arg_external_array(2, taichi::uint64(ext_c.data()), n, - /*is_device_allocation=*/false); + ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_a.data()), n, + {n}); + ctx_ext.set_arg_external_array_with_shape(1, taichi::uint64(ext_b.data()), n, + {n}); + ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n, + {n}); (*kernel_init)(ctx_init); (*kernel_forward)(ctx_forward); diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp index 992f6ae1d79f2..a97122ed9e821 100644 --- a/cpp_examples/run_snode.cpp +++ b/cpp_examples/run_snode.cpp @@ -129,7 +129,8 @@ void run_snode() { auto ctx_ret = kernel_ret->make_launch_context(); auto ctx_ext = kernel_ext->make_launch_context(); std::vector ext_arr(n); - ctx_ext.set_arg_external_array(0, taichi::uint64(ext_arr.data()), n, false); + ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()), + n, {n}); (*kernel_init)(ctx_init); (*kernel_ret)(ctx_ret); diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 41b367e0dd72b..679421708f8d4 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -841,7 +841,7 @@ FunctionType CUDAModuleToFunctionConverter::convert( device_buffers[i] = arg_buffers[i]; } // device_buffers[i] saves a raw ptr on CUDA device. - context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz); + context.set_arg(i, (uint64)device_buffers[i]); } else if (arr_sz > 0) { // arg_buffers[i] is a DeviceAllocation* @@ -857,7 +857,7 @@ FunctionType CUDAModuleToFunctionConverter::convert( arg_buffers[i] = device_buffers[i]; // device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i] - context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz); + context.set_arg(i, (uint64)device_buffers[i]); } } } diff --git a/taichi/program/context.h b/taichi/program/context.h index 059f2ef39df07..ea2ecdb7c9f6a 100644 --- a/taichi/program/context.h +++ b/taichi/program/context.h @@ -88,11 +88,17 @@ struct RuntimeContext { set_array_device_allocation_type(arg_id, DevAllocType::kRWTexture); } - void set_arg_external_array(int arg_id, uintptr_t ptr, uint64 size) { + void set_arg_external_array(int arg_id, + uintptr_t ptr, + uint64 size, + const std::vector &shape) { set_arg(arg_id, ptr); set_array_runtime_size(arg_id, size); set_array_device_allocation_type(arg_id, - RuntimeContext::DevAllocType::kNdarray); + RuntimeContext::DevAllocType::kNone); + for (uint64 i = 0; i < shape.size(); ++i) { + extra_args[arg_id][i] = shape[i]; + } } void set_arg_ndarray(int arg_id, diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 3d6c7cf98fa8c..931d54dab086f 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -220,11 +220,11 @@ void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) { ctx_->extra_args[i][j] = d; } -void Kernel::LaunchContextBuilder::set_arg_external_array( +void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape( int arg_id, uintptr_t ptr, uint64 size, - bool is_device_allocation) { + const std::vector &shape) { TI_ASSERT_INFO( kernel_->args[arg_id].is_array, "Assigning external (numpy) array to scalar argument is not allowed."); @@ -235,25 +235,9 @@ void Kernel::LaunchContextBuilder::set_arg_external_array( ActionArg("address", fmt::format("0x{:x}", ptr)), ActionArg("array_size_in_bytes", (int64)size)}); - ctx_->set_arg(arg_id, ptr); - ctx_->set_array_runtime_size(arg_id, size); - ctx_->set_array_device_allocation_type( - arg_id, is_device_allocation ? RuntimeContext::DevAllocType::kNdarray - : RuntimeContext::DevAllocType::kNone); -} - -void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape( - int arg_id, - uintptr_t ptr, - uint64 size, - const std::vector &shape) { - this->set_arg_external_array(arg_id, ptr, size, - /*is_device_allocation=*/false); TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices, "External array cannot have > {max_num_indices} indices"); - for (uint64 i = 0; i < shape.size(); ++i) { - this->set_extra_arg_int(arg_id, i, shape[i]); - } + ctx_->set_arg_external_array(arg_id, ptr, size, shape); } void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id, diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 2c4ec842c6129..a4c18284beb11 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -39,11 +39,6 @@ class TI_DLL_EXPORT Kernel : public Callable { void set_extra_arg_int(int i, int j, int32 d); - void set_arg_external_array(int arg_id, - uintptr_t ptr, - uint64 size, - bool is_device_allocation); - void set_arg_external_array_with_shape(int arg_id, uintptr_t ptr, uint64 size, diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index f964051698586..acf3d855335ba 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -671,8 +671,6 @@ void export_lang(py::module &m) { py::class_(m, "KernelLaunchContext") .def("set_arg_int", &Kernel::LaunchContextBuilder::set_arg_int) .def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float) - .def("set_arg_external_array", - &Kernel::LaunchContextBuilder::set_arg_external_array) .def("set_arg_external_array_with_shape", &Kernel::LaunchContextBuilder::set_arg_external_array_with_shape) .def("set_arg_ndarray", &Kernel::LaunchContextBuilder::set_arg_ndarray) diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index 38f47d4d608a0..a7f60bc416a8e 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -114,8 +114,8 @@ TEST(IRBuilder, ExternalPtr) { auto ker = std::make_unique(*test_prog.prog(), std::move(block)); ker->insert_arg(get_data_type(), /*is_array=*/true); auto launch_ctx = ker->make_launch_context(); - launch_ctx.set_arg_external_array(/*arg_id=*/0, (uint64)array.get(), size, - /*is_device_allocation=*/false); + launch_ctx.set_arg_external_array_with_shape( + /*arg_id=*/0, (uint64)array.get(), size, {size}); (*ker)(launch_ctx); EXPECT_EQ(array[0], 2); EXPECT_EQ(array[1], 1); @@ -139,9 +139,7 @@ TEST(IRBuilder, Ndarray) { array.write_int({2}, 40); auto ker1 = setup_kernel1(test_prog.prog()); auto launch_ctx1 = ker1->make_launch_context(); - launch_ctx1.set_arg_external_array( - /*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size, - /*is_device_allocation=*/true); + launch_ctx1.set_arg_ndarray(/*arg_id=*/0, array); (*ker1)(launch_ctx1); EXPECT_EQ(array.read_int({0}), 2); EXPECT_EQ(array.read_int({1}), 1); @@ -149,9 +147,7 @@ TEST(IRBuilder, Ndarray) { auto ker2 = setup_kernel2(test_prog.prog()); auto launch_ctx2 = ker2->make_launch_context(); - launch_ctx2.set_arg_external_array( - /*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size, - /*is_device_allocation=*/true); + launch_ctx2.set_arg_ndarray(/*arg_id=*/0, array); launch_ctx2.set_arg_int(/*arg_id=*/1, 3); (*ker2)(launch_ctx2); EXPECT_EQ(array.read_int({0}), 2);