Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,25 @@ def Cxx_ClassType : Cxx_Type<"Class", "class", [MutableType]> {

}

def Cxx_FunctionType : Cxx_Type<"Function", "function"> {
let parameters = (ins
ArrayRefParameter<"mlir::Type">:$inputs
, ArrayRefParameter<"mlir::Type">:$results
, "bool":$variadic
);

let assemblyFormat = "`<` $inputs `,` $results `,` $variadic `>`";

let extraClassDeclaration = [{
auto clone(mlir::TypeRange inputs, mlir::TypeRange results) const -> FunctionType;
}];
}

// ops

def Cxx_FuncOp : Cxx_Op<"func", [FunctionOpInterface, IsolatedFromAbove]> {
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
TypeAttrOf<Cxx_FunctionType>:$function_type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);

Expand Down Expand Up @@ -134,10 +148,9 @@ def Cxx_ReturnOp : Cxx_Op<"return", [Pure, HasParent<"FuncOp">, Terminator]> {

def Cxx_CallOp : Cxx_Op<"call"> {
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$inputs,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
FlatSymbolRefAttr:$callee
, Variadic<AnyType>:$inputs
, OptionalAttr<TypeAttrOf<Cxx_FunctionType>>:$var_callee_type
);

let results = (outs Optional<AnyType>:$result);
Expand Down
4 changes: 3 additions & 1 deletion src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
resultTypes.push_back(convertType(returnType));
}

auto funcType = builder_.getFunctionType(inputTypes, resultTypes);
auto funcType =
mlir::cxx::FunctionType::get(builder_.getContext(), inputTypes,
resultTypes, functionType->isVariadic());

std::string name;

Expand Down
8 changes: 6 additions & 2 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,12 @@ auto Codegen::ExpressionVisitor::operator()(CallExpressionAST* ast)
}

auto op = gen.builder_.create<mlir::cxx::CallOp>(
loc, resultTypes, funcOp.getSymName(), arguments, mlir::ArrayAttr{},
mlir::ArrayAttr{});
loc, resultTypes, funcOp.getSymName(), arguments, mlir::TypeAttr{});

if (functionType->isVariadic()) {
op.setVarCalleeType(
cast<mlir::cxx::FunctionType>(gen.convertType(functionType)));
}

return ExpressionResult{op.getResult()};
};
Expand Down
11 changes: 10 additions & 1 deletion src/mlir/cxx/mlir/convert_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,16 @@ auto Codegen::ConvertType::operator()(const RvalueReferenceType* type)
}

auto Codegen::ConvertType::operator()(const FunctionType* type) -> mlir::Type {
return getExprType();
mlir::SmallVector<mlir::Type> inputs;
for (auto argType : type->parameterTypes()) {
inputs.push_back(gen.convertType(argType));
}
mlir::SmallVector<mlir::Type> results;
if (!control()->is_void(type->returnType())) {
results.push_back(gen.convertType(type->returnType()));
}
return gen.builder_.getType<mlir::cxx::FunctionType>(inputs, results,
type->isVariadic());
}

auto Codegen::ConvertType::operator()(const ClassType* type) -> mlir::Type {
Expand Down
11 changes: 9 additions & 2 deletions src/mlir/cxx/mlir/cxx_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ void CxxDialect::initialize() {
}

void FuncOp::print(OpAsmPrinter &p) {
const auto isVariadic = getFunctionType().getVariadic();
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
p, *this, isVariadic, getFunctionTypeAttrName(), getArgAttrsAttrName(),
getResAttrsAttrName());
}

auto FuncOp::parse(OpAsmParser &parser, OperationState &result) -> ParseResult {
Expand Down Expand Up @@ -160,6 +161,12 @@ auto StoreOp::verify() -> LogicalResult {
return success();
}

auto FunctionType::clone(TypeRange inputs, TypeRange results) const
-> FunctionType {
return get(getContext(), llvm::to_vector(inputs), llvm::to_vector(results),
getVariadic());
}

auto ClassType::getNamed(MLIRContext *context, StringRef name) -> ClassType {
return Base::get(context, name);
}
Expand Down
30 changes: 29 additions & 1 deletion src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class FuncOpLowering : public OpConversionPattern<cxx::FuncOp> {
? LLVM::LLVMVoidType::get(getContext())
: resultTypes.front();

const auto isVarArg = false;
const auto isVarArg = funcType.getVariadic();

auto llvmFuncType =
LLVM::LLVMFunctionType::get(returnType, argumentTypes, isVarArg);
Expand Down Expand Up @@ -159,6 +159,12 @@ class CallOpLowering : public OpConversionPattern<cxx::CallOp> {
auto llvmCallOp = rewriter.create<LLVM::CallOp>(
op.getLoc(), resultTypes, adaptor.getCallee(), adaptor.getInputs());

if (op.getVarCalleeType().has_value()) {
auto varCalleeType =
typeConverter->convertType(op.getVarCalleeType().value());
llvmCallOp.setVarCalleeType(cast<LLVM::LLVMFunctionType>(varCalleeType));
}

rewriter.replaceOp(op, llvmCallOp);
return success();
}
Expand Down Expand Up @@ -1251,6 +1257,28 @@ void CxxToLLVMLoweringPass::runOnOperation() {
return LLVM::LLVMArrayType::get(elementType, size);
});

typeConverter.addConversion([&](cxx::FunctionType type) -> Type {
SmallVector<Type> inputs;
for (auto argType : type.getInputs()) {
auto convertedType = typeConverter.convertType(argType);
inputs.push_back(convertedType);
}
SmallVector<Type> results;
for (auto resultType : type.getResults()) {
auto convertedType = typeConverter.convertType(resultType);
results.push_back(convertedType);
}
if (results.size() > 1) {
return {};
}
if (results.empty()) {
results.push_back(LLVM::LLVMVoidType::get(type.getContext()));
}
auto context = type.getContext();
return LLVM::LLVMFunctionType::get(context, results.front(), inputs,
type.getVariadic());
});

DenseMap<cxx::ClassType, Type> convertedClassTypes;
typeConverter.addConversion([&](cxx::ClassType type) -> Type {
if (auto it = convertedClassTypes.find(type);
Expand Down
Loading