Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Commit

Permalink
Refactor the TypeConverter to support more robust type conversions:
Browse files Browse the repository at this point in the history
* Support for 1->0 type mappings, i.e. when the argument is being removed.
* Reordering types when converting a type signature.
* Adding new inputs when converting a type signature.

This cl also lays down the initial foundation for supporting 1->N type mappings, but full support will come in a followup.

Moving forward, function signature changes will be driven by populating a SignatureConversion instance. This class contains all of the necessary information for adding/removing/remapping function signatures; e.g. addInputs, addResults, remapInputs, etc.

PiperOrigin-RevId: 254064665
  • Loading branch information
River707 authored and joker-eph committed Jun 20, 2019
1 parent 80a29b8 commit 1fc8140
Show file tree
Hide file tree
Showing 9 changed files with 573 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class LLVMType;
/// Conversion from types in the Standard dialect to the LLVM IR dialect.
class LLVMTypeConverter : public TypeConverter {
public:
using TypeConverter::convertType;

LLVMTypeConverter(MLIRContext *ctx);

/// Convert types to LLVM IR. This calls `convertAdditionalType` to convert
Expand All @@ -64,9 +66,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert function signatures to LLVM IR. In particular, convert functions
/// with multiple results into functions returning LLVM IR's structure type.
/// Use `convertType` to convert individual argument and result types.
FunctionType convertFunctionSignatureType(
FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override final;
LogicalResult convertSignature(FunctionType t,
ArrayRef<NamedAttributeList> argAttrs,
SignatureConversion &result) final;

/// LLVM IR module used to parse/create types.
llvm::Module *module;
Expand Down
6 changes: 4 additions & 2 deletions include/mlir/IR/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ class Block : public IRObjectWithUseList,
/// Add one argument to the argument list for each type specified in the list.
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);

/// Erase the argument at 'index' and remove it from the argument list.
void eraseArgument(unsigned index);
/// Erase the argument at 'index' and remove it from the argument list. If
/// 'updatePredTerms' is set to true, this argument is also removed from the
/// terminators of each predecessor to this block.
void eraseArgument(unsigned index, bool updatePredTerms = true);

unsigned getNumArguments() { return arguments.size(); }
BlockArgument *getArgument(unsigned i) { return arguments[i]; }
Expand Down
141 changes: 105 additions & 36 deletions include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,39 +113,109 @@ class TypeConverter {
public:
virtual ~TypeConverter() = default;

/// Derived classes must reimplement this hook if they need to convert
/// block or function argument types or function result types. If the target
/// dialect has support for custom first-class function types, convertType
/// should create those types for arguments of MLIR function type. It can be
/// used for values (constant, operands, results) of function type but not for
/// the function signatures. For the latter, convertFunctionSignatureType is
/// used instead.
///
/// For block attribute types, this function will be called for each attribute
/// individually.
///
/// If type conversion can fail, this function should return a
/// default-constructed Type. The failure will be then propagated to trigger
/// the pass failure.
/// This class provides all of the information necessary to convert a
/// FunctionType signature.
class SignatureConversion {
public:
SignatureConversion(unsigned numOrigInputs)
: remappedInputs(numOrigInputs) {}

/// This struct represents a range of new types that remap an existing
/// signature input.
struct InputMapping {
size_t inputNo, size;
};

/// Return the converted type signature.
FunctionType getConvertedType(MLIRContext *ctx) const {
return FunctionType::get(argTypes, resultTypes, ctx);
}

/// Return the argument types for the new signature.
ArrayRef<Type> getConvertedArgTypes() const { return argTypes; }

/// Return the result types for the new signature.
ArrayRef<Type> getConvertedResultTypes() const { return resultTypes; }

/// Returns the attributes for the arguments of the new signature.
ArrayRef<NamedAttributeList> getConvertedArgAttrs() const {
return argAttrs;
}

/// Get the input mapping for the given argument.
llvm::Optional<InputMapping> getInputMapping(unsigned input) const {
return remappedInputs[input];
}

//===------------------------------------------------------------------===//
// Conversion Hooks
//===------------------------------------------------------------------===//

/// Append new result types to the signature conversion.
void addResults(ArrayRef<Type> results);

/// Remap an input of the original signature with a new set of types. The
/// new types are appended to the new signature conversion.
void addInputs(unsigned origInputNo, ArrayRef<Type> types,
ArrayRef<NamedAttributeList> attrs = llvm::None);

/// Append new input types to the signature conversion, this should only be
/// used if the new types are not intended to remap an existing input.
void addInputs(ArrayRef<Type> types,
ArrayRef<NamedAttributeList> attrs = llvm::None);

/// Remap an input of the original signature with a range of types in the
/// new signature.
void remapInput(unsigned origInputNo, unsigned newInputNo,
unsigned newInputCount = 1);

private:
/// The remapping information for each of the original arguments.
SmallVector<llvm::Optional<InputMapping>, 4> remappedInputs;

/// The set of argument and results types.
SmallVector<Type, 4> argTypes, resultTypes;

/// The set of attributes for each new argument type.
SmallVector<NamedAttributeList, 4> argAttrs;
};

/// This hooks allows for converting a type. This function should return
/// failure if no valid conversion exists, success otherwise. If the new set
/// of types is empty, the type is removed and any usages of the existing
/// value are expected to be removed during conversion.
virtual LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);

/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type convert to on success, and a null type on failure.
virtual Type convertType(Type t) { return t; }

/// Derived classes must reimplement this hook if they need to change the
/// function signature during conversion. This function will be called on
/// a function type corresponding to a function signature and must produce the
/// converted MLIR function type.
///
/// Note: even if some target dialects have first-class function types, they
/// cannot be used at the top level of MLIR function signature.
///
/// The default behavior of this function is to call convertType on individual
/// function operands and results, and then create a new MLIR function type
/// from those.
/// Convert the given FunctionType signature. This functions returns a valid
/// SignatureConversion on success, None otherwise.
llvm::Optional<SignatureConversion>
convertSignature(FunctionType type, ArrayRef<NamedAttributeList> argAttrs);
llvm::Optional<SignatureConversion> convertSignature(FunctionType type) {
SmallVector<NamedAttributeList, 4> argAttrs(type.getNumInputs());
return convertSignature(type, argAttrs);
}

/// This hook allows for changing a FunctionType signature. This function
/// should populate 'result' with the new arguments and result on success,
/// otherwise return failure.
///
/// Post-condition: if the returned optional attribute list does not have
/// a value then no argument attribute conversion happened.
virtual FunctionType convertFunctionSignatureType(
FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs);
/// The default behavior of this function is to call 'convertType' on
/// individual function operands and results. Any argument attributes are
/// dropped if the resultant conversion is not a 1->1 mapping.
virtual LogicalResult convertSignature(FunctionType type,
ArrayRef<NamedAttributeList> argAttrs,
SignatureConversion &result);

/// This hook allows for converting a specific argument of a signature. It
/// takes as inputs the original argument input number, type, and attributes.
/// On success, this function should populate 'result' with any new mappings.
virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
NamedAttributeList attrs,
SignatureConversion &result);
};

/// This class describes a specific conversion target.
Expand Down Expand Up @@ -253,16 +323,15 @@ class ConversionTarget {
};

/// Convert the given module with the provided conversion patterns and type
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
/// conversion object. This function returns failure if a type conversion
/// failed, potentially leaving the IR in an invalid state.
LLVM_NODISCARD LogicalResult applyConversionPatterns(
Module &module, ConversionTarget &target, TypeConverter &converter,
OwningRewritePatternList &&patterns);

/// Convert the given functions with the provided conversion patterns. This will
/// convert as many of the operations within each function as possible given the
/// set of patterns. If conversion fails for specific functions, those functions
/// remains unmodified.
/// Convert the given functions with the provided conversion patterns. This
/// function returns failure if a type conversion failed, potentially leaving
/// the IR in an invalid state.
LLVM_NODISCARD
LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
ConversionTarget &target,
Expand Down
33 changes: 14 additions & 19 deletions lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,31 +985,26 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
}

// Convert function signatures using the stored LLVM IR module.
FunctionType LLVMTypeConverter::convertFunctionSignatureType(
FunctionType type, ArrayRef<NamedAttributeList> argAttrs,
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {

convertedArgAttrs.reserve(argAttrs.size());
for (auto attr : argAttrs)
convertedArgAttrs.push_back(attr);

SmallVector<Type, 8> argTypes;
for (auto t : type.getInputs()) {
auto converted = convertType(t);
if (!converted)
return {};
argTypes.push_back(converted);
}
LogicalResult
LLVMTypeConverter::convertSignature(FunctionType type,
ArrayRef<NamedAttributeList> argAttrs,
SignatureConversion &result) {
// Convert the original function arguments.
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (failed(convertSignatureArg(i, type.getInput(i), argAttrs[i], result)))
return failure();

// If function does not return anything, return immediately.
if (type.getNumResults() == 0)
return FunctionType::get(argTypes, {}, type.getContext());
return success();

// Otherwise pack the result types into a struct.
if (auto result = packFunctionResults(type.getResults()))
return FunctionType::get(argTypes, result, result.getContext());
if (auto packedRet = packFunctionResults(type.getResults())) {
result.addResults(packedRet);
return success();
}

return {};
return failure();
}

namespace {
Expand Down
6 changes: 5 additions & 1 deletion lib/IR/Block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,17 @@ auto Block::addArguments(ArrayRef<Type> types)
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
}

void Block::eraseArgument(unsigned index) {
void Block::eraseArgument(unsigned index, bool updatePredTerms) {
assert(index < arguments.size());

// Delete the argument.
delete arguments[index];
arguments.erase(arguments.begin() + index);

// If we aren't updating predecessors, there is nothing left to do.
if (!updatePredTerms)
return;

// Erase this argument from each of the predecessor's terminator.
for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE;
++predIt) {
Expand Down
Loading

0 comments on commit 1fc8140

Please sign in to comment.