Skip to content

Commit

Permalink
[refactor] Unify ways to set ndarray args (#5559)
Browse files Browse the repository at this point in the history
We have various ways to set ndarray args and this PR aims to unify them
around RuntimeContex.
  • Loading branch information
ailzhang committed Jul 29, 2022
1 parent 960fb8b commit db71d0c
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 93 deletions.
21 changes: 1 addition & 20 deletions c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,26 +299,7 @@ void ti_launch_kernel(TiRuntime runtime,
std::vector<int> shape(ndarray.shape.dims,
ndarray.shape.dims + ndarray.shape.dim_count);

size_t total_array_size = 1;
for (const auto &val : shape) {
total_array_size *= val;
}

if (ndarray.elem_shape.dim_count != 0) {
std::vector<int> elem_shape(
ndarray.elem_shape.dims,
ndarray.elem_shape.dims + ndarray.elem_shape.dim_count);

for (const auto &val : elem_shape) {
total_array_size *= val;
}

runtime_context.set_arg_devalloc(i, *devalloc, shape, elem_shape);
runtime_context.set_array_runtime_size(i, total_array_size);
} else {
runtime_context.set_arg_devalloc(i, *devalloc, shape);
runtime_context.set_array_runtime_size(i, total_array_size);
}
runtime_context.set_arg_ndarray(i, (intptr_t)devalloc.get(), shape);

devallocs.emplace_back(std::move(devalloc));
break;
Expand Down
5 changes: 4 additions & 1 deletion taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "taichi/program/ndarray.h"
#include "taichi/program/texture.h"

#include <numeric>

namespace taichi {
namespace lang {
namespace aot {
Expand Down Expand Up @@ -37,7 +39,8 @@ void CompiledGraph::run(
symbolic_arg.name, symbolic_arg.dtype().to_string(),
arr->dtype.to_string());

set_runtime_ctx_ndarray(&ctx, i, arr);
ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(),
arr->shape);
} else if (ival.tag == aot::ArgKind::kScalar) {
ctx.set_arg(i, ival.val);
} else if (ival.tag == aot::ArgKind::kTexture) {
Expand Down
27 changes: 2 additions & 25 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,25 +709,6 @@ std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCUDA::make_codegen_llvm(
}
#endif // TI_WITH_LLVM

static void set_arg_external_array(RuntimeContext *ctx,
const std::string &kernel_name,
int arg_id,
uintptr_t ptr,
uint64 size,
bool is_device_allocation) {
ActionRecorder::get_instance().record(
"set_kernel_arg_ext_ptr",
{ActionArg("kernel_name", kernel_name), ActionArg("arg_id", arg_id),
ActionArg("address", fmt::format("0x{:x}", ptr)),
ActionArg("array_size_in_bytes", (int64)size)});

ctx->set_arg(arg_id, ptr);
ctx->set_array_runtime_size(arg_id, size);
ctx->set_array_device_allocation_type(
arg_id, is_device_allocation ? RuntimeContext::DevAllocType::kNdarray
: RuntimeContext::DevAllocType::kNone);
}

LLVMCompiledData KernelCodeGenCUDA::modulegen(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
Expand Down Expand Up @@ -860,9 +841,7 @@ FunctionType CUDAModuleToFunctionConverter::convert(
device_buffers[i] = arg_buffers[i];
}
// device_buffers[i] saves a raw ptr on CUDA device.
set_arg_external_array(&context, kernel_name, i,
(uint64)device_buffers[i], arr_sz,
/*is_device_allocation=*/false);
context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz);

} else if (arr_sz > 0) {
// arg_buffers[i] is a DeviceAllocation*
Expand All @@ -878,9 +857,7 @@ FunctionType CUDAModuleToFunctionConverter::convert(
arg_buffers[i] = device_buffers[i];

// device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i]
set_arg_external_array(&context, kernel_name, i,
(uint64)device_buffers[i], arr_sz,
/*is_device_allocation=*/false);
context.set_arg_external_array(i, (uint64)device_buffers[i], arr_sz);
}
}
}
Expand Down
32 changes: 13 additions & 19 deletions taichi/program/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,32 +88,26 @@ struct RuntimeContext {
set_array_device_allocation_type(arg_id, DevAllocType::kRWTexture);
}

void set_arg_devalloc(int arg_id,
DeviceAllocation &alloc,
const std::vector<int> &shape) {
args[arg_id] = taichi_union_cast_with_different_sizes<uint64>(&alloc);
set_array_device_allocation_type(arg_id, DevAllocType::kNdarray);
TI_ASSERT(shape.size() <= taichi_max_num_indices);
for (int i = 0; i < shape.size(); i++) {
extra_args[arg_id][i] = shape[i];
}
void set_arg_external_array(int arg_id, uintptr_t ptr, uint64 size) {
set_arg(arg_id, ptr);
set_array_runtime_size(arg_id, size);
set_array_device_allocation_type(arg_id,
RuntimeContext::DevAllocType::kNdarray);
}

void set_arg_devalloc(int arg_id,
DeviceAllocation &alloc,
const std::vector<int> &shape,
const std::vector<int> &element_shape) {
args[arg_id] = taichi_union_cast_with_different_sizes<uint64>(&alloc);
void set_arg_ndarray(int arg_id,
intptr_t devalloc_ptr,
const std::vector<int> &shape) {
args[arg_id] = taichi_union_cast_with_different_sizes<uint64>(devalloc_ptr);
set_array_device_allocation_type(arg_id, DevAllocType::kNdarray);
TI_ASSERT(shape.size() + element_shape.size() <= taichi_max_num_indices);
TI_ASSERT(shape.size() <= taichi_max_num_indices);
size_t total_size = 1;
for (int i = 0; i < shape.size(); i++) {
extra_args[arg_id][i] = shape[i];
total_size *= shape[i];
}
for (int i = 0; i < element_shape.size(); i++) {
extra_args[arg_id][i + shape.size()] = element_shape[i];
}
set_array_runtime_size(arg_id, total_size);
}

#endif
};

Expand Down
7 changes: 1 addition & 6 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,9 @@ void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape(
void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id,
const Ndarray &arr) {
intptr_t ptr = arr.get_device_allocation_ptr_as_int();
uint64 arr_size = arr.get_element_size() * arr.get_nelement();
this->set_arg_external_array(arg_id, ptr, arr_size,
/*is_device_allocation=*/true);
TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
for (uint64 i = 0; i < arr.shape.size(); ++i) {
this->set_extra_arg_int(arg_id, i, arr.shape[i]);
}
ctx_->set_arg_ndarray(arg_id, ptr, arr.shape);
}

void Kernel::LaunchContextBuilder::set_arg_texture(int arg_id,
Expand Down
14 changes: 0 additions & 14 deletions taichi/program/ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,5 @@ void Ndarray::write_float(const std::vector<int> &i, float64 val) {
write<float>(i, val);
}

void set_runtime_ctx_ndarray(RuntimeContext *ctx,
int arg_id,
Ndarray *ndarray) {
ctx->set_arg_devalloc(arg_id, ndarray->ndarray_alloc_, ndarray->shape);

uint64_t total_array_size = 1;
for (const auto &dim : ndarray->total_shape()) {
total_array_size *= dim;
}
total_array_size *= data_type_size(ndarray->dtype);

ctx->set_array_runtime_size(arg_id, total_array_size);
}

} // namespace lang
} // namespace taichi
3 changes: 0 additions & 3 deletions taichi/program/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,5 @@ class TI_DLL_EXPORT Ndarray {
Program *prog_{nullptr};
};

// TODO: move this as a method inside RuntimeContext once Ndarray is decoupled
// with Program
void set_runtime_ctx_ndarray(RuntimeContext *ctx, int arg_id, Ndarray *ndarray);
} // namespace lang
} // namespace taichi
10 changes: 6 additions & 4 deletions tests/cpp/aot/llvm/kernel_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ TEST(LlvmAotTest, CpuKernel) {
constexpr int kArrLen = 32;
constexpr int kArrBytes = kArrLen * sizeof(int32_t);
auto arr_devalloc = exec.allocate_memory_ndarray(kArrBytes, result_buffer);
Ndarray arr = Ndarray(arr_devalloc, PrimitiveType::i32, {kArrLen});

cpu::AotModuleParams aot_params;
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");
Expand All @@ -44,8 +45,8 @@ TEST(LlvmAotTest, CpuKernel) {
RuntimeContext ctx;
ctx.runtime = exec.get_llvm_runtime();
ctx.set_arg(0, /*v=*/0);
ctx.set_arg_devalloc(/*arg_id=*/1, arr_devalloc, /*shape=*/{kArrLen});
ctx.set_array_runtime_size(/*arg_id=*/1, kArrBytes);
ctx.set_arg_ndarray(/*arg_id=*/1, arr.get_device_allocation_ptr_as_int(),
/*shape=*/arr.shape);
k_run->launch(&ctx);

auto *data = reinterpret_cast<int32_t *>(
Expand All @@ -70,6 +71,7 @@ TEST(LlvmAotTest, CudaKernel) {
constexpr int kArrLen = 32;
constexpr int kArrBytes = kArrLen * sizeof(int32_t);
auto arr_devalloc = exec.allocate_memory_ndarray(kArrBytes, result_buffer);
Ndarray arr = Ndarray(arr_devalloc, PrimitiveType::i32, {kArrLen});

cuda::AotModuleParams aot_params;
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");
Expand All @@ -83,8 +85,8 @@ TEST(LlvmAotTest, CudaKernel) {
RuntimeContext ctx;
ctx.runtime = exec.get_llvm_runtime();
ctx.set_arg(0, /*v=*/0);
ctx.set_arg_devalloc(/*arg_id=*/1, arr_devalloc, /*shape=*/{kArrLen});
ctx.set_array_runtime_size(/*arg_id=*/1, kArrBytes);
ctx.set_arg_ndarray(/*arg_id=*/1, arr.get_device_allocation_ptr_as_int(),
/*shape=*/arr.shape);
k_run->launch(&ctx);

auto *data = reinterpret_cast<int32_t *>(
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/aot/vulkan/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ TEST(AotSaveLoad, VulkanNdarray) {
DeviceAllocation devalloc_arr_ =
embedded_device->device()->allocate_memory(alloc_params);
Ndarray arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size});
taichi::lang::set_runtime_ctx_ndarray(&host_ctx, 0, &arr);
host_ctx.set_arg_ndarray(0, arr.get_device_allocation_ptr_as_int(),
arr.shape);
int src[size] = {0};
src[0] = 2;
src[2] = 40;
Expand Down

0 comments on commit db71d0c

Please sign in to comment.