Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 25, 2022
1 parent 6f898e3 commit 46fd43d
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 68 deletions.
10 changes: 6 additions & 4 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {

#ifdef TI_WITH_LLVM
// static
std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCPU::make_codegen_llvm(Kernel *kernel,
IRNode *ir) {
std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCPU::make_codegen_llvm(
Kernel *kernel,
IRNode *ir) {
return std::make_unique<TaskCodeGenCPU>(kernel, ir);
}

Expand Down Expand Up @@ -274,8 +275,9 @@ FunctionType CPUModuleToFunctionConverter::convert(
};
}

LLVMCompiledData KernelCodeGenCPU::modulegen(std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
LLVMCompiledData KernelCodeGenCPU::modulegen(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
TaskCodeGenCPU gen(kernel, stmt);
return gen.run_compilation();
}
Expand Down
5 changes: 3 additions & 2 deletions taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ TLANG_NAMESPACE_BEGIN

class KernelCodeGenCPU : public KernelCodeGen {
public:
KernelCodeGenCPU(Kernel *kernel, IRNode *ir = nullptr) : KernelCodeGen(kernel, ir) {
KernelCodeGenCPU(Kernel *kernel, IRNode *ir = nullptr)
: KernelCodeGen(kernel, ir) {
}

// TODO: Stop defining this macro guards in the headers
#ifdef TI_WITH_LLVM
static std::unique_ptr<TaskCodeGenLLVM> make_codegen_llvm(Kernel *kernel,
IRNode *ir);
IRNode *ir);

bool supports_offline_cache() const override {
return true;
Expand Down
9 changes: 5 additions & 4 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {

#ifdef TI_WITH_LLVM
// static
std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCUDA::make_codegen_llvm(Kernel *kernel,
IRNode *ir) {
std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCUDA::make_codegen_llvm(
Kernel *kernel,
IRNode *ir) {
return std::make_unique<TaskCodeGenCUDA>(kernel, ir);
}
#endif // TI_WITH_LLVM
Expand Down Expand Up @@ -823,8 +824,8 @@ FunctionType CUDAModuleToFunctionConverter::convert(

} else if (arr_sz > 0) {
// arg_buffers[i] is a DeviceAllocation*
// TODO: Unwraps DeviceAllocation* can be done at TaskCodeGenLLVM since
// it's shared by cpu and cuda.
// TODO: Unwraps DeviceAllocation* can be done at TaskCodeGenLLVM
// since it's shared by cpu and cuda.
DeviceAllocation *ptr =
static_cast<DeviceAllocation *>(arg_buffers[i]);
device_buffers[i] = executor->get_ndarray_alloc_info_ptr(*ptr);
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class KernelCodeGenCUDA : public KernelCodeGen {
// TODO: Stop defining this macro guards in the headers
#ifdef TI_WITH_LLVM
static std::unique_ptr<TaskCodeGenLLVM> make_codegen_llvm(Kernel *kernel,
IRNode *ir);
IRNode *ir);
#endif // TI_WITH_LLVM

bool supports_offline_cache() const override {
Expand Down
56 changes: 31 additions & 25 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ std::unique_ptr<RuntimeObject> TaskCodeGenLLVM::emit_struct_meta_object(
}

void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name,
llvm::Value *node_meta,
SNode *snode) {
llvm::Value *node_meta,
SNode *snode) {
RuntimeObject common("StructMeta", this, builder.get(), node_meta);
std::size_t element_size;
if (snode->type == SNodeType::dense) {
Expand Down Expand Up @@ -312,8 +312,8 @@ void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name,
}

TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
IRNode *ir,
std::unique_ptr<llvm::Module> &&module)
IRNode *ir,
std::unique_ptr<llvm::Module> &&module)
// TODO: simplify LLVMModuleBuilder ctor input
: LLVMModuleBuilder(
module == nullptr ? get_llvm_program(kernel->program)
Expand Down Expand Up @@ -738,8 +738,8 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) {
}

llvm::Value *TaskCodeGenLLVM::create_print(std::string tag,
DataType dt,
llvm::Value *value) {
DataType dt,
llvm::Value *value) {
if (!arch_is_cpu(kernel->arch)) {
TI_WARN("print not supported on arch {}", arch_name(kernel->arch));
return nullptr;
Expand All @@ -759,7 +759,8 @@ llvm::Value *TaskCodeGenLLVM::create_print(std::string tag,
return create_call(runtime_printf, func_type_func->getFunctionType(), args);
}

llvm::Value *TaskCodeGenLLVM::create_print(std::string tag, llvm::Value *value) {
llvm::Value *TaskCodeGenLLVM::create_print(std::string tag,
llvm::Value *value) {
if (value->getType() == llvm::Type::getFloatTy(*llvm_context))
return create_print(
tag,
Expand Down Expand Up @@ -923,8 +924,8 @@ void TaskCodeGenLLVM::visit(WhileStmt *stmt) {
}

llvm::Value *TaskCodeGenLLVM::cast_pointer(llvm::Value *val,
std::string dest_ty_name,
int addr_space) {
std::string dest_ty_name,
int addr_space) {
return builder->CreateBitCast(
val, llvm::PointerType::get(get_runtime_type(dest_ty_name), addr_space));
}
Expand All @@ -949,13 +950,14 @@ void TaskCodeGenLLVM::emit_gc(OffloadedStmt *stmt) {
}

llvm::Value *TaskCodeGenLLVM::create_call(llvm::Function *func,
llvm::ArrayRef<llvm::Value *> args) {
llvm::ArrayRef<llvm::Value *> args) {
return create_call(func, func->getFunctionType(), args);
}

llvm::Value *TaskCodeGenLLVM::create_call(llvm::Value *func,
llvm::FunctionType *func_ty,
llvm::ArrayRef<llvm::Value *> args_arr) {
llvm::Value *TaskCodeGenLLVM::create_call(
llvm::Value *func,
llvm::FunctionType *func_ty,
llvm::ArrayRef<llvm::Value *> args_arr) {
std::vector<llvm::Value *> args = args_arr;
check_func_call_signature(func_ty, func->getName(), args, builder.get());
#ifdef TI_LLVM_15
Expand All @@ -966,7 +968,7 @@ llvm::Value *TaskCodeGenLLVM::create_call(llvm::Value *func,
}

llvm::Value *TaskCodeGenLLVM::create_call(std::string func_name,
llvm::ArrayRef<llvm::Value *> args) {
llvm::ArrayRef<llvm::Value *> args) {
auto func = get_runtime_function(func_name);
return create_call(func, args);
}
Expand Down Expand Up @@ -1062,7 +1064,8 @@ void TaskCodeGenLLVM::visit(RangeForStmt *for_stmt) {
create_naive_range_for(for_stmt);
}

llvm::Value *TaskCodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) {
llvm::Value *TaskCodeGenLLVM::bitcast_from_u64(llvm::Value *val,
DataType type) {
llvm::Type *dest_ty = nullptr;
TI_ASSERT(!type->is<PointerType>());
if (auto qit = type->cast<QuantIntType>()) {
Expand Down Expand Up @@ -1442,12 +1445,12 @@ void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) {
}

llvm::Value *TaskCodeGenLLVM::create_intrinsic_load(llvm::Value *ptr,
llvm::Type *ty) {
llvm::Type *ty) {
TI_NOT_IMPLEMENTED;
}

void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
bool should_cache_as_read_only) {
bool should_cache_as_read_only) {
auto ptr = llvm_val[stmt->src];
auto ptr_type = stmt->src->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
Expand Down Expand Up @@ -1527,10 +1530,11 @@ std::string TaskCodeGenLLVM::get_runtime_snode_name(SNode *snode) {
}
}

llvm::Value *TaskCodeGenLLVM::call(SNode *snode,
llvm::Value *node_ptr,
const std::string &method,
const std::vector<llvm::Value *> &arguments) {
llvm::Value *TaskCodeGenLLVM::call(
SNode *snode,
llvm::Value *node_ptr,
const std::string &method,
const std::vector<llvm::Value *> &arguments) {
auto prefix = get_runtime_snode_name(snode);
auto s = emit_struct_meta(snode);
auto s_ptr =
Expand Down Expand Up @@ -1583,7 +1587,7 @@ void TaskCodeGenLLVM::visit(LinearizeStmt *stmt) {
void TaskCodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED}

llvm::Value *TaskCodeGenLLVM::create_bit_ptr(llvm::Value *byte_ptr,
llvm::Value *bit_offset) {
llvm::Value *bit_offset) {
// 1. define the bit pointer struct (X=8/16/32/64)
// struct bit_pointer_X {
// iX* byte_ptr;
Expand Down Expand Up @@ -1798,7 +1802,7 @@ void TaskCodeGenLLVM::visit(ExternalTensorShapeAlongAxisStmt *stmt) {
}

std::string TaskCodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt,
std::string suffix) {
std::string suffix) {
current_loop_reentry = nullptr;
current_while_after_loop = nullptr;

Expand Down Expand Up @@ -1887,7 +1891,8 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
return std::tuple(begin, end);
}

void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {
void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt,
bool spmd) {
using namespace llvm;
// TODO: instead of constructing tons of LLVM IR, writing the logic in
// runtime.cpp may be a cleaner solution. See
Expand Down Expand Up @@ -2616,7 +2621,8 @@ llvm::Value *TaskCodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) {
return xlogue;
}

llvm::Value *TaskCodeGenLLVM::create_mesh_xlogue(std::unique_ptr<Block> &block) {
llvm::Value *TaskCodeGenLLVM::create_mesh_xlogue(
std::unique_ptr<Block> &block) {
llvm::Value *xlogue;

auto xlogue_type = get_mesh_xlogue_function_type();
Expand Down
7 changes: 4 additions & 3 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class FunctionCreationGuard {
llvm::BasicBlock *old_entry, *allocas, *entry, *old_final, *final;
llvm::IRBuilder<>::InsertPoint ip;

FunctionCreationGuard(TaskCodeGenLLVM *mb, std::vector<llvm::Type *> arguments);
FunctionCreationGuard(TaskCodeGenLLVM *mb,
std::vector<llvm::Type *> arguments);

~FunctionCreationGuard();
};
Expand Down Expand Up @@ -66,8 +67,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
using LLVMModuleBuilder::call;

TaskCodeGenLLVM(Kernel *kernel,
IRNode *ir = nullptr,
std::unique_ptr<llvm::Module> &&module = nullptr);
IRNode *ir = nullptr,
std::unique_ptr<llvm::Module> &&module = nullptr);

Arch current_arch() {
return kernel->arch;
Expand Down
55 changes: 28 additions & 27 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ inline void update_mask(uint64 &mask, uint32 num_bits, uint32 offset) {
} // namespace

llvm::Value *TaskCodeGenLLVM::atomic_add_quant_int(llvm::Value *ptr,
llvm::Type *physical_type,
QuantIntType *qit,
llvm::Value *value,
bool value_is_signed) {
llvm::Type *physical_type,
QuantIntType *qit,
llvm::Value *value,
bool value_is_signed) {
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
return create_call(
fmt::format("atomic_add_partial_bits_b{}",
Expand All @@ -29,9 +29,9 @@ llvm::Value *TaskCodeGenLLVM::atomic_add_quant_int(llvm::Value *ptr,
}

llvm::Value *TaskCodeGenLLVM::atomic_add_quant_fixed(llvm::Value *ptr,
llvm::Type *physical_type,
QuantFixedType *qfxt,
llvm::Value *value) {
llvm::Type *physical_type,
QuantFixedType *qfxt,
llvm::Value *value) {
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto qit = qfxt->get_digits_type()->as<QuantIntType>();
auto val_store = to_quant_fixed(value, qfxt);
Expand All @@ -43,7 +43,7 @@ llvm::Value *TaskCodeGenLLVM::atomic_add_quant_fixed(llvm::Value *ptr,
}

llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real,
QuantFixedType *qfxt) {
QuantFixedType *qfxt) {
// Compute int(real * (1.0 / scale) + 0.5)
auto compute_type = qfxt->get_compute_type();
auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()),
Expand All @@ -65,10 +65,10 @@ llvm::Value *TaskCodeGenLLVM::to_quant_fixed(llvm::Value *real,
}

void TaskCodeGenLLVM::store_quant_int(llvm::Value *ptr,
llvm::Type *physical_type,
QuantIntType *qit,
llvm::Value *value,
bool atomic) {
llvm::Type *physical_type,
QuantIntType *qit,
llvm::Value *value,
bool atomic) {
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
// TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers.
// Try to support 8/16-bit physical types.
Expand All @@ -79,20 +79,20 @@ void TaskCodeGenLLVM::store_quant_int(llvm::Value *ptr,
}

void TaskCodeGenLLVM::store_quant_fixed(llvm::Value *ptr,
llvm::Type *physical_type,
QuantFixedType *qfxt,
llvm::Value *value,
bool atomic) {
llvm::Type *physical_type,
QuantFixedType *qfxt,
llvm::Value *value,
bool atomic) {
store_quant_int(ptr, physical_type,
qfxt->get_digits_type()->as<QuantIntType>(),
to_quant_fixed(value, qfxt), atomic);
}

void TaskCodeGenLLVM::store_masked(llvm::Value *ptr,
llvm::Type *ty,
uint64 mask,
llvm::Value *value,
bool atomic) {
llvm::Type *ty,
uint64 mask,
llvm::Value *value,
bool atomic) {
if (!mask) {
// do not store anything
return;
Expand All @@ -110,7 +110,7 @@ void TaskCodeGenLLVM::store_masked(llvm::Value *ptr,
}

llvm::Value *TaskCodeGenLLVM::get_exponent_offset(llvm::Value *exponent,
QuantFloatType *qflt) {
QuantFloatType *qflt) {
// Since we have fewer bits in the exponent type than in f32, an
// offset is necessary to make sure the stored exponent values are
// representable by the exponent quant int type.
Expand Down Expand Up @@ -392,7 +392,8 @@ llvm::Value *TaskCodeGenLLVM::extract_exponent_from_f32(llvm::Value *f) {
return builder->CreateAnd(exp_bits, tlctx->get_constant((1 << 8) - 1));
}

llvm::Value *TaskCodeGenLLVM::extract_digits_from_f32(llvm::Value *f, bool full) {
llvm::Value *TaskCodeGenLLVM::extract_digits_from_f32(llvm::Value *f,
bool full) {
TI_ASSERT(f->getType() == llvm::Type::getFloatTy(*llvm_context));
f = builder->CreateBitCast(f, llvm::Type::getInt32Ty(*llvm_context));
auto digits = builder->CreateAnd(f, tlctx->get_constant((1 << 23) - 1));
Expand Down Expand Up @@ -428,8 +429,8 @@ llvm::Value *TaskCodeGenLLVM::extract_digits_from_f32_with_shared_exponent(
}

llvm::Value *TaskCodeGenLLVM::extract_quant_float(llvm::Value *physical_value,
BitStructType *bit_struct,
int digits_id) {
BitStructType *bit_struct,
int digits_id) {
auto qflt = bit_struct->get_member_type(digits_id)->as<QuantFloatType>();
auto exponent_id = bit_struct->get_member_exponent(digits_id);
auto exponent_bit_offset = bit_struct->get_member_bit_offset(exponent_id);
Expand All @@ -445,8 +446,8 @@ llvm::Value *TaskCodeGenLLVM::extract_quant_float(llvm::Value *physical_value,
}

llvm::Value *TaskCodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
QuantIntType *qit) {
llvm::Value *bit_offset,
QuantIntType *qit) {
auto physical_type = physical_value->getType();
// bit shifting
// first left shift `physical_type - (offset + num_bits)`
Expand All @@ -473,7 +474,7 @@ llvm::Value *TaskCodeGenLLVM::extract_quant_int(llvm::Value *physical_value,
}

llvm::Value *TaskCodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt) {
QuantFixedType *qfxt) {
// Compute float(digits) * scale
llvm::Value *cast = nullptr;
auto compute_type = qfxt->get_compute_type()->as<PrimitiveType>();
Expand Down
5 changes: 3 additions & 2 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ FunctionType KernelCodeGenWASM::codegen() {
};
}

LLVMCompiledData KernelCodeGenWASM::modulegen(std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
LLVMCompiledData KernelCodeGenWASM::modulegen(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
bool init_flag = module == nullptr;
std::vector<OffloadedTask> name_list;
auto gen = std::make_unique<TaskCodeGenWASM>(kernel, ir, std::move(module));
Expand Down

0 comments on commit 46fd43d

Please sign in to comment.