diff --git a/cpp_examples/autograd.cpp b/cpp_examples/autograd.cpp index 623cbc58491aa..b6d11d8b90b0a 100644 --- a/cpp_examples/autograd.cpp +++ b/cpp_examples/autograd.cpp @@ -42,7 +42,7 @@ void autograd() { using namespace lang; auto program = Program(Arch::x64); - const auto &config = program.this_thread_config(); + const auto &config = program.compile_config(); int n = 10; program.materialize_runtime(); diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp index 9c74b1a726353..ee40111e1a659 100644 --- a/cpp_examples/run_snode.cpp +++ b/cpp_examples/run_snode.cpp @@ -40,7 +40,7 @@ void run_snode() { using namespace taichi; using namespace lang; auto program = Program(Arch::x64); - const auto &config = program.this_thread_config(); + const auto &config = program.compile_config(); /*CompileConfig config_print_ir; config_print_ir.print_ir = true; prog_.config = config_print_ir;*/ // print_ir = True diff --git a/taichi/aot/graph_data.cpp b/taichi/aot/graph_data.cpp index 9fc992f2d058f..38b1ce8f23031 100644 --- a/taichi/aot/graph_data.cpp +++ b/taichi/aot/graph_data.cpp @@ -94,7 +94,7 @@ void CompiledGraph::run( TI_ASSERT(dispatch.ti_kernel); lang::Kernel::LaunchContextBuilder launch_ctx(dispatch.ti_kernel, &ctx); auto *ker = dispatch.ti_kernel; - ker->operator()(ker->program->this_thread_config(), launch_ctx); + ker->operator()(ker->program->compile_config(), launch_ctx); } } } diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index da2c541fdf1de..f8ef27bbdfc18 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -14,7 +14,7 @@ DataType Expr::get_ret_type() const { return expr->ret_type; } -void Expr::type_check(CompileConfig *config) { +void Expr::type_check(const CompileConfig *config) { expr->type_check(config); } diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index b8058c7f1ad22..9b59dc036ae47 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -101,7 +101,7 @@ class Expr { DataType get_ret_type() const; - void type_check(CompileConfig *config); + void type_check(const CompileConfig *config); }; // Value cast diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index e0dfb1f328632..918cb6a9f9032 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -41,7 +41,7 @@ class Expression { stmt = nullptr; } - virtual void type_check(CompileConfig *config) = 0; + virtual void type_check(const CompileConfig *config) = 0; virtual void accept(ExpressionVisitor *visitor) = 0; diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index a576418c24a7c..46bf005fdf96f 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -108,7 +108,7 @@ void FrontendForStmt::add_loop_var(const Expr &loop_var) { loop_var.expr->ret_type = PrimitiveType::i32; } -void ArgLoadExpression::type_check(CompileConfig *) { +void ArgLoadExpression::type_check(const CompileConfig *) { TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, "Invalid dt [{}] for ArgLoadExpression", dt->to_string()); ret_type = dt; @@ -120,7 +120,7 @@ void ArgLoadExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void TexturePtrExpression::type_check(CompileConfig *config) { +void TexturePtrExpression::type_check(const CompileConfig *config) { } void TexturePtrExpression::flatten(FlattenContext *ctx) { @@ -130,7 +130,7 @@ void TexturePtrExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void RandExpression::type_check(CompileConfig *) { +void RandExpression::type_check(const CompileConfig *) { TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, "Invalid dt [{}] for RandExpression", dt->to_string()); ret_type = dt; @@ -142,7 +142,7 @@ void RandExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void UnaryOpExpression::type_check(CompileConfig *config) { +void UnaryOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(operand); TI_ASSERT(config != nullptr); @@ -238,7 +238,7 @@ std::tuple unify_binop_operands(const Expr &e1, const Expr &e2) { } } -void BinaryOpExpression::type_check(CompileConfig *config) { +void BinaryOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(lhs); TI_ASSERT_TYPE_CHECKED(rhs); auto lhs_type = lhs->ret_type; @@ -426,7 +426,7 @@ static std::tuple unify_ternaryop_operands(const Expr &e1, to_broadcast_tensor(e3, target_dtype)); } -void TernaryOpExpression::type_check(CompileConfig *config) { +void TernaryOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(op1); TI_ASSERT_TYPE_CHECKED(op2); TI_ASSERT_TYPE_CHECKED(op3); @@ -509,7 +509,7 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { stmt->ret_type = ret_type; } -void InternalFuncCallExpression::type_check(CompileConfig *) { +void InternalFuncCallExpression::type_check(const CompileConfig *) { for (auto &arg : args) { TI_ASSERT_TYPE_CHECKED(arg); // no arg type compatibility check for now due to lack of specification @@ -666,7 +666,7 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, shape, tb); } -void MatrixExpression::type_check(CompileConfig *config) { +void MatrixExpression::type_check(const CompileConfig *config) { TI_ASSERT(dt->as()->get_num_elements() == elements.size()); for (auto &arg : elements) { @@ -754,7 +754,7 @@ static void field_validation(FieldExpression *field_expr, int index_dim) { } } -void IndexExpression::type_check(CompileConfig *) { +void IndexExpression::type_check(const CompileConfig *) { // TODO: Change to type-based solution // Currently, dimension compatibility check happens in Python TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape), @@ -847,7 +847,7 @@ void IndexExpression::flatten(FlattenContext *ctx) { stmt->tb = tb; } -void RangeAssumptionExpression::type_check(CompileConfig *) { +void RangeAssumptionExpression::type_check(const CompileConfig *) { TI_ASSERT_TYPE_CHECKED(input); TI_ASSERT_TYPE_CHECKED(base); if (!input->ret_type->is() || @@ -867,7 +867,7 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void LoopUniqueExpression::type_check(CompileConfig *) { +void LoopUniqueExpression::type_check(const CompileConfig *) { TI_ASSERT_TYPE_CHECKED(input); if (!input->ret_type->is()) throw TaichiTypeError( @@ -889,7 +889,7 @@ void IdExpression::flatten(FlattenContext *ctx) { } } -void AtomicOpExpression::type_check(CompileConfig *config) { +void AtomicOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(dest); TI_ASSERT_TYPE_CHECKED(val); auto error = [&]() { @@ -968,7 +968,7 @@ SNodeOpExpression::SNodeOpExpression(SNode *snode, this->values = values; } -void SNodeOpExpression::type_check(CompileConfig *config) { +void SNodeOpExpression::type_check(const CompileConfig *config) { if (op_type == SNodeOpType::get_addr) { ret_type = PrimitiveType::u64; } else { @@ -1035,7 +1035,7 @@ TextureOpExpression::TextureOpExpression(TextureOpType op, : op(op), texture_ptr(texture_ptr), args(args) { } -void TextureOpExpression::type_check(CompileConfig *config) { +void TextureOpExpression::type_check(const CompileConfig *config) { TI_ASSERT(texture_ptr.is()); auto ptr = texture_ptr.cast(); if (op == TextureOpType::kSampleLod) { @@ -1125,7 +1125,7 @@ void TextureOpExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void ConstExpression::type_check(CompileConfig *) { +void ConstExpression::type_check(const CompileConfig *) { TI_ASSERT_INFO( val.dt->is() && val.dt != PrimitiveType::unknown, "Invalid dt [{}] for ConstExpression", val.dt->to_string()); @@ -1137,7 +1137,7 @@ void ConstExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void ExternalTensorShapeAlongAxisExpression::type_check(CompileConfig *) { +void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) { TI_ASSERT_INFO( ptr.is() || ptr.is(), "Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression", @@ -1152,7 +1152,7 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void GetElementExpression::type_check(CompileConfig *config) { +void GetElementExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(src); ret_type = src->ret_type->as()->get_element_type(index); @@ -1170,11 +1170,11 @@ void MeshPatchIndexExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void MeshPatchIndexExpression::type_check(CompileConfig *) { +void MeshPatchIndexExpression::type_check(const CompileConfig *) { ret_type = PrimitiveType::i32; } -void MeshRelationAccessExpression::type_check(CompileConfig *) { +void MeshRelationAccessExpression::type_check(const CompileConfig *) { ret_type = PrimitiveType::i32; } @@ -1198,7 +1198,7 @@ MeshIndexConversionExpression::MeshIndexConversionExpression( : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { } -void MeshIndexConversionExpression::type_check(CompileConfig *) { +void MeshIndexConversionExpression::type_check(const CompileConfig *) { ret_type = PrimitiveType::i32; } @@ -1208,7 +1208,7 @@ void MeshIndexConversionExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void ReferenceExpression::type_check(CompileConfig *) { +void ReferenceExpression::type_check(const CompileConfig *) { ret_type = var->ret_type; } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 76cc244f057ea..42e525567b2b7 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -290,7 +290,7 @@ class ArgLoadExpression : public Expression { : arg_id(arg_id), dt(dt), is_ptr(is_ptr) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -331,7 +331,7 @@ class TexturePtrExpression : public Expression { lod(lod) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -345,7 +345,7 @@ class RandExpression : public Expression { explicit RandExpression(DataType dt) : dt(dt) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -367,7 +367,7 @@ class UnaryOpExpression : public Expression { : type(type), operand(operand), cast_type(cast_type) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; bool is_cast() const; @@ -385,7 +385,7 @@ class BinaryOpExpression : public Expression { : type(type), lhs(lhs), rhs(rhs) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -407,7 +407,7 @@ class TernaryOpExpression : public Expression { this->op3.set(op3); } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -429,7 +429,7 @@ class InternalFuncCallExpression : public Expression { } } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -472,7 +472,7 @@ class ExternalTensorExpression : public Expression { } } - ExternalTensorExpression(Expr *expr, bool is_grad = true) { + explicit ExternalTensorExpression(Expr *expr, bool is_grad = true) { auto ptr = expr->cast(); init(ptr->dt, ptr->dim, ptr->arg_id, ptr->element_dim, is_grad); } @@ -481,18 +481,18 @@ class ExternalTensorExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION - CompileConfig *get_compile_config() { + const CompileConfig *get_compile_config() { TI_ASSERT(config_ != nullptr); return config_; } - void type_check(CompileConfig *config) override { + void type_check(const CompileConfig *config) override { ret_type = dt; config_ = config; } private: - CompileConfig *config_ = nullptr; + const CompileConfig *config_ = nullptr; void init(const DataType &dt, int dim, @@ -524,7 +524,7 @@ class FieldExpression : public Expression { FieldExpression(DataType dt, const Identifier &ident) : ident(ident), dt(dt) { } - void type_check(CompileConfig *config) override { + void type_check(const CompileConfig *config) override { } void set_snode(SNode *snode) { @@ -559,7 +559,7 @@ class MatrixFieldExpression : public Expression { } } - void type_check(CompileConfig *config) override { + void type_check(const CompileConfig *config) override { } TI_DEFINE_ACCEPT_FOR_EXPRESSION @@ -581,7 +581,7 @@ class MatrixExpression : public Expression { this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -608,7 +608,7 @@ class IndexExpression : public Expression { const std::vector &ret_shape, std::string tb = ""); - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -642,7 +642,7 @@ class RangeAssumptionExpression : public Expression { : input(input), base(base), low(low), high(high) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -658,7 +658,7 @@ class LoopUniqueExpression : public Expression { : input(input), covers(covers) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -672,7 +672,7 @@ class IdExpression : public Expression { explicit IdExpression(const Identifier &id) : id(id) { } - void type_check(CompileConfig *config) override { + void type_check(const CompileConfig *config) override { } void flatten(FlattenContext *ctx) override; @@ -698,7 +698,7 @@ class AtomicOpExpression : public Expression { : op_type(op_type), dest(dest), val(val) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -721,7 +721,7 @@ class SNodeOpExpression : public Expression { const ExprGroup &indices, const std::vector &values); - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -738,7 +738,7 @@ class TextureOpExpression : public Expression { Expr texture_ptr, const ExprGroup &args); - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -758,7 +758,7 @@ class ConstExpression : public Expression { ret_type = dt; } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -774,7 +774,7 @@ class ExternalTensorShapeAlongAxisExpression : public Expression { : ptr(ptr), axis(axis) { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -807,7 +807,7 @@ class GetElementExpression : public Expression { Expr src; std::vector index; - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; GetElementExpression(const Expr &src, std::vector index) : src(src), index(index) { @@ -825,7 +825,7 @@ class MeshPatchIndexExpression : public Expression { MeshPatchIndexExpression() { } - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; void flatten(FlattenContext *ctx) override; @@ -839,7 +839,7 @@ class MeshRelationAccessExpression : public Expression { mesh::MeshElementType to_type; Expr neighbor_idx; - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; MeshRelationAccessExpression(mesh::Mesh *mesh, const Expr mesh_idx, @@ -869,7 +869,7 @@ class MeshIndexConversionExpression : public Expression { Expr idx; mesh::ConvType conv_type; - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; MeshIndexConversionExpression(mesh::Mesh *mesh, mesh::MeshElementType idx_type, @@ -884,7 +884,7 @@ class MeshIndexConversionExpression : public Expression { class ReferenceExpression : public Expression { public: Expr var; - void type_check(CompileConfig *config) override; + void type_check(const CompileConfig *config) override; explicit ReferenceExpression(const Expr &expr) : var(expr) { } diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index a9a50ae217498..bb002778dc90d 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -11,28 +11,27 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = - std::make_unique(program->this_thread_config().arch); + context = std::make_unique(program->compile_config().arch); ir = context->get_root(); func(); - if (program->this_thread_config().offline_cache) { // For generating AST-Key + if (program->compile_config().offline_cache) { // For generating AST-Key std::ostringstream oss; gen_offline_cache_key(program, ir.get(), &oss); ast_serialization_data_ = oss.str(); } - irpass::compile_function(ir.get(), program->this_thread_config(), this, + irpass::compile_function(ir.get(), program->compile_config(), this, /*autodiff_mode=*/AutodiffMode::kNone, - /*verbose=*/program->this_thread_config().print_ir, + /*verbose=*/program->compile_config().print_ir, /*start_from_ast=*/true); } void Function::set_function_body(std::unique_ptr func_body) { ir = std::move(func_body); - irpass::compile_function(ir.get(), program->this_thread_config(), this, + irpass::compile_function(ir.get(), program->compile_config(), this, /*autodiff_mode=*/AutodiffMode::kNone, - /*verbose=*/program->this_thread_config().print_ir, + /*verbose=*/program->compile_config().print_ir, /*start_from_ast=*/false); } diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 49d6d58059b3b..8b1187e4c33d5 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -328,8 +328,7 @@ void Kernel::init(Program &program, is_accessor = false; is_evaluator = false; compiled_ = nullptr; - context = - std::make_unique(program.this_thread_config().arch); + context = std::make_unique(program.compile_config().arch); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index cc0d94b0a7f52..7301a730b98be 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -75,12 +75,9 @@ Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { : "ri"(fpcr | (1 << 24))); // Bit 24 is FZ __asm__ __volatile__(""); #endif // defined(__arm64__) || defined(__aarch64__) - main_thread_id_ = std::this_thread::get_id(); - // Rehash in advance to avoid rehashing during compilation - configs.rehash(default_compile_config.num_compile_threads + 1); - configs[main_thread_id_] = default_compile_config; - configs[main_thread_id_].arch = desired_arch; - auto &config = this_thread_config(); + auto &config = compile_config_; + config = default_compile_config; + config.arch = desired_arch; config.fit(); profiler = make_profiler(config.arch, config.kernel_profiler); @@ -205,10 +202,10 @@ void Program::materialize_runtime() { } void Program::destroy_snode_tree(SNodeTree *snode_tree) { - TI_ASSERT(arch_uses_llvm(this_thread_config().arch) || - this_thread_config().arch == Arch::vulkan || - this_thread_config().arch == Arch::dx11 || - this_thread_config().arch == Arch::dx12); + TI_ASSERT(arch_uses_llvm(compile_config().arch) || + compile_config().arch == Arch::vulkan || + compile_config().arch == Arch::dx11 || + compile_config().arch == Arch::dx12); program_impl_->destroy_snode_tree(snode_tree); free_snode_tree_ids_.push(snode_tree->id()); } @@ -259,7 +256,7 @@ Kernel &Program::get_snode_reader(SNode *snode) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { auto argload_expr = Expr::make(i, PrimitiveType::i32); - argload_expr->type_check(&this->this_thread_config()); + argload_expr->type_check(&this->compile_config()); indices.push_back(std::move(argload_expr)); } ASTBuilder &builder = kernel->context->builder(); @@ -282,7 +279,7 @@ Kernel &Program::get_snode_writer(SNode *snode) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { auto argload_expr = Expr::make(i, PrimitiveType::i32); - argload_expr->type_check(&this->this_thread_config()); + argload_expr->type_check(&this->compile_config()); indices.push_back(std::move(argload_expr)); } ASTBuilder &builder = kernel->context->builder(); @@ -311,12 +308,11 @@ void Program::finalize() { return; } synchronize(); - TI_ASSERT(std::this_thread::get_id() == main_thread_id_); TI_TRACE("Program finalizing..."); synchronize(); memory_pool_->terminate(); - if (arch_uses_llvm(this_thread_config().arch)) { + if (arch_uses_llvm(compile_config().arch)) { program_impl_->finalize(); } @@ -325,8 +321,7 @@ void Program::finalize() { finalized_ = true; num_instances_ -= 1; program_impl_->dump_cache_data_to_disk(); - configs.clear(); - configs[main_thread_id_] = default_compile_config; + compile_config_ = default_compile_config; TI_TRACE("Program ({}) finalized_.", fmt::ptr(this)); } @@ -353,7 +348,7 @@ Ndarray *Program::create_ndarray(const DataType type, bool zero_fill) { auto arr = std::make_unique(this, type, shape, layout); if (zero_fill) { - Arch arch = this_thread_config().arch; + Arch arch = compile_config().arch; if (arch_is_cpu(arch) || arch == Arch::cuda) { fill_ndarray_fast_u32(arr.get(), /*data=*/0); } else if (arch != Arch::dx12) { @@ -412,8 +407,8 @@ Texture *Program::create_texture(const DataType type, intptr_t Program::get_ndarray_data_ptr_as_int(const Ndarray *ndarray) { uint64_t *data_ptr{nullptr}; - if (arch_is_cpu(this_thread_config().arch) || - this_thread_config().arch == Arch::cuda) { + if (arch_is_cpu(compile_config().arch) || + compile_config().arch == Arch::cuda) { // For the LLVM backends, device allocation is a physical pointer. data_ptr = program_impl_->get_ndarray_alloc_info_ptr(ndarray->ndarray_alloc_); @@ -489,17 +484,17 @@ std::unique_ptr Program::make_aot_module_builder( if (arch == Arch::wasm) { // Have to check WASM first, or it dispatches to the LlvmProgramImpl. #ifdef TI_WITH_LLVM - return std::make_unique(&this_thread_config()); + return std::make_unique(&compile_config()); #else TI_NOT_IMPLEMENTED #endif } - if (arch_uses_llvm(this_thread_config().arch) || - this_thread_config().arch == Arch::metal || - this_thread_config().arch == Arch::vulkan || - this_thread_config().arch == Arch::opengl || - this_thread_config().arch == Arch::gles || - this_thread_config().arch == Arch::dx12) { + if (arch_uses_llvm(compile_config().arch) || + compile_config().arch == Arch::metal || + compile_config().arch == Arch::vulkan || + compile_config().arch == Arch::opengl || + compile_config().arch == Arch::gles || + compile_config().arch == Arch::dx12) { return program_impl_->make_aot_module_builder(cfg); } return nullptr; diff --git a/taichi/program/program.h b/taichi/program/program.h index 6bd9238704a1d..67ccd899109ab 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -92,12 +92,6 @@ class StructCompiler; class TI_DLL_EXPORT Program { public: using Kernel = taichi::lang::Kernel; - // We let every thread has its own config because the constant folding pass - // wants to change the CompileConfig so that it can compile the evaluator, - // but we don't want it to change the global config. We will refactor it - // later when we make Taichi thread-safe. - std::unordered_map configs; - std::thread::id main_thread_id_; uint64 *result_buffer{nullptr}; // Note result_buffer is used by all backends @@ -120,22 +114,8 @@ class TI_DLL_EXPORT Program { ~Program(); - CompileConfig &this_thread_config() { - // std::unordered_map is not thread safe even if we do the rehash in - // advance, so we need to add a lock to protect it. - std::shared_lock read_lock(config_map_mut); - auto thread_id = std::this_thread::get_id(); - if (!configs.count(thread_id)) { - read_lock.unlock(); - std::unique_lock write_lock(config_map_mut); - configs[thread_id] = configs[main_thread_id_]; - return configs[thread_id]; - } - return configs[thread_id]; - } - - const CompileConfig &config() { - return configs[main_thread_id_]; + const CompileConfig &compile_config() const { + return compile_config_; } struct KernelProfilerQueryResult { @@ -352,7 +332,7 @@ class TI_DLL_EXPORT Program { * Please limit its use to LLVM backend only */ ProgramImpl *get_program_impl() { - TI_ASSERT(arch_uses_llvm(this_thread_config().arch)); + TI_ASSERT(arch_uses_llvm(compile_config().arch)); return program_impl_.get(); } @@ -364,6 +344,8 @@ class TI_DLL_EXPORT Program { // could store ProgramImpl rather than Program. private: + CompileConfig compile_config_; + uint64 ndarray_writer_counter_{0}; uint64 ndarray_reader_counter_{0}; int global_id_counter_{0}; @@ -387,7 +369,6 @@ class TI_DLL_EXPORT Program { // TODO: Move ndarrays_ and textures_ to be managed by runtime std::unordered_map> ndarrays_; std::vector> textures_; - std::shared_mutex config_map_mut; }; } // namespace taichi::lang diff --git a/taichi/program/snode_rw_accessors_bank.cpp b/taichi/program/snode_rw_accessors_bank.cpp index 1c5e2b318f6a4..dfe1f3006c7f0 100644 --- a/taichi/program/snode_rw_accessors_bank.cpp +++ b/taichi/program/snode_rw_accessors_bank.cpp @@ -41,14 +41,14 @@ void SNodeRwAccessorsBank::Accessors::write_float(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_float(snode_->num_active_indices, val); prog_->synchronize(); - (*writer_)(prog_->this_thread_config(), launch_ctx); + (*writer_)(prog_->compile_config(), launch_ctx); } float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - (*reader_)(prog_->this_thread_config(), launch_ctx); + (*reader_)(prog_->compile_config(), launch_ctx); prog_->synchronize(); auto ret = reader_->get_ret_float(0); return ret; @@ -61,14 +61,14 @@ void SNodeRwAccessorsBank::Accessors::write_int(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_int(snode_->num_active_indices, val); prog_->synchronize(); - (*writer_)(prog_->this_thread_config(), launch_ctx); + (*writer_)(prog_->compile_config(), launch_ctx); } int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - (*reader_)(prog_->this_thread_config(), launch_ctx); + (*reader_)(prog_->compile_config(), launch_ctx); prog_->synchronize(); auto ret = reader_->get_ret_int(0); return ret; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index a2ab8837747f0..6ff5a500e319e 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -328,7 +328,8 @@ void export_lang(py::module &m) { py::class_(m, "Program") .def(py::init<>()) - .def("config", &Program::config) + .def("config", &Program::compile_config, + py::return_value_policy::reference) .def("sync_kernel_profiler", [](Program *program) { program->profiler->sync(); }) .def("update_kernel_profiler", @@ -392,8 +393,8 @@ void export_lang(py::module &m) { .def("create_sparse_matrix_builder", [](Program *program, int n, int m, uint64 max_num_entries, DataType dtype, const std::string &storage_format) { - TI_ERROR_IF(!arch_is_cpu(program->this_thread_config().arch) && - !arch_is_cuda(program->this_thread_config().arch), + TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) && + !arch_is_cuda(program->compile_config().arch), "SparseMatrix only supports CPU and CUDA for now."); return SparseMatrixBuilder(n, m, max_num_entries, dtype, storage_format, program); @@ -401,18 +402,18 @@ void export_lang(py::module &m) { .def("create_sparse_matrix", [](Program *program, int n, int m, DataType dtype, std::string storage_format) { - TI_ERROR_IF(!arch_is_cpu(program->this_thread_config().arch) && - !arch_is_cuda(program->this_thread_config().arch), + TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) && + !arch_is_cuda(program->compile_config().arch), "SparseMatrix only supports CPU and CUDA for now."); - if (arch_is_cpu(program->this_thread_config().arch)) + if (arch_is_cpu(program->compile_config().arch)) return make_sparse_matrix(n, m, dtype, storage_format); else return make_cu_sparse_matrix(n, m, dtype); }) .def("make_sparse_matrix_from_ndarray", [](Program *program, SparseMatrix &sm, const Ndarray &ndarray) { - TI_ERROR_IF(!arch_is_cpu(program->this_thread_config().arch) && - !arch_is_cuda(program->this_thread_config().arch), + TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) && + !arch_is_cuda(program->compile_config().arch), "SparseMatrix only supports CPU and CUDA for now."); return make_sparse_matrix_from_ndarray(program, sm, ndarray); }) @@ -689,11 +690,11 @@ void export_lang(py::module &m) { return &self->context->builder(); }, py::return_value_policy::reference) - .def("__call__", [](Kernel *kernel, - Kernel::LaunchContextBuilder &launch_ctx) { - py::gil_scoped_release release; - kernel->operator()(kernel->program->this_thread_config(), launch_ctx); - }); + .def("__call__", + [](Kernel *kernel, Kernel::LaunchContextBuilder &launch_ctx) { + py::gil_scoped_release release; + kernel->operator()(kernel->program->compile_config(), launch_ctx); + }); py::class_(m, "KernelLaunchContext") .def("set_arg_int", &Kernel::LaunchContextBuilder::set_arg_int) diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index 8259d1227e237..fe5a7bb284774 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -43,7 +43,7 @@ TEST(FrontendTypeInference, BinaryOp) { auto const_f32 = value(5.0); const_f32->type_check(nullptr); auto truediv_f64 = expr_truediv(const_i32, const_f32); - truediv_f64->type_check(&prog->this_thread_config()); + truediv_f64->type_check(&prog->compile_config()); EXPECT_EQ(truediv_f64->ret_type, PrimitiveType::f64); } @@ -62,7 +62,7 @@ TEST(FrontendTypeInference, UnaryOp) { bit_not_i16->type_check(&dummy_config); EXPECT_EQ(bit_not_i16->ret_type, PrimitiveType::i16); auto log_f64 = expr_log(const_i16); - log_f64->type_check(&prog->this_thread_config()); + log_f64->type_check(&prog->compile_config()); EXPECT_EQ(log_f64->ret_type, PrimitiveType::f64); } diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index 0906ef08bc441..00a2a43c6c1a5 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -115,7 +115,7 @@ TEST(IRBuilder, ExternalPtr) { auto launch_ctx = ker->make_launch_context(); launch_ctx.set_arg_external_array_with_shape( /*arg_id=*/0, (uint64)array.get(), size, {size}); - (*ker)(test_prog.prog()->this_thread_config(), launch_ctx); + (*ker)(test_prog.prog()->compile_config(), launch_ctx); EXPECT_EQ(array[0], 2); EXPECT_EQ(array[1], 1); EXPECT_EQ(array[2], 42); @@ -139,7 +139,7 @@ TEST(IRBuilder, Ndarray) { auto ker1 = setup_kernel1(test_prog.prog()); auto launch_ctx1 = ker1->make_launch_context(); launch_ctx1.set_arg_ndarray(/*arg_id=*/0, array); - (*ker1)(test_prog.prog()->this_thread_config(), launch_ctx1); + (*ker1)(test_prog.prog()->compile_config(), launch_ctx1); EXPECT_EQ(array.read_int({0}), 2); EXPECT_EQ(array.read_int({1}), 1); EXPECT_EQ(array.read_int({2}), 42); @@ -148,7 +148,7 @@ TEST(IRBuilder, Ndarray) { auto launch_ctx2 = ker2->make_launch_context(); launch_ctx2.set_arg_ndarray(/*arg_id=*/0, array); launch_ctx2.set_arg_int(/*arg_id=*/1, 3); - (*ker2)(test_prog.prog()->this_thread_config(), launch_ctx2); + (*ker2)(test_prog.prog()->compile_config(), launch_ctx2); EXPECT_EQ(array.read_int({0}), 2); EXPECT_EQ(array.read_int({1}), 3); EXPECT_EQ(array.read_int({2}), 42); @@ -175,7 +175,7 @@ TEST(IRBuilder, AtomicOp) { auto launch_ctx = ker->make_launch_context(); launch_ctx.set_arg_external_array_with_shape( /*arg_id=*/0, (uint64)array.get(), size, {size}); - (*ker)(test_prog.prog()->this_thread_config(), launch_ctx); + (*ker)(test_prog.prog()->compile_config(), launch_ctx); EXPECT_EQ(array[0], 3); } diff --git a/tests/cpp/transforms/alg_simp_test.cpp b/tests/cpp/transforms/alg_simp_test.cpp index 25cea03786592..86f541b517c20 100644 --- a/tests/cpp/transforms/alg_simp_test.cpp +++ b/tests/cpp/transforms/alg_simp_test.cpp @@ -97,7 +97,6 @@ TEST_F(AlgebraicSimplicationTest, SimplifyMultiplyZeroFastMath) { CompileConfig config_without_fast_math; config_without_fast_math.fast_math = false; - kernel->program->this_thread_config() = config_without_fast_math; irpass::type_check(block.get(), config_without_fast_math); EXPECT_EQ(block->size(), 8); @@ -133,7 +132,6 @@ TEST_F(AlgebraicSimplicationTest, SimplifyMultiplyZeroFastMath) { CompileConfig config_with_fast_math; config_with_fast_math.fast_math = true; - kernel->program->this_thread_config() = config_with_fast_math; irpass::alg_simp(block.get(), config_with_fast_math); // should eliminate mul, add diff --git a/tests/cpp/transforms/simplify_test.cpp b/tests/cpp/transforms/simplify_test.cpp index d468276354ae6..231fdd1d51dc3 100644 --- a/tests/cpp/transforms/simplify_test.cpp +++ b/tests/cpp/transforms/simplify_test.cpp @@ -32,21 +32,21 @@ TEST(Simplify, SimplifyLinearizedWithTrivialInputs) { [[maybe_unused]] auto lookup2 = block->push_back( root.ch[0].get(), get_child, linearized_zero, true); - irpass::type_check(block.get(), kernel->program->this_thread_config()); + irpass::type_check(block.get(), kernel->program->compile_config()); EXPECT_EQ(block->size(), 7); irpass::simplify( block.get(), - kernel->program->this_thread_config()); // should lower linearized + kernel->program->compile_config()); // should lower linearized // EXPECT_EQ(block->size(), 11); // not required to check size here - irpass::constant_fold(block.get(), kernel->program->this_thread_config(), + irpass::constant_fold(block.get(), kernel->program->compile_config(), {kernel->program}); - irpass::alg_simp(block.get(), kernel->program->this_thread_config()); + irpass::alg_simp(block.get(), kernel->program->compile_config()); irpass::die(block.get()); // should eliminate consts - irpass::simplify(block.get(), kernel->program->this_thread_config()); + irpass::simplify(block.get(), kernel->program->compile_config()); irpass::whole_kernel_cse(block.get()); - if (kernel->program->this_thread_config().advanced_optimization) { + if (kernel->program->compile_config().advanced_optimization) { // get root, const 0, lookup, get child, lookup EXPECT_EQ(block->size(), 5); }