Skip to content

Commit

Permalink
[refactor] Program::this_thread_config() -> Program::compile_config() (
Browse files Browse the repository at this point in the history
…taichi-dev#7199)

Issue: taichi-dev#7002 
Removed multi-thread version `Program::this_thread_config()` (see
taichi-dev#7159 (comment))

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent fa1900c commit 9018502
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 145 deletions.
2 changes: 1 addition & 1 deletion cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void autograd() {
using namespace lang;

auto program = Program(Arch::x64);
const auto &config = program.this_thread_config();
const auto &config = program.compile_config();

int n = 10;
program.materialize_runtime();
Expand Down
2 changes: 1 addition & 1 deletion cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void run_snode() {
using namespace taichi;
using namespace lang;
auto program = Program(Arch::x64);
const auto &config = program.this_thread_config();
const auto &config = program.compile_config();
/*CompileConfig config_print_ir;
config_print_ir.print_ir = true;
prog_.config = config_print_ir;*/ // print_ir = True
Expand Down
2 changes: 1 addition & 1 deletion taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void CompiledGraph::run(
TI_ASSERT(dispatch.ti_kernel);
lang::Kernel::LaunchContextBuilder launch_ctx(dispatch.ti_kernel, &ctx);
auto *ker = dispatch.ti_kernel;
ker->operator()(ker->program->this_thread_config(), launch_ctx);
ker->operator()(ker->program->compile_config(), launch_ctx);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ DataType Expr::get_ret_type() const {
return expr->ret_type;
}

void Expr::type_check(CompileConfig *config) {
void Expr::type_check(const CompileConfig *config) {
expr->type_check(config);
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Expr {

DataType get_ret_type() const;

void type_check(CompileConfig *config);
void type_check(const CompileConfig *config);
};

// Value cast
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Expression {
stmt = nullptr;
}

virtual void type_check(CompileConfig *config) = 0;
virtual void type_check(const CompileConfig *config) = 0;

virtual void accept(ExpressionVisitor *visitor) = 0;

Expand Down
42 changes: 21 additions & 21 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void FrontendForStmt::add_loop_var(const Expr &loop_var) {
loop_var.expr->ret_type = PrimitiveType::i32;
}

void ArgLoadExpression::type_check(CompileConfig *) {
void ArgLoadExpression::type_check(const CompileConfig *) {
TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
"Invalid dt [{}] for ArgLoadExpression", dt->to_string());
ret_type = dt;
Expand All @@ -120,7 +120,7 @@ void ArgLoadExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void TexturePtrExpression::type_check(CompileConfig *config) {
void TexturePtrExpression::type_check(const CompileConfig *config) {
}

void TexturePtrExpression::flatten(FlattenContext *ctx) {
Expand All @@ -130,7 +130,7 @@ void TexturePtrExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void RandExpression::type_check(CompileConfig *) {
void RandExpression::type_check(const CompileConfig *) {
TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
"Invalid dt [{}] for RandExpression", dt->to_string());
ret_type = dt;
Expand All @@ -142,7 +142,7 @@ void RandExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void UnaryOpExpression::type_check(CompileConfig *config) {
void UnaryOpExpression::type_check(const CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(operand);

TI_ASSERT(config != nullptr);
Expand Down Expand Up @@ -238,7 +238,7 @@ std::tuple<Expr, Expr> unify_binop_operands(const Expr &e1, const Expr &e2) {
}
}

void BinaryOpExpression::type_check(CompileConfig *config) {
void BinaryOpExpression::type_check(const CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(lhs);
TI_ASSERT_TYPE_CHECKED(rhs);
auto lhs_type = lhs->ret_type;
Expand Down Expand Up @@ -426,7 +426,7 @@ static std::tuple<Expr, Expr, Expr> unify_ternaryop_operands(const Expr &e1,
to_broadcast_tensor(e3, target_dtype));
}

void TernaryOpExpression::type_check(CompileConfig *config) {
void TernaryOpExpression::type_check(const CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(op1);
TI_ASSERT_TYPE_CHECKED(op2);
TI_ASSERT_TYPE_CHECKED(op3);
Expand Down Expand Up @@ -509,7 +509,7 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) {
stmt->ret_type = ret_type;
}

void InternalFuncCallExpression::type_check(CompileConfig *) {
void InternalFuncCallExpression::type_check(const CompileConfig *) {
for (auto &arg : args) {
TI_ASSERT_TYPE_CHECKED(arg);
// no arg type compatibility check for now due to lack of specification
Expand Down Expand Up @@ -666,7 +666,7 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx,
shape, tb);
}

void MatrixExpression::type_check(CompileConfig *config) {
void MatrixExpression::type_check(const CompileConfig *config) {
TI_ASSERT(dt->as<TensorType>()->get_num_elements() == elements.size());

for (auto &arg : elements) {
Expand Down Expand Up @@ -754,7 +754,7 @@ static void field_validation(FieldExpression *field_expr, int index_dim) {
}
}

void IndexExpression::type_check(CompileConfig *) {
void IndexExpression::type_check(const CompileConfig *) {
// TODO: Change to type-based solution
// Currently, dimension compatibility check happens in Python
TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape),
Expand Down Expand Up @@ -847,7 +847,7 @@ void IndexExpression::flatten(FlattenContext *ctx) {
stmt->tb = tb;
}

void RangeAssumptionExpression::type_check(CompileConfig *) {
void RangeAssumptionExpression::type_check(const CompileConfig *) {
TI_ASSERT_TYPE_CHECKED(input);
TI_ASSERT_TYPE_CHECKED(base);
if (!input->ret_type->is<PrimitiveType>() ||
Expand All @@ -867,7 +867,7 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void LoopUniqueExpression::type_check(CompileConfig *) {
void LoopUniqueExpression::type_check(const CompileConfig *) {
TI_ASSERT_TYPE_CHECKED(input);
if (!input->ret_type->is<PrimitiveType>())
throw TaichiTypeError(
Expand All @@ -889,7 +889,7 @@ void IdExpression::flatten(FlattenContext *ctx) {
}
}

void AtomicOpExpression::type_check(CompileConfig *config) {
void AtomicOpExpression::type_check(const CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(dest);
TI_ASSERT_TYPE_CHECKED(val);
auto error = [&]() {
Expand Down Expand Up @@ -968,7 +968,7 @@ SNodeOpExpression::SNodeOpExpression(SNode *snode,
this->values = values;
}

void SNodeOpExpression::type_check(CompileConfig *config) {
void SNodeOpExpression::type_check(const CompileConfig *config) {
if (op_type == SNodeOpType::get_addr) {
ret_type = PrimitiveType::u64;
} else {
Expand Down Expand Up @@ -1035,7 +1035,7 @@ TextureOpExpression::TextureOpExpression(TextureOpType op,
: op(op), texture_ptr(texture_ptr), args(args) {
}

void TextureOpExpression::type_check(CompileConfig *config) {
void TextureOpExpression::type_check(const CompileConfig *config) {
TI_ASSERT(texture_ptr.is<TexturePtrExpression>());
auto ptr = texture_ptr.cast<TexturePtrExpression>();
if (op == TextureOpType::kSampleLod) {
Expand Down Expand Up @@ -1125,7 +1125,7 @@ void TextureOpExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void ConstExpression::type_check(CompileConfig *) {
void ConstExpression::type_check(const CompileConfig *) {
TI_ASSERT_INFO(
val.dt->is<PrimitiveType>() && val.dt != PrimitiveType::unknown,
"Invalid dt [{}] for ConstExpression", val.dt->to_string());
Expand All @@ -1137,7 +1137,7 @@ void ConstExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void ExternalTensorShapeAlongAxisExpression::type_check(CompileConfig *) {
void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) {
TI_ASSERT_INFO(
ptr.is<ExternalTensorExpression>() || ptr.is<TexturePtrExpression>(),
"Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression",
Expand All @@ -1152,7 +1152,7 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void GetElementExpression::type_check(CompileConfig *config) {
void GetElementExpression::type_check(const CompileConfig *config) {
TI_ASSERT_TYPE_CHECKED(src);

ret_type = src->ret_type->as<StructType>()->get_element_type(index);
Expand All @@ -1170,11 +1170,11 @@ void MeshPatchIndexExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void MeshPatchIndexExpression::type_check(CompileConfig *) {
void MeshPatchIndexExpression::type_check(const CompileConfig *) {
ret_type = PrimitiveType::i32;
}

void MeshRelationAccessExpression::type_check(CompileConfig *) {
void MeshRelationAccessExpression::type_check(const CompileConfig *) {
ret_type = PrimitiveType::i32;
}

Expand All @@ -1198,7 +1198,7 @@ MeshIndexConversionExpression::MeshIndexConversionExpression(
: mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) {
}

void MeshIndexConversionExpression::type_check(CompileConfig *) {
void MeshIndexConversionExpression::type_check(const CompileConfig *) {
ret_type = PrimitiveType::i32;
}

Expand All @@ -1208,7 +1208,7 @@ void MeshIndexConversionExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void ReferenceExpression::type_check(CompileConfig *) {
void ReferenceExpression::type_check(const CompileConfig *) {
ret_type = var->ret_type;
}

Expand Down
Loading

0 comments on commit 9018502

Please sign in to comment.