diff --git a/cinn/backends/CMakeLists.txt b/cinn/backends/CMakeLists.txt index 321c5b9682ebe..4100b36dab8cc 100644 --- a/cinn/backends/CMakeLists.txt +++ b/cinn/backends/CMakeLists.txt @@ -1,4 +1,6 @@ cc_library( backends SRCS outputs.cc + codegen_c.cc + DEPS ir ) diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc new file mode 100644 index 0000000000000..bae707d9bd911 --- /dev/null +++ b/cinn/backends/codegen_c.cc @@ -0,0 +1,143 @@ +#include "cinn/backends/codegen_c.h" + +namespace cinn { +namespace backends { + +CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {} + +void CodeGenC::Compile(const lang::Module &module) {} +void CodeGenC::Compile(const lang::LoweredFunc &function) { + os() << "void " << function.name; + + // output arguments + os() << "("; + + auto print_arg = [&](const lang::Argument &arg) { + if (arg.is_buffer()) { + os() << "struct cinn_buffer_t *"; + } else if (arg.is_scalar()) { + os() << PrintType(arg.type) << " "; + os() << arg.name; + } + os() << arg.name; + }; + + for (int i = 0; i < function.args.size() - 1; i++) { + print_arg(function.args[i]); + os() << ", "; + } + if (function.args.size() >= 1) { + print_arg(function.args.back()); + } + + os() << ")"; + + DoIndent(); + os() << "{\n"; + + Print(function.body); + + DoIndent(); + os() << "}"; +} +void CodeGenC::Compile(const ir::Buffer &buffer) {} +std::string CodeGenC::PrintType(Type type) { return std::__cxx11::string(); } +void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::UIntImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::FloatImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Add *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Sub *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Mul *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Div *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Mod *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::EQ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::NE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::LT *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::LE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::GT *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::GE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::And *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Or *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Min *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Max *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Minus *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Not *op) { + os() << "(!"; + IrPrinter::Print(op->v); + os() << ")"; +} +void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v); } +void CodeGenC::Visit(const ir::For *op) { LOG(FATAL) << "Not Implemented"; } +void CodeGenC::Visit(const ir::PolyFor *op) { + os() << "for ("; + Print(op->init); + os() << "; "; + Print(op->condition); + os() << "; "; + Print(op->inc); + os() << ")"; + + Print(op->body); +} +void CodeGenC::Visit(const ir::Select *op) { + os() << "("; + os() << "("; + Print(op->condition); + os() << ") ? "; + Print(op->true_value); + os() << " : "; + Print(op->false_value); + os() << ")"; +} +void CodeGenC::Visit(const ir::IfThenElse *op) { + os() << "if ("; + Print(op->condition); + os() << ")"; + Print(op->true_case); + + if (op->false_case.defined()) { + os() << "else\n"; + Print(op->false_case); + } +} +void CodeGenC::Visit(const ir::Block *op) { + os() << "{\n"; + + IncIndent(); + + for (int i = 0; i < op->stmts.size() - 1; i++) { + DoIndent(); + Print(op->stmts[i]); + os() << ";\n"; + } + if (op->stmts.size() >= 1) { + DoIndent(); + Print(op->stmts.back()); + os() << ";"; + } + + DecIndent(); + os() << "\n"; + DoIndent(); + os() << "}"; +} +void CodeGenC::Visit(const ir::Call *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Module *op) { NOT_IMPLEMENTED } +void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; } +void CodeGenC::Visit(const ir::Load *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Store *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Alloc *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Free *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_Range_ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_IterVar_ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_Buffer_ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_Tensor_ *op) { IrPrinter::Visit(op); } + +void CodeGenC::PrintCastExpr(const Type &type, Expr e) { + os() << PrintType(type) << "("; + Print(e); + os() << ")"; +} + +} // namespace backends +} // namespace cinn diff --git a/cinn/backends/codegen_c.h b/cinn/backends/codegen_c.h new file mode 100644 index 0000000000000..9fe240b086281 --- /dev/null +++ b/cinn/backends/codegen_c.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include "cinn/common/common.h" +#include "cinn/ir/function.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/module.h" + +namespace cinn { + +namespace lang { +class Module; +} // namespace lang + +namespace backends { + +class CodeGenC : public ir::IrPrinter { + public: + CodeGenC(std::ostream& os, Target target); + + void Compile(const lang::Module& module); + + protected: + void Compile(const lang::LoweredFunc& function); + void Compile(const ir::Buffer& buffer); + + std::string PrintType(Type type); + void PrintCastExpr(const Type& type, Expr e); + +#define __DEFINE_VISIT(op__) void Visit(const ir::op__* op) override; + NODETY_FORALL(__DEFINE_VISIT) +#undef __DEFINE_VISIT + + private: + Target target_; +}; + +} // namespace backends +} // namespace cinn diff --git a/cinn/common/common.h b/cinn/common/common.h index eb0955d5166df..2da737baecb74 100644 --- a/cinn/common/common.h +++ b/cinn/common/common.h @@ -7,7 +7,6 @@ #include "cinn/common/shared.h" #include "cinn/common/target.h" #include "cinn/common/type.h" -#include "target.h" namespace cinn { @@ -25,5 +24,8 @@ using common::Int; using common::type_of; using common::Target; +using common::Type; + +#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented"; } // namespace cinn diff --git a/cinn/ir/ir_printer.h b/cinn/ir/ir_printer.h index ac8cc5842cee3..200e310f7b629 100644 --- a/cinn/ir/ir_printer.h +++ b/cinn/ir/ir_printer.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include "cinn/ir/buffer.h" #include "cinn/ir/ir.h" @@ -32,6 +33,8 @@ struct IrPrinter : public IRVisitor { //! Decrease the indent size. void DecIndent(); + std::ostream &os() { return os_; } + void Visit(const IntImm *x) override; void Visit(const UIntImm *x) override; void Visit(const FloatImm *x) override; diff --git a/cinn/lang/module.h b/cinn/lang/module.h index 8dbec7a63d0f9..274ad9afb133b 100644 --- a/cinn/lang/module.h +++ b/cinn/lang/module.h @@ -49,5 +49,48 @@ class Module { Shared<_Module_> module_; }; +/** + * A struct representing an argument to a lowered function. Used for specifying the function signature of generated + * code. + */ +struct Argument { + //! The name of the argument. + std::string name; + + enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar}; + + //! Number of the dimensions of buffer. + uint32_t ndims{0}; + + //! The type of the buffer or scalar. + Type type; + + bool is_buffer() const { return kind == Kind::kBuffer; } + bool is_scalar() const { return kind == Kind::kScalar; } + + Argument() {} + Argument(const std::string& name, Kind kind, const Type& type, int ndims) + : name(name), kind(kind), type(type), ndims(ndims) {} + + explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {} +}; + +/** + * Definition of a lowered function. Note that, it should be functional. + */ +struct LoweredFunc { + //! The name of this function. + std::string name; + + //! The Arguments used in the body of the function. + std::vector args; + + //! Body of this function. + Expr body; + + LoweredFunc(const std::string& name, const std::vector& args, const Expr& body) + : name(name), args(args), body(body) {} +}; + } // namespace lang } // namespace cinn