Skip to content

Commit

Permalink
[type] [refactor] Decouple quant from SNode 6/n: Rewrite extract_quan…
Browse files Browse the repository at this point in the history
…t_float() without SNode (#5448)

* [type] [refactor] Rewrite extract_quant_float() without SNode

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] committed Jul 19, 2022
1 parent 0b5794b commit 2e1675e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 132 deletions.
21 changes: 7 additions & 14 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,21 +535,14 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
return true; // on CUDA, pass the argument by value
}

llvm::Value *create_intrinsic_load(const DataType &dtype,
llvm::Value *data_ptr) override {
// Issue an CUDA "__ldg" instruction so that data are cached in
// the CUDA read-only data cache.
auto llvm_dtype = llvm_type(dtype);
auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0);
llvm::Intrinsic::ID intrin;
if (is_real(dtype)) {
intrin = llvm::Intrinsic::nvvm_ldg_global_f;
} else {
intrin = llvm::Intrinsic::nvvm_ldg_global_i;
}
llvm::Value *create_intrinsic_load(llvm::Value *ptr,
llvm::Type *ty) override {
// Issue an "__ldg" instruction to cache data in the read-only data cache.
auto intrin = ty->isFloatingPointTy() ? llvm::Intrinsic::nvvm_ldg_global_f
: llvm::Intrinsic::nvvm_ldg_global_i;
return builder->CreateIntrinsic(
intrin, {llvm_dtype, llvm_dtype_ptr},
{data_ptr, tlctx->get_constant(data_type_size(dtype))});
intrin, {ty, llvm::PointerType::get(ty, 0)},
{ptr, tlctx->get_constant(ty->getScalarSizeInBits())});
}

void visit(GlobalLoadStmt *stmt) override {
Expand Down
34 changes: 15 additions & 19 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1416,8 +1416,8 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
}
}

llvm::Value *CodeGenLLVM::create_intrinsic_load(const DataType &dtype,
llvm::Value *data_ptr) {
llvm::Value *CodeGenLLVM::create_intrinsic_load(llvm::Value *ptr,
llvm::Type *ty) {
TI_NOT_IMPLEMENTED;
}

Expand All @@ -1428,24 +1428,28 @@ void CodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
auto get_ch = stmt->src->as<GetChStmt>();
auto physical_type = get_ch->input_snode->physical_type;
auto physical_type = llvm_type(get_ch->input_snode->physical_type);
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto physical_value = should_cache_as_read_only
? create_intrinsic_load(byte_ptr, physical_type)
: builder->CreateLoad(physical_type, byte_ptr);
if (auto qit = val_type->cast<QuantIntType>()) {
llvm_val[stmt] =
load_quant_int(ptr, qit, physical_type, should_cache_as_read_only);
llvm_val[stmt] = extract_quant_int(physical_value, bit_offset, qit);
} else if (auto qfxt = val_type->cast<QuantFixedType>()) {
llvm_val[stmt] =
load_quant_fixed(ptr, qfxt, physical_type, should_cache_as_read_only);
qit = qfxt->get_digits_type()->as<QuantIntType>();
auto digits = extract_quant_int(physical_value, bit_offset, qit);
llvm_val[stmt] = reconstruct_quant_fixed(digits, qfxt);
} else {
TI_ASSERT(val_type->is<QuantFloatType>());
TI_ASSERT(get_ch->input_snode->dt->is<BitStructType>());
llvm_val[stmt] = load_quant_float(
ptr, get_ch->input_snode->dt->as<BitStructType>(),
get_ch->output_snode->id_in_bit_struct, should_cache_as_read_only);
llvm_val[stmt] = extract_quant_float(
physical_value, get_ch->input_snode->dt->as<BitStructType>(),
get_ch->output_snode->id_in_bit_struct);
}
} else {
// Byte pointer case.
if (should_cache_as_read_only) {
llvm_val[stmt] = create_intrinsic_load(stmt->ret_type, ptr);
llvm_val[stmt] = create_intrinsic_load(ptr, llvm_type(stmt->ret_type));
} else {
llvm_val[stmt] =
builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr);
Expand Down Expand Up @@ -1610,14 +1614,6 @@ std::tuple<llvm::Value *, llvm::Value *> CodeGenLLVM::load_bit_ptr(
return std::make_tuple(byte_ptr, bit_offset);
}

llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *bit_ptr,
int bit_offset_delta) {
auto [byte_ptr, bit_offset] = load_bit_ptr(bit_ptr);
auto new_bit_offset =
builder->CreateAdd(bit_offset, tlctx->get_constant(bit_offset_delta));
return create_bit_ptr(byte_ptr, new_bit_offset);
}

void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
llvm::Value *parent = nullptr;
parent = llvm_val[stmt->input_snode];
Expand Down
34 changes: 5 additions & 29 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,46 +250,24 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void store_quant_floats_with_shared_exponents(BitStructStoreStmt *stmt);

llvm::Value *extract_quant_float(llvm::Value *local_bit_struct,
SNode *digits_snode);

virtual llvm::Value *create_intrinsic_load(const DataType &dtype,
llvm::Value *data_ptr);

llvm::Value *load_quant_int(llvm::Value *ptr,
QuantIntType *qit,
Type *physical_type,
bool should_cache_as_read_only);
llvm::Value *extract_quant_float(llvm::Value *physical_value,
BitStructType *bit_struct,
int digits_id);

llvm::Value *extract_quant_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
QuantIntType *qit);

llvm::Value *load_quant_fixed(llvm::Value *ptr,
QuantFixedType *qfxt,
Type *physical_type,
bool should_cache_as_read_only);

llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt);

llvm::Value *load_quant_float(llvm::Value *digits_ptr,
BitStructType *bit_struct,
int digits_id,
bool should_cache_as_read_only);

llvm::Value *load_quant_float(llvm::Value *digits_ptr,
llvm::Value *exponent_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool should_cache_as_read_only,
bool shared_exponent);

llvm::Value *reconstruct_quant_float(llvm::Value *input_digits,
llvm::Value *input_exponent_val,
QuantFloatType *qflt,
bool shared_exponent);

virtual llvm::Value *create_intrinsic_load(llvm::Value *ptr, llvm::Type *ty);

void create_global_load(GlobalLoadStmt *stmt, bool should_cache_as_read_only);

void visit(GlobalLoadStmt *stmt) override;
Expand All @@ -308,8 +286,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

std::tuple<llvm::Value *, llvm::Value *> load_bit_ptr(llvm::Value *bit_ptr);

llvm::Value *offset_bit_ptr(llvm::Value *bit_ptr, int bit_offset_delta);

void visit(SNodeLookupStmt *stmt) override;

void visit(GetChStmt *stmt) override;
Expand Down
88 changes: 18 additions & 70 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
BitStructStoreStmt *stmt) {
// handle each exponent separately
auto snode = stmt->get_bit_struct_snode();
auto bit_struct_physical_type =
snode->dt->as<BitStructType>()->get_physical_type();
auto bit_struct = snode->dt->as<BitStructType>();
auto bit_struct_physical_type = bit_struct->get_physical_type();
auto local_bit_struct = builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(bit_struct_physical_type),
Expand All @@ -352,7 +352,8 @@ void CodeGenLLVM::store_quant_floats_with_shared_exponents(
input != stmt->ch_ids.end()) {
floats.push_back(llvm_val[stmt->values[input - stmt->ch_ids.begin()]]);
} else {
floats.push_back(extract_quant_float(local_bit_struct, user));
floats.push_back(
extract_quant_float(local_bit_struct, bit_struct, ch_id));
}
}
// convert to i32 for bit operations
Expand Down Expand Up @@ -479,34 +480,21 @@ llvm::Value *CodeGenLLVM::extract_digits_from_f32_with_shared_exponent(
return builder->CreateLShr(digits, exp_offset);
}

llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *local_bit_struct,
SNode *digits_snode) {
auto qflt = digits_snode->dt->as<QuantFloatType>();
auto exponent_type = qflt->get_exponent_type()->as<QuantIntType>();
auto digits_type = qflt->get_digits_type()->as<QuantIntType>();
auto digits = extract_quant_int(local_bit_struct,
tlctx->get_constant(digits_snode->bit_offset),
digits_type);
llvm::Value *CodeGenLLVM::extract_quant_float(llvm::Value *physical_value,
BitStructType *bit_struct,
int digits_id) {
auto qflt = bit_struct->get_member_type(digits_id)->as<QuantFloatType>();
auto exponent_id = bit_struct->get_member_exponent(digits_id);
auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id);
auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id);
auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id);
auto digits =
extract_quant_int(physical_value, tlctx->get_constant(digits_bit_offset),
qflt->get_digits_type()->as<QuantIntType>());
auto exponent = extract_quant_int(
local_bit_struct,
tlctx->get_constant(digits_snode->exp_snode->bit_offset), exponent_type);
return reconstruct_quant_float(digits, exponent, qflt,
digits_snode->owns_shared_exponent);
}

llvm::Value *CodeGenLLVM::load_quant_int(llvm::Value *ptr,
QuantIntType *qit,
Type *physical_type,
bool should_cache_as_read_only) {
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto physical_value = should_cache_as_read_only
? create_intrinsic_load(physical_type, byte_ptr)
: builder->CreateLoad(
#ifdef TI_LLVM_15
llvm_type(physical_type),
#endif
byte_ptr);
return extract_quant_int(physical_value, bit_offset, qit);
physical_value, tlctx->get_constant(exponent_bit_offset),
qflt->get_exponent_type()->as<QuantIntType>());
return reconstruct_quant_float(digits, exponent, qflt, shared_exponent);
}

llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
Expand Down Expand Up @@ -537,15 +525,6 @@ llvm::Value *CodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
qit->get_is_signed());
}

llvm::Value *CodeGenLLVM::load_quant_fixed(llvm::Value *ptr,
QuantFixedType *qfxt,
Type *physical_type,
bool should_cache_as_read_only) {
auto digits = load_quant_int(ptr, qfxt->get_digits_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
return reconstruct_quant_fixed(digits, qfxt);
}

llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt) {
// Compute float(digits) * scale
Expand All @@ -561,37 +540,6 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
return builder->CreateFMul(cast, s);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr,
BitStructType *bit_struct,
int digits_id,
bool should_cache_as_read_only) {
auto exponent_id = bit_struct->get_member_exponent(digits_id);
auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id);
auto digits_bit_offset = bit_struct->get_member_bit_offset(digits_id);
auto bit_offset_delta = exponent_bit_offset - digits_bit_offset;
auto exponent_ptr = offset_bit_ptr(digits_ptr, bit_offset_delta);
auto qflt = bit_struct->get_member_type(digits_id)->as<QuantFloatType>();
auto physical_type = bit_struct->get_physical_type();
auto shared_exponent = bit_struct->get_member_owns_shared_exponent(digits_id);
return load_quant_float(digits_ptr, exponent_ptr, qflt, physical_type,
should_cache_as_read_only, shared_exponent);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_ptr,
llvm::Value *exponent_ptr,
QuantFloatType *qflt,
Type *physical_type,
bool should_cache_as_read_only,
bool shared_exponent) {
auto digits =
load_quant_int(digits_ptr, qflt->get_digits_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
auto exponent_val = load_quant_int(
exponent_ptr, qflt->get_exponent_type()->as<QuantIntType>(),
physical_type, should_cache_as_read_only);
return reconstruct_quant_float(digits, exponent_val, qflt, shared_exponent);
}

llvm::Value *CodeGenLLVM::reconstruct_quant_float(
llvm::Value *input_digits,
llvm::Value *input_exponent_val,
Expand Down

0 comments on commit 2e1675e

Please sign in to comment.