diff --git a/src/ast/analysis/Type.cpp b/src/ast/analysis/Type.cpp index 1ba9272fcd8..28e718c5351 100644 --- a/src/ast/analysis/Type.cpp +++ b/src/ast/analysis/Type.cpp @@ -884,26 +884,31 @@ bool TypeAnalysis::isMultiResultFunctor(const Functor& functor) { fatal("Missing functor type."); } -IntrinsicFunctors TypeAnalysis::validOverloads(const IntrinsicFunctor& inf) const { - auto typeAttrs = [&](const Argument* arg) -> std::set { - std::set tyAttrs; - if (const auto* inf = dynamic_cast(arg)) { - if (hasValidTypeInfo(inf)) { - tyAttrs.insert(getFunctorReturnType(inf)); - return tyAttrs; - } +std::set TypeAnalysis::getTypeAttributes(const Argument* arg) const { + std::set typeAttributes; + + if (const auto* inf = dynamic_cast(arg)) { + // intrinsic functor type is its return type if its set + if (hasValidTypeInfo(inf)) { + typeAttributes.insert(getFunctorReturnType(inf)); + return typeAttributes; } - auto&& types = getTypes(arg); - if (types.isAll()) - return {TypeAttribute::Signed, TypeAttribute::Unsigned, TypeAttribute::Float, - TypeAttribute::Symbol, TypeAttribute::Record}; - - for (auto&& ty : types) - tyAttrs.insert(getTypeAttribute(ty)); - return tyAttrs; - }; - auto retTys = typeAttrs(&inf); - auto argTys = map(inf.getArguments(), typeAttrs); + } + + const auto& types = getTypes(arg); + if (types.isAll()) { + return {TypeAttribute::Signed, TypeAttribute::Unsigned, TypeAttribute::Float, TypeAttribute::Symbol, + TypeAttribute::Record}; + } + for (const auto& type : types) { + typeAttributes.insert(getTypeAttribute(type)); + } + return typeAttributes; +} + +IntrinsicFunctors TypeAnalysis::validOverloads(const IntrinsicFunctor& inf) const { + auto retTys = getTypeAttributes(&inf); + auto argTys = map(inf.getArguments(), [&](const Argument* arg) { return getTypeAttributes(arg); }); IntrinsicFunctors functorInfos = contains(functorInfo, &inf) ? functorBuiltIn(getPolymorphicOperator(&inf)) diff --git a/src/ast/analysis/Type.h b/src/ast/analysis/Type.h index ff0a4835db3..7f37da705d1 100644 --- a/src/ast/analysis/Type.h +++ b/src/ast/analysis/Type.h @@ -26,6 +26,7 @@ #include "ast/analysis/Analysis.h" #include "ast/analysis/TypeSystem.h" #include +#include #include #include #include @@ -67,9 +68,10 @@ class TypeAnalysis : public Analysis { // Checks whether an argument has been assigned a valid type bool hasValidTypeInfo(const Argument* argument) const; + std::set getTypeAttributes(const Argument* arg) const; + /** -- Functor-related methods -- */ IntrinsicFunctors validOverloads(const ast::IntrinsicFunctor& inf) const; - TypeAttribute getFunctorReturnType(const Functor* functor) const; TypeAttribute getFunctorArgType(const Functor* functor, const size_t idx) const; const std::vector& getFunctorArgTypes(const UserDefinedFunctor& udf) const;