Skip to content

Commit

Permalink
Refactor2023: Refactor Kernel::get_ret* (bugs)
Browse files Browse the repository at this point in the history
  • Loading branch information
PGZXB committed Dec 5, 2022
1 parent 3577275 commit d87a8ba
Show file tree
Hide file tree
Showing 18 changed files with 197 additions and 138 deletions.
4 changes: 3 additions & 1 deletion cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ void run_snode() {

launch_kernel(&program, *kernel_init, ctx_init.get_context());
launch_kernel(&program, *kernel_ret, ctx_ret.get_context());
std::cout << program.fetch_result<int>(0) << std::endl;
std::cout << ctx_ret.get_ret<int>(program.get_compute_device(), 0)
<< std::endl;

launch_kernel(&program, *kernel_ext, ctx_ext.get_context());
for (int i = 0; i < n; i++)
std::cout << ext_arr[i] << " ";
Expand Down
16 changes: 10 additions & 6 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,21 +813,25 @@ def call_back():
if has_ret:
if id(ret_dt) in primitive_types.integer_type_ids:
if is_signed(cook_dtype(ret_dt)):
ret = t_kernel.get_ret_int(0)
print("$$$$$$--1")
launch_ctx.get_ret_int(impl.get_runtime().prog.get_compute_device(), 0)
print("$$$$$$--2")
ret = launch_ctx.get_ret_int(impl.get_runtime().prog.get_compute_device(), 0)
print("$$$$$$--3")
else:
ret = t_kernel.get_ret_uint(0)
ret = launch_ctx.get_ret_uint(impl.get_runtime().prog.get_compute_device(), 0)
elif id(ret_dt) in primitive_types.real_type_ids:
ret = t_kernel.get_ret_float(0)
ret = launch_ctx.get_ret_float(impl.get_runtime().prog.get_compute_device(), 0)
elif id(ret_dt.dtype) in primitive_types.integer_type_ids:
if is_signed(cook_dtype(ret_dt.dtype)):
it = iter(t_kernel.get_ret_int_tensor(0))
it = iter(launch_ctx.get_ret_int_tensor(impl.get_runtime().prog.get_compute_device(), 0))
else:
it = iter(t_kernel.get_ret_uint_tensor(0))
it = iter(launch_ctx.get_ret_uint_tensor(impl.get_runtime().prog.get_compute_device(), 0))
ret = Matrix([[next(it) for _ in range(ret_dt.m)]
for _ in range(ret_dt.n)],
ndim=getattr(ret_dt, 'ndim', 2))
else:
it = iter(t_kernel.get_ret_float_tensor(0))
it = iter(launch_ctx.get_ret_float_tensor(impl.get_runtime().prog.get_compute_device(), 0))
ret = Matrix([[next(it) for _ in range(ret_dt.m)]
for _ in range(ret_dt.n)],
ndim=getattr(ret_dt, 'ndim', 2))
Expand Down
173 changes: 114 additions & 59 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,86 +201,141 @@ void LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) {
ctx_->set_arg<uint64>(arg_id, d);
}

RuntimeContext &LaunchContextBuilder::get_context() {
kernel_->program->prepare_runtime_context(ctx_);
return *ctx_;
}

template <typename T>
T Kernel::fetch_ret(DataType dt, int i) {
if (dt->is_primitive(PrimitiveTypeID::f32)) {
return (T)program->fetch_result<float32>(i);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return (T)program->fetch_result<float64>(i);
} else if (dt->is_primitive(PrimitiveTypeID::i32)) {
return (T)program->fetch_result<int32>(i);
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
return (T)program->fetch_result<int64>(i);
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
return (T)program->fetch_result<int8>(i);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return (T)program->fetch_result<int16>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return (T)program->fetch_result<uint8>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
return (T)program->fetch_result<uint16>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u32)) {
return (T)program->fetch_result<uint32>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u64)) {
return (T)program->fetch_result<uint64>(i);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
// use f32 to interact with python
return (T)program->fetch_result<float32>(i);
} else {
TI_NOT_IMPLEMENTED
}
}

float64 Kernel::get_ret_float(int i) {
auto dt = rets[i].dt->get_compute_type();
return fetch_ret<float64>(dt, i);
// Refactor2023:FIXME: Bad smell. Use template function.
float64 LaunchContextBuilder::get_ret_float(Device *device,
unsigned retNo) const {
auto *dt = kernel_->rets[retNo].dt->get_compute_type();
return fetch_ret<float64>(dt, retNo, device, ctx_);
}

int64 Kernel::get_ret_int(int i) {
auto dt = rets[i].dt->get_compute_type();
return fetch_ret<int64>(dt, i);
int64 LaunchContextBuilder::get_ret_int(Device *device, unsigned retNo) {
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);
auto *dt = kernel_->rets[retNo].dt->get_compute_type();
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);
auto p = fetch_ret<int64>(dt, retNo, device, ctx_);
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);
return p;
}

uint64 Kernel::get_ret_uint(int i) {
auto dt = rets[i].dt->get_compute_type();
return fetch_ret<uint64>(dt, i);
uint64 LaunchContextBuilder::get_ret_uint(Device *device,
unsigned retNo) const {
auto *dt = kernel_->rets[retNo].dt->get_compute_type();
return fetch_ret<uint64>(dt, retNo, device, ctx_);
}

std::vector<int64> Kernel::get_ret_int_tensor(int i) {
DataType dt = rets[i].dt->as<TensorType>()->get_element_type();
int size = rets[i].dt->as<TensorType>()->get_num_elements();
std::vector<int64> LaunchContextBuilder::get_ret_int_tensor(
Device *device,
unsigned retNo) const {
auto *tensor_dt = kernel_->rets[retNo].dt->as<TensorType>();
TI_ASSERT(tensor_dt != nullptr);
DataType element_dt = tensor_dt->get_element_type();
int element_count = tensor_dt->get_num_elements();
TI_ASSERT(element_count >= 0);
std::vector<int64> res;
for (int j = 0; j < size; j++) {
res.emplace_back(fetch_ret<int64>(dt, j));
for (unsigned j = 0; j < (unsigned)element_count; ++j) {
res.push_back(fetch_ret<int64>(element_dt, j, device, ctx_));
}
return res;
}

std::vector<uint64> Kernel::get_ret_uint_tensor(int i) {
DataType dt = rets[i].dt->as<TensorType>()->get_element_type();
int size = rets[i].dt->as<TensorType>()->get_num_elements();
std::vector<uint64> LaunchContextBuilder::get_ret_uint_tensor(
Device *device,
unsigned retNo) const {
auto *tensor_dt = kernel_->rets[retNo].dt->as<TensorType>();
TI_ASSERT(tensor_dt != nullptr);
DataType element_dt = tensor_dt->get_element_type();
int element_count = tensor_dt->get_num_elements();
TI_ASSERT(element_count >= 0);
std::vector<uint64> res;
for (int j = 0; j < size; j++) {
res.emplace_back(fetch_ret<uint64>(dt, j));
for (unsigned j = 0; j < (unsigned)element_count; ++j) {
res.push_back(fetch_ret<uint64>(element_dt, j, device, ctx_));
}
return res;
}

std::vector<float64> Kernel::get_ret_float_tensor(int i) {
DataType dt = rets[i].dt->as<TensorType>()->get_element_type();
int size = rets[i].dt->as<TensorType>()->get_num_elements();
std::vector<float64> LaunchContextBuilder::get_ret_float_tensor(
Device *device,
unsigned retNo) const {
auto *tensor_dt = kernel_->rets[retNo].dt->as<TensorType>();
TI_ASSERT(tensor_dt != nullptr);
DataType element_dt = tensor_dt->get_element_type();
int element_count = tensor_dt->get_num_elements();
TI_ASSERT(element_count >= 0);
std::vector<float64> res;
for (int j = 0; j < size; j++) {
res.emplace_back(fetch_ret<float64>(dt, j));
for (unsigned j = 0; j < (unsigned)element_count; ++j) {
res.push_back(fetch_ret<float64>(element_dt, j, device, ctx_));
}
return res;
}

RuntimeContext &LaunchContextBuilder::get_context() {
// Refactor2023:FIXME: Move to KernelLauncher
kernel_->program->prepare_runtime_context(ctx_);
return *ctx_;
}

template <typename T, typename G>
T ___taichi_union_cast_with_different_sizes(G g) {
fmt::print(stderr, "{}:{} {} 1\n", __FILE__, __LINE__, __func__);
union {
T t;
G g;
} u;
fmt::print(stderr, "{}:{} {} 1\n", __FILE__, __LINE__, __func__);
u.g = g;
fmt::print(stderr, "{}:{} {} 1\n", __FILE__, __LINE__, __func__);

auto t = u.t;
fmt::print(stderr, "{}:{} {} 1\n", __FILE__, __LINE__, __func__);

return t;
}

template <typename T>
T LaunchContextBuilder::fetch_ret(DataType dt,
unsigned retNo,
Device *device,
RuntimeContext *rt_ctx) {
TI_ASSERT(device);
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);

auto *primative_dt = dt->cast<PrimitiveType>();
if (!primative_dt) {
TI_NOT_IMPLEMENTED;
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);
}
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);

#define FETCH_AND_CAST(dt_enum, dt_type) \
case dt_enum: { \
fmt::print(stderr, "{}:{} {} 1\n", __FILE__, __LINE__, __func__); \
auto i = device->fetch_result_uint64(retNo, rt_ctx->result_buffer); \
fmt::print(stderr, "{}:{} {} 2\n", __FILE__, __LINE__, __func__); \
return (T)___taichi_union_cast_with_different_sizes<dt_type>(i); \
}
TI_ASSERT(device);
TI_ASSERT(rt_ctx->result_buffer);
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);

switch (primative_dt->type) {
FETCH_AND_CAST(PrimitiveTypeID::f32, float32);
FETCH_AND_CAST(PrimitiveTypeID::f64, float64);
FETCH_AND_CAST(PrimitiveTypeID::i32, int32);
FETCH_AND_CAST(PrimitiveTypeID::i64, int64);
FETCH_AND_CAST(PrimitiveTypeID::i8, int8);
FETCH_AND_CAST(PrimitiveTypeID::i16, int16);
FETCH_AND_CAST(PrimitiveTypeID::u8, uint8);
FETCH_AND_CAST(PrimitiveTypeID::u16, uint16);
FETCH_AND_CAST(PrimitiveTypeID::u32, uint32);
FETCH_AND_CAST(PrimitiveTypeID::u64, uint64);
FETCH_AND_CAST(PrimitiveTypeID::f16, float32); // use f32
default:
TI_NOT_IMPLEMENTED;
}
fmt::print(stderr, "{}:{} {}\n", __FILE__, __LINE__, __func__);
#undef FETCH_AND_CAST
}

std::string Kernel::get_name() const {
return name;
}
Expand Down
44 changes: 24 additions & 20 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,26 @@ class LaunchContextBuilder {
// This ignores the underlying kernel's |arg_id|-th arg type.
void set_arg_raw(int arg_id, uint64 d);

template <typename T>
T get_ret(Device *device, unsigned retNo) const;

float64 get_ret_float(Device *device, unsigned retNo) const;
int64 get_ret_int(Device *device, unsigned retNo);
uint64 get_ret_uint(Device *device, unsigned retNo) const;
std::vector<int64> get_ret_int_tensor(Device *device, unsigned retNo) const;
std::vector<uint64> get_ret_uint_tensor(Device *device, unsigned retNo) const;
std::vector<float64> get_ret_float_tensor(Device *device,
unsigned retNo) const;

RuntimeContext &get_context();

private:
template <typename T>
static T fetch_ret(DataType dt,
unsigned retNo,
Device *device,
RuntimeContext *rt_ctx);

Kernel *kernel_;
std::unique_ptr<RuntimeContext> owned_ctx_;
// |ctx_| *almost* always points to |owned_ctx_|. However, it is possible
Expand Down Expand Up @@ -94,28 +111,9 @@ class TI_DLL_EXPORT Kernel : public Callable {
lowered_ = lowered;
}

// Refactor2023:FIXME: Move
// Refactor2023:FIXME: Move to KernelLauncher
LaunchContextBuilder make_launch_context();

// Refactor2023:FIXME: Move
template <typename T>
T fetch_ret(DataType dt, int i);

// Refactor2023:FIXME: Move
float64 get_ret_float(int i);

// Refactor2023:FIXME: Move
int64 get_ret_int(int i);
// Refactor2023:FIXME: Move
uint64 get_ret_uint(int i);

// Refactor2023:FIXME: Move
std::vector<int64> get_ret_int_tensor(int i);
// Refactor2023:FIXME: Move
std::vector<uint64> get_ret_uint_tensor(int i);
// Refactor2023:FIXME: Move
std::vector<float64> get_ret_float_tensor(int i);

// Refactor2023:FIXME: Pre-refactor & Remove
uint64 get_next_task_id() {
return task_counter_++;
Expand Down Expand Up @@ -178,4 +176,10 @@ class TI_DLL_EXPORT Kernel : public Callable {
// Refactor2023:FIXME: Remove
void launch_kernel(Program *prog, Kernel &kernel, RuntimeContext &ctx);

template <typename T>
T LaunchContextBuilder::get_ret(Device *device, unsigned retNo) const {
auto *dt = kernel_->rets[retNo].dt->get_compute_type();
return fetch_ret<float64>(dt, retNo, device, ctx_);
}

} // namespace taichi::lang
4 changes: 0 additions & 4 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,6 @@ Kernel &Program::get_snode_writer(SNode *snode) {
return ker;
}

uint64 Program::fetch_result_uint64(int i) {
return program_impl_->fetch_result_uint64(i, result_buffer);
}

void Program::finalize() {
if (finalized_) {
return;
Expand Down
7 changes: 0 additions & 7 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,6 @@ class TI_DLL_EXPORT Program {

Kernel &get_snode_writer(SNode *snode);

uint64 fetch_result_uint64(int i);

template <typename T>
T fetch_result(int i) {
return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64(i));
}

Arch get_host_arch() {
return host_arch();
}
Expand Down
4 changes: 0 additions & 4 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ class ProgramImpl {
virtual void finalize() {
}

virtual uint64 fetch_result_uint64(int i, uint64 *result_buffer) {
return result_buffer[i];
}

private:
};

Expand Down
4 changes: 2 additions & 2 deletions taichi/program/snode_rw_accessors_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector<int> &I) {
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
lang::launch_kernel(reader_->program, *reader_, launch_ctx.get_context());
prog_->synchronize();
auto ret = reader_->get_ret_float(0);
auto ret = launch_ctx.get_ret_float(prog_->get_compute_device(), 0);
return ret;
}

Expand All @@ -70,7 +70,7 @@ int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector<int> &I) {
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
lang::launch_kernel(reader_->program, *reader_, launch_ctx.get_context());
prog_->synchronize();
auto ret = reader_->get_ret_int(0);
auto ret = launch_ctx.get_ret_int(prog_->get_compute_device(), 0);
return ret;
}

Expand Down
Loading

0 comments on commit d87a8ba

Please sign in to comment.