Skip to content

Commit

Permalink
Refactor2023: Rename *arg* to *param* in class Callable
Browse files Browse the repository at this point in the history
  • Loading branch information
PGZXB committed Dec 2, 2022
1 parent b5d6beb commit b943d3d
Show file tree
Hide file tree
Showing 18 changed files with 69 additions and 69 deletions.
2 changes: 1 addition & 1 deletion cpp_examples/aot_save.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void aot_save(taichi::Arch arch) {
builder.create_return(builder.create_local_load(sum));

kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret");
kernel_ret->insert_ret(PrimitiveType::i32);
kernel_ret->add_ret(PrimitiveType::i32);
}

aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
Expand Down
6 changes: 3 additions & 3 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ void autograd() {
}

kernel_ext = std::make_unique<Kernel>(program, builder.extract_ir(), "ext");
kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
kernel_ext->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {n});
kernel_ext->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {n});
kernel_ext->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {n});
}

auto ctx_init = kernel_init->make_launch_context();
Expand Down
2 changes: 1 addition & 1 deletion cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void run_snode() {
}

kernel_ext = std::make_unique<Kernel>(program, builder.extract_ir(), "ext");
kernel_ext->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {n});
kernel_ext->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {n});
}

auto ctx_init = kernel_init->make_launch_context();
Expand Down
12 changes: 6 additions & 6 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def decl_scalar_arg(dtype):
is_ref = True
dtype = dtype.tp
dtype = cook_dtype(dtype)
arg_id = impl.get_runtime().compiling_callable.insert_scalar_arg(dtype)
arg_id = impl.get_runtime().compiling_callable.add_scalar_param(dtype)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))


Expand All @@ -71,15 +71,15 @@ def decl_sparse_matrix(dtype):
value_type = cook_dtype(dtype)
ptr_type = cook_dtype(u64)
# Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer
arg_id = impl.get_runtime().compiling_callable.insert_scalar_arg(ptr_type)
arg_id = impl.get_runtime().compiling_callable.add_scalar_param(ptr_type)
return SparseMatrixProxy(
_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type)


def decl_ndarray_arg(dtype, dim, element_shape, layout):
dtype = cook_dtype(dtype)
element_dim = len(element_shape)
arg_id = impl.get_runtime().compiling_callable.insert_arr_arg(dtype, dim, element_shape)
arg_id = impl.get_runtime().compiling_callable.add_arr_param(dtype, dim, element_shape)
if layout == Layout.AOS:
element_dim = -element_dim
return AnyArray(
Expand All @@ -89,14 +89,14 @@ def decl_ndarray_arg(dtype, dim, element_shape, layout):

def decl_texture_arg(num_dimensions):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().compiling_callable.insert_texture_arg(f32)
arg_id = impl.get_runtime().compiling_callable.add_texture_param(f32)
return TextureSampler(
_ti_core.make_texture_ptr_expr(arg_id, num_dimensions), num_dimensions)


def decl_rw_texture_arg(num_dimensions, num_channels, channel_format, lod):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().compiling_callable.insert_texture_arg(f32)
arg_id = impl.get_runtime().compiling_callable.add_texture_param(f32)
return RWTextureAccessor(
_ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, num_channels,
channel_format, lod), num_dimensions)
Expand All @@ -108,4 +108,4 @@ def decl_ret(dtype):
[dtype.n, dtype.m], dtype.dtype)
else:
dtype = cook_dtype(dtype)
return impl.get_runtime().compiling_callable.insert_ret(dtype)
return impl.get_runtime().compiling_callable.add_ret(dtype)
6 changes: 3 additions & 3 deletions taichi/codegen/spirv/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ KernelContextAttributes::KernelContextAttributes(
: args_bytes_(0),
rets_bytes_(0),
extra_args_bytes_(RuntimeContext::extra_args_size) {
arr_access.resize(kernel.args.size(), irpass::ExternalPtrAccess(0));
arg_attribs_vec_.reserve(kernel.args.size());
arr_access.resize(kernel.parameter_list.size(), irpass::ExternalPtrAccess(0));
arg_attribs_vec_.reserve(kernel.parameter_list.size());
// TODO: We should be able to limit Kernel args and rets to be primitive types
// as well but let's leave that as a followup up PR.
for (const auto &ka : kernel.args) {
for (const auto &ka : kernel.parameter_list) {
ArgAttributes aa;
aa.dtype = ka.get_element_type()->as<PrimitiveType>()->type;
const size_t dt_bytes = ka.get_element_size();
Expand Down
22 changes: 11 additions & 11 deletions taichi/program/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,28 @@ Callable::Callable() = default;

Callable::~Callable() = default;

int Callable::insert_scalar_arg(const DataType &dt) {
args.emplace_back(dt->get_compute_type(), /*is_array=*/false);
return (int)args.size() - 1;
int Callable::add_scalar_param(const DataType &dt) {
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/false);
return (int)parameter_list.size() - 1;
}

int Callable::insert_ret(const DataType &dt) {
int Callable::add_ret(const DataType &dt) {
rets.emplace_back(dt->get_compute_type());
return (int)rets.size() - 1;
}

int Callable::insert_arr_arg(const DataType &dt,
int Callable::add_arr_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape) {
args.emplace_back(dt->get_compute_type(), /*is_array=*/true, /*size=*/0,
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true, /*size=*/0,
total_dim, element_shape);
return (int)args.size() - 1;
return (int)parameter_list.size() - 1;
}

int Callable::insert_texture_arg(const DataType &dt) {
// FIXME: we shouldn't abuse is_array for texture args
args.emplace_back(dt->get_compute_type(), /*is_array=*/true);
return (int)args.size() - 1;
int Callable::add_texture_param(const DataType &dt) {
// FIXME: we shouldn't abuse is_array for texture parameter_list
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true);
return (int)parameter_list.size() - 1;
}

} // namespace taichi::lang
16 changes: 8 additions & 8 deletions taichi/program/callable.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TI_DLL_EXPORT Callable {
std::unique_ptr<IRNode> ir{nullptr};
std::unique_ptr<FrontendContext> context{nullptr};

struct Arg {
struct Parameter {
bool is_array{
false}; // This is true for both ndarray and external array args.
std::size_t total_dim{0}; // total dim of array
Expand All @@ -32,7 +32,7 @@ class TI_DLL_EXPORT Callable {
However we kept the interfaces unchanged temporarily, so as to minimize
possible regressions.
*/
explicit Arg(const DataType &dt = PrimitiveType::unknown,
explicit Parameter(const DataType &dt = PrimitiveType::unknown,
bool is_array = false,
std::size_t size_unused = 0,
int total_dim = 0,
Expand Down Expand Up @@ -69,27 +69,27 @@ class TI_DLL_EXPORT Callable {
DataType dt_;
};

struct Ret {
struct Ret {
DataType dt;

explicit Ret(const DataType &dt = PrimitiveType::unknown) : dt(dt) {
}
};

std::vector<Arg> args;
std::vector<Parameter> parameter_list;
std::vector<Ret> rets;

Callable();
virtual ~Callable();

int insert_scalar_arg(const DataType &dt);
int add_scalar_param(const DataType &dt);

int insert_arr_arg(const DataType &dt,
int add_arr_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape);
int insert_texture_arg(const DataType &dt);
int add_texture_param(const DataType &dt);

int insert_ret(const DataType &dt);
int add_ret(const DataType &dt);

[[nodiscard]] virtual std::string get_name() const = 0;
};
Expand Down
12 changes: 6 additions & 6 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel)
}

void LaunchContextBuilder::set_arg_float(int arg_id, float64 d) {
TI_ASSERT_INFO(!kernel_->args[arg_id].is_array,
TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array,
"Assigning scalar value to external (numpy) array argument is "
"not allowed.");

Expand All @@ -80,7 +80,7 @@ void LaunchContextBuilder::set_arg_float(int arg_id, float64 d) {
{ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
ActionArg("val", d)});

auto dt = kernel_->args[arg_id].get_dtype();
auto dt = kernel_->parameter_list[arg_id].get_dtype();
if (dt->is_primitive(PrimitiveTypeID::f32)) {
ctx_->set_arg(arg_id, (float32)d);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
Expand Down Expand Up @@ -110,7 +110,7 @@ void LaunchContextBuilder::set_arg_float(int arg_id, float64 d) {
}

void LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {
TI_ASSERT_INFO(!kernel_->args[arg_id].is_array,
TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array,
"Assigning scalar value to external (numpy) array argument is "
"not allowed.");

Expand All @@ -119,7 +119,7 @@ void LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {
{ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
ActionArg("val", d)});

auto dt = kernel_->args[arg_id].get_dtype();
auto dt = kernel_->parameter_list[arg_id].get_dtype();
if (dt->is_primitive(PrimitiveTypeID::i32)) {
ctx_->set_arg(arg_id, (int32)d);
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
Expand Down Expand Up @@ -156,7 +156,7 @@ void LaunchContextBuilder::set_arg_external_array_with_shape(
uint64 size,
const std::vector<int64> &shape) {
TI_ASSERT_INFO(
kernel_->args[arg_id].is_array,
kernel_->parameter_list[arg_id].is_array,
"Assigning external (numpy) array to scalar argument is not allowed.");

ActionRecorder::get_instance().record(
Expand Down Expand Up @@ -188,7 +188,7 @@ void LaunchContextBuilder::set_arg_rw_texture(int arg_id, const Texture &tex) {
}

void LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) {
TI_ASSERT_INFO(!kernel_->args[arg_id].is_array,
TI_ASSERT_INFO(!kernel_->parameter_list[arg_id].is_array,
"Assigning scalar value to external (numpy) array argument is "
"not allowed.");

Expand Down
8 changes: 4 additions & 4 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ Kernel &Program::get_snode_reader(SNode *snode) {
ker.name = kernel_name;
ker.is_accessor = true;
for (int i = 0; i < snode->num_active_indices; i++)
ker.insert_scalar_arg(PrimitiveType::i32);
ker.insert_ret(snode->dt);
ker.add_scalar_param(PrimitiveType::i32);
ker.add_ret(snode->dt);
return ker;
}

Expand All @@ -393,8 +393,8 @@ Kernel &Program::get_snode_writer(SNode *snode) {
ker.name = kernel_name;
ker.is_accessor = true;
for (int i = 0; i < snode->num_active_indices; i++)
ker.insert_scalar_arg(PrimitiveType::i32);
ker.insert_scalar_arg(snode->dt);
ker.add_scalar_param(PrimitiveType::i32);
ker.add_scalar_param(snode->dt);
return ker;
}

Expand Down
16 changes: 8 additions & 8 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,10 @@ void export_lang(py::module &m) {
// TODO(#2193): Also apply to @ti.func?
self->no_activate.push_back(snode);
})
.def("insert_scalar_arg", &Kernel::insert_scalar_arg)
.def("insert_arr_arg", &Kernel::insert_arr_arg)
.def("insert_texture_arg", &Kernel::insert_texture_arg)
.def("insert_ret", &Kernel::insert_ret)
.def("add_scalar_param", &Kernel::add_scalar_param)
.def("add_arr_param", &Kernel::add_arr_param)
.def("add_texture_param", &Kernel::add_texture_param)
.def("add_ret", &Kernel::add_ret)
.def("get_ret_int", &Kernel::get_ret_int)
.def("get_ret_uint", &Kernel::get_ret_uint)
.def("get_ret_float", &Kernel::get_ret_float)
Expand Down Expand Up @@ -705,10 +705,10 @@ void export_lang(py::module &m) {
.def("set_extra_arg_int", &LaunchContextBuilder::set_extra_arg_int);

py::class_<Function>(m, "Function")
.def("insert_scalar_arg", &Function::insert_scalar_arg)
.def("insert_arr_arg", &Function::insert_arr_arg)
.def("insert_texture_arg", &Function::insert_texture_arg)
.def("insert_ret", &Function::insert_ret)
.def("add_scalar_param", &Function::add_scalar_param)
.def("add_arr_param", &Function::add_arr_param)
.def("add_texture_param", &Function::add_texture_param)
.def("add_ret", &Function::add_ret)
.def("set_function_body",
py::overload_cast<const std::function<void()> &>(
&Function::set_function_body))
Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/llvm/launch_arg_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ bool LlvmLaunchArgInfo::operator==(const LlvmLaunchArgInfo &other) const {

std::vector<LlvmLaunchArgInfo> infer_launch_args(const Kernel *kernel) {
std::vector<LlvmLaunchArgInfo> res;
res.reserve(kernel->args.size());
for (const auto &a : kernel->args) {
res.reserve(kernel->parameter_list.size());
for (const auto &a : kernel->parameter_list) {
res.push_back(LlvmLaunchArgInfo{a.is_array});
}
return res;
Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/metal/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ std::string KernelAttributes::debug_string() const {

KernelContextAttributes::KernelContextAttributes(const Kernel &kernel)
: ctx_bytes_(0), extra_args_bytes_(RuntimeContext::extra_args_size) {
arg_attribs_vec_.reserve(kernel.args.size());
for (const auto &ka : kernel.args) {
arg_attribs_vec_.reserve(kernel.parameter_list.size());
for (const auto &ka : kernel.parameter_list) {
ArgAttributes ma;
ma.dt = to_metal_type(ka.get_element_type());
const size_t dt_bytes = metal_data_type_bytes(ma.dt);
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class ConstantFold : public BasicStmtVisitor {
};

auto ker = std::make_unique<Kernel>(*program, func, kernel_name);
ker->insert_ret(id.ret);
ker->insert_scalar_arg(id.lhs);
ker->add_ret(id.ret);
ker->add_scalar_param(id.lhs);
if (id.is_binary)
ker->insert_scalar_arg(id.rhs);
ker->add_scalar_param(id.rhs);
ker->is_evaluator = true;

auto *ker_ptr = ker.get();
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class Inliner : public BasicStmtVisitor {
void visit(FuncCallStmt *stmt) override {
auto *func = stmt->func;
TI_ASSERT(func);
TI_ASSERT(func->args.size() == stmt->args.size());
TI_ASSERT(func->parameter_list.size() == stmt->args.size());
TI_ASSERT(func->ir->is<Block>());
TI_ASSERT(func->rets.size() <= 1);
auto inlined_ir = irpass::analysis::clone(func->ir.get());
if (!func->args.empty()) {
if (!func->parameter_list.empty()) {
irpass::replace_statements(
inlined_ir.get(),
/*filter=*/[&](Stmt *s) { return s->is<ArgLoadStmt>(); },
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/aot/dx12/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace fs = std::filesystem;

kernel_simple_ret =
std::make_unique<Kernel>(program, builder.extract_ir(), "simple_ret");
kernel_simple_ret->insert_ret(PrimitiveType::f32);
kernel_simple_ret->add_ret(PrimitiveType::f32);
}

{
Expand Down Expand Up @@ -94,7 +94,7 @@ namespace fs = std::filesystem;
builder.create_return(builder.create_local_load(sum));

kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret");
kernel_ret->insert_ret(PrimitiveType::i32);
kernel_ret->add_ret(PrimitiveType::i32);
}

aot_builder->add("simple_ret", kernel_simple_ret.get());
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ TEST(IRBuilder, ExternalPtr) {
builder.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2]
auto block = builder.extract_ir();
auto ker = std::make_unique<Kernel>(*test_prog.prog(), std::move(block));
ker->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {1});
ker->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_external_array_with_shape(
/*arg_id=*/0, (uint64)array.get(), size, {size});
Expand Down Expand Up @@ -171,7 +171,7 @@ TEST(IRBuilder, AtomicOp) {
builder.create_atomic_add(a0ptr, one); // a[0] += 1
auto block = builder.extract_ir();
auto ker = std::make_unique<Kernel>(*test_prog.prog(), std::move(block));
ker->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {1});
ker->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_external_array_with_shape(
/*arg_id=*/0, (uint64)array.get(), size, {size});
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/ir/ndarray_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ std::unique_ptr<Kernel> setup_kernel1(Program *prog) {
}
auto block = builder1.extract_ir();
auto ker1 = std::make_unique<Kernel>(*prog, std::move(block), "ker1");
ker1->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {1});
ker1->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
return ker1;
}

Expand All @@ -39,8 +39,8 @@ std::unique_ptr<Kernel> setup_kernel2(Program *prog) {
}
auto block2 = builder2.extract_ir();
auto ker2 = std::make_unique<Kernel>(*prog, std::move(block2), "ker2");
ker2->insert_arr_arg(get_data_type<int>(), /*total_dim=*/1, {1});
ker2->insert_scalar_arg(get_data_type<int>());
ker2->add_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker2->add_scalar_param(get_data_type<int>());
return ker2;
}
} // namespace taichi::lang
Loading

0 comments on commit b943d3d

Please sign in to comment.