Skip to content

Commit

Permalink
[refactor] Use f16 function from external lib
Browse files Browse the repository at this point in the history
ghstack-source-id: 8f9dc182227dd22a25e14de3a0db8cf10241ef9e
Pull Request resolved: taichi-dev#7622
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Mar 22, 2023
1 parent d93516e commit e2455dc
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 89 deletions.
20 changes: 5 additions & 15 deletions taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,25 +323,15 @@ TypedConstant LaunchContextBuilder::fetch_ret_impl(int offset, const Type *dt) {
auto primitive_type = dt->as<PrimitiveType>();
char *ptr = result_buffer_.get() + offset;
switch (primitive_type->type) {
#define RETURN_PRIMITIVE(type, ctype) \
case PrimitiveTypeID::type: \
#define PER_C_TYPE(type, ctype) \
case PrimitiveTypeID::type: \
return TypedConstant(*(ctype *)ptr);

RETURN_PRIMITIVE(f32, float32);
RETURN_PRIMITIVE(f64, float64);
RETURN_PRIMITIVE(i8, int8);
RETURN_PRIMITIVE(i16, int16);
RETURN_PRIMITIVE(i32, int32);
RETURN_PRIMITIVE(i64, int64);
RETURN_PRIMITIVE(u8, uint8);
RETURN_PRIMITIVE(u16, uint16);
RETURN_PRIMITIVE(u32, uint32);
RETURN_PRIMITIVE(u64, uint64);
#undef RETURN_PRIMITIVE
#include "taichi/inc/data_type_with_c_type.inc.h"
#undef PER_C_TYPE
case PrimitiveTypeID::f16: {
// first fetch the data as u16, and then convert it to f32
uint16 half = *(uint16 *)ptr;
return TypedConstant(bit::half_to_float(half));
return TypedConstant(fp16_ieee_to_fp32_value(half));
}
default:
TI_NOT_IMPLEMENTED
Expand Down
4 changes: 0 additions & 4 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,6 @@ class TI_DLL_EXPORT Program {

uint64 fetch_result_uint64(int i);

TypedConstant fetch_result(char *result_buffer, int offset, const Type *dt) {
return program_impl_->fetch_result(result_buffer, offset, dt);
}

template <typename T>
T fetch_result(int i) {
return taichi_union_cast_with_different_sizes<T>(fetch_result_uint64(i));
Expand Down
6 changes: 0 additions & 6 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,6 @@ class ProgramImpl {
return result_buffer[i];
}

virtual TypedConstant fetch_result(char *result_buffer,
int offset,
const Type *dt) {
TI_NOT_IMPLEMENTED;
}

virtual std::string get_kernel_return_data_layout() {
return "";
};
Expand Down
42 changes: 0 additions & 42 deletions taichi/runtime/program_impls/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,48 +150,6 @@ class LlvmProgramImpl : public ProgramImpl {
return runtime_exec_->fetch_result_uint64(i, result_buffer);
}

TypedConstant fetch_result(char *result_buffer,
int offset,
const Type *dt) override {
if (dt->is_primitive(PrimitiveTypeID::f32)) {
return TypedConstant(
runtime_exec_->fetch_result<float32>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return TypedConstant(
runtime_exec_->fetch_result<float64>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::i32)) {
return TypedConstant(
runtime_exec_->fetch_result<int32>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
return TypedConstant(
runtime_exec_->fetch_result<int64>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
return TypedConstant(
runtime_exec_->fetch_result<int8>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return TypedConstant(
runtime_exec_->fetch_result<int16>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return TypedConstant(
runtime_exec_->fetch_result<uint8>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
return TypedConstant(
runtime_exec_->fetch_result<uint16>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::u32)) {
return TypedConstant(
runtime_exec_->fetch_result<uint32>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::u64)) {
return TypedConstant(
runtime_exec_->fetch_result<uint64>(result_buffer, offset));
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
// first fetch the data as u16, and then convert it to f32
uint16 half = runtime_exec_->fetch_result<uint16>(result_buffer, offset);
return TypedConstant(bit::half_to_float(half));
} else {
TI_NOT_IMPLEMENTED
}
}

template <typename T, typename... Args>
T runtime_query(const std::string &key,
uint64 *result_buffer,
Expand Down
22 changes: 0 additions & 22 deletions taichi/util/bit.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@ TI_FORCE_INLINE constexpr bool is_power_of_two(uint64 x) {
return x != 0 && (x & (x - 1)) == 0;
}

TI_FORCE_INLINE uint32 as_uint(const float32 x) {
return *(uint32 *)&x;
}

TI_FORCE_INLINE float32 as_float(const uint32 x) {
return *(float32 *)&x;
}

TI_FORCE_INLINE float32 half_to_float(const uint16 x) {
// Reference: https://stackoverflow.com/a/60047308
const uint32 e = (x & 0x7C00) >> 10; // exponent
const uint32 m = (x & 0x03FF) << 13; // mantissa
const uint32 v =
as_uint((float32)m) >>
23; // evil log2 bit hack to count leading zeros in denormalized format
return as_float(
(x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) |
((e == 0) & (m != 0)) *
((v - 37) << 23 | ((m << (150 - v)) &
0x007FE000))); // sign : normalized : denormalized
}

template <int length>
struct Bits {
static_assert(is_power_of_two(length), "length must be a power of two");
Expand Down

0 comments on commit e2455dc

Please sign in to comment.