From e2455dc6a084499f85f2aa3d32e4fee6dfcb1425 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 21 Mar 2023 17:52:47 +0800 Subject: [PATCH] [refactor] Use f16 function from external lib ghstack-source-id: 8f9dc182227dd22a25e14de3a0db8cf10241ef9e Pull Request resolved: https://github.com/taichi-dev/taichi/pull/7622 --- taichi/program/launch_context_builder.cpp | 20 +++------ taichi/program/program.h | 4 -- taichi/program/program_impl.h | 6 --- .../runtime/program_impls/llvm/llvm_program.h | 42 ------------------- taichi/util/bit.h | 22 ---------- 5 files changed, 5 insertions(+), 89 deletions(-) diff --git a/taichi/program/launch_context_builder.cpp b/taichi/program/launch_context_builder.cpp index ee12095796fed..27cfaa61bd9d6 100644 --- a/taichi/program/launch_context_builder.cpp +++ b/taichi/program/launch_context_builder.cpp @@ -323,25 +323,15 @@ TypedConstant LaunchContextBuilder::fetch_ret_impl(int offset, const Type *dt) { auto primitive_type = dt->as(); char *ptr = result_buffer_.get() + offset; switch (primitive_type->type) { -#define RETURN_PRIMITIVE(type, ctype) \ - case PrimitiveTypeID::type: \ +#define PER_C_TYPE(type, ctype) \ + case PrimitiveTypeID::type: \ return TypedConstant(*(ctype *)ptr); - - RETURN_PRIMITIVE(f32, float32); - RETURN_PRIMITIVE(f64, float64); - RETURN_PRIMITIVE(i8, int8); - RETURN_PRIMITIVE(i16, int16); - RETURN_PRIMITIVE(i32, int32); - RETURN_PRIMITIVE(i64, int64); - RETURN_PRIMITIVE(u8, uint8); - RETURN_PRIMITIVE(u16, uint16); - RETURN_PRIMITIVE(u32, uint32); - RETURN_PRIMITIVE(u64, uint64); -#undef RETURN_PRIMITIVE +#include "taichi/inc/data_type_with_c_type.inc.h" +#undef PER_C_TYPE case PrimitiveTypeID::f16: { // first fetch the data as u16, and then convert it to f32 uint16 half = *(uint16 *)ptr; - return TypedConstant(bit::half_to_float(half)); + return TypedConstant(fp16_ieee_to_fp32_value(half)); } default: TI_NOT_IMPLEMENTED diff --git a/taichi/program/program.h b/taichi/program/program.h index 88d2af13f4417..a33519362d7c2 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -185,10 +185,6 @@ class TI_DLL_EXPORT Program { uint64 fetch_result_uint64(int i); - TypedConstant fetch_result(char *result_buffer, int offset, const Type *dt) { - return program_impl_->fetch_result(result_buffer, offset, dt); - } - template T fetch_result(int i) { return taichi_union_cast_with_different_sizes(fetch_result_uint64(i)); diff --git a/taichi/program/program_impl.h b/taichi/program/program_impl.h index 0db486ed11891..75d13b6a2e9dc 100644 --- a/taichi/program/program_impl.h +++ b/taichi/program/program_impl.h @@ -156,12 +156,6 @@ class ProgramImpl { return result_buffer[i]; } - virtual TypedConstant fetch_result(char *result_buffer, - int offset, - const Type *dt) { - TI_NOT_IMPLEMENTED; - } - virtual std::string get_kernel_return_data_layout() { return ""; }; diff --git a/taichi/runtime/program_impls/llvm/llvm_program.h b/taichi/runtime/program_impls/llvm/llvm_program.h index 678c64f01c484..e198bb7c7dca3 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.h +++ b/taichi/runtime/program_impls/llvm/llvm_program.h @@ -150,48 +150,6 @@ class LlvmProgramImpl : public ProgramImpl { return runtime_exec_->fetch_result_uint64(i, result_buffer); } - TypedConstant fetch_result(char *result_buffer, - int offset, - const Type *dt) override { - if (dt->is_primitive(PrimitiveTypeID::f32)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::f64)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::i32)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::i64)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::i8)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::i16)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::u8)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::u16)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::u32)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::u64)) { - return TypedConstant( - runtime_exec_->fetch_result(result_buffer, offset)); - } else if (dt->is_primitive(PrimitiveTypeID::f16)) { - // first fetch the data as u16, and then convert it to f32 - uint16 half = runtime_exec_->fetch_result(result_buffer, offset); - return TypedConstant(bit::half_to_float(half)); - } else { - TI_NOT_IMPLEMENTED - } - } - template T runtime_query(const std::string &key, uint64 *result_buffer, diff --git a/taichi/util/bit.h b/taichi/util/bit.h index a33aae5cf9378..6d98026136f52 100644 --- a/taichi/util/bit.h +++ b/taichi/util/bit.h @@ -25,28 +25,6 @@ TI_FORCE_INLINE constexpr bool is_power_of_two(uint64 x) { return x != 0 && (x & (x - 1)) == 0; } -TI_FORCE_INLINE uint32 as_uint(const float32 x) { - return *(uint32 *)&x; -} - -TI_FORCE_INLINE float32 as_float(const uint32 x) { - return *(float32 *)&x; -} - -TI_FORCE_INLINE float32 half_to_float(const uint16 x) { - // Reference: https://stackoverflow.com/a/60047308 - const uint32 e = (x & 0x7C00) >> 10; // exponent - const uint32 m = (x & 0x03FF) << 13; // mantissa - const uint32 v = - as_uint((float32)m) >> - 23; // evil log2 bit hack to count leading zeros in denormalized format - return as_float( - (x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) | - ((e == 0) & (m != 0)) * - ((v - 37) << 23 | ((m << (150 - v)) & - 0x007FE000))); // sign : normalized : denormalized -} - template struct Bits { static_assert(is_power_of_two(length), "length must be a power of two");