From 2a21eddf7d67b17c0268615dc236491d2afee330 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Fri, 21 Feb 2020 21:50:23 +0800 Subject: [PATCH 1/2] update --- cinn/backends/CMakeLists.txt | 2 + cinn/backends/codegen_c.cc | 143 +++++++++++++++++++++++++++++++++++ cinn/backends/codegen_c.h | 42 ++++++++++ cinn/common/common.h | 4 +- cinn/ir/ir_printer.h | 3 + cinn/lang/module.h | 43 +++++++++++ 6 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 cinn/backends/codegen_c.cc create mode 100644 cinn/backends/codegen_c.h diff --git a/cinn/backends/CMakeLists.txt b/cinn/backends/CMakeLists.txt index 321c5b9682ebe..4100b36dab8cc 100644 --- a/cinn/backends/CMakeLists.txt +++ b/cinn/backends/CMakeLists.txt @@ -1,4 +1,6 @@ cc_library( backends SRCS outputs.cc + codegen_c.cc + DEPS ir ) diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc new file mode 100644 index 0000000000000..bae707d9bd911 --- /dev/null +++ b/cinn/backends/codegen_c.cc @@ -0,0 +1,143 @@ +#include "cinn/backends/codegen_c.h" + +namespace cinn { +namespace backends { + +CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {} + +void CodeGenC::Compile(const lang::Module &module) {} +void CodeGenC::Compile(const lang::LoweredFunc &function) { + os() << "void " << function.name; + + // output arguments + os() << "("; + + auto print_arg = [&](const lang::Argument &arg) { + if (arg.is_buffer()) { + os() << "struct cinn_buffer_t *"; + } else if (arg.is_scalar()) { + os() << PrintType(arg.type) << " "; + os() << arg.name; + } + os() << arg.name; + }; + + for (int i = 0; i < function.args.size() - 1; i++) { + print_arg(function.args[i]); + os() << ", "; + } + if (function.args.size() >= 1) { + print_arg(function.args.back()); + } + + os() << ")"; + + DoIndent(); + os() << "{\n"; + + Print(function.body); + + DoIndent(); + os() << "}"; +} +void CodeGenC::Compile(const ir::Buffer &buffer) {} +std::string CodeGenC::PrintType(Type type) { return std::__cxx11::string(); } +void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::UIntImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::FloatImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Add *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Sub *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Mul *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Div *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Mod *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::EQ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::NE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::LT *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::LE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::GT *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::GE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::And *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Or *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Min *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Max *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Minus *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Not *op) { + os() << "(!"; + IrPrinter::Print(op->v); + os() << ")"; +} +void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v); } +void CodeGenC::Visit(const ir::For *op) { LOG(FATAL) << "Not Implemented"; } +void CodeGenC::Visit(const ir::PolyFor *op) { + os() << "for ("; + Print(op->init); + os() << "; "; + Print(op->condition); + os() << "; "; + Print(op->inc); + os() << ")"; + + Print(op->body); +} +void CodeGenC::Visit(const ir::Select *op) { + os() << "("; + os() << "("; + Print(op->condition); + os() << ") ? "; + Print(op->true_value); + os() << " : "; + Print(op->false_value); + os() << ")"; +} +void CodeGenC::Visit(const ir::IfThenElse *op) { + os() << "if ("; + Print(op->condition); + os() << ")"; + Print(op->true_case); + + if (op->false_case.defined()) { + os() << "else\n"; + Print(op->false_case); + } +} +void CodeGenC::Visit(const ir::Block *op) { + os() << "{\n"; + + IncIndent(); + + for (int i = 0; i < op->stmts.size() - 1; i++) { + DoIndent(); + Print(op->stmts[i]); + os() << ";\n"; + } + if (op->stmts.size() >= 1) { + DoIndent(); + Print(op->stmts.back()); + os() << ";"; + } + + DecIndent(); + os() << "\n"; + DoIndent(); + os() << "}"; +} +void CodeGenC::Visit(const ir::Call *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Module *op) { NOT_IMPLEMENTED } +void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; } +void CodeGenC::Visit(const ir::Load *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Store *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Alloc *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Free *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_Range_ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_IterVar_ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_Buffer_ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::_Tensor_ *op) { IrPrinter::Visit(op); } + +void CodeGenC::PrintCastExpr(const Type &type, Expr e) { + os() << PrintType(type) << "("; + Print(e); + os() << ")"; +} + +} // namespace backends +} // namespace cinn diff --git a/cinn/backends/codegen_c.h b/cinn/backends/codegen_c.h new file mode 100644 index 0000000000000..9fe240b086281 --- /dev/null +++ b/cinn/backends/codegen_c.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include "cinn/common/common.h" +#include "cinn/ir/function.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/module.h" + +namespace cinn { + +namespace lang { +class Module; +} // namespace lang + +namespace backends { + +class CodeGenC : public ir::IrPrinter { + public: + CodeGenC(std::ostream& os, Target target); + + void Compile(const lang::Module& module); + + protected: + void Compile(const lang::LoweredFunc& function); + void Compile(const ir::Buffer& buffer); + + std::string PrintType(Type type); + void PrintCastExpr(const Type& type, Expr e); + +#define __DEFINE_VISIT(op__) void Visit(const ir::op__* op) override; + NODETY_FORALL(__DEFINE_VISIT) +#undef __DEFINE_VISIT + + private: + Target target_; +}; + +} // namespace backends +} // namespace cinn diff --git a/cinn/common/common.h b/cinn/common/common.h index eb0955d5166df..2da737baecb74 100644 --- a/cinn/common/common.h +++ b/cinn/common/common.h @@ -7,7 +7,6 @@ #include "cinn/common/shared.h" #include "cinn/common/target.h" #include "cinn/common/type.h" -#include "target.h" namespace cinn { @@ -25,5 +24,8 @@ using common::Int; using common::type_of; using common::Target; +using common::Type; + +#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented"; } // namespace cinn diff --git a/cinn/ir/ir_printer.h b/cinn/ir/ir_printer.h index ac8cc5842cee3..200e310f7b629 100644 --- a/cinn/ir/ir_printer.h +++ b/cinn/ir/ir_printer.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include "cinn/ir/buffer.h" #include "cinn/ir/ir.h" @@ -32,6 +33,8 @@ struct IrPrinter : public IRVisitor { //! Decrease the indent size. void DecIndent(); + std::ostream &os() { return os_; } + void Visit(const IntImm *x) override; void Visit(const UIntImm *x) override; void Visit(const FloatImm *x) override; diff --git a/cinn/lang/module.h b/cinn/lang/module.h index 8dbec7a63d0f9..274ad9afb133b 100644 --- a/cinn/lang/module.h +++ b/cinn/lang/module.h @@ -49,5 +49,48 @@ class Module { Shared<_Module_> module_; }; +/** + * A struct representing an argument to a lowered function. Used for specifying the function signature of generated + * code. + */ +struct Argument { + //! The name of the argument. + std::string name; + + enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar}; + + //! Number of the dimensions of buffer. + uint32_t ndims{0}; + + //! The type of the buffer or scalar. + Type type; + + bool is_buffer() const { return kind == Kind::kBuffer; } + bool is_scalar() const { return kind == Kind::kScalar; } + + Argument() {} + Argument(const std::string& name, Kind kind, const Type& type, int ndims) + : name(name), kind(kind), type(type), ndims(ndims) {} + + explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {} +}; + +/** + * Definition of a lowered function. Note that, it should be functional. + */ +struct LoweredFunc { + //! The name of this function. + std::string name; + + //! The Arguments used in the body of the function. + std::vector args; + + //! Body of this function. + Expr body; + + LoweredFunc(const std::string& name, const std::vector& args, const Expr& body) + : name(name), args(args), body(body) {} +}; + } // namespace lang } // namespace cinn From feca5495834229aab0d6767e6db5b3baba01a5c3 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 23 Feb 2020 08:44:50 +0800 Subject: [PATCH 2/2] init runtime --- cinn/backends/outputs.cc | 2 +- cinn/runtime/CMakeLists.txt | 1 + cinn/runtime/cinn_runtime.cc | 0 cinn/runtime/cinn_runtime.h | 120 +++++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 cinn/runtime/cinn_runtime.cc create mode 100644 cinn/runtime/cinn_runtime.h diff --git a/cinn/backends/outputs.cc b/cinn/backends/outputs.cc index 732c41632024e..e1c23fd72e421 100644 --- a/cinn/backends/outputs.cc +++ b/cinn/backends/outputs.cc @@ -1,4 +1,4 @@ -#include "cinn/lang/outputs.h" +#include "cinn/backends/outputs.h" namespace cinn { namespace lang {} // namespace lang diff --git a/cinn/runtime/CMakeLists.txt b/cinn/runtime/CMakeLists.txt index 20f344b472f34..ee0a5c71da1b0 100644 --- a/cinn/runtime/CMakeLists.txt +++ b/cinn/runtime/CMakeLists.txt @@ -1,4 +1,5 @@ cc_library(runtime SRCS intrinsic.cc buffer.cc + cinn_runtime.cc DEPS common ir) diff --git a/cinn/runtime/cinn_runtime.cc b/cinn/runtime/cinn_runtime.cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/cinn/runtime/cinn_runtime.h b/cinn/runtime/cinn_runtime.h new file mode 100644 index 0000000000000..a5a6d4b2b5396 --- /dev/null +++ b/cinn/runtime/cinn_runtime.h @@ -0,0 +1,120 @@ +#ifndef CINN_RUNTIME_H_ +#define CINN_RUNTIME_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define CINN_ALWAYS_INLINE __attribute__((always_inline)) inline + +typedef enum cinn_type_code_t { + cinn_type_int = 0, //! signed int + cinn_type_uint = 1, //! unsigned int + cinn_type_float = 1, //! floating point + cinn_type_handle = 1 //! void* +} cinn_type_code_t; + +#ifndef CINN_ATTRIBUTE_ALIGN +#define CINN_ATTRIBUTE_ALIGN(n) __attribute__((aligned(n))) +#endif + +/** + * A tuntime tag for type in CINN system. + */ +typedef struct cinn_type_t { +#if __cplusplus >= 201103L + CINN_ATTRIBUTE_ALIGN(1) cinn_type_code_t code; +#else + uint8_t code; +#endif + + //! Number of bits. + uint8_t bits; + + //! Number of elements in a vector, 1 for scalar. + uint16_t lanes; + +#ifdef __cplusplus + CINN_ALWAYS_INLINE cinn_type_t() : code(cinn_type_int), bits(0), lanes(0) {} + CINN_ALWAYS_INLINE cinn_type_t(cinn_type_code_t code, uint8_t bits, uint16_t lanes = 1) + : code(code), bits(bits), lanes(lanes) {} + CINN_ALWAYS_INLINE bool operator==(const cinn_type_t& other) const { + return code == other.code && bits == other.bits && lanes == other.lanes; + } + CINN_ALWAYS_INLINE bool operator!=(const cinn_type_t& other) const { return !(*this == other); } + CINN_ALWAYS_INLINE uint16_t bytes() const { return (bits + 7) / 8; } +#endif // __cplusplus + +} cinn_type_t; + +//! Help to define the size of a dimension, due to polyhedral representation, we no need to record the extend or +//! min(default to 0). +typedef int cinn_dimension_t; + +//! Help to tell where the buffer locates. +typedef enum cinn_buffer_kind_t { + cinn_buffer_on_host = 0, //! buffer on host + cinn_buffer_on_device = 1 << 1 // ! buffer on device e.g. GPU. +} cinn_buffer_kind_t; + +//! The raw representation of a buffer,used in the generated code/lib. +typedef struct cinn_buffer_t { + //! A device handle. + uint64_t device; + + //! A pointer to the memory in host. + uint8_t* host_memory; + + //! Extra flags. + uint64_t flag; + + //! Data type. + cinn_type_t type; + + //! Number of dimensions. + int32_t ndims; + cinn_buffer_t* dims; + +#ifdef __cplusplus + int num_elements() const { + int res = 1; + for (int i = 0; i < ndims; i++) { + res *= dims[i]; + } + return res; + } + + CINN_ALWAYS_INLINE bool on_host() const { return get_flag(cinn_buffer_on_host); } + CINN_ALWAYS_INLINE bool on_device() const { return get_flag(cinn_buffer_on_device); } + CINN_ALWAYS_INLINE void set_on_host(bool x = true) { + if (x) { + set_flag(cinn_buffer_on_host); + } else { + flag &= ~cinn_buffer_on_host; + } + } + CINN_ALWAYS_INLINE void set_on_device(bool x = true) { + if (x) { + set_flag(cinn_buffer_on_device); + } else { + flag &= ~cinn_buffer_on_device; + } + } + CINN_ALWAYS_INLINE uint8_t *begin() const { + } + + CINN_ALWAYS_INLINE bool get_flag(cinn_buffer_kind_t flag) const { return (this->flag & flag) != 0; } + CINN_ALWAYS_INLINE void set_flag(cinn_buffer_kind_t flag) { this->flag |= flag; } + +#endif // __cplusplus + +} cinn_buffer_t; + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // CINN_RUNTIME_H_