Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ auto Codegen::newUniqueSymbolName(std::string_view prefix) -> std::string {
void Codegen::branch(mlir::Location loc, mlir::Block* block,
mlir::ValueRange operands) {
if (currentBlockMightHaveTerminator()) return;
builder_.create<mlir::cf::BranchOp>(loc, block, operands);
mlir::cf::BranchOp::create(builder_, loc, block, operands);
}

auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional<mlir::Value> {
Expand All @@ -83,7 +83,7 @@ auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional<mlir::Value> {
auto ptrType = builder_.getType<mlir::cxx::PointerType>(type);

auto loc = getLocation(var->location());
auto allocaOp = builder_.create<mlir::cxx::AllocaOp>(loc, ptrType);
auto allocaOp = mlir::cxx::AllocaOp::create(builder_, loc, ptrType);

locals_.emplace(var, allocaOp);

Expand All @@ -93,7 +93,7 @@ auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional<mlir::Value> {
auto Codegen::newTemp(const Type* type, SourceLocation loc)
-> mlir::cxx::AllocaOp {
auto ptrType = builder_.getType<mlir::cxx::PointerType>(convertType(type));
return builder_.create<mlir::cxx::AllocaOp>(getLocation(loc), ptrType);
return mlir::cxx::AllocaOp::create(builder_, getLocation(loc), ptrType);
}

auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
Expand Down Expand Up @@ -145,8 +145,8 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)

builder_.setInsertionPointToStart(module_.getBody());

auto func = builder_.create<mlir::cxx::FuncOp>(
loc, name, funcType, mlir::ArrayAttr{}, mlir::ArrayAttr{});
auto func = mlir::cxx::FuncOp::create(builder_, loc, name, funcType,
mlir::ArrayAttr{}, mlir::ArrayAttr{});

funcOps_.insert_or_assign(functionSymbol, func);

Expand All @@ -165,14 +165,14 @@ auto Codegen::getLocation(SourceLocation location) -> mlir::Location {
auto Codegen::emitTodoStmt(SourceLocation location, std::string_view message)
-> mlir::cxx::TodoStmtOp {
const auto loc = getLocation(location);
auto op = builder_.create<mlir::cxx::TodoStmtOp>(loc, message);
auto op = mlir::cxx::TodoStmtOp::create(builder_, loc, message);
return op;
}

auto Codegen::emitTodoExpr(SourceLocation location, std::string_view message)
-> mlir::cxx::TodoExprOp {
const auto loc = getLocation(location);
auto op = builder_.create<mlir::cxx::TodoExprOp>(loc, message);
auto op = mlir::cxx::TodoExprOp::create(builder_, loc, message);
return op;
}

Expand Down
46 changes: 34 additions & 12 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <cxx/ast.h>
#include <cxx/control.h>
#include <cxx/external_name_encoder.h>
#include <cxx/names.h>
#include <cxx/symbols.h>
#include <cxx/translation_unit.h>
#include <cxx/types.h>
Expand All @@ -37,6 +38,17 @@

namespace cxx {

namespace {

[[nodiscard]] auto is_global_namespace(Symbol* symbol) -> bool {
if (!symbol) return false;
if (!symbol->isNamespace()) return false;
if (symbol->parent()) return false;
return true;
}

} // namespace

struct Codegen::DeclarationVisitor {
Codegen& gen;

Expand Down Expand Up @@ -205,8 +217,8 @@ auto Codegen::DeclarationVisitor::operator()(SimpleDeclarationAST* ast)

const auto elementType = gen.convertType(var->type());

gen.builder_.create<mlir::cxx::StoreOp>(loc, expressionResult.value,
local.value());
mlir::cxx::StoreOp::create(gen.builder_, loc, expressionResult.value,
local.value());
}

return {};
Expand Down Expand Up @@ -365,7 +377,17 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
gen.getLocation(ast->functionBody->firstSourceLocation());
auto exitValueType = gen.convertType(returnType);
auto ptrType = gen.builder_.getType<mlir::cxx::PointerType>(exitValueType);
exitValue = gen.builder_.create<mlir::cxx::AllocaOp>(exitValueLoc, ptrType);
exitValue =
mlir::cxx::AllocaOp::create(gen.builder_, exitValueLoc, ptrType);

auto id = name_cast<Identifier>(functionSymbol->name());
if (id && id->name() == "main" &&
is_global_namespace(functionSymbol->parent())) {
auto zeroOp = mlir::cxx::IntConstantOp::create(
gen.builder_, loc, gen.convertType(gen.control()->getIntType()), 0);

mlir::cxx::StoreOp::create(gen.builder_, exitValueLoc, zeroOp, exitValue);
}
}

std::unordered_map<Symbol*, mlir::Value> locals;
Expand All @@ -389,8 +411,8 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
thisValue = gen.newTemp(classSymbol->type(), ast->firstSourceLocation());

// store the `this` pointer in the entry block
gen.builder_.create<mlir::cxx::StoreOp>(
loc, gen.entryBlock_->getArgument(0), thisValue);
mlir::cxx::StoreOp::create(gen.builder_, loc,
gen.entryBlock_->getArgument(0), thisValue);
}

FunctionParametersSymbol* params = nullptr;
Expand All @@ -408,11 +430,11 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
auto ptrType = gen.builder_.getType<mlir::cxx::PointerType>(type);

auto loc = gen.getLocation(arg->location());
auto allocaOp = gen.builder_.create<mlir::cxx::AllocaOp>(loc, ptrType);
auto allocaOp = mlir::cxx::AllocaOp::create(gen.builder_, loc, ptrType);

auto value = args[argc];
++argc;
gen.builder_.create<mlir::cxx::StoreOp>(loc, value, allocaOp);
mlir::cxx::StoreOp::create(gen.builder_, loc, value, allocaOp);

gen.locals_.emplace(arg, allocaOp);
}
Expand All @@ -430,7 +452,7 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
const auto endLoc = gen.getLocation(ast->lastSourceLocation());

if (!gen.builder_.getBlock()->mightHaveTerminator()) {
gen.builder_.create<mlir::cf::BranchOp>(endLoc, gen.exitBlock_);
mlir::cf::BranchOp::create(gen.builder_, endLoc, gen.exitBlock_);
}

gen.builder_.setInsertionPointToEnd(gen.exitBlock_);
Expand All @@ -439,13 +461,13 @@ auto Codegen::DeclarationVisitor::operator()(FunctionDefinitionAST* ast)
// We need to return a value of the correct type.
auto elementType = gen.exitValue_.getType().getElementType();

auto value = gen.builder_.create<mlir::cxx::LoadOp>(endLoc, elementType,
gen.exitValue_);
auto value = mlir::cxx::LoadOp::create(gen.builder_, endLoc, elementType,
gen.exitValue_);

gen.builder_.create<mlir::cxx::ReturnOp>(endLoc, value->getResults());
mlir::cxx::ReturnOp::create(gen.builder_, endLoc, value->getResults());
} else {
// If the function returns void, we don't need to return anything.
gen.builder_.create<mlir::cxx::ReturnOp>(endLoc);
mlir::cxx::ReturnOp::create(gen.builder_, endLoc);
}

// restore the state
Expand Down
Loading