Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#34 from Superjomn/fea/init-codegen-c
Browse files Browse the repository at this point in the history
fea/init codegen c
  • Loading branch information
Superjomn committed Feb 23, 2020
2 parents 78be2c3 + feca549 commit 73ec6d2
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 2 deletions.
2 changes: 2 additions & 0 deletions cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
cc_library(
backends
SRCS outputs.cc
codegen_c.cc
DEPS ir
)
143 changes: 143 additions & 0 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#include "cinn/backends/codegen_c.h"

namespace cinn {
namespace backends {

CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {}

void CodeGenC::Compile(const lang::Module &module) {}
void CodeGenC::Compile(const lang::LoweredFunc &function) {
os() << "void " << function.name;

// output arguments
os() << "(";

auto print_arg = [&](const lang::Argument &arg) {
if (arg.is_buffer()) {
os() << "struct cinn_buffer_t *";
} else if (arg.is_scalar()) {
os() << PrintType(arg.type) << " ";
os() << arg.name;
}
os() << arg.name;
};

for (int i = 0; i < function.args.size() - 1; i++) {
print_arg(function.args[i]);
os() << ", ";
}
if (function.args.size() >= 1) {
print_arg(function.args.back());
}

os() << ")";

DoIndent();
os() << "{\n";

Print(function.body);

DoIndent();
os() << "}";
}
void CodeGenC::Compile(const ir::Buffer &buffer) {}
std::string CodeGenC::PrintType(Type type) { return std::__cxx11::string(); }
void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::UIntImm *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::FloatImm *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Add *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Sub *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Mul *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Div *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Mod *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::EQ *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::NE *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::LT *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::LE *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::GT *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::GE *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::And *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Or *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Min *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Max *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Minus *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Not *op) {
os() << "(!";
IrPrinter::Print(op->v);
os() << ")";
}
void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v); }
void CodeGenC::Visit(const ir::For *op) { LOG(FATAL) << "Not Implemented"; }
void CodeGenC::Visit(const ir::PolyFor *op) {
os() << "for (";
Print(op->init);
os() << "; ";
Print(op->condition);
os() << "; ";
Print(op->inc);
os() << ")";

Print(op->body);
}
void CodeGenC::Visit(const ir::Select *op) {
os() << "(";
os() << "(";
Print(op->condition);
os() << ") ? ";
Print(op->true_value);
os() << " : ";
Print(op->false_value);
os() << ")";
}
void CodeGenC::Visit(const ir::IfThenElse *op) {
os() << "if (";
Print(op->condition);
os() << ")";
Print(op->true_case);

if (op->false_case.defined()) {
os() << "else\n";
Print(op->false_case);
}
}
void CodeGenC::Visit(const ir::Block *op) {
os() << "{\n";

IncIndent();

for (int i = 0; i < op->stmts.size() - 1; i++) {
DoIndent();
Print(op->stmts[i]);
os() << ";\n";
}
if (op->stmts.size() >= 1) {
DoIndent();
Print(op->stmts.back());
os() << ";";
}

DecIndent();
os() << "\n";
DoIndent();
os() << "}";
}
void CodeGenC::Visit(const ir::Call *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Module *op) { NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; }
void CodeGenC::Visit(const ir::Load *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Store *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Alloc *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Free *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::_Range_ *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::_IterVar_ *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::_Buffer_ *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::_Tensor_ *op) { IrPrinter::Visit(op); }

void CodeGenC::PrintCastExpr(const Type &type, Expr e) {
os() << PrintType(type) << "(";
Print(e);
os() << ")";
}

} // namespace backends
} // namespace cinn
42 changes: 42 additions & 0 deletions cinn/backends/codegen_c.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <string>
#include <vector>

#include "cinn/common/common.h"
#include "cinn/ir/function.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/lang/module.h"

namespace cinn {

namespace lang {
class Module;
} // namespace lang

namespace backends {

class CodeGenC : public ir::IrPrinter {
public:
CodeGenC(std::ostream& os, Target target);

void Compile(const lang::Module& module);

protected:
void Compile(const lang::LoweredFunc& function);
void Compile(const ir::Buffer& buffer);

std::string PrintType(Type type);
void PrintCastExpr(const Type& type, Expr e);

#define __DEFINE_VISIT(op__) void Visit(const ir::op__* op) override;
NODETY_FORALL(__DEFINE_VISIT)
#undef __DEFINE_VISIT

private:
Target target_;
};

} // namespace backends
} // namespace cinn
2 changes: 1 addition & 1 deletion cinn/backends/outputs.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "cinn/lang/outputs.h"
#include "cinn/backends/outputs.h"

namespace cinn {
namespace lang {} // namespace lang
Expand Down
4 changes: 3 additions & 1 deletion cinn/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "cinn/common/shared.h"
#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "target.h"

namespace cinn {

Expand All @@ -25,5 +24,8 @@ using common::Int;
using common::type_of;

using common::Target;
using common::Type;

#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented";

} // namespace cinn
3 changes: 3 additions & 0 deletions cinn/ir/ir_printer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <string>
#include <vector>

#include "cinn/ir/buffer.h"
#include "cinn/ir/ir.h"
Expand Down Expand Up @@ -32,6 +33,8 @@ struct IrPrinter : public IRVisitor {
//! Decrease the indent size.
void DecIndent();

std::ostream &os() { return os_; }

void Visit(const IntImm *x) override;
void Visit(const UIntImm *x) override;
void Visit(const FloatImm *x) override;
Expand Down
43 changes: 43 additions & 0 deletions cinn/lang/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,48 @@ class Module {
Shared<_Module_> module_;
};

/**
* A struct representing an argument to a lowered function. Used for specifying the function signature of generated
* code.
*/
struct Argument {
//! The name of the argument.
std::string name;

enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar};

//! Number of the dimensions of buffer.
uint32_t ndims{0};

//! The type of the buffer or scalar.
Type type;

bool is_buffer() const { return kind == Kind::kBuffer; }
bool is_scalar() const { return kind == Kind::kScalar; }

Argument() {}
Argument(const std::string& name, Kind kind, const Type& type, int ndims)
: name(name), kind(kind), type(type), ndims(ndims) {}

explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {}
};

/**
* Definition of a lowered function. Note that, it should be functional.
*/
struct LoweredFunc {
//! The name of this function.
std::string name;

//! The Arguments used in the body of the function.
std::vector<Argument> args;

//! Body of this function.
Expr body;

LoweredFunc(const std::string& name, const std::vector<Argument>& args, const Expr& body)
: name(name), args(args), body(body) {}
};

} // namespace lang
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
cc_library(runtime SRCS
intrinsic.cc
buffer.cc
cinn_runtime.cc
DEPS common ir)
Empty file added cinn/runtime/cinn_runtime.cc
Empty file.

0 comments on commit 73ec6d2

Please sign in to comment.