diff --git a/c_api/src/taichi_core_impl.cpp b/c_api/src/taichi_core_impl.cpp index c06752ea4fbca..a59f2dbca0f09 100644 --- a/c_api/src/taichi_core_impl.cpp +++ b/c_api/src/taichi_core_impl.cpp @@ -299,26 +299,7 @@ void ti_launch_kernel(TiRuntime runtime, std::vector shape(ndarray.shape.dims, ndarray.shape.dims + ndarray.shape.dim_count); - size_t total_array_size = 1; - for (const auto &val : shape) { - total_array_size *= val; - } - - if (ndarray.elem_shape.dim_count != 0) { - std::vector elem_shape( - ndarray.elem_shape.dims, - ndarray.elem_shape.dims + ndarray.elem_shape.dim_count); - - for (const auto &val : elem_shape) { - total_array_size *= val; - } - - runtime_context.set_arg_devalloc(i, *devalloc, shape, elem_shape); - runtime_context.set_array_runtime_size(i, total_array_size); - } else { - runtime_context.set_arg_devalloc(i, *devalloc, shape); - runtime_context.set_array_runtime_size(i, total_array_size); - } + runtime_context.set_arg_ndarray(i, (intptr_t)devalloc.get(), shape); devallocs.emplace_back(std::move(devalloc)); break; diff --git a/taichi/aot/graph_data.cpp b/taichi/aot/graph_data.cpp index 11cc3804b2a43..69fb590d4af5e 100644 --- a/taichi/aot/graph_data.cpp +++ b/taichi/aot/graph_data.cpp @@ -2,6 +2,8 @@ #include "taichi/program/ndarray.h" #include "taichi/program/texture.h" +#include + namespace taichi { namespace lang { namespace aot { @@ -37,7 +39,8 @@ void CompiledGraph::run( symbolic_arg.name, symbolic_arg.dtype().to_string(), arr->dtype.to_string()); - set_runtime_ctx_ndarray(&ctx, i, arr); + ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(), + arr->shape); } else if (ival.tag == aot::ArgKind::kScalar) { ctx.set_arg(i, ival.val); } else if (ival.tag == aot::ArgKind::kTexture) { diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 6f5d3712c3ae9..41b367e0dd72b 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -709,25 +709,6 @@ std::unique_ptr KernelCodeGenCUDA::make_codegen_llvm( } #endif // TI_WITH_LLVM -static void set_arg_external_array(RuntimeContext *ctx, - const std::string &kernel_name, - int arg_id, - uintptr_t ptr, - uint64 size, - bool is_device_allocation) { - ActionRecorder::get_instance().record( - "set_kernel_arg_ext_ptr", - {ActionArg("kernel_name", kernel_name), ActionArg("arg_id", arg_id), - ActionArg("address", fmt::format("0x{:x}", ptr)), - ActionArg("array_size_in_bytes", (int64)size)}); - - ctx->set_arg(arg_id, ptr); - ctx->set_array_runtime_size(arg_id, size); - ctx->set_array_device_allocation_type( - arg_id, is_device_allocation ? RuntimeContext::DevAllocType::kNdarray - : RuntimeContext::DevAllocType::kNone); -} - LLVMCompiledData KernelCodeGenCUDA::modulegen( std::unique_ptr &&module, OffloadedStmt *stmt) { @@ -860,9 +841,7 @@ FunctionType CUDAModuleToFunctionConverter::convert( device_buffers[i] = arg_buffers[i]; } // device_buffers[i] saves a raw ptr on CUDA device. - set_arg_external_array(&context, kernel_name, i, - (uint64)device_buffers[i], arr_sz, - /*is_device_allocation=*/false); + context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz); } else if (arr_sz > 0) { // arg_buffers[i] is a DeviceAllocation* @@ -878,9 +857,7 @@ FunctionType CUDAModuleToFunctionConverter::convert( arg_buffers[i] = device_buffers[i]; // device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i] - set_arg_external_array(&context, kernel_name, i, - (uint64)device_buffers[i], arr_sz, - /*is_device_allocation=*/false); + context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz); } } } diff --git a/taichi/program/context.h b/taichi/program/context.h index 4b4c981107f13..059f2ef39df07 100644 --- a/taichi/program/context.h +++ b/taichi/program/context.h @@ -88,32 +88,26 @@ struct RuntimeContext { set_array_device_allocation_type(arg_id, DevAllocType::kRWTexture); } - void set_arg_devalloc(int arg_id, - DeviceAllocation &alloc, - const std::vector &shape) { - args[arg_id] = taichi_union_cast_with_different_sizes(&alloc); - set_array_device_allocation_type(arg_id, DevAllocType::kNdarray); - TI_ASSERT(shape.size() <= taichi_max_num_indices); - for (int i = 0; i < shape.size(); i++) { - extra_args[arg_id][i] = shape[i]; - } + void set_arg_external_array(int arg_id, uintptr_t ptr, uint64 size) { + set_arg(arg_id, ptr); + set_array_runtime_size(arg_id, size); + set_array_device_allocation_type(arg_id, + RuntimeContext::DevAllocType::kNdarray); } - void set_arg_devalloc(int arg_id, - DeviceAllocation &alloc, - const std::vector &shape, - const std::vector &element_shape) { - args[arg_id] = taichi_union_cast_with_different_sizes(&alloc); + void set_arg_ndarray(int arg_id, + intptr_t devalloc_ptr, + const std::vector &shape) { + args[arg_id] = taichi_union_cast_with_different_sizes(devalloc_ptr); set_array_device_allocation_type(arg_id, DevAllocType::kNdarray); - TI_ASSERT(shape.size() + element_shape.size() <= taichi_max_num_indices); + TI_ASSERT(shape.size() <= taichi_max_num_indices); + size_t total_size = 1; for (int i = 0; i < shape.size(); i++) { extra_args[arg_id][i] = shape[i]; + total_size *= shape[i]; } - for (int i = 0; i < element_shape.size(); i++) { - extra_args[arg_id][i + shape.size()] = element_shape[i]; - } + set_array_runtime_size(arg_id, total_size); } - #endif }; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 82a9c1bc597dc..3d6c7cf98fa8c 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -259,14 +259,9 @@ void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape( void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id, const Ndarray &arr) { intptr_t ptr = arr.get_device_allocation_ptr_as_int(); - uint64 arr_size = arr.get_element_size() * arr.get_nelement(); - this->set_arg_external_array(arg_id, ptr, arr_size, - /*is_device_allocation=*/true); TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices, "External array cannot have > {max_num_indices} indices"); - for (uint64 i = 0; i < arr.shape.size(); ++i) { - this->set_extra_arg_int(arg_id, i, arr.shape[i]); - } + ctx_->set_arg_ndarray(arg_id, ptr, arr.shape); } void Kernel::LaunchContextBuilder::set_arg_texture(int arg_id, diff --git a/taichi/program/ndarray.cpp b/taichi/program/ndarray.cpp index 46c201b1749c5..11382b205bb87 100644 --- a/taichi/program/ndarray.cpp +++ b/taichi/program/ndarray.cpp @@ -190,19 +190,5 @@ void Ndarray::write_float(const std::vector &i, float64 val) { write(i, val); } -void set_runtime_ctx_ndarray(RuntimeContext *ctx, - int arg_id, - Ndarray *ndarray) { - ctx->set_arg_devalloc(arg_id, ndarray->ndarray_alloc_, ndarray->shape); - - uint64_t total_array_size = 1; - for (const auto &dim : ndarray->total_shape()) { - total_array_size *= dim; - } - total_array_size *= data_type_size(ndarray->dtype); - - ctx->set_array_runtime_size(arg_id, total_array_size); -} - } // namespace lang } // namespace taichi diff --git a/taichi/program/ndarray.h b/taichi/program/ndarray.h index 6f5c30f23b298..746515428dcc6 100644 --- a/taichi/program/ndarray.h +++ b/taichi/program/ndarray.h @@ -72,8 +72,5 @@ class TI_DLL_EXPORT Ndarray { Program *prog_{nullptr}; }; -// TODO: move this as a method inside RuntimeContext once Ndarray is decoupled -// with Program -void set_runtime_ctx_ndarray(RuntimeContext *ctx, int arg_id, Ndarray *ndarray); } // namespace lang } // namespace taichi diff --git a/tests/cpp/aot/llvm/kernel_aot_test.cpp b/tests/cpp/aot/llvm/kernel_aot_test.cpp index 93458b7c84613..d0783938f98ab 100644 --- a/tests/cpp/aot/llvm/kernel_aot_test.cpp +++ b/tests/cpp/aot/llvm/kernel_aot_test.cpp @@ -30,6 +30,7 @@ TEST(LlvmAotTest, CpuKernel) { constexpr int kArrLen = 32; constexpr int kArrBytes = kArrLen * sizeof(int32_t); auto arr_devalloc = exec.allocate_memory_ndarray(kArrBytes, result_buffer); + Ndarray arr = Ndarray(arr_devalloc, PrimitiveType::i32, {kArrLen}); cpu::AotModuleParams aot_params; const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH"); @@ -44,8 +45,8 @@ TEST(LlvmAotTest, CpuKernel) { RuntimeContext ctx; ctx.runtime = exec.get_llvm_runtime(); ctx.set_arg(0, /*v=*/0); - ctx.set_arg_devalloc(/*arg_id=*/1, arr_devalloc, /*shape=*/{kArrLen}); - ctx.set_array_runtime_size(/*arg_id=*/1, kArrBytes); + ctx.set_arg_ndarray(/*arg_id=*/1, arr.get_device_allocation_ptr_as_int(), + /*shape=*/arr.shape); k_run->launch(&ctx); auto *data = reinterpret_cast( @@ -70,6 +71,7 @@ TEST(LlvmAotTest, CudaKernel) { constexpr int kArrLen = 32; constexpr int kArrBytes = kArrLen * sizeof(int32_t); auto arr_devalloc = exec.allocate_memory_ndarray(kArrBytes, result_buffer); + Ndarray arr = Ndarray(arr_devalloc, PrimitiveType::i32, {kArrLen}); cuda::AotModuleParams aot_params; const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH"); @@ -83,8 +85,8 @@ TEST(LlvmAotTest, CudaKernel) { RuntimeContext ctx; ctx.runtime = exec.get_llvm_runtime(); ctx.set_arg(0, /*v=*/0); - ctx.set_arg_devalloc(/*arg_id=*/1, arr_devalloc, /*shape=*/{kArrLen}); - ctx.set_array_runtime_size(/*arg_id=*/1, kArrBytes); + ctx.set_arg_ndarray(/*arg_id=*/1, arr.get_device_allocation_ptr_as_int(), + /*shape=*/arr.shape); k_run->launch(&ctx); auto *data = reinterpret_cast( diff --git a/tests/cpp/aot/vulkan/aot_save_load_test.cpp b/tests/cpp/aot/vulkan/aot_save_load_test.cpp index 601bbc650a754..2004aaf324a18 100644 --- a/tests/cpp/aot/vulkan/aot_save_load_test.cpp +++ b/tests/cpp/aot/vulkan/aot_save_load_test.cpp @@ -272,7 +272,8 @@ TEST(AotSaveLoad, VulkanNdarray) { DeviceAllocation devalloc_arr_ = embedded_device->device()->allocate_memory(alloc_params); Ndarray arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size}); - taichi::lang::set_runtime_ctx_ndarray(&host_ctx, 0, &arr); + host_ctx.set_arg_ndarray(0, arr.get_device_allocation_ptr_as_int(), + arr.shape); int src[size] = {0}; src[0] = 2; src[2] = 40;