From 48ccef7cdee20685a406ed8e4c6fcec38568c5f5 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 15 Mar 2023 17:27:04 -0400 Subject: [PATCH 1/6] [NFC] Store the list of opened pack parameters as a flat array in opened generic environments Finding these is very hot for these environments, so doing it once is a pretty nice win in both speed and code complexity. I'm not actually using this yet. --- include/swift/AST/GenericEnvironment.h | 9 +++++++++ lib/AST/ASTContext.cpp | 6 ++++-- lib/AST/GenericEnvironment.cpp | 28 +++++++++++++++++++++++++- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/include/swift/AST/GenericEnvironment.h b/include/swift/AST/GenericEnvironment.h index 70257d9326875..ac4963fc14d8f 100644 --- a/include/swift/AST/GenericEnvironment.h +++ b/include/swift/AST/GenericEnvironment.h @@ -129,6 +129,12 @@ class alignas(1 << DeclAlignInBits) GenericEnvironment final /// generic signature. ArrayRef getContextTypes() const; + /// Retrieve the array of opened pack parameters for this opened-element + /// environment. This is parallel to the array of element parameters, + /// i.e. the innermost generic context. + MutableArrayRef getOpenedPackParams(); + ArrayRef getOpenedPackParams() const; + /// Get the nested type storage, allocating it if required. NestedTypeStorage &getOrCreateNestedTypeStorage(); @@ -201,6 +207,9 @@ class alignas(1 << DeclAlignInBits) GenericEnvironment final /// Retrieve the UUID for an opened element environment. UUID getOpenedElementUUID() const; + /// Return the number of opened pack parameters. + unsigned getNumOpenedPackParams() const; + void forEachPackElementArchetype( llvm::function_ref function) const; diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 0a154a5640405..be8926291d5d9 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -5164,10 +5164,12 @@ GenericEnvironment::forOpenedElement(GenericSignature signature, // Allocate and construct the new environment. unsigned numGenericParams = signature.getGenericParams().size(); + unsigned numOpenedParams = signature.getInnermostGenericParams().size(); size_t bytes = totalSizeToAlloc( - 0, 0, 1, numGenericParams); + OpenedElementEnvironmentData, + Type>( + 0, 0, 1, numGenericParams + numOpenedParams); void *mem = ctx.Allocate(bytes, alignof(GenericEnvironment)); auto *genericEnv = new (mem) GenericEnvironment(signature, uuid, shapeClass, diff --git a/lib/AST/GenericEnvironment.cpp b/lib/AST/GenericEnvironment.cpp index 2790a3976d725..7f1d24fba2063 100644 --- a/lib/AST/GenericEnvironment.cpp +++ b/lib/AST/GenericEnvironment.cpp @@ -63,7 +63,8 @@ size_t GenericEnvironment::numTrailingObjects( } size_t GenericEnvironment::numTrailingObjects(OverloadToken) const { - return getGenericParams().size(); + return getGenericParams().size() + + (getKind() == Kind::OpenedElement ? getNumOpenedPackParams() : 0); } /// Retrieve the array containing the context types associated with the @@ -82,6 +83,21 @@ ArrayRef GenericEnvironment::getContextTypes() const { getGenericParams().size()); } +unsigned GenericEnvironment::getNumOpenedPackParams() const { + assert(getKind() == Kind::OpenedElement); + return getGenericSignature().getInnermostGenericParams().size(); +} + +MutableArrayRef GenericEnvironment::getOpenedPackParams() { + auto begin = getTrailingObjects() + getGenericParams().size(); + return MutableArrayRef(begin, getNumOpenedPackParams()); +} + +ArrayRef GenericEnvironment::getOpenedPackParams() const { + auto begin = getTrailingObjects() + getGenericParams().size(); + return ArrayRef(begin, getNumOpenedPackParams()); +} + TypeArrayView GenericEnvironment::getGenericParams() const { return getGenericSignature().getGenericParams(); @@ -230,6 +246,16 @@ GenericEnvironment::GenericEnvironment(GenericSignature signature, // Clear out the memory that holds the context types. std::uninitialized_fill(getContextTypes().begin(), getContextTypes().end(), Type()); + + // Fill in the array of opened pack parameters. + auto openedPacksBuffer = getOpenedPackParams(); + unsigned i = 0; + for (auto param : signature.getGenericParams()) { + if (!param->isParameterPack()) continue; + if (!signature->haveSameShape(param, shapeClass)) continue; + openedPacksBuffer[i++] = param; + } + assert(i == openedPacksBuffer.size()); } void GenericEnvironment::addMapping(GenericParamKey key, From 7a8d8b4997ac861f8dd2c28b7f415221186c2530 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 15 Mar 2023 17:45:55 -0400 Subject: [PATCH 2/6] Fix the mapping of pack types into opened element environments. First, we need this to work on both lowered and unlowered types, so Type::subst is problematic: it'll assert if it sees a type like SILFunctionType. In this case, the substitution is simple enough that that's never a problem, but Type::subst doesn't know that, and the assertion is generally a good one. Second, we need this to not recurse into nested pack expansions. Third, we need this to not mess around with any existing element archetypes we might see in the type, so mapping in and out of context is not really okay. Fortunately, because we're mapping between structures (pack and element archetypes) that are guaranteed to have the same constraints, this transformation is really easy and we can just do it with transformRec. --- include/swift/AST/GenericEnvironment.h | 16 +++ lib/AST/GenericEnvironment.cpp | 147 ++++++++++++++++++------- 2 files changed, 123 insertions(+), 40 deletions(-) diff --git a/include/swift/AST/GenericEnvironment.h b/include/swift/AST/GenericEnvironment.h index ac4963fc14d8f..e2b441aed4883 100644 --- a/include/swift/AST/GenericEnvironment.h +++ b/include/swift/AST/GenericEnvironment.h @@ -294,10 +294,26 @@ class alignas(1 << DeclAlignInBits) GenericEnvironment final /// Map a contextual type containing parameter packs to a contextual /// type in the opened element generic context. + /// + /// This operation only makes sense if the generic environment that the + /// pack archetypes are contextual in matches the generic signature + /// of this environment. That will be true for opened element + /// environments coming straight out of the type checker, such as + /// the one in a PackExpansionExpr, or opened element environments + /// created directly from the current environment. It is not + /// reliable for opened element environments in arbitrary SIL functions. Type mapContextualPackTypeIntoElementContext(Type type) const; /// Map a contextual type containing parameter packs to a contextual /// type in the opened element generic context. + /// + /// This operation only makes sense if the generic environment that the + /// pack archetypes are contextual in matches the generic signature + /// of this environment. That will be true for opened element + /// environments coming straight out of the type checker, such as + /// the one in a PackExpansionExpr, or opened element environments + /// created directly from the current environment. It is not + /// reliable for opened element environments in arbitrary SIL functions. CanType mapContextualPackTypeIntoElementContext(CanType type) const; /// Map a type containing pack element type parameters to a contextual diff --git a/lib/AST/GenericEnvironment.cpp b/lib/AST/GenericEnvironment.cpp index 7f1d24fba2063..2ced66b92fdca 100644 --- a/lib/AST/GenericEnvironment.cpp +++ b/lib/AST/GenericEnvironment.cpp @@ -148,6 +148,53 @@ UUID GenericEnvironment::getOpenedElementUUID() const { return getTrailingObjects()->uuid; } +namespace { + +struct FindOpenedElementParam { + ArrayRef openedPacks; + TypeArrayView packElementParams; + + FindOpenedElementParam(const GenericEnvironment *env, + ArrayRef openedPacks) + : openedPacks(openedPacks), + packElementParams( + env->getGenericSignature().getInnermostGenericParams()) { + assert(openedPacks.size() == packElementParams.size()); + } + + GenericTypeParamType *operator()(Type packParam) { + for (auto i : indices(openedPacks)) { + if (openedPacks[i]->isEqual(packParam)) + return packElementParams[i]; + } + llvm_unreachable("parameter was not an opened pack parameter"); + } +}; + +struct FindElementArchetypeForOpenedPackParam { + FindOpenedElementParam findElementParam; + QueryInterfaceTypeSubstitutions getElementArchetype; + + FindElementArchetypeForOpenedPackParam(const GenericEnvironment *env, + ArrayRef openedPacks) + : findElementParam(env, openedPacks), getElementArchetype(env) {} + + + ElementArchetypeType *operator()(Type interfaceType) { + assert(interfaceType->isTypeParameter()); + if (auto member = interfaceType->getAs()) { + auto baseArchetype = (*this)(member->getBase()); + return baseArchetype->getNestedType(member->getAssocType()) + ->castTo(); + } + assert(interfaceType->is()); + return getElementArchetype(findElementParam(interfaceType)) + ->castTo(); + } +}; + +} + void GenericEnvironment::forEachPackElementArchetype( llvm::function_ref function) const { auto packElements = getGenericSignature().getInnermostGenericParams(); @@ -613,25 +660,51 @@ Type GenericEnvironment::mapTypeIntoContext(GenericTypeParamType *type) const { Type GenericEnvironment::mapContextualPackTypeIntoElementContext(Type type) const { + assert(getKind() == Kind::OpenedElement); + assert(!type->hasTypeParameter() && "expected contextual type"); + if (!type->hasArchetype()) return type; - // FIXME: this is potentially wrong if there are multiple - // openings in play at once, because we really shouldn't touch - // other element archetypes. - return mapPackTypeIntoElementContext(type->mapTypeOutOfContext()); + auto sig = getGenericSignature(); + auto shapeClass = getOpenedElementShapeClass(); + + FindElementArchetypeForOpenedPackParam + findElementArchetype(this, getOpenedPackParams()); + + return type.transformRec([&](TypeBase *ty) -> Optional { + // We're only directly substituting pack archetypes. + auto archetype = ty->getAs(); + if (!archetype) { + // Don't recurse into nested pack expansions. + if (ty->is()) + return Type(ty); + + // Recurse into any other type. + return None; + } + + auto rootArchetype = cast(archetype->getRoot()); + + // TODO: assert that the generic environment of the pack archetype + // matches the signature that was originally opened to make this + // environment. Unfortunately, that isn't a trivial check because of + // the extra opened-element parameters. + + // If the archetype isn't the shape that was opened by this + // environment, ignore it. + auto rootParam = cast( + rootArchetype->getInterfaceType().getPointer()); + assert(rootParam->isParameterPack()); + if (!sig->haveSameShape(rootParam, shapeClass)) + return Type(ty); + + return Type(findElementArchetype(archetype->getInterfaceType())); + }); } CanType GenericEnvironment::mapContextualPackTypeIntoElementContext(CanType type) const { - if (!type->hasArchetype()) return type; - - // FIXME: this is potentially wrong if there are multiple - // openings in play at once, because we really shouldn't touch - // other element archetypes. - // FIXME: if we do this properly, there's no way for this rewrite - // to produce a non-canonical type. - return mapPackTypeIntoElementContext(type->mapTypeOutOfContext()) - ->getCanonicalType(); + return CanType(mapContextualPackTypeIntoElementContext(Type(type))); } Type @@ -641,40 +714,34 @@ GenericEnvironment::mapPackTypeIntoElementContext(Type type) const { auto sig = getGenericSignature(); auto shapeClass = getOpenedElementShapeClass(); - QueryInterfaceTypeSubstitutions substitutions(this); - llvm::SmallDenseMap elementParamForPack; - auto packElements = sig.getInnermostGenericParams(); - auto elementDepth = packElements.front()->getDepth(); - - for (auto *genericParam : sig.getGenericParams()) { - if (genericParam->getDepth() == elementDepth) - break; - - if (!genericParam->isParameterPack()) - continue; - - if (!sig->haveSameShape(genericParam, shapeClass)) - continue; - - auto elementIndex = elementParamForPack.size(); - elementParamForPack[{genericParam}] = packElements[elementIndex]; - } + FindElementArchetypeForOpenedPackParam + findElementArchetype(this, getOpenedPackParams()); // Map the interface type to the element type by stripping // away the isParameterPack bit before mapping type parameters // to archetypes. - return type.subst([&](SubstitutableType *type) { - auto *genericParam = type->getAs(); - if (!genericParam) - return Type(); + return type.transformRec([&](TypeBase *ty) -> Optional { + // We're only directly substituting pack parameters. + if (!ty->isTypeParameter()) { + // Don't recurse into nested pack expansions; just map it into + // context. + if (ty->is()) + return mapTypeIntoContext(ty); + + // Recurse into any other type. + return None; + } - if (auto *elementParam = elementParamForPack[{genericParam}]) - return substitutions(elementParam); + // Just do normal mapping for types that are not rooted in + // opened type parameters. + auto rootParam = ty->getRootGenericParam(); + if (!rootParam->isParameterPack() || + !sig->haveSameShape(rootParam, shapeClass)) + return mapTypeIntoContext(ty); - return substitutions(genericParam); - }, LookUpConformanceInSignature(sig.getPointer())); + return Type(findElementArchetype(ty)); + }); } Type From 9ab4dc494cd6024fad863f985e7c695030f5df1f Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 15 Mar 2023 23:01:25 -0400 Subject: [PATCH 3/6] [NFC] Add better APIs for parallel destructuring of orig+subst types As I've been iterating on this work, I've been gradually mulling these over, and I think this is the way to go for now. These should make it a lot less cumbersome to write these kinds of traversals correctly. The intent is to the sunset the existing expanded-components stuff after I do a similar pass for function parameters. --- include/swift/SIL/AbstractionPattern.h | 60 +++++++- lib/SIL/IR/AbstractionPattern.cpp | 195 ++++++++++++++++--------- lib/SIL/IR/SILFunctionType.cpp | 91 +++++------- lib/SILGen/ResultPlan.cpp | 60 +++----- lib/SILGen/ResultPlan.h | 2 +- lib/SILGen/SILGenApply.cpp | 18 ++- lib/SILGen/SILGenProlog.cpp | 86 +++++------ lib/SILGen/SILGenStmt.cpp | 84 +++++------ 8 files changed, 336 insertions(+), 260 deletions(-) diff --git a/include/swift/SIL/AbstractionPattern.h b/include/swift/SIL/AbstractionPattern.h index 660fa5ac6595c..8353b47027ac6 100644 --- a/include/swift/SIL/AbstractionPattern.h +++ b/include/swift/SIL/AbstractionPattern.h @@ -913,6 +913,9 @@ class AbstractionPattern { bool hasCachingKey() const { // Only the simplest Kind::Type pattern has a caching key; we // don't want to try to unique by Clang node. + // + // Even if we support Clang nodes someday, we *cannot* cache + // by the open-coded patterns like Tuple and PackExpansion. return getKind() == Kind::Type || getKind() == Kind::Opaque || getKind() == Kind::Discard; } @@ -1216,9 +1219,7 @@ class AbstractionPattern { case Kind::Invalid: llvm_unreachable("querying invalid abstraction pattern!"); case Kind::Opaque: - return typename CanTypeWrapperTraits::type(); case Kind::Tuple: - return typename CanTypeWrapperTraits::type(); case Kind::OpaqueFunction: case Kind::OpaqueDerivativeFunction: return typename CanTypeWrapperTraits::type(); @@ -1275,7 +1276,7 @@ class AbstractionPattern { /// Is the given tuple type a valid substitution of this abstraction /// pattern? - bool matchesTuple(CanTupleType substType); + bool matchesTuple(CanTupleType substType) const; bool isTuple() const { switch (getKind()) { @@ -1346,6 +1347,40 @@ class AbstractionPattern { return { { this, 0 }, { this, getNumTupleElements() } }; } + /// Perform a parallel visitation of the elements of a tuple type, + /// preserving structure about where pack expansions appear in the + /// original type and how many elements of the substituted type they + /// expand to. + /// + /// This pattern must be a tuple pattern. + /// + /// Calls handleScalar or handleExpansion as appropriate for each + /// element of the original tuple, in order. + void forEachTupleElement(CanTupleType substType, + llvm::function_ref + handleScalar, + llvm::function_ref + handleExpansion) const; + + /// Perform a parallel visitation of the elements of a tuple type, + /// expanding the elements of the type. This preserves the structure + /// of the *substituted* tuple type: it will be called once per element + /// of the substituted type, in order. The original element trappings + /// are also provided for convenience. + /// + /// This pattern must match the substituted type, but it may be an + /// opaque pattern. + void forEachExpandedTupleElement(CanTupleType substType, + llvm::function_ref handleElement) const; + /// Is the given pack type a valid substitution of this abstraction /// pattern? bool matchesPack(CanPackType substType); @@ -1420,13 +1455,20 @@ class AbstractionPattern { /// the abstraction pattern for an element type. AbstractionPattern getPackElementType(unsigned index) const; - /// Give that the value being abstracted is a pack expansion type, return the - /// underlying pattern type. + /// Given that the value being abstracted is a pack expansion type, + /// return the underlying pattern type. + /// + /// If you're looking for getPackExpansionCountType(), it deliberately + /// does not exist. Count types are not lowered types, and the original + /// count types are not relevant to lowering. Only the substituted + /// components and expansion counts are significant. AbstractionPattern getPackExpansionPatternType() const; - /// Give that the value being abstracted is a pack expansion type, return the - /// underlying count type. - AbstractionPattern getPackExpansionCountType() const; + /// Given that the value being abstracted is a pack expansion type, + /// return the appropriate pattern type for the given expansion + /// component. + AbstractionPattern getPackExpansionComponentType(CanType substType) const; + AbstractionPattern getPackExpansionComponentType(bool isExpansion) const; /// Given that the value being abstracted is a function, return the /// abstraction pattern for its result type. @@ -1486,6 +1528,8 @@ class AbstractionPattern { void forEachPackExpandedComponent( llvm::function_ref fn) const; + size_t getNumPackExpandedComponents() const; + SmallVector getPackExpandedComponents() const; /// If this pattern refers to a foreign ObjC method that was imported as diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index ac1f9d2bfd56c..00a6acb6cece8 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -282,7 +282,7 @@ LayoutConstraint AbstractionPattern::getLayoutConstraint() const { } } -bool AbstractionPattern::matchesTuple(CanTupleType substType) { +bool AbstractionPattern::matchesTuple(CanTupleType substType) const { switch (getKind()) { case Kind::Invalid: llvm_unreachable("querying invalid abstraction pattern!"); @@ -311,26 +311,25 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) { LLVM_FALLTHROUGH; case Kind::Tuple: { size_t nextSubstIndex = 0; - auto nextComponentIsAcceptable = - [&](AbstractionPattern origComponentType) -> bool { + auto nextComponentIsAcceptable = [&](bool isPackExpansion) -> bool { if (nextSubstIndex == substType->getNumElements()) return false; auto substComponentType = substType.getElementType(nextSubstIndex++); - return (origComponentType.isPackExpansion() == - isa(substComponentType)); + return (isPackExpansion == isa(substComponentType)); }; - for (size_t i = 0, n = getNumTupleElements(); i != n; ++i) { - auto elt = getTupleElementType(i); - if (elt.isPackExpansion()) { - bool fail = false; - elt.forEachPackExpandedComponent([&](AbstractionPattern component) { - if (!nextComponentIsAcceptable(component)) - fail = true; - }); - if (fail) return false; - } else { - if (!nextComponentIsAcceptable(elt)) - return false; + for (auto elt : getTupleElementTypes()) { + bool isPackExpansion = elt.isPackExpansion(); + if (isPackExpansion && elt.GenericSubs) { + auto origExpansion = cast(elt.getType()); + auto substShape = cast( + origExpansion.getCountType().subst(elt.GenericSubs) + ->getCanonicalType()); + for (auto shapeElt : substShape.getElementTypes()) { + if (!nextComponentIsAcceptable(isa(shapeElt))) + return false; + } + } else if (!nextComponentIsAcceptable(isPackExpansion)) { + return false; } } return nextSubstIndex == substType->getNumElements(); @@ -469,6 +468,87 @@ bool AbstractionPattern::doesTupleContainPackExpansionType() const { llvm_unreachable("bad kind"); } +void AbstractionPattern::forEachTupleElement(CanTupleType substType, + llvm::function_ref + handleScalar, + llvm::function_ref + handleExpansion) const { + assert(isTuple() && "can only call on a tuple expansion"); + assert(matchesTuple(substType)); + + size_t substEltIndex = 0; + auto substEltTypes = substType.getElementTypes(); + for (size_t origEltIndex : range(getNumTupleElements())) { + auto origEltType = getTupleElementType(origEltIndex); + if (!origEltType.isPackExpansion()) { + handleScalar(origEltIndex, substEltIndex, + origEltType, substEltTypes[substEltIndex]); + substEltIndex++; + } else { + auto numComponents = origEltType.getNumPackExpandedComponents(); + handleExpansion(origEltIndex, substEltIndex, origEltType, + substEltTypes.slice(substEltIndex, numComponents)); + substEltIndex += numComponents; + } + } + assert(substEltIndex == substEltTypes.size()); +} + +void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType, + llvm::function_ref + handleElement) const { + assert(matchesTuple(substType)); + + auto substEltTypes = substType.getElementTypes(); + + // Handle opaque patterns by just iterating the substituted components. + if (!isTuple()) { + for (auto i : indices(substEltTypes)) { + handleElement(getTupleElementType(i), substEltTypes[i], + substType->getElement(i)); + } + return; + } + + // For non-opaque patterns, we have to iterate the original components + // in order to match things up properly, but we'll still end up calling + // once per substituted element. + size_t substEltIndex = 0; + for (size_t origEltIndex : range(getNumTupleElements())) { + auto origEltType = getTupleElementType(origEltIndex); + if (!origEltType.isPackExpansion()) { + handleElement(origEltType, substEltTypes[substEltIndex], + substType->getElement(substEltIndex)); + substEltIndex++; + } else { + auto origPatternType = origEltType.getPackExpansionPatternType(); + for (auto i : range(origEltType.getNumPackExpandedComponents())) { + (void) i; + auto substEltType = substEltTypes[substEltIndex]; + // When the substituted type is a pack expansion, pass down + // the original element type so that it's *also* a pack expansion. + // Clients expect to look through this structure in parallel on + // both types. The count is misleading, but normal usage won't + // access it, and there's nothing we could provide that *wouldn't* + // be misleading in one way or another. + handleElement(isa(substEltType) + ? origEltType : origPatternType, + substEltType, substType->getElement(substEltIndex)); + substEltIndex++; + } + } + } + assert(substEltIndex == substEltTypes.size()); +} + static CanType getCanPackElementType(CanType type, unsigned index) { return cast(type).getElementType(index); } @@ -541,6 +621,17 @@ bool AbstractionPattern::matchesPack(CanPackType substType) { llvm_unreachable("bad kind"); } +AbstractionPattern +AbstractionPattern::getPackExpansionComponentType(CanType substType) const { + return getPackExpansionComponentType(isa(substType)); +} + +AbstractionPattern +AbstractionPattern::getPackExpansionComponentType(bool isExpansion) const { + assert(isPackExpansion()); + return isExpansion ? *this : getPackExpansionPatternType(); +} + static CanType getPackExpansionPatternType(CanType type) { return cast(type).getPatternType(); } @@ -584,49 +675,6 @@ AbstractionPattern AbstractionPattern::getPackExpansionPatternType() const { llvm_unreachable("bad kind"); } -static CanType getPackExpansionCountType(CanType type) { - return cast(type).getCountType(); -} - -AbstractionPattern AbstractionPattern::getPackExpansionCountType() const { - switch (getKind()) { - case Kind::Invalid: - llvm_unreachable("querying invalid abstraction pattern!"); - case Kind::ObjCMethodType: - case Kind::CurriedObjCMethodType: - case Kind::PartialCurriedObjCMethodType: - case Kind::CFunctionAsMethodType: - case Kind::CurriedCFunctionAsMethodType: - case Kind::PartialCurriedCFunctionAsMethodType: - case Kind::CXXMethodType: - case Kind::CurriedCXXMethodType: - case Kind::PartialCurriedCXXMethodType: - case Kind::Tuple: - case Kind::OpaqueFunction: - case Kind::OpaqueDerivativeFunction: - case Kind::ObjCCompletionHandlerArgumentsType: - case Kind::ClangType: - llvm_unreachable("pattern for function or tuple cannot be for " - "pack expansion type"); - - case Kind::Opaque: - return *this; - - case Kind::Type: - if (isTypeParameterOrOpaqueArchetype()) - return AbstractionPattern::getOpaque(); - return AbstractionPattern(getGenericSubstitutions(), - getGenericSignature(), - ::getPackExpansionCountType(getType())); - - case Kind::Discard: - return AbstractionPattern::getDiscard( - getGenericSubstitutions(), getGenericSignature(), - ::getPackExpansionCountType(getType())); - } - llvm_unreachable("bad kind"); -} - SmallVector AbstractionPattern::getPackExpandedComponents() const { SmallVector result; @@ -636,6 +684,21 @@ AbstractionPattern::getPackExpandedComponents() const { return result; } +size_t AbstractionPattern::getNumPackExpandedComponents() const { + assert(isPackExpansion()); + assert(getKind() == Kind::Type || getKind() == Kind::Discard); + + // If we don't have substitutions, we should be walking parallel + // structure; take a single element. + if (!GenericSubs) return 1; + + // Otherwise, substitute the expansion shape. + auto origExpansion = cast(getType()); + auto substShape = cast( + origExpansion.getCountType().subst(GenericSubs)->getCanonicalType()); + return substShape->getNumElements(); +} + void AbstractionPattern::forEachPackExpandedComponent( llvm::function_ref fn) const { assert(isPackExpansion()); @@ -665,7 +728,7 @@ void AbstractionPattern::forEachPackExpandedComponent( } default: - return fn(*this); + llvm_unreachable("not a pack expansion"); } llvm_unreachable("bad kind"); } @@ -692,8 +755,9 @@ AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const { llvm_unreachable("not handled yet"); case Kind::Discard: llvm_unreachable("operation not needed on discarded abstractions yet"); - case Kind::Opaque: case Kind::Tuple: + llvm_unreachable("cannot apply move-only wrappers to open-coded patterns"); + case Kind::Opaque: case Kind::Type: if (auto mvi = dyn_cast(getType())) { return AbstractionPattern(getGenericSubstitutions(), @@ -727,8 +791,9 @@ AbstractionPattern AbstractionPattern::addingMoveOnlyWrapper() const { llvm_unreachable("not handled yet"); case Kind::Discard: llvm_unreachable("operation not needed on discarded abstractions yet"); - case Kind::Opaque: case Kind::Tuple: + llvm_unreachable("cannot add move only wrapper to open-coded pattern"); + case Kind::Opaque: case Kind::Type: if (isa(getType())) return *this; @@ -1686,7 +1751,7 @@ AbstractionPattern::operator==(const AbstractionPattern &other) const { } } return true; - + case Kind::Type: case Kind::Discard: return OrigType == other.OrigType @@ -1996,7 +2061,7 @@ class SubstFunctionTypePatternVisitor auto substPatternType = visit(pack->getPatternType(), pattern.getPackExpansionPatternType()); auto substCountType = visit(pack->getCountType(), - pattern.getPackExpansionCountType()); + AbstractionPattern::getOpaque()); SmallVector rootParameterPacks; substPatternType->getTypeParameterPacks(rootParameterPacks); diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 2b1ed944fa47f..dc459162950ad 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -1272,43 +1272,34 @@ class DestructureResults { // Recur into tuples. if (origType.isTuple()) { auto substTupleType = cast(substType); - size_t substEltIndex = 0; - for (size_t origEltIndex = 0, n = origType.getNumTupleElements(); - origEltIndex != n; ++origEltIndex) { - AbstractionPattern origEltType = - origType.getTupleElementType(origEltIndex); - + origType.forEachTupleElement(substTupleType, + [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origEltType, CanType substEltType) { // If the original element type is not a pack expansion, just // pull off the next substituted element type. - if (!origEltType.isPackExpansion()) { - CanType substEltType = - substTupleType.getElementType(substEltIndex++); - destructure(origEltType, substEltType); - continue; - } + destructure(origEltType, substEltType); + }, [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origExpansionType, + CanTupleEltTypeArrayRef substEltTypes) { // If the original element type is a pack expansion, build a // lowered pack type for the substituted components it expands to. - bool indirect = origEltType.arePackElementsPassedIndirectly(TC); - SmallVector packElts; + bool indirect = origExpansionType.arePackElementsPassedIndirectly(TC); - origEltType.forEachPackExpandedComponent( - [&](AbstractionPattern origComponentType) { - CanType substEltType = - substTupleType.getElementType(substEltIndex++); - SILType substEltTy = - TC.getLoweredType(origComponentType, substEltType, context); - packElts.push_back(substEltTy.getASTType()); - }); + SmallVector packElts; + for (auto substEltType : substEltTypes) { + auto origComponentType + = origExpansionType.getPackExpansionComponentType(substEltType); + CanType loweredEltTy = + TC.getLoweredRValueType(context, origComponentType, substEltType); + packElts.push_back(loweredEltTy); + }; SILPackType::ExtInfo extInfo(indirect); auto packType = SILPackType::get(TC.Context, extInfo, packElts); SILResultInfo result(packType, ResultConvention::Pack); Results.push_back(result); - } - - assert(substEltIndex == substTupleType->getNumElements() && - "didn't exhaust the substituted type"); + }); return; } @@ -1684,36 +1675,30 @@ class DestructureInputs { bool isNonDifferentiable) { assert(ownership != ValueOwnership::InOut); assert(origType.isTuple()); - assert(origType.matchesTuple(substType)); - - unsigned numOrigElts = origType.getNumTupleElements(); - unsigned nextSubstEltIndex = 0; - for (unsigned i = 0; i != numOrigElts; ++i) { - auto origEltType = origType.getTupleElementType(i); - if (!origEltType.isPackExpansion()) { - auto substEltType = - substType.getElementType(nextSubstEltIndex++); - visit(ownership, forSelf, origEltType, substEltType, - isNonDifferentiable); - } else { - SmallVector packElts; - origEltType.forEachPackExpandedComponent( - [&](AbstractionPattern origEltComponentType) { - auto substEltType = substType.getElementType(nextSubstEltIndex++); - auto eltTy = TC.getLoweredType(origEltComponentType, substEltType, - expansion); - packElts.push_back(eltTy.getASTType()); - }); - bool indirect = origEltType.arePackElementsPassedIndirectly(TC); - SILPackType::ExtInfo extInfo(/*address*/ indirect); - auto packTy = SILPackType::get(TC.Context, extInfo, packElts); + origType.forEachTupleElement(substType, + [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origEltType, CanType substEltType) { + visit(ownership, forSelf, origEltType, substEltType, + isNonDifferentiable); + }, [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origExpansionType, + CanTupleEltTypeArrayRef substEltTypes) { + SmallVector packElts; + for (auto substEltType : substEltTypes) { + auto origComponentType + = origExpansionType.getPackExpansionComponentType(substEltType); + auto loweredEltTy = + TC.getLoweredRValueType(expansion, origComponentType, substEltType); + packElts.push_back(loweredEltTy); + }; - addPackParameter(packTy, ownership, isNonDifferentiable); - } - } + bool indirect = origExpansionType.arePackElementsPassedIndirectly(TC); + SILPackType::ExtInfo extInfo(/*address*/ indirect); + auto packTy = SILPackType::get(TC.Context, extInfo, packElts); - assert(nextSubstEltIndex == substType->getNumElements()); + addPackParameter(packTy, ownership, isNonDifferentiable); + }); } /// Add a parameter that we derived from deconstructing the diff --git a/lib/SILGen/ResultPlan.cpp b/lib/SILGen/ResultPlan.cpp index 570a32212adc6..5f0e10b2ea774 100644 --- a/lib/SILGen/ResultPlan.cpp +++ b/lib/SILGen/ResultPlan.cpp @@ -414,27 +414,27 @@ class PackExpansionResultPlan : public ResultPlan { PackExpansionResultPlan(ResultPlanBuilder &builder, SILValue packAddr, MutableArrayRef inits, - ArrayRef origTypes, + AbstractionPattern origExpansionType, CanTupleEltTypeArrayRef substEltTypes) : PackAddr(packAddr) { auto packTy = packAddr->getType().castTo(); auto formalPackType = CanPackType::get(packTy->getASTContext(), substEltTypes); + auto origPatternType = origExpansionType.getPackExpansionPatternType(); ComponentPlans.reserve(inits.size()); for (auto i : indices(inits)) { auto &init = inits[i]; - auto origType = origTypes[i]; CanType substEltType = substEltTypes[i]; if (isa(substEltType)) { ComponentPlans.emplace_back( builder.buildPackExpansionIntoPack(packAddr, formalPackType, i, - init.get(), origType)); + init.get(), origPatternType)); } else { ComponentPlans.emplace_back( builder.buildScalarIntoPack(packAddr, formalPackType, i, - init.get(), origType)); + init.get(), origPatternType)); } } } @@ -616,30 +616,21 @@ class TupleInitializationResultPlan final : public ResultPlan { // Create plans for all the sub-initializations. eltPlans.reserve(origType.getNumTupleElements()); - auto substEltTypes = substType.getElementTypes(); - - size_t nextSubstEltIndex = 0; - - for (auto origEltType : origType.getTupleElementTypes()) { - if (origEltType.isPackExpansion()) { - auto origComponentTypes = origEltType.getPackExpandedComponents(); - auto numComponents = origComponentTypes.size(); - auto i = nextSubstEltIndex; - nextSubstEltIndex += numComponents; - auto componentInits = eltInits.slice(i, numComponents); - auto substComponentTypes = substEltTypes.slice(i, numComponents); - eltPlans.push_back(builder.buildForPackExpansion(componentInits, - origComponentTypes, - substComponentTypes)); - } else { - auto i = nextSubstEltIndex++; - CanType substEltType = substEltTypes[i]; - Initialization *eltInit = eltInits[i].get(); - eltPlans.push_back(builder.build(eltInit, origEltType, substEltType)); - } - } - - assert(nextSubstEltIndex == substType->getNumElements()); + origType.forEachTupleElement(substType, + [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origEltType, + CanType substEltType) { + Initialization *eltInit = eltInits[substEltIndex].get(); + eltPlans.push_back(builder.build(eltInit, origEltType, substEltType)); + }, + [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origExpansionType, + CanTupleEltTypeArrayRef substEltTypes) { + auto componentInits = eltInits.slice(substEltIndex, substEltTypes.size()); + eltPlans.push_back(builder.buildForPackExpansion(componentInits, + origExpansionType, + substEltTypes)); + }); } RValue finish(SILGenFunction &SGF, SILLocation loc, @@ -1151,10 +1142,9 @@ ResultPlanPtr ResultPlanBuilder::buildForScalar(Initialization *init, ResultPlanPtr ResultPlanBuilder:: buildForPackExpansion(MutableArrayRef inits, - ArrayRef origTypes, + AbstractionPattern origExpansionType, CanTupleEltTypeArrayRef substTypes) { - assert(inits.size() == origTypes.size() && - inits.size() == substTypes.size()); + assert(inits.size() == substTypes.size()); // Pack expansions in the original result type always turn into // a single @pack_out result. @@ -1172,7 +1162,7 @@ ResultPlanPtr ResultPlanBuilder:: SGF.emitTemporaryPackAllocation(loc, packTy.getObjectType()); return ResultPlanPtr(new PackExpansionResultPlan(*this, packAddr, inits, - origTypes, substTypes)); + origExpansionType, substTypes)); } ResultPlanPtr @@ -1180,8 +1170,7 @@ ResultPlanBuilder::buildPackExpansionIntoPack(SILValue packAddr, CanPackType formalPackType, unsigned componentIndex, Initialization *init, - AbstractionPattern origType) { - assert(origType.isPackExpansion()); + AbstractionPattern origPatternType) { assert(init && init->canPerformPackExpansionInitialization()); // Create an opened-element environment sufficient for working with @@ -1232,10 +1221,9 @@ ResultPlanBuilder::buildPackExpansionIntoPack(SILValue packAddr, }); // The result plan will write into `init` during finish(). - origType = origType.getPackExpansionPatternType(); return ResultPlanPtr( new PackTransformResultPlan(packAddr, formalPackType, - componentIndex, init, origType, + componentIndex, init, origPatternType, calleeTypeInfo.getOverrideRep())); } diff --git a/lib/SILGen/ResultPlan.h b/lib/SILGen/ResultPlan.h index d936e37a794c4..3322cc8ba1cc1 100644 --- a/lib/SILGen/ResultPlan.h +++ b/lib/SILGen/ResultPlan.h @@ -93,7 +93,7 @@ struct ResultPlanBuilder { AbstractionPattern origType, CanTupleType substType); ResultPlanPtr buildForPackExpansion(MutableArrayRef inits, - ArrayRef origTypes, + AbstractionPattern origPatternType, CanTupleEltTypeArrayRef substTypes); ResultPlanPtr buildPackExpansionIntoPack(SILValue packAddr, CanPackType formalPackType, diff --git a/lib/SILGen/SILGenApply.cpp b/lib/SILGen/SILGenApply.cpp index 98450810cac66..37ccee01ec6aa 100644 --- a/lib/SILGen/SILGenApply.cpp +++ b/lib/SILGen/SILGenApply.cpp @@ -2203,10 +2203,20 @@ static unsigned getFlattenedValueCount(AbstractionPattern origType, // Otherwise, add up the elements. unsigned count = 0; - for (auto i : indices(substTuple.getElementTypes())) { - count += getFlattenedValueCount(origType.getTupleElementType(i), - substTuple.getElementType(i)); - } + origType.forEachTupleElement(substTuple, + [&](unsigned origEltIndex, + unsigned substEltIndex, + AbstractionPattern origEltType, + CanType substEltType) { + // Recursively expand scalar components. + count += getFlattenedValueCount(origEltType, substEltType); + }, [&](unsigned origEltIndex, + unsigned substEltIndex, + AbstractionPattern origExpansionType, + CanTupleEltTypeArrayRef substEltTypes) { + // Expansion components turn into a single parameter. + count++; + }); return count; } diff --git a/lib/SILGen/SILGenProlog.cpp b/lib/SILGen/SILGenProlog.cpp index 6066b60482315..023cfc5acb993 100644 --- a/lib/SILGen/SILGenProlog.cpp +++ b/lib/SILGen/SILGenProlog.cpp @@ -241,31 +241,31 @@ class EmitBBArguments : public CanTypeVisitor elements; - for (AbstractionPattern origEltType : orig.getTupleElementTypes()) { - ManagedValue elt; - - // Reabstraction can give us original types that are pack - // expansions without having pack expansions in the result. - // In this case, we do not need to force emission into a pack - // expansion. - if (origEltType.isPackExpansion()) { - assert(init); - expandPack(origEltType, t, nextSubstEltIndex, eltInits, elements); - } else { - size_t i = nextSubstEltIndex++; - elt = visit(t.getElementType(i), origEltType, - init ? eltInits[i].get() : nullptr); - assert((init != nullptr) == (elt.isInContext())); - if (!elt.isInContext()) - elements.push_back(elt); - - if (elt.hasCleanup()) - canBeGuaranteed = false; - } - } - assert(nextSubstEltIndex == t->getNumElements()); + orig.forEachTupleElement(t, + [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origEltType, + CanType substEltType) { + auto elt = visit(substEltType, origEltType, + init ? eltInits[substEltIndex].get() : nullptr); + assert((init != nullptr) == (elt.isInContext())); + if (!elt.isInContext()) + elements.push_back(elt); + + if (elt.hasCleanup()) + canBeGuaranteed = false; + }, [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origExpansionType, + CanTupleEltTypeArrayRef substEltTypes) { + assert(init); + expandPack(origExpansionType, substEltTypes, substEltIndex, + eltInits, elements); + }); // If we emitted into a context, we're done. if (init) { @@ -312,11 +312,13 @@ class EmitBBArguments : public CanTypeVisitor eltInits, SmallVectorImpl &eltMVs) { + assert(substEltTypes.size() == eltInits.size()); + // The next parameter is a pack which corresponds to some number of // components in the tuple. Some of them may be pack expansions. // Either copy/move them into the tuple (necessary if there are any @@ -331,20 +333,15 @@ class EmitBBArguments : public CanTypeVisitorgetType().castTo(); + auto origPatternType = origExpansionType.getPackExpansionPatternType(); + auto inducedPackType = - substTupleType.getInducedPackType(nextSubstEltIndex, - packTy->getNumElements()); - - unsigned nextPackIndex = 0; - origEltType.forEachPackExpandedComponent( - [&](AbstractionPattern origComponentType) { - size_t substEltIndex = nextSubstEltIndex++; - CanType substComponentType = - substTupleType.getElementType(substEltIndex); - Initialization *componentInit = - eltInits.empty() ? nullptr : eltInits[substEltIndex].get(); + CanPackType::get(SGF.getASTContext(), substEltTypes); - auto packComponentIndex = nextPackIndex++; + for (auto packComponentIndex : indices(substEltTypes)) { + CanType substComponentType = substEltTypes[packComponentIndex]; + Initialization *componentInit = + eltInits.empty() ? nullptr : eltInits[packComponentIndex].get(); auto packComponentTy = packTy->getSILElementType(packComponentIndex); auto substExpansionType = @@ -359,13 +356,13 @@ class EmitBBArguments : public CanTypeVisitormapContextualPackTypeIntoElementContext(substEltType); } - auto result = handleScalar(eltAddrMV, origEltType, substEltType, + auto result = handleScalar(eltAddrMV, origPatternType, substEltType, eltInit, /*inout*/ false); assert(result.isInContext()); (void) result; }); }); componentInit->finishInitialization(SGF); - }); - assert(nextPackIndex == packTy->getNumElements()); + } } }; } // end anonymous namespace diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index 248d4ca71f6fd..49d0ecb27bfc3 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -498,47 +498,37 @@ createIndirectResultInit(SILGenFunction &SGF, SILValue addr, static void preparePackResultInit(SILGenFunction &SGF, SILLocation loc, AbstractionPattern origExpansionType, - CanTupleType resultTupleType, - size_t &nextResultEltIndex, + CanTupleEltTypeArrayRef resultEltTypes, SILArgument *packAddr, SmallVectorImpl &cleanups, SmallVectorImpl &inits) { - assert(origExpansionType.isPackExpansion()); - auto origComponentTypes = origExpansionType.getPackExpandedComponents(); - auto loweredPackType = packAddr->getType().castTo(); - assert(loweredPackType->getNumElements() == origComponentTypes.size() && + assert(loweredPackType->getNumElements() == resultEltTypes.size() && "mismatched pack components; possible missing substitutions on orig type?"); // If the pack expanded to nothing, there shouldn't be any initializers // for it in our context. - if (origComponentTypes.empty()) { + if (resultEltTypes.empty()) { return; } + auto origPatternType = origExpansionType.getPackExpansionPatternType(); + // Induce a formal pack type from the slice of the tuple elements. CanPackType formalPackType = - resultTupleType.getInducedPackType(nextResultEltIndex, - origComponentTypes.size()); - nextResultEltIndex += origComponentTypes.size(); + CanPackType::get(SGF.getASTContext(), resultEltTypes); - for (auto componentIndex : indices(origComponentTypes)) { - auto origComponentType = origComponentTypes[componentIndex]; + for (auto componentIndex : indices(resultEltTypes)) { auto resultComponentType = formalPackType.getElementType(componentIndex); auto loweredComponentType = loweredPackType->getElementType(componentIndex); - assert(origComponentType.isPackExpansion() + assert(isa(loweredComponentType) == isa(resultComponentType) && "need expansions in similar places"); - assert(origComponentType.isPackExpansion() - == isa(loweredComponentType) && - "need expansions in similar places"); // If we have a pack expansion, the initializer had better be a // pack expansion expression, and we'll generate a loop for it. // Preserve enough information to do this properly. - if (origComponentType.isPackExpansion()) { - auto origPatternType = - origComponentType.getPackExpansionPatternType(); + if (isa(resultComponentType)) { auto resultPatternType = cast(resultComponentType).getPatternType(); auto expectedPatternTy = SILType::getPrimitiveAddressType( @@ -568,7 +558,7 @@ preparePackResultInit(SILGenFunction &SGF, SILLocation loc, SILType::getPrimitiveAddressType(loweredComponentType)); inits.push_back(createIndirectResultInit(SGF, eltAddr, - origComponentType, + origPatternType, resultComponentType, cleanups)); } @@ -590,33 +580,33 @@ prepareIndirectResultInit(SILGenFunction &SGF, SILLocation loc, auto tupleInit = new TupleInitialization(resultTupleType); tupleInit->SubInitializations.reserve(resultTupleType->getNumElements()); - size_t nextResultEltIndex = 0; - for (size_t origEltIndex = 0, e = origResultType.getNumTupleElements(); - origEltIndex < e; ++origEltIndex) { - auto origEltType = origResultType.getTupleElementType(origEltIndex); - if (origEltType.isPackExpansion()) { - assert(allResults[0].isPack()); - assert(SGF.silConv.isSILIndirect(allResults[0])); - allResults = allResults.slice(1); - - auto packAddr = indirectResultAddrs[0]; - indirectResultAddrs = indirectResultAddrs.slice(1); - - preparePackResultInit(SGF, loc, origEltType, resultTupleType, - nextResultEltIndex, packAddr, - cleanups, tupleInit->SubInitializations); - } else { - auto substEltType = - resultTupleType.getElementType(nextResultEltIndex++); - auto eltInit = prepareIndirectResultInit(SGF, loc, fnTypeForResults, - origEltType, substEltType, - allResults, - directResults, - indirectResultAddrs, cleanups); - tupleInit->SubInitializations.push_back(std::move(eltInit)); - } - } - assert(nextResultEltIndex == resultTupleType->getNumElements()); + origResultType.forEachTupleElement(resultTupleType, + [&](unsigned origEltIndex, + unsigned substEltIndex, + AbstractionPattern origEltType, + CanType substEltType) { + auto eltInit = prepareIndirectResultInit(SGF, loc, fnTypeForResults, + origEltType, substEltType, + allResults, + directResults, + indirectResultAddrs, cleanups); + tupleInit->SubInitializations.push_back(std::move(eltInit)); + }, + [&](unsigned origEltIndex, + unsigned substEltIndex, + AbstractionPattern origExpansionType, + CanTupleEltTypeArrayRef substEltTypes) { + assert(allResults[0].isPack()); + assert(SGF.silConv.isSILIndirect(allResults[0])); + allResults = allResults.slice(1); + + auto packAddr = indirectResultAddrs[0]; + indirectResultAddrs = indirectResultAddrs.slice(1); + + preparePackResultInit(SGF, loc, origExpansionType, substEltTypes, + packAddr, + cleanups, tupleInit->SubInitializations); + }); return InitializationPtr(tupleInit); } From d524c7d230c89c02bb20ceeb6f10e67061e7dc02 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 15 Mar 2023 23:03:55 -0400 Subject: [PATCH 4/6] Use correct parallel destructuring when lowering tuple types. --- lib/SIL/IR/TypeLowering.cpp | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/lib/SIL/IR/TypeLowering.cpp b/lib/SIL/IR/TypeLowering.cpp index 8d20fec031fb3..fe71adc24b2d0 100644 --- a/lib/SIL/IR/TypeLowering.cpp +++ b/lib/SIL/IR/TypeLowering.cpp @@ -2534,24 +2534,21 @@ static CanTupleType computeLoweredTupleType(TypeConverter &tc, TypeExpansionContext context, AbstractionPattern origType, CanTupleType substType) { - assert(origType.matchesTuple(substType)); + if (substType->getNumElements() == 0) return substType; - // Does the lowered tuple type differ from the substituted type in - // any interesting way? bool changed = false; SmallVector loweredElts; loweredElts.reserve(substType->getNumElements()); - for (auto i : indices(substType->getElementTypes())) { - auto origEltType = origType.getTupleElementType(i); - auto substEltType = substType.getElementType(i); - - CanType loweredTy = - tc.getLoweredRValueType(context, origEltType, substEltType); - changed = (changed || substEltType != loweredTy); - - loweredElts.push_back(substType->getElement(i).getWithType(loweredTy)); - } + origType.forEachExpandedTupleElement(substType, + [&](AbstractionPattern origEltType, + CanType substEltType, + const TupleTypeElt &elt) { + auto loweredTy = + tc.getLoweredRValueType(context, origEltType, substEltType); + if (loweredTy != substEltType) changed = true; + loweredElts.push_back(elt.getWithType(loweredTy)); + }); if (!changed) return substType; @@ -3014,12 +3011,8 @@ TypeConverter::computeLoweredRValueType(TypeExpansionContext forExpansion, substPatternType); changed |= (loweredSubstPatternType != substPatternType); - CanType substCountType = substPackExpansionType.getCountType(); - CanType loweredSubstCountType = TC.getLoweredRValueType( - forExpansion, - origType.getPackExpansionCountType(), - substCountType); - changed |= (loweredSubstCountType != substCountType); + // Count types are AST types and are not lowered. + CanType loweredSubstCountType = substPackExpansionType.getCountType(); if (!changed) return substPackExpansionType; From 4499e3d0551774b6d9638a4409a891d0696c3ef3 Mon Sep 17 00:00:00 2001 From: John McCall Date: Thu, 16 Mar 2023 01:14:23 -0400 Subject: [PATCH 5/6] [NFC] Introduce new APIs for traversing orig/subst parameters in parallel --- include/swift/SIL/AbstractionPattern.h | 24 +++++++++ lib/SIL/IR/AbstractionPattern.cpp | 50 +++++++++++++++++ lib/SIL/IR/SILFunctionType.cpp | 74 +++++++++++++------------- lib/SILGen/SILGenApply.cpp | 29 +++++----- 4 files changed, 124 insertions(+), 53 deletions(-) diff --git a/include/swift/SIL/AbstractionPattern.h b/include/swift/SIL/AbstractionPattern.h index 8353b47027ac6..0cea229f56207 100644 --- a/include/swift/SIL/AbstractionPattern.h +++ b/include/swift/SIL/AbstractionPattern.h @@ -1488,6 +1488,30 @@ class AbstractionPattern { /// parameters in the pattern. unsigned getNumFunctionParams() const; + /// Perform a parallel visitation of the parameters of a function. + /// + /// If this is a function pattern, calls handleScalar or + /// handleExpansion as appropriate for each parameter of the + /// original function, in order. + /// + /// If this is not a function pattern, calls handleScalar for each + /// parameter of the substituted function type. Functions with + /// pack expansions cannot be abstracted legally this way. + void forEachFunctionParam(AnyFunctionType::CanParamArrayRef substParams, + bool ignoreFinalParam, + llvm::function_ref + handleScalar, + llvm::function_ref + handleExpansion) const; + /// Given that the value being abstracted is optional, return the /// abstraction pattern for its object type. AbstractionPattern getOptionalObjectType() const; diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index 00a6acb6cece8..c7b76d661df00 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -1221,6 +1221,56 @@ unsigned AbstractionPattern::getNumFunctionParams() const { return cast(getType()).getParams().size(); } +void AbstractionPattern:: +forEachFunctionParam(AnyFunctionType::CanParamArrayRef substParams, + bool ignoreFinalParam, + llvm::function_ref + handleScalar, + llvm::function_ref + handleExpansion) const { + // Honor ignoreFinalParam for the substituted parameters on all paths. + if (ignoreFinalParam) substParams = substParams.drop_back(); + + // If this isn't a function type, use the substituted type. + if (isTypeParameterOrOpaqueArchetype()) { + for (auto substParamIndex : indices(substParams)) { + handleScalar(substParamIndex, substParamIndex, + substParams[substParamIndex].getParameterFlags(), + *this, substParams[substParamIndex]); + } + return; + } + + size_t numOrigParams = getNumFunctionParams(); + if (ignoreFinalParam) numOrigParams--; + + size_t substParamIndex = 0; + for (auto origParamIndex : range(numOrigParams)) { + auto origParamType = getFunctionParamType(origParamIndex); + if (origParamType.isPackExpansion()) { + unsigned numComponents = origParamType.getNumPackExpandedComponents(); + handleExpansion(origParamIndex, substParamIndex, + getFunctionParamFlags(origParamIndex), origParamType, + substParams.slice(substParamIndex, numComponents)); + substParamIndex += numComponents; + } else { + handleScalar(origParamIndex, substParamIndex, + getFunctionParamFlags(origParamIndex), origParamType, + substParams[substParamIndex]); + substParamIndex++; + } + } + assert(substParamIndex == substParams.size()); +} + static CanType getOptionalObjectType(CanType type) { auto objectType = type.getOptionalObjectType(); assert(objectType && "type was not optional"); diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index dc459162950ad..e0016afa19c93 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -1525,12 +1525,18 @@ class DestructureInputs { origType.isTypeParameter() ? params.size() : origType.getNumFunctionParams(); - unsigned nextParamIndex = 0; + // If we're importing a freestanding foreign function as a member + // function, the formal types (subst and orig) will conspire to + // pretend that there is a self parameter in the position Swift + // expects it: the end of the parameter lists. In the lowered type, + // we need to put this in its proper place, which for static methods + // generally means dropping it entirely. + bool hasForeignSelf = Foreign.self.isImportAsMember(); + + // Is there a self parameter in the formal parameter lists? bool hasSelf = - (extInfoBuilder.hasSelfParam() || Foreign.self.isImportAsMember()); - unsigned numOrigNonSelfParams = - (hasSelf ? numOrigParams - 1 : numOrigParams); + (extInfoBuilder.hasSelfParam() || hasForeignSelf); TopLevelOrigType = origType; // If we have a foreign self parameter, set up the ForeignSelfInfo @@ -1538,7 +1544,7 @@ class DestructureInputs { if (Foreign.self.isInstance()) { assert(hasSelf && numOrigParams > 0); ForeignSelf = ForeignSelfInfo{ - origType.getFunctionParamType(numOrigNonSelfParams), + origType.getFunctionParamType(numOrigParams - 1), params.back() }; } @@ -1549,55 +1555,47 @@ class DestructureInputs { maybeAddForeignParameters(); // Process all the non-self parameters. - for (unsigned i = 0; i != numOrigNonSelfParams; ++i) { - auto origParamType = origType.getFunctionParamType(i); - + origType.forEachFunctionParam(params, hasSelf, + [&](unsigned origParamIndex, unsigned substParamIndex, + ParameterTypeFlags origFlags, + AbstractionPattern origParamType, + AnyFunctionType::CanParam substParam) { // If the parameter is not a pack expansion, just pull off the // next parameter and destructure it in parallel with the abstraction // pattern for the type. - if (!origParamType.isPackExpansion()) { - visit(origParamType, params[nextParamIndex++], /*forSelf*/false); - continue; - } - + visit(origParamType, substParam, /*forSelf*/false); + }, [&](unsigned origParamIndex, unsigned substParamIndex, + ParameterTypeFlags origFlags, + AbstractionPattern origExpansionType, + AnyFunctionType::CanParamArrayRef substParams) { // Otherwise, collect the substituted components into a pack. - - // If the parameter *is* a pack expansion, it must not be an - // opaque pattern, so we can safely call this. - auto origFlags = origType.getFunctionParamFlags(i); - SmallVector packElts; - origParamType.forEachPackExpandedComponent( - [&](AbstractionPattern origParamComponent) { - auto substParam = params[nextParamIndex++]; + for (auto substParam : substParams) { auto substParamType = substParam.getParameterType(); + auto origParamType = + origExpansionType.getPackExpansionComponentType(substParamType); + auto loweredParamTy = TC.getLoweredRValueType(expansion, + origParamType, substParamType); + packElts.push_back(loweredParamTy); + } - auto substTy = TC.getLoweredType(origParamType, substParamType, - expansion); - packElts.push_back(substTy.getASTType()); - }); - - bool indirect = origParamType.arePackElementsPassedIndirectly(TC); + bool indirect = origExpansionType.arePackElementsPassedIndirectly(TC); SILPackType::ExtInfo extInfo(/*address*/ indirect); auto packTy = SILPackType::get(TC.Context, extInfo, packElts); addPackParameter(packTy, origFlags.getValueOwnership(), origFlags.isNoDerivative()); - } + }); - // Process the self parameter. - if (hasSelf && Foreign.self.isImportAsMember()) { - // Drop the formal foreign self parameter at this point if we - // set it up earlier. - nextParamIndex++; - } else if (hasSelf) { - auto origParamType = origType.getFunctionParamType(numOrigNonSelfParams); - auto substParam = params[nextParamIndex++]; + // Process the self parameter. But if we have a formal foreign self + // parameter, we should have processed it earlier in a call to + // maybeAddForeignParameters(). + if (hasSelf && !hasForeignSelf) { + auto origParamType = origType.getFunctionParamType(numOrigParams - 1); + auto substParam = params.back(); visit(origParamType, substParam, /*forSelf*/true); } - assert(nextParamIndex == params.size()); - TopLevelOrigType = AbstractionPattern::getInvalid(); ForeignSelf = None; } diff --git a/lib/SILGen/SILGenApply.cpp b/lib/SILGen/SILGenApply.cpp index 37ccee01ec6aa..1ec8bbf912e71 100644 --- a/lib/SILGen/SILGenApply.cpp +++ b/lib/SILGen/SILGenApply.cpp @@ -3239,13 +3239,13 @@ class ArgEmitter { // Otherwise we need to emit a pack argument. } else { - auto origPackEltPatterns = - origFormalParamType.getPackExpandedComponents(); + auto numComponents = + origFormalParamType.getNumPackExpandedComponents(); auto argSourcesSlice = - argSources.slice(nextArgSourceIndex, origPackEltPatterns.size()); - emitPackArg(argSourcesSlice, origPackEltPatterns); - nextArgSourceIndex += origPackEltPatterns.size(); + argSources.slice(nextArgSourceIndex, numComponents); + emitPackArg(argSourcesSlice, origFormalParamType); + nextArgSourceIndex += numComponents; } } @@ -3690,7 +3690,7 @@ class ArgEmitter { } void emitPackArg(MutableArrayRef args, - ArrayRef origFormalTypes) { + AbstractionPattern origExpansionType) { // Adjust for the foreign error or async argument if necessary. maybeEmitForeignArgument(); @@ -3722,7 +3722,7 @@ class ArgEmitter { auto formalPackType = getFormalPackType(args); bool consumed = param.getConvention() == ParameterConvention::Pack_Owned; - emitIndirectIntoPack(args, origFormalTypes, pack, formalPackType, + emitIndirectIntoPack(args, origExpansionType, pack, formalPackType, consumed); } @@ -3735,12 +3735,10 @@ class ArgEmitter { } void emitIndirectIntoPack(MutableArrayRef args, - ArrayRef origFormalTypes, + AbstractionPattern origExpansionType, SILValue packAddr, CanPackType formalPackType, bool consumed) { - assert(args.size() == origFormalTypes.size()); - auto packTy = packAddr->getType().castTo(); assert(packTy->getNumElements() == args.size() && "wrong pack shape for arguments"); @@ -3749,11 +3747,14 @@ class ArgEmitter { for (auto i : indices(args)) { ArgumentSource &&arg = std::move(args[i]); - const AbstractionPattern &origFormalType = origFormalTypes[i]; auto expectedEltTy = packTy->getSILElementType(i); + bool isPackExpansion = expectedEltTy.is(); + AbstractionPattern origFormalType = + origExpansionType.getPackExpansionComponentType(isPackExpansion); + auto cleanup = CleanupHandle::invalid(); - if (origFormalType.isPackExpansion()) { + if (isPackExpansion) { cleanup = emitPackExpansionIntoPack(std::move(arg), origFormalType, expectedEltTy, consumed, @@ -4400,9 +4401,7 @@ struct ParamLowering { auto origParamType = origFormalType.getFunctionParamType(i); if (origParamType.isPackExpansion()) { count++; - origParamType.forEachPackExpandedComponent([&](AbstractionPattern) { - nextSubstParamIndex++; - }); + nextSubstParamIndex += origParamType.getNumPackExpandedComponents(); } else { auto substParam = substParams[nextSubstParamIndex++]; if (substParam.isInOut()) { From f3e7daf4785cc722d5e657e99d804932ec384bf5 Mon Sep 17 00:00:00 2001 From: John McCall Date: Thu, 16 Mar 2023 01:21:20 -0400 Subject: [PATCH 6/6] [NFC] Remove the now-dead PackExpanded accessors from AbstractionPattern --- include/swift/SIL/AbstractionPattern.h | 18 +++++------ lib/SIL/IR/AbstractionPattern.cpp | 43 -------------------------- 2 files changed, 7 insertions(+), 54 deletions(-) diff --git a/include/swift/SIL/AbstractionPattern.h b/include/swift/SIL/AbstractionPattern.h index 0cea229f56207..4507031111267 100644 --- a/include/swift/SIL/AbstractionPattern.h +++ b/include/swift/SIL/AbstractionPattern.h @@ -1543,19 +1543,15 @@ class AbstractionPattern { AbstractionPattern getObjCMethodAsyncCompletionHandlerType( CanType swiftCompletionHandlerType) const; - /// Given that this is a pack expansion, invoke the given callback for - /// each component of the substituted expansion of this pattern. The - /// pattern will be for a pack expansion type over a contextual type if - /// the substituted component is still a pack expansion. If there aren't - /// substitutions available, this will just invoke the callback with the - /// component. - void forEachPackExpandedComponent( - llvm::function_ref fn) const; - + /// Given that this is a pack expansion, return the number of components + /// that it should expand to. This, and the general correctness of + /// traversing variadically generic tuple and function types under + /// substitution, relies on substitutions having been set properly + /// on the abstraction pattern; without that, AbstractionPattern assumes + /// that every component expands to a single pack expansion component, + /// which will generally only work in specific situations. size_t getNumPackExpandedComponents() const; - SmallVector getPackExpandedComponents() const; - /// If this pattern refers to a foreign ObjC method that was imported as /// async, return the bridged-back-to-ObjC completion handler type. CanType getObjCMethodAsyncCompletionHandlerForeignType( diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index c7b76d661df00..5849d2fccd77a 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -675,15 +675,6 @@ AbstractionPattern AbstractionPattern::getPackExpansionPatternType() const { llvm_unreachable("bad kind"); } -SmallVector -AbstractionPattern::getPackExpandedComponents() const { - SmallVector result; - forEachPackExpandedComponent([&](AbstractionPattern pattern) { - result.push_back(pattern); - }); - return result; -} - size_t AbstractionPattern::getNumPackExpandedComponents() const { assert(isPackExpansion()); assert(getKind() == Kind::Type || getKind() == Kind::Discard); @@ -699,40 +690,6 @@ size_t AbstractionPattern::getNumPackExpandedComponents() const { return substShape->getNumElements(); } -void AbstractionPattern::forEachPackExpandedComponent( - llvm::function_ref fn) const { - assert(isPackExpansion()); - - switch (getKind()) { - case Kind::Type: - case Kind::Discard: { - // If we don't have generic substitutions, just produce this pattern. - if (!GenericSubs) return fn(*this); - auto origExpansion = cast(getType()); - - // Substitute the expansion shape. - auto substShape = cast( - origExpansion.getCountType().subst(GenericSubs)->getCanonicalType()); - - // Call the callback with each component of the substituted shape. - for (auto substShapeElt : substShape.getElementTypes()) { - CanType origEltType = origExpansion.getPatternType(); - if (auto substShapeEltExpansion = - dyn_cast(substShapeElt)) { - origEltType = CanPackExpansionType::get(origEltType, - substShapeEltExpansion.getCountType()); - } - fn(AbstractionPattern(GenericSubs, GenericSig, origEltType)); - } - return; - } - - default: - llvm_unreachable("not a pack expansion"); - } - llvm_unreachable("bad kind"); -} - AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const { switch (getKind()) { case Kind::Invalid: