diff --git a/docs/SIL.rst b/docs/SIL.rst index 34143b4724f97..ab25e2cce843c 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5148,6 +5148,21 @@ unchecked_bitwise_cast Bitwise copies an object of type ``A`` into a new object of type ``B`` of the same size or smaller. +unchecked_value_cast +```````````````````` +:: + + sil-instruction ::= 'unchecked_value_cast' sil-operand 'to' sil-type + + %1 = unchecked_value_cast %0 : $A to $B + +Bitwise copies an object of type ``A`` into a new layout-compatible object of +type ``B`` of the same size. + +This instruction is assumed to forward a fixed ownership (set upon its +construction) and lowers to 'unchecked_bitwise_cast' in non-ossa code. This +causes the cast to lose its guarantee of layout-compatibility. + ref_to_raw_pointer `````````````````` :: diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 10422ec523278..8eafabda8a62d 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -1102,6 +1102,12 @@ class SILBuilder { getSILDebugLocation(Loc), Op, Ty, getFunction(), C.OpenedArchetypes)); } + UncheckedValueCastInst *createUncheckedValueCast(SILLocation Loc, SILValue Op, + SILType Ty) { + return insert(UncheckedValueCastInst::create( + getSILDebugLocation(Loc), Op, Ty, getFunction(), C.OpenedArchetypes)); + } + RefToBridgeObjectInst *createRefToBridgeObject(SILLocation Loc, SILValue Ref, SILValue Bits) { auto Ty = SILType::getBridgeObjectType(getASTContext()); @@ -1847,7 +1853,18 @@ class SILBuilder { // Unchecked cast helpers //===--------------------------------------------------------------------===// - // Create the appropriate cast instruction based on result type. + /// Create the appropriate cast instruction based on result type. + /// + /// NOTE: We allow for non-layout compatible casts that shrink the underlying + /// type we are bit casting! + SingleValueInstruction * + createUncheckedReinterpretCast(SILLocation Loc, SILValue Op, SILType Ty); + + /// Create an appropriate cast instruction based on result type. + /// + /// NOTE: This assumes that the input and the result cast are layout + /// compatible. Reduces to createUncheckedReinterpretCast when ownership is + /// disabled. SingleValueInstruction *createUncheckedBitCast(SILLocation Loc, SILValue Op, SILType Ty); diff --git a/include/swift/SIL/SILCloner.h b/include/swift/SIL/SILCloner.h index 24d4fadca8dbe..e7c544992dddc 100644 --- a/include/swift/SIL/SILCloner.h +++ b/include/swift/SIL/SILCloner.h @@ -1538,6 +1538,16 @@ visitUncheckedBitwiseCastInst(UncheckedBitwiseCastInst *Inst) { getOpType(Inst->getType()))); } +template +void SILCloner::visitUncheckedValueCastInst( + UncheckedValueCastInst *Inst) { + getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope())); + recordClonedInstruction(Inst, getBuilder().createUncheckedValueCast( + getOpLocation(Inst->getLoc()), + getOpValue(Inst->getOperand()), + getOpType(Inst->getType()))); +} + template void SILCloner:: diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index d57e8454a096e..dd9e3731b9ac9 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -4576,6 +4576,23 @@ class UncheckedBitwiseCastInst final SILFunction &F, SILOpenedArchetypesState &OpenedArchetypes); }; +/// Bitwise copy a value into another value of the same size. +class UncheckedValueCastInst final + : public UnaryInstructionWithTypeDependentOperandsBase< + SILInstructionKind::UncheckedValueCastInst, UncheckedValueCastInst, + OwnershipForwardingConversionInst> { + friend SILBuilder; + + UncheckedValueCastInst(SILDebugLocation DebugLoc, SILValue Operand, + ArrayRef TypeDependentOperands, SILType Ty) + : UnaryInstructionWithTypeDependentOperandsBase( + DebugLoc, Operand, TypeDependentOperands, Ty, + Operand.getOwnershipKind()) {} + static UncheckedValueCastInst * + create(SILDebugLocation DebugLoc, SILValue Operand, SILType Ty, + SILFunction &F, SILOpenedArchetypesState &OpenedArchetypes); +}; + /// Build a Builtin.BridgeObject from a heap object reference by bitwise-or-ing /// in bits from a word. class RefToBridgeObjectInst diff --git a/include/swift/SIL/SILNodes.def b/include/swift/SIL/SILNodes.def index 352589fde1a81..d26054e66aed4 100644 --- a/include/swift/SIL/SILNodes.def +++ b/include/swift/SIL/SILNodes.def @@ -500,6 +500,8 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction) ConversionInst, None, DoesNotRelease) SINGLE_VALUE_INST(UncheckedBitwiseCastInst, unchecked_bitwise_cast, ConversionInst, None, DoesNotRelease) + SINGLE_VALUE_INST(UncheckedValueCastInst, unchecked_value_cast, + ConversionInst, None, DoesNotRelease) SINGLE_VALUE_INST(RefToRawPointerInst, ref_to_raw_pointer, ConversionInst, None, DoesNotRelease) SINGLE_VALUE_INST(RawPointerToRefInst, raw_pointer_to_ref, diff --git a/include/swift/SILOptimizer/Differentiation/ADContext.h b/include/swift/SILOptimizer/Differentiation/ADContext.h index a5bf5e86f1505..4fc2656ad404e 100644 --- a/include/swift/SILOptimizer/Differentiation/ADContext.h +++ b/include/swift/SILOptimizer/Differentiation/ADContext.h @@ -357,7 +357,7 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, return diagnose(loc, diag::autodiff_when_differentiating_function_call); } } - llvm_unreachable("invalid invoker"); + llvm_unreachable("Invalid invoker kind"); // silences MSVC C4715 } } // end namespace autodiff diff --git a/include/swift/SILOptimizer/Differentiation/JVPCloner.h b/include/swift/SILOptimizer/Differentiation/JVPCloner.h index 3e0ba11e5beb6..5a07003bc4f33 100644 --- a/include/swift/SILOptimizer/Differentiation/JVPCloner.h +++ b/include/swift/SILOptimizer/Differentiation/JVPCloner.h @@ -18,394 +18,35 @@ #ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_JVPCLONER_H #define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_JVPCLONER_H -#include "swift/SILOptimizer/Differentiation/AdjointValue.h" #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" -#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" - -#include "swift/SIL/SILValue.h" -#include "swift/SIL/TypeSubstCloner.h" -#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" -#include "llvm/ADT/DenseMap.h" namespace swift { -class SILArgument; -class SILDifferentiabilityWitness; -class SILBasicBlock; class SILFunction; -class SILInstruction; -class SILOptFunctionBuilder; +class SILDifferentiabilityWitness; namespace autodiff { class ADContext; -class JVPCloner final - : public TypeSubstCloner { -private: - /// The global context. - ADContext &context; - - /// The original function. - SILFunction *const original; - - /// The witness. - SILDifferentiabilityWitness *const witness; - - /// The JVP function. - SILFunction *const jvp; - - llvm::BumpPtrAllocator allocator; - - /// The differentiation invoker. - DifferentiationInvoker invoker; - - /// Info from activity analysis on the original function. - const DifferentiableActivityInfo &activityInfo; - - /// The differential info. - LinearMapInfo differentialInfo; - - bool errorOccurred = false; - - //--------------------------------------------------------------------------// - // Differential generation related fields - //--------------------------------------------------------------------------// - - /// The builder for the differential function. - SILBuilder differentialBuilder; - - /// Mapping from original basic blocks to corresponding differential basic - /// blocks. - llvm::DenseMap diffBBMap; - - /// Mapping from original basic blocks and original values to corresponding - /// tangent values. - llvm::DenseMap tangentValueMap; - - /// Mapping from original basic blocks and original buffers to corresponding - /// tangent buffers. - llvm::DenseMap, SILValue> bufferMap; - - /// Mapping from differential basic blocks to differential struct arguments. - llvm::DenseMap differentialStructArguments; - - /// Mapping from differential struct field declarations to differential struct - /// elements destructured from the linear map basic block argument. In the - /// beginning of each differential basic block, the block's differential - /// struct is destructured into the individual elements stored here. - llvm::DenseMap differentialStructElements; - - /// An auxiliary differential local allocation builder. - SILBuilder diffLocalAllocBuilder; - - /// Stack buffers allocated for storing local tangent values. - SmallVector differentialLocalAllocations; - - /// Mapping from original blocks to differential values. Used to build - /// differential struct instances. - llvm::DenseMap> differentialValues; - - //--------------------------------------------------------------------------// - // Getters - //--------------------------------------------------------------------------// - - ASTContext &getASTContext() const { return jvp->getASTContext(); } - SILModule &getModule() const { return jvp->getModule(); } - const SILAutoDiffIndices getIndices() const { - return witness->getSILAutoDiffIndices(); - } - SILBuilder &getDifferentialBuilder() { return differentialBuilder; } - SILFunction &getDifferential() { return differentialBuilder.getFunction(); } - SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) { -#ifndef NDEBUG - auto *diffStruct = differentialStructArguments[origBB] - ->getType() - .getStructOrBoundGenericStruct(); - assert(diffStruct == differentialInfo.getLinearMapStruct(origBB)); -#endif - return differentialStructArguments[origBB]; - } - - //--------------------------------------------------------------------------// - // Initialization helpers - //--------------------------------------------------------------------------// - - static SubstitutionMap getSubstitutionMap(SILFunction *original, - SILFunction *jvp); - - /// Returns the activity info about the SILValues in the original function. - static const DifferentiableActivityInfo & - getActivityInfo(ADContext &context, SILFunction *original, - SILAutoDiffIndices indices, SILFunction *jvp); - - //--------------------------------------------------------------------------// - // Differential struct mapping - //--------------------------------------------------------------------------// - - void initializeDifferentialStructElements(SILBasicBlock *origBB, - SILInstructionResultArray values); - - SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field); - - //--------------------------------------------------------------------------// - // General utilities - //--------------------------------------------------------------------------// - - SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint(); - - /// Get the lowered SIL type of the given AST type. - SILType getLoweredType(Type type); - - /// Get the lowered SIL type of the given nominal type declaration. - SILType getNominalDeclLoweredType(NominalTypeDecl *nominal); - - /// Build a differential struct value for the original block corresponding to - /// the given terminator. - StructInst *buildDifferentialValueStructValue(TermInst *termInst); - - //--------------------------------------------------------------------------// - // Tangent value factory methods - //--------------------------------------------------------------------------// - - AdjointValue makeZeroTangentValue(SILType type); - - AdjointValue makeConcreteTangentValue(SILValue value); - - //--------------------------------------------------------------------------// - // Tangent materialization - //--------------------------------------------------------------------------// - - void emitZeroIndirect(CanType type, SILValue bufferAccess, SILLocation loc); - - SILValue emitZeroDirect(CanType type, SILLocation loc); - - SILValue materializeTangentDirect(AdjointValue val, SILLocation loc); - - SILValue materializeTangent(AdjointValue val, SILLocation loc); - - //--------------------------------------------------------------------------// - // Tangent buffer mapping - //--------------------------------------------------------------------------// - - void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, - SILValue tangentBuffer); - - SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer); - - //--------------------------------------------------------------------------// - // Differential type calculations - //--------------------------------------------------------------------------// +/// A helper class for generating JVP functions. +class JVPCloner final { + class Implementation; + Implementation &impl; - /// Substitutes all replacement types of the given substitution map using the - /// tangent function's substitution map. - SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap); - - /// Remap any archetypes into the differential function's context. - Type remapTypeInDifferential(Type ty); - - /// Remap any archetypes into the differential function's context. - SILType remapSILTypeInDifferential(SILType ty); - - /// Find the tangent space of a given canonical type. - Optional getTangentSpace(CanType type); - - /// Assuming the given type conforms to `Differentiable` after remapping, - /// returns the associated tangent space SIL type. - SILType getRemappedTangentType(SILType type); - - //--------------------------------------------------------------------------// - // Tangent value mapping - //--------------------------------------------------------------------------// - - /// Get the tangent for an original value. The given value must be in the - /// original function. - /// - /// This method first tries to find an entry in `tangentValueMap`. If an entry - /// doesn't exist, create a zero tangent. - AdjointValue getTangentValue(SILValue originalValue); - - /// Map the tangent value to the given original value. - void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, - AdjointValue newTangentValue); - - //--------------------------------------------------------------------------// - // Tangent emission helpers - //--------------------------------------------------------------------------// public: -#define CLONE_AND_EMIT_TANGENT(INST, ID) \ - void visit##INST##Inst(INST##Inst *inst); \ - void emitTangentFor##INST##Inst(INST##Inst *(ID)) - - CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi); - CLONE_AND_EMIT_TANGENT(EndBorrow, ebi); - CLONE_AND_EMIT_TANGENT(DestroyValue, dvi); - CLONE_AND_EMIT_TANGENT(CopyValue, cvi); - - /// Handle `load` instruction. - /// Original: y = load x - /// Tangent: tan[y] = load tan[x] - CLONE_AND_EMIT_TANGENT(Load, li); - - /// Handle `load_borrow` instruction. - /// Original: y = load_borrow x - /// Tangent: tan[y] = load_borrow tan[x] - CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi); - - /// Handle `store` instruction in the differential. - /// Original: store x to y - /// Tangent: store tan[x] to tan[y] - CLONE_AND_EMIT_TANGENT(Store, si); - - /// Handle `store_borrow` instruction in the differential. - /// Original: store_borrow x to y - /// Tangent: store_borrow tan[x] to tan[y] - CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi); - - /// Handle `copy_addr` instruction. - /// Original: copy_addr x to y - /// Tangent: copy_addr tan[x] to tan[y] - CLONE_AND_EMIT_TANGENT(CopyAddr, cai); - - /// Handle `unconditional_checked_cast_addr` instruction. - /// Original: unconditional_checked_cast_addr $X in x to $Y in y - /// Tangent: unconditional_checked_cast_addr $X.Tan in tan[x] - /// to $Y.Tan in tan[y] - CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai); - - /// Handle `begin_access` instruction (and do differentiability checks). - /// Original: y = begin_access x - /// Tangent: tan[y] = begin_access tan[x] - CLONE_AND_EMIT_TANGENT(BeginAccess, bai); - - /// Handle `end_access` instruction. - /// Original: begin_access x - /// Tangent: end_access tan[x] - CLONE_AND_EMIT_TANGENT(EndAccess, eai); - - /// Handle `alloc_stack` instruction. - /// Original: y = alloc_stack $T - /// Tangent: tan[y] = alloc_stack $T.Tangent - CLONE_AND_EMIT_TANGENT(AllocStack, asi); - - /// Handle `dealloc_stack` instruction. - /// Original: dealloc_stack x - /// Tangent: dealloc_stack tan[x] - CLONE_AND_EMIT_TANGENT(DeallocStack, dsi); - - /// Handle `destroy_addr` instruction. - /// Original: destroy_addr x - /// Tangent: destroy_addr tan[x] - CLONE_AND_EMIT_TANGENT(DestroyAddr, dai); - - /// Handle `struct` instruction. - /// Original: y = struct $T (x0, x1, x2, ...) - /// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) - CLONE_AND_EMIT_TANGENT(Struct, si); - - /// Handle `struct_extract` instruction. - /// Original: y = struct_extract x, #field - /// Tangent: tan[y] = struct_extract tan[x], #field' - /// ^~~~~~~ - /// field in tangent space corresponding to #field - CLONE_AND_EMIT_TANGENT(StructExtract, sei); - - /// Handle `struct_element_addr` instruction. - /// Original: y = struct_element_addr x, #field - /// Tangent: tan[y] = struct_element_addr tan[x], #field' - /// ^~~~~~~ - /// field in tangent space corresponding to #field - CLONE_AND_EMIT_TANGENT(StructElementAddr, seai); - - /// Handle `tuple` instruction. - /// Original: y = tuple (x0, x1, x2, ...) - /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) - /// ^~~ - /// excluding non-differentiable elements - CLONE_AND_EMIT_TANGENT(Tuple, ti); - - /// Handle `tuple_extract` instruction. - /// Original: y = tuple_extract x, - /// Tangent: tan[y] = tuple_extract tan[x], - /// ^~~~ - /// tuple tangent space index corresponding to n - CLONE_AND_EMIT_TANGENT(TupleExtract, tei); - - /// Handle `tuple_element_addr` instruction. - /// Original: y = tuple_element_addr x, - /// Tangent: tan[y] = tuple_element_addr tan[x], - /// ^~~~ - /// tuple tangent space index corresponding to n - CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai); - - /// Handle `destructure_tuple` instruction. - /// Original: (y0, y1, ...) = destructure_tuple x, - /// Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], - /// ^~~~ - /// tuple tangent space index corresponding to n - CLONE_AND_EMIT_TANGENT(DestructureTuple, dti); - -#undef CLONE_AND_EMIT_TANGENT - - /// Handle `apply` instruction, given: - /// - The minimal indices for differentiating the `apply`. - /// - The original non-reabstracted differential type. + /// Creates a JVP cloner. /// - /// Original: y = apply f(x0, x1, ...) - /// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...) - void emitTangentForApplyInst(ApplyInst *ai, SILAutoDiffIndices applyIndices, - CanSILFunctionType originalDifferentialType); - - /// Generate a `return` instruction in the current differential basic block. - void emitReturnInstForDifferential(); - -private: - /// Set up the differential function. This includes: - /// - Creating all differential blocks. - /// - Creating differential entry block arguments based on the function type. - /// - Creating tangent value mapping for original/differential parameters. - /// - Checking for unvaried result and emitting related warnings. - void prepareForDifferentialGeneration(); - -public: + /// The parent JVP cloner stores the original function and an empty + /// to-be-generated pullback function. explicit JVPCloner(ADContext &context, SILFunction *original, SILDifferentiabilityWitness *witness, SILFunction *jvp, DifferentiationInvoker invoker); + ~JVPCloner(); - static SILFunction * - createEmptyDifferential(ADContext &context, - SILDifferentiabilityWitness *witness, - LinearMapInfo *linearMapInfo); - - /// Run JVP generation. Returns true on error. + /// Performs JVP generation on the empty JVP function. Returns true if any + /// error occurs. bool run(); - - void postProcess(SILInstruction *orig, SILInstruction *cloned); - - /// Remap original basic blocks. - SILBasicBlock *remapBasicBlock(SILBasicBlock *bb); - - /// General visitor for all instructions. If any error is emitted by previous - /// visits, bail out. - void visit(SILInstruction *inst); - - void visitSILInstruction(SILInstruction *inst); - - void visitInstructionsInBlock(SILBasicBlock *bb); - - // If an `apply` has active results or active inout parameters, replace it - // with an `apply` of its JVP. - void visitApplyInst(ApplyInst *ai); - - void visitReturnInst(ReturnInst *ri); - - void visitBranchInst(BranchInst *bi); - - void visitCondBranchInst(CondBranchInst *cbi); - - void visitSwitchEnumInst(SwitchEnumInst *sei); - - void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi); }; } // end namespace autodiff diff --git a/include/swift/SILOptimizer/Differentiation/PullbackCloner.h b/include/swift/SILOptimizer/Differentiation/PullbackCloner.h index 37d666e16463c..e6caa786ef04e 100644 --- a/include/swift/SILOptimizer/Differentiation/PullbackCloner.h +++ b/include/swift/SILOptimizer/Differentiation/PullbackCloner.h @@ -18,29 +18,14 @@ #ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKCLONER_H #define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_PULLBACKCLONER_H -#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" -#include "swift/SILOptimizer/Differentiation/AdjointValue.h" -#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" -#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" - -#include "swift/SIL/TypeSubstCloner.h" -#include "llvm/ADT/DenseMap.h" - namespace swift { - -class SILDifferentiabilityWitness; -class SILBasicBlock; -class SILFunction; -class SILInstruction; - namespace autodiff { -class ADContext; class VJPCloner; /// A helper class for generating pullback functions. class PullbackCloner final { - struct Implementation; + class Implementation; Implementation &impl; public: diff --git a/include/swift/SILOptimizer/Differentiation/VJPCloner.h b/include/swift/SILOptimizer/Differentiation/VJPCloner.h index 346bf2f490c80..ff8ce874021e6 100644 --- a/include/swift/SILOptimizer/Differentiation/VJPCloner.h +++ b/include/swift/SILOptimizer/Differentiation/VJPCloner.h @@ -22,149 +22,41 @@ #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" -#include "swift/SIL/TypeSubstCloner.h" -#include "llvm/ADT/DenseMap.h" - namespace swift { - -class SILDifferentiabilityWitness; -class SILBasicBlock; -class SILFunction; -class SILInstruction; - namespace autodiff { class ADContext; class PullbackCloner; -class VJPCloner final - : public TypeSubstCloner { - friend class PullbackCloner; - -private: - /// The global context. - ADContext &context; - - /// The original function. - SILFunction *const original; - - /// The differentiability witness. - SILDifferentiabilityWitness *const witness; - - /// The VJP function. - SILFunction *const vjp; - - /// The pullback function. - SILFunction *pullback; - - /// The differentiation invoker. - DifferentiationInvoker invoker; - - /// Info from activity analysis on the original function. - const DifferentiableActivityInfo &activityInfo; - - /// The linear map info. - LinearMapInfo pullbackInfo; - - /// Caches basic blocks whose phi arguments have been remapped (adding a - /// predecessor enum argument). - SmallPtrSet remappedBasicBlocks; - - bool errorOccurred = false; - - /// Mapping from original blocks to pullback values. Used to build pullback - /// struct instances. - llvm::DenseMap> pullbackValues; - - ASTContext &getASTContext() const { return vjp->getASTContext(); } - SILModule &getModule() const { return vjp->getModule(); } - const SILAutoDiffIndices getIndices() const { - return witness->getSILAutoDiffIndices(); - } - - static SubstitutionMap getSubstitutionMap(SILFunction *original, - SILFunction *vjp); - - static const DifferentiableActivityInfo & - getActivityInfo(ADContext &context, SILFunction *original, - SILAutoDiffIndices indices, SILFunction *vjp); +/// A helper class for generating VJP functions. +class VJPCloner final { + class Implementation; + Implementation &impl; public: + /// Creates a VJP cloner. + /// + /// The parent VJP cloner stores the original function and an empty + /// to-be-generated pullback function. explicit VJPCloner(ADContext &context, SILFunction *original, SILDifferentiabilityWitness *witness, SILFunction *vjp, DifferentiationInvoker invoker); - - SILFunction *createEmptyPullback(); - - /// Run VJP generation. Returns true on error. + ~VJPCloner(); + + ADContext &getContext() const; + SILModule &getModule() const; + SILFunction &getOriginal() const; + SILFunction &getVJP() const; + SILFunction &getPullback() const; + SILDifferentiabilityWitness *getWitness() const; + const SILAutoDiffIndices getIndices() const; + DifferentiationInvoker getInvoker() const; + LinearMapInfo &getPullbackInfo() const; + const DifferentiableActivityInfo &getActivityInfo() const; + + /// Performs VJP generation on the empty VJP function. Returns true if any + /// error occurs. bool run(); - - void postProcess(SILInstruction *orig, SILInstruction *cloned); - - /// Remap original basic blocks, adding predecessor enum arguments. - SILBasicBlock *remapBasicBlock(SILBasicBlock *bb); - - /// General visitor for all instructions. If any error is emitted by previous - /// visits, bail out. - void visit(SILInstruction *inst); - - void visitSILInstruction(SILInstruction *inst); - -private: - /// Get the lowered SIL type of the given AST type. - SILType getLoweredType(Type type); - - /// Get the lowered SIL type of the given nominal type declaration. - SILType getNominalDeclLoweredType(NominalTypeDecl *nominal); - - // Creates a trampoline block for given original terminator instruction, the - // pullback struct value for its parent block, and a successor basic block. - // - // The trampoline block has the same arguments as and branches to the remapped - // successor block, but drops the last predecessor enum argument. - // - // Used for cloning branching terminator instructions with specific - // requirements on successor block arguments, where an additional predecessor - // enum argument is not acceptable. - SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst, - StructInst *pbStructVal, - SILBasicBlock *succBB); - - /// Build a pullback struct value for the given original terminator - /// instruction. - StructInst *buildPullbackValueStructValue(TermInst *termInst); - - /// Build a predecessor enum instance using the given builder for the given - /// original predecessor/successor blocks and pullback struct value. - EnumInst *buildPredecessorEnumValue(SILBuilder &builder, - SILBasicBlock *predBB, - SILBasicBlock *succBB, - SILValue pbStructVal); - -public: - void visitReturnInst(ReturnInst *ri); - - void visitBranchInst(BranchInst *bi); - - void visitCondBranchInst(CondBranchInst *cbi); - - void visitSwitchEnumInstBase(SwitchEnumInstBase *inst); - - void visitSwitchEnumInst(SwitchEnumInst *sei); - - void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai); - - void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi); - - void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi); - - void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi); - - // If an `apply` has active results or active inout arguments, replace it - // with an `apply` of its VJP. - void visitApplyInst(ApplyInst *ai); - - void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi); }; } // end namespace autodiff diff --git a/include/swift/SILOptimizer/Utils/CFGOptUtils.h b/include/swift/SILOptimizer/Utils/CFGOptUtils.h index 120d19ec9eb9e..35de0e33cc6f8 100644 --- a/include/swift/SILOptimizer/Utils/CFGOptUtils.h +++ b/include/swift/SILOptimizer/Utils/CFGOptUtils.h @@ -107,6 +107,10 @@ bool splitCriticalEdgesFrom(SILBasicBlock *fromBB, DominanceInfo *domInfo = nullptr, SILLoopInfo *loopInfo = nullptr); +/// Splits all critical edges that have `toBB` as a destination. +bool splitCriticalEdgesTo(SILBasicBlock *toBB, DominanceInfo *domInfo = nullptr, + SILLoopInfo *loopInfo = nullptr); + /// Splits the edges between two basic blocks. /// /// Updates dominance information and loop information if not null. diff --git a/lib/FrontendTool/FrontendTool.cpp b/lib/FrontendTool/FrontendTool.cpp index da9d2c7bc740f..06347e600b319 100644 --- a/lib/FrontendTool/FrontendTool.cpp +++ b/lib/FrontendTool/FrontendTool.cpp @@ -427,6 +427,7 @@ static bool emitLoadedModuleTraceIfNeeded(ModuleDecl *mainModule, ModuleDecl::ImportFilter filter = ModuleDecl::ImportFilterKind::Public; filter |= ModuleDecl::ImportFilterKind::Private; filter |= ModuleDecl::ImportFilterKind::ImplementationOnly; + filter |= ModuleDecl::ImportFilterKind::SPIAccessControl; filter |= ModuleDecl::ImportFilterKind::ShadowedBySeparateOverlay; SmallVector imports; mainModule->getImportedModules(imports, filter); diff --git a/lib/IDE/CompletionInstance.cpp b/lib/IDE/CompletionInstance.cpp index 95fb4a2e4efda..78b8c4eabf967 100644 --- a/lib/IDE/CompletionInstance.cpp +++ b/lib/IDE/CompletionInstance.cpp @@ -89,6 +89,7 @@ template Decl *getElementAt(const Range &Decls, unsigned N) { /// config block, this function returns \c false. static DeclContext *getEquivalentDeclContextFromSourceFile(DeclContext *DC, SourceFile *SF) { + PrettyStackTraceDeclContext trace("getting equivalent decl context for", DC); auto *newDC = DC; // NOTE: Shortcut for DC->getParentSourceFile() == SF case is not needed // because they should be always different. @@ -115,13 +116,17 @@ static DeclContext *getEquivalentDeclContextFromSourceFile(DeclContext *DC, D = storage; } - if (auto parentSF = dyn_cast(parentDC)) + if (auto parentSF = dyn_cast(parentDC)) { N = findIndexInRange(D, parentSF->getTopLevelDecls()); - else if (auto parentIDC = - dyn_cast(parentDC->getAsDecl())) + } else if (auto parentIDC = dyn_cast_or_null( + parentDC->getAsDecl())) { N = findIndexInRange(D, parentIDC->getMembers()); - else + } else { +#ifndef NDEBUG llvm_unreachable("invalid DC kind for finding equivalent DC (indexpath)"); +#endif + return nullptr; + } // Not found in the decl context tree. // FIXME: Probably DC is in an inactive #if block. diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index 8f489458e455d..ad2d0dfc79c44 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -1026,6 +1026,9 @@ class IRGenSILFunction : void visitUncheckedAddrCastInst(UncheckedAddrCastInst *i); void visitUncheckedTrivialBitCastInst(UncheckedTrivialBitCastInst *i); void visitUncheckedBitwiseCastInst(UncheckedBitwiseCastInst *i); + void visitUncheckedValueCastInst(UncheckedValueCastInst *i) { + llvm_unreachable("Should never be seen in Lowered SIL"); + } void visitRefToRawPointerInst(RefToRawPointerInst *i); void visitRawPointerToRefInst(RawPointerToRefInst *i); void visitThinToThickFunctionInst(ThinToThickFunctionInst *i); diff --git a/lib/IRGen/LoadableByAddress.cpp b/lib/IRGen/LoadableByAddress.cpp index 01abf45d6e0e7..8ec81868638ce 100644 --- a/lib/IRGen/LoadableByAddress.cpp +++ b/lib/IRGen/LoadableByAddress.cpp @@ -1544,7 +1544,7 @@ void LoadableStorageAllocation:: storageType); if (pass.containsDifferentFunctionSignature(pass.F->getLoweredFunctionType(), storageType)) { - auto *castInstr = argBuilder.createUncheckedBitCast( + auto *castInstr = argBuilder.createUncheckedReinterpretCast( RegularLocation(const_cast(arg->getDecl())), arg, newSILType); arg->replaceAllUsesWith(castInstr); @@ -1912,8 +1912,8 @@ static void castTupleInstr(SingleValueInstruction *instr, IRGenModule &Mod, switch (instr->getKind()) { // Add cast to the new sil function type: case SILInstructionKind::TupleExtractInst: { - castInstr = castBuilder.createUncheckedBitCast(instr->getLoc(), instr, - newSILType.getObjectType()); + castInstr = castBuilder.createUncheckedReinterpretCast( + instr->getLoc(), instr, newSILType.getObjectType()); break; } case SILInstructionKind::TupleElementAddrInst: { @@ -2471,8 +2471,8 @@ getOperandTypeWithCastIfNecessary(SILInstruction *containingInstr, SILValue op, } assert(currSILType.isObject() && "Expected an object type"); if (newSILType != currSILType) { - auto castInstr = builder.createUncheckedBitCast(containingInstr->getLoc(), - op, newSILType); + auto castInstr = builder.createUncheckedReinterpretCast( + containingInstr->getLoc(), op, newSILType); return castInstr; } } @@ -2653,8 +2653,8 @@ bool LoadableByAddress::recreateUncheckedEnumDataInstr( auto *takeEnum = enumBuilder.createUncheckedEnumData( enumInstr->getLoc(), enumInstr->getOperand(), enumInstr->getElement(), caseTy); - newInstr = enumBuilder.createUncheckedBitCast(enumInstr->getLoc(), takeEnum, - newType); + newInstr = enumBuilder.createUncheckedReinterpretCast(enumInstr->getLoc(), + takeEnum, newType); } else { newInstr = enumBuilder.createUncheckedEnumData( enumInstr->getLoc(), enumInstr->getOperand(), enumInstr->getElement(), @@ -2708,7 +2708,7 @@ bool LoadableByAddress::fixStoreToBlockStorageInstr( if (destType.getObjectType() != srcType) { // Add cast to destType SILBuilderWithScope castBuilder(instr); - auto *castInstr = castBuilder.createUncheckedBitCast( + auto *castInstr = castBuilder.createUncheckedReinterpretCast( instr->getLoc(), src, destType.getObjectType()); instr->setOperand(StoreInst::Src, castInstr); } diff --git a/lib/SIL/IR/OperandOwnership.cpp b/lib/SIL/IR/OperandOwnership.cpp index 0246ff0d18219..26aecc97afdf5 100644 --- a/lib/SIL/IR/OperandOwnership.cpp +++ b/lib/SIL/IR/OperandOwnership.cpp @@ -350,6 +350,7 @@ FORWARD_ANY_OWNERSHIP_INST(DestructureTuple) FORWARD_ANY_OWNERSHIP_INST(InitExistentialRef) FORWARD_ANY_OWNERSHIP_INST(DifferentiableFunction) FORWARD_ANY_OWNERSHIP_INST(LinearFunction) +FORWARD_ANY_OWNERSHIP_INST(UncheckedValueCast) #undef FORWARD_ANY_OWNERSHIP_INST // An instruction that forwards a constant ownership or trivial ownership. diff --git a/lib/SIL/IR/SILBuilder.cpp b/lib/SIL/IR/SILBuilder.cpp index 37b3ce61d0d79..037bfc9b3e6ec 100644 --- a/lib/SIL/IR/SILBuilder.cpp +++ b/lib/SIL/IR/SILBuilder.cpp @@ -141,7 +141,8 @@ SILBuilder::createClassifyBridgeObject(SILLocation Loc, SILValue value) { // Create the appropriate cast instruction based on result type. SingleValueInstruction * -SILBuilder::createUncheckedBitCast(SILLocation Loc, SILValue Op, SILType Ty) { +SILBuilder::createUncheckedReinterpretCast(SILLocation Loc, SILValue Op, + SILType Ty) { assert(isLoadableOrOpaque(Ty)); if (Ty.isTrivial(getFunction())) return insert(UncheckedTrivialBitCastInst::create( @@ -156,6 +157,26 @@ SILBuilder::createUncheckedBitCast(SILLocation Loc, SILValue Op, SILType Ty) { getSILDebugLocation(Loc), Op, Ty, getFunction(), C.OpenedArchetypes)); } +// Create the appropriate cast instruction based on result type. +SingleValueInstruction * +SILBuilder::createUncheckedBitCast(SILLocation Loc, SILValue Op, SILType Ty) { + // Without ownership, delegate to unchecked reinterpret cast. + if (!hasOwnership()) + return createUncheckedReinterpretCast(Loc, Op, Ty); + + assert(isLoadableOrOpaque(Ty)); + if (Ty.isTrivial(getFunction())) + return insert(UncheckedTrivialBitCastInst::create( + getSILDebugLocation(Loc), Op, Ty, getFunction(), C.OpenedArchetypes)); + + if (SILType::canRefCast(Op->getType(), Ty, getModule())) + return createUncheckedRefCast(Loc, Op, Ty); + + // The destination type is nontrivial, and may be smaller than the source + // type, so RC identity cannot be assumed. + return createUncheckedValueCast(Loc, Op, Ty); +} + BranchInst *SILBuilder::createBranch(SILLocation Loc, SILBasicBlock *TargetBlock, OperandValueArrayRef Args) { diff --git a/lib/SIL/IR/SILInstructions.cpp b/lib/SIL/IR/SILInstructions.cpp index 8ad354575fca6..85e21ea609f21 100644 --- a/lib/SIL/IR/SILInstructions.cpp +++ b/lib/SIL/IR/SILInstructions.cpp @@ -2217,6 +2217,21 @@ UncheckedRefCastInst::create(SILDebugLocation DebugLoc, SILValue Operand, TypeDependentOperands, Ty); } +UncheckedValueCastInst * +UncheckedValueCastInst::create(SILDebugLocation DebugLoc, SILValue Operand, + SILType Ty, SILFunction &F, + SILOpenedArchetypesState &OpenedArchetypes) { + SILModule &Mod = F.getModule(); + SmallVector TypeDependentOperands; + collectTypeDependentOperands(TypeDependentOperands, OpenedArchetypes, F, + Ty.getASTType()); + unsigned size = + totalSizeToAlloc(1 + TypeDependentOperands.size()); + void *Buffer = Mod.allocateInst(size, alignof(UncheckedValueCastInst)); + return ::new (Buffer) + UncheckedValueCastInst(DebugLoc, Operand, TypeDependentOperands, Ty); +} + UncheckedAddrCastInst * UncheckedAddrCastInst::create(SILDebugLocation DebugLoc, SILValue Operand, SILType Ty, SILFunction &F, diff --git a/lib/SIL/IR/SILPrinter.cpp b/lib/SIL/IR/SILPrinter.cpp index 763c0122779b0..2225c047abd3a 100644 --- a/lib/SIL/IR/SILPrinter.cpp +++ b/lib/SIL/IR/SILPrinter.cpp @@ -1620,6 +1620,9 @@ class SILPrinter : public SILInstructionVisitor { void visitUncheckedBitwiseCastInst(UncheckedBitwiseCastInst *CI) { printUncheckedConversionInst(CI, CI->getOperand()); } + void visitUncheckedValueCastInst(UncheckedValueCastInst *CI) { + printUncheckedConversionInst(CI, CI->getOperand()); + } void visitRefToRawPointerInst(RefToRawPointerInst *CI) { printUncheckedConversionInst(CI, CI->getOperand()); } diff --git a/lib/SIL/IR/TypeLowering.cpp b/lib/SIL/IR/TypeLowering.cpp index 06a4273f92dc1..2d5535bb421ef 100644 --- a/lib/SIL/IR/TypeLowering.cpp +++ b/lib/SIL/IR/TypeLowering.cpp @@ -1453,6 +1453,7 @@ namespace { if (D->isCxxNotTriviallyCopyable()) { properties.setAddressOnly(); + properties.setNonTrivial(); } auto subMap = structType->getContextSubstitutionMap(&TC.M, D); diff --git a/lib/SIL/IR/ValueOwnership.cpp b/lib/SIL/IR/ValueOwnership.cpp index 4acf9d278c6f5..be815d3cd89de 100644 --- a/lib/SIL/IR/ValueOwnership.cpp +++ b/lib/SIL/IR/ValueOwnership.cpp @@ -254,6 +254,7 @@ FORWARDING_OWNERSHIP_INST(Tuple) FORWARDING_OWNERSHIP_INST(UncheckedRefCast) FORWARDING_OWNERSHIP_INST(UnconditionalCheckedCast) FORWARDING_OWNERSHIP_INST(Upcast) +FORWARDING_OWNERSHIP_INST(UncheckedValueCast) FORWARDING_OWNERSHIP_INST(UncheckedEnumData) FORWARDING_OWNERSHIP_INST(SelectEnum) FORWARDING_OWNERSHIP_INST(Enum) diff --git a/lib/SIL/Parser/ParseSIL.cpp b/lib/SIL/Parser/ParseSIL.cpp index 8a9df9d667734..b237e70782399 100644 --- a/lib/SIL/Parser/ParseSIL.cpp +++ b/lib/SIL/Parser/ParseSIL.cpp @@ -3238,6 +3238,7 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B, case SILInstructionKind::UncheckedAddrCastInst: case SILInstructionKind::UncheckedTrivialBitCastInst: case SILInstructionKind::UncheckedBitwiseCastInst: + case SILInstructionKind::UncheckedValueCastInst: case SILInstructionKind::UpcastInst: case SILInstructionKind::AddressToPointerInst: case SILInstructionKind::BridgeObjectToRefInst: @@ -3307,6 +3308,9 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B, case SILInstructionKind::UncheckedBitwiseCastInst: ResultVal = B.createUncheckedBitwiseCast(InstLoc, Val, Ty); break; + case SILInstructionKind::UncheckedValueCastInst: + ResultVal = B.createUncheckedValueCast(InstLoc, Val, Ty); + break; case SILInstructionKind::UpcastInst: ResultVal = B.createUpcast(InstLoc, Val, Ty); break; diff --git a/lib/SIL/Utils/OwnershipUtils.cpp b/lib/SIL/Utils/OwnershipUtils.cpp index d696a9b1db227..8ece70a783b50 100644 --- a/lib/SIL/Utils/OwnershipUtils.cpp +++ b/lib/SIL/Utils/OwnershipUtils.cpp @@ -34,6 +34,7 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) { case SILNodeKind::LinearFunctionInst: case SILNodeKind::OpenExistentialRefInst: case SILNodeKind::UpcastInst: + case SILNodeKind::UncheckedValueCastInst: case SILNodeKind::UncheckedRefCastInst: case SILNodeKind::ConvertFunctionInst: case SILNodeKind::RefToBridgeObjectInst: diff --git a/lib/SIL/Verifier/SILVerifier.cpp b/lib/SIL/Verifier/SILVerifier.cpp index cca9546938961..c8ad63ff8be84 100644 --- a/lib/SIL/Verifier/SILVerifier.cpp +++ b/lib/SIL/Verifier/SILVerifier.cpp @@ -1904,6 +1904,12 @@ class SILVerifier : public SILVerifierBase { "Inst with qualified ownership in a function that is not qualified"); } + void checkUncheckedValueCastInst(UncheckedValueCastInst *) { + require( + F.hasOwnership(), + "Inst with qualified ownership in a function that is not qualified"); + } + template void checkAccessEnforcement(AI *AccessInst) { if (AccessInst->getModule().getStage() != SILStage::Raw) { diff --git a/lib/SILGen/SILGenApply.cpp b/lib/SILGen/SILGenApply.cpp index 5829066d560c6..255bdeb4bafa9 100644 --- a/lib/SILGen/SILGenApply.cpp +++ b/lib/SILGen/SILGenApply.cpp @@ -1340,8 +1340,8 @@ class SILGenApply : public Lowering::ExprVisitor { && loweredResultTy.getASTType()->hasDynamicSelfType()) { assert(selfMetaTy); selfValue = SGF.emitManagedRValueWithCleanup( - SGF.B.createUncheckedBitCast(loc, selfValue.forward(SGF), - loweredResultTy)); + SGF.B.createUncheckedReinterpretCast(loc, selfValue.forward(SGF), + loweredResultTy)); } else { selfValue = SGF.emitManagedRValueWithCleanup( SGF.B.createUpcast(loc, selfValue.forward(SGF), loweredResultTy)); diff --git a/lib/SILGen/SILGenBuilder.cpp b/lib/SILGen/SILGenBuilder.cpp index bf0e64a90cc79..b5ab0668810a6 100644 --- a/lib/SILGen/SILGenBuilder.cpp +++ b/lib/SILGen/SILGenBuilder.cpp @@ -610,7 +610,7 @@ ManagedValue SILGenBuilder::createUncheckedBitCast(SILLocation loc, ManagedValue value, SILType type) { CleanupCloner cloner(*this, value); - SILValue cast = createUncheckedBitCast(loc, value.getValue(), type); + SILValue cast = createUncheckedReinterpretCast(loc, value.getValue(), type); // Currently createUncheckedBitCast only produces these // instructions. We assert here to make sure if this changes, this code is diff --git a/lib/SILGen/SILGenBuilder.h b/lib/SILGen/SILGenBuilder.h index b5b6c3567d213..1267b37d9e8af 100644 --- a/lib/SILGen/SILGenBuilder.h +++ b/lib/SILGen/SILGenBuilder.h @@ -279,7 +279,7 @@ class SILGenBuilder : public SILBuilder { ManagedValue createUncheckedAddrCast(SILLocation loc, ManagedValue op, SILType resultTy); - using SILBuilder::createUncheckedBitCast; + using SILBuilder::createUncheckedReinterpretCast; ManagedValue createUncheckedBitCast(SILLocation loc, ManagedValue original, SILType type); diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp index dfd28c4511918..1523b5fc9b737 100644 --- a/lib/SILGen/SILGenExpr.cpp +++ b/lib/SILGen/SILGenExpr.cpp @@ -2066,8 +2066,8 @@ RValue RValueEmitter::visitUnderlyingToOpaqueExpr(UnderlyingToOpaqueExpr *E, if (value.getType() == opaqueTL.getLoweredType()) return RValue(SGF, E, value); - auto cast = SGF.B.createUncheckedBitCast(E, value.forward(SGF), - opaqueTL.getLoweredType()); + auto cast = SGF.B.createUncheckedReinterpretCast(E, value.forward(SGF), + opaqueTL.getLoweredType()); value = SGF.emitManagedRValueWithCleanup(cast); return RValue(SGF, E, value); diff --git a/lib/SILOptimizer/Differentiation/DifferentiationInvoker.cpp b/lib/SILOptimizer/Differentiation/DifferentiationInvoker.cpp index 133980174900d..aeb7fb43b474a 100644 --- a/lib/SILOptimizer/Differentiation/DifferentiationInvoker.cpp +++ b/lib/SILOptimizer/Differentiation/DifferentiationInvoker.cpp @@ -36,7 +36,7 @@ SourceLoc DifferentiationInvoker::getLocation() const { ->getLocation() .getSourceLoc(); } - llvm_unreachable("invalid differentation invoker kind"); + llvm_unreachable("Invalid invoker kind"); // silences MSVC C4715 } void DifferentiationInvoker::print(llvm::raw_ostream &os) const { diff --git a/lib/SILOptimizer/Differentiation/JVPCloner.cpp b/lib/SILOptimizer/Differentiation/JVPCloner.cpp index 0b33bbf439b02..e860562c9b2da 100644 --- a/lib/SILOptimizer/Differentiation/JVPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/JVPCloner.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// // -// This file defines a helper class for generating VJP functions for automatic +// This file defines a helper class for generating JVP functions for automatic // differentiation. // //===----------------------------------------------------------------------===// @@ -18,848 +18,1378 @@ #define DEBUG_TYPE "differentiation" #include "swift/SILOptimizer/Differentiation/JVPCloner.h" +#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" +#include "swift/SILOptimizer/Differentiation/AdjointValue.h" +#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" +#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" +#include "swift/SILOptimizer/Differentiation/PullbackCloner.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" +#include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" +#include "llvm/ADT/DenseMap.h" namespace swift { namespace autodiff { -//--------------------------------------------------------------------------// -// Initialization helpers -//--------------------------------------------------------------------------// +class JVPCloner::Implementation final + : public TypeSubstCloner { +private: + /// The global context. + ADContext &context; -/*static*/ -SubstitutionMap JVPCloner::getSubstitutionMap(SILFunction *original, - SILFunction *jvp) { - auto substMap = original->getForwardingSubstitutionMap(); - if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { - auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); - substMap = SubstitutionMap::get( - jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, - LookUpConformanceInSubstitutionMap(jvpSubstMap)); - } - return substMap; -} + /// The original function. + SILFunction *const original; -/*static*/ -const DifferentiableActivityInfo & -JVPCloner::getActivityInfo(ADContext &context, SILFunction *original, - SILAutoDiffIndices indices, SILFunction *jvp) { - // Get activity info of the original function. - auto &passManager = context.getPassManager(); - auto *activityAnalysis = - passManager.getAnalysis(); - auto &activityCollection = *activityAnalysis->get(original); - auto &activityInfo = activityCollection.getActivityInfo( - jvp->getLoweredFunctionType()->getSubstGenericSignature(), - AutoDiffDerivativeFunctionKind::JVP); - LLVM_DEBUG(activityInfo.dump(indices, getADDebugStream())); - return activityInfo; -} + /// The witness. + SILDifferentiabilityWitness *const witness; -JVPCloner::JVPCloner(ADContext &context, SILFunction *original, - SILDifferentiabilityWitness *witness, SILFunction *jvp, - DifferentiationInvoker invoker) - : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)), - context(context), original(original), witness(witness), jvp(jvp), - invoker(invoker), - activityInfo(getActivityInfo(context, original, - witness->getSILAutoDiffIndices(), jvp)), - differentialInfo(context, AutoDiffLinearMapKind::Differential, original, - jvp, witness->getSILAutoDiffIndices(), activityInfo), - differentialBuilder(SILBuilder( - *createEmptyDifferential(context, witness, &differentialInfo))), - diffLocalAllocBuilder(getDifferential()) { - // Create empty differential function. - context.recordGeneratedFunction(&getDifferential()); -} + /// The JVP function. + SILFunction *const jvp; -//--------------------------------------------------------------------------// -// Differential struct mapping -//--------------------------------------------------------------------------// + llvm::BumpPtrAllocator allocator; -void JVPCloner::initializeDifferentialStructElements( - SILBasicBlock *origBB, SILInstructionResultArray values) { - auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB); - assert(diffStructDecl->getStoredProperties().size() == values.size() && - "The number of differential struct fields must equal the number of " - "differential struct element values"); - for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) { - assert(std::get<1>(pair).getOwnershipKind() != - ValueOwnershipKind::Guaranteed && - "Differential struct elements must be @owned"); - auto insertion = differentialStructElements.insert( - {std::get<0>(pair), std::get<1>(pair)}); - (void)insertion; - assert(insertion.second && - "A differential struct element mapping already exists!"); - } -} + /// The differentiation invoker. + DifferentiationInvoker invoker; -SILValue JVPCloner::getDifferentialStructElement(SILBasicBlock *origBB, - VarDecl *field) { - assert(differentialInfo.getLinearMapStruct(origBB) == - cast(field->getDeclContext())); - assert(differentialStructElements.count(field) && - "Differential struct element for this field does not exist!"); - return differentialStructElements.lookup(field); -} + /// Info from activity analysis on the original function. + const DifferentiableActivityInfo &activityInfo; -//--------------------------------------------------------------------------// -// General utilities -//--------------------------------------------------------------------------// + /// The differential info. + LinearMapInfo differentialInfo; -SILBasicBlock::iterator -JVPCloner::getNextDifferentialLocalAllocationInsertionPoint() { - // If there are no local allocations, insert at the beginning of the tangent - // entry. - if (differentialLocalAllocations.empty()) - return getDifferential().getEntryBlock()->begin(); - // Otherwise, insert before the last local allocation. Inserting before - // rather than after ensures that allocation and zero initialization - // instructions are grouped together. - auto lastLocalAlloc = differentialLocalAllocations.back(); - auto it = lastLocalAlloc->getDefiningInstruction()->getIterator(); - return it; -} + bool errorOccurred = false; -SILType JVPCloner::getLoweredType(Type type) { - auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature(); - Lowering::AbstractionPattern pattern(jvpGenSig, - type->getCanonicalType(jvpGenSig)); - return jvp->getLoweredType(pattern, type); -} + //--------------------------------------------------------------------------// + // Differential generation related fields + //--------------------------------------------------------------------------// -SILType JVPCloner::getNominalDeclLoweredType(NominalTypeDecl *nominal) { - auto nominalType = - getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); - return getLoweredType(nominalType); -} + /// The builder for the differential function. + SILBuilder differentialBuilder; -StructInst *JVPCloner::buildDifferentialValueStructValue(TermInst *termInst) { - assert(termInst->getFunction() == original); - auto loc = termInst->getFunction()->getLocation(); - auto *origBB = termInst->getParent(); - auto *jvpBB = BBMap[origBB]; - assert(jvpBB && "Basic block mapping should exist"); - auto *diffStruct = differentialInfo.getLinearMapStruct(origBB); - assert(diffStruct && "The differential struct should have been declared"); - auto structLoweredTy = getNominalDeclLoweredType(diffStruct); - auto bbDifferentialValues = differentialValues[origBB]; - if (!origBB->isEntry()) { - auto *enumArg = jvpBB->getArguments().back(); - bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg); - } - return getBuilder().createStruct(loc, structLoweredTy, bbDifferentialValues); -} + /// Mapping from original basic blocks to corresponding differential basic + /// blocks. + llvm::DenseMap diffBBMap; -//--------------------------------------------------------------------------// -// Tangent value factory methods -//--------------------------------------------------------------------------// + /// Mapping from original basic blocks and original values to corresponding + /// tangent values. + llvm::DenseMap tangentValueMap; -AdjointValue JVPCloner::makeZeroTangentValue(SILType type) { - return AdjointValue::createZero(allocator, remapSILTypeInDifferential(type)); -} + /// Mapping from original basic blocks and original buffers to corresponding + /// tangent buffers. + llvm::DenseMap, SILValue> bufferMap; -AdjointValue JVPCloner::makeConcreteTangentValue(SILValue value) { - return AdjointValue::createConcrete(allocator, value); -} + /// Mapping from differential basic blocks to differential struct arguments. + llvm::DenseMap differentialStructArguments; -//--------------------------------------------------------------------------// -// Tangent materialization -//--------------------------------------------------------------------------// + /// Mapping from differential struct field declarations to differential struct + /// elements destructured from the linear map basic block argument. In the + /// beginning of each differential basic block, the block's differential + /// struct is destructured into the individual elements stored here. + llvm::DenseMap differentialStructElements; + + /// An auxiliary differential local allocation builder. + SILBuilder diffLocalAllocBuilder; + + /// Stack buffers allocated for storing local tangent values. + SmallVector differentialLocalAllocations; + + /// Mapping from original blocks to differential values. Used to build + /// differential struct instances. + llvm::DenseMap> differentialValues; + + //--------------------------------------------------------------------------// + // Getters + //--------------------------------------------------------------------------// + + ASTContext &getASTContext() const { return jvp->getASTContext(); } + SILModule &getModule() const { return jvp->getModule(); } + const SILAutoDiffIndices getIndices() const { + return witness->getSILAutoDiffIndices(); + } + SILBuilder &getDifferentialBuilder() { return differentialBuilder; } + SILFunction &getDifferential() { return differentialBuilder.getFunction(); } + SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) { +#ifndef NDEBUG + auto *diffStruct = differentialStructArguments[origBB] + ->getType() + .getStructOrBoundGenericStruct(); + assert(diffStruct == differentialInfo.getLinearMapStruct(origBB)); +#endif + return differentialStructArguments[origBB]; + } + + //--------------------------------------------------------------------------// + // Differential struct mapping + //--------------------------------------------------------------------------// + + void initializeDifferentialStructElements(SILBasicBlock *origBB, + SILInstructionResultArray values); + + SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field); + + //--------------------------------------------------------------------------// + // General utilities + //--------------------------------------------------------------------------// + + SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint() { + // If there are no local allocations, insert at the beginning of the tangent + // entry. + if (differentialLocalAllocations.empty()) + return getDifferential().getEntryBlock()->begin(); + // Otherwise, insert before the last local allocation. Inserting before + // rather than after ensures that allocation and zero initialization + // instructions are grouped together. + auto lastLocalAlloc = differentialLocalAllocations.back(); + auto it = lastLocalAlloc->getDefiningInstruction()->getIterator(); + return it; + } + + /// Get the lowered SIL type of the given AST type. + SILType getLoweredType(Type type) { + auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature(); + Lowering::AbstractionPattern pattern(jvpGenSig, + type->getCanonicalType(jvpGenSig)); + return jvp->getLoweredType(pattern, type); + } -void JVPCloner::emitZeroIndirect(CanType type, SILValue bufferAccess, - SILLocation loc) { - auto builder = getDifferentialBuilder(); - auto tangentSpace = getTangentSpace(type); - assert(tangentSpace && "No tangent space for this type"); - switch (tangentSpace->getKind()) { - case TangentSpace::Kind::TangentVector: - emitZeroIntoBuffer(builder, type, bufferAccess, loc); - return; - case TangentSpace::Kind::Tuple: { - auto tupleType = tangentSpace->getTuple(); - SmallVector zeroElements; - for (unsigned i : range(tupleType->getNumElements())) { - auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); - emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), - eltAddr, loc); + /// Get the lowered SIL type of the given nominal type declaration. + SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { + auto nominalType = + getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); + return getLoweredType(nominalType); + } + + /// Build a differential struct value for the original block corresponding to + /// the given terminator. + StructInst *buildDifferentialValueStructValue(TermInst *termInst) { + assert(termInst->getFunction() == original); + auto loc = termInst->getFunction()->getLocation(); + auto *origBB = termInst->getParent(); + auto *jvpBB = BBMap[origBB]; + assert(jvpBB && "Basic block mapping should exist"); + auto *diffStruct = differentialInfo.getLinearMapStruct(origBB); + assert(diffStruct && "The differential struct should have been declared"); + auto structLoweredTy = getNominalDeclLoweredType(diffStruct); + auto bbDifferentialValues = differentialValues[origBB]; + if (!origBB->isEntry()) { + auto *enumArg = jvpBB->getArguments().back(); + bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg); } - return; + return getBuilder().createStruct(loc, structLoweredTy, + bbDifferentialValues); } + + //--------------------------------------------------------------------------// + // Tangent value factory methods + //--------------------------------------------------------------------------// + + AdjointValue makeZeroTangentValue(SILType type) { + return AdjointValue::createZero(allocator, + remapSILTypeInDifferential(type)); } -} -SILValue JVPCloner::emitZeroDirect(CanType type, SILLocation loc) { - auto diffBuilder = getDifferentialBuilder(); - auto silType = getModule().Types.getLoweredLoadableType( - type, TypeExpansionContext::minimal(), getModule()); - auto *buffer = diffBuilder.createAllocStack(loc, silType); - emitZeroIndirect(type, buffer, loc); - auto loaded = diffBuilder.emitLoadValueOperation( - loc, buffer, LoadOwnershipQualifier::Take); - diffBuilder.createDeallocStack(loc, buffer); - return loaded; -} + AdjointValue makeConcreteTangentValue(SILValue value) { + return AdjointValue::createConcrete(allocator, value); + } -SILValue JVPCloner::materializeTangentDirect(AdjointValue val, - SILLocation loc) { - assert(val.getType().isObject()); - LLVM_DEBUG(getADDebugStream() - << "Materializing tangents for " << val << '\n'); - switch (val.getKind()) { - case AdjointValueKind::Zero: { - auto zeroVal = emitZeroDirect(val.getSwiftType(), loc); - return zeroVal; - } - case AdjointValueKind::Aggregate: - llvm_unreachable( - "Tuples and structs are not supported in forward mode yet."); - case AdjointValueKind::Concrete: - return val.getConcreteValue(); - } - llvm_unreachable("invalid value kind"); -} + //--------------------------------------------------------------------------// + // Tangent materialization + //--------------------------------------------------------------------------// + + void emitZeroIndirect(CanType type, SILValue bufferAccess, SILLocation loc) { + auto builder = getDifferentialBuilder(); + auto tangentSpace = getTangentSpace(type); + assert(tangentSpace && "No tangent space for this type"); + switch (tangentSpace->getKind()) { + case TangentSpace::Kind::TangentVector: + emitZeroIntoBuffer(builder, type, bufferAccess, loc); + return; + case TangentSpace::Kind::Tuple: { + auto tupleType = tangentSpace->getTuple(); + SmallVector zeroElements; + for (unsigned i : range(tupleType->getNumElements())) { + auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); + emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), + eltAddr, loc); + } + return; + } + } + } + + SILValue emitZeroDirect(CanType type, SILLocation loc) { + auto diffBuilder = getDifferentialBuilder(); + auto silType = getModule().Types.getLoweredLoadableType( + type, TypeExpansionContext::minimal(), getModule()); + auto *buffer = diffBuilder.createAllocStack(loc, silType); + emitZeroIndirect(type, buffer, loc); + auto loaded = diffBuilder.emitLoadValueOperation( + loc, buffer, LoadOwnershipQualifier::Take); + diffBuilder.createDeallocStack(loc, buffer); + return loaded; + } -SILValue JVPCloner::materializeTangent(AdjointValue val, SILLocation loc) { - if (val.isConcrete()) { + SILValue materializeTangentDirect(AdjointValue val, SILLocation loc) { + assert(val.getType().isObject()); LLVM_DEBUG(getADDebugStream() - << "Materializing tangent: Value is concrete.\n"); - return val.getConcreteValue(); + << "Materializing tangents for " << val << '\n'); + switch (val.getKind()) { + case AdjointValueKind::Zero: { + auto zeroVal = emitZeroDirect(val.getSwiftType(), loc); + return zeroVal; + } + case AdjointValueKind::Aggregate: + llvm_unreachable( + "Tuples and structs are not supported in forward mode yet."); + case AdjointValueKind::Concrete: + return val.getConcreteValue(); + } + llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 } - LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is " - "non-concrete. Materializing directly.\n"); - return materializeTangentDirect(val, loc); -} -//--------------------------------------------------------------------------// -// Tangent buffer mapping -//--------------------------------------------------------------------------// + SILValue materializeTangent(AdjointValue val, SILLocation loc) { + if (val.isConcrete()) { + LLVM_DEBUG(getADDebugStream() + << "Materializing tangent: Value is concrete.\n"); + return val.getConcreteValue(); + } + LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is " + "non-concrete. Materializing directly.\n"); + return materializeTangentDirect(val, loc); + } -void JVPCloner::setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, - SILValue tangentBuffer) { - assert(originalBuffer->getType().isAddress()); - auto insertion = - bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); - assert(insertion.second && "tangent buffer already exists."); - (void)insertion; -} + //--------------------------------------------------------------------------// + // Tangent value mapping + //--------------------------------------------------------------------------// + + /// Get the tangent for an original value. The given value must be in the + /// original function. + /// + /// This method first tries to find an entry in `tangentValueMap`. If an entry + /// doesn't exist, create a zero tangent. + AdjointValue getTangentValue(SILValue originalValue) { + assert(originalValue->getType().isObject()); + assert(originalValue->getFunction() == original); + auto insertion = tangentValueMap.try_emplace( + originalValue, + makeZeroTangentValue(getRemappedTangentType(originalValue->getType()))); + return insertion.first->getSecond(); + } -SILValue &JVPCloner::getTangentBuffer(SILBasicBlock *origBB, - SILValue originalBuffer) { - assert(originalBuffer->getType().isAddress()); - assert(originalBuffer->getFunction() == original); - auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, SILValue()); - assert(!insertion.second && "tangent buffer should already exist"); - return insertion.first->getSecond(); -} + /// Map the tangent value to the given original value. + void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, + AdjointValue newTangentValue) { +#ifndef NDEBUG + if (auto *defInst = originalValue->getDefiningInstruction()) { + bool isTupleTypedApplyResult = + isa(defInst) && originalValue->getType().is(); + assert(!isTupleTypedApplyResult && + "Should not set tangent value for tuple-typed result from `apply` " + "instruction; use `destructure_tuple` on `apply` result and set " + "tangent value for `destructure_tuple` results instead."); + } +#endif + assert(originalValue->getType().isObject()); + assert(newTangentValue.getType().isObject()); + assert(originalValue->getFunction() == original); + LLVM_DEBUG(getADDebugStream() << "Adding tangent for " << originalValue); + // The tangent value must be in the tangent space. + assert(newTangentValue.getType() == + getRemappedTangentType(originalValue->getType())); + auto insertion = + tangentValueMap.try_emplace(originalValue, newTangentValue); + (void)insertion; + assert(insertion.second && "The tangent value should not already exist."); + } -//--------------------------------------------------------------------------// -// Differential type calculations -//--------------------------------------------------------------------------// + //--------------------------------------------------------------------------// + // Tangent buffer mapping + //--------------------------------------------------------------------------// -/// Substitutes all replacement types of the given substitution map using the -/// tangent function's substitution map. -SubstitutionMap -JVPCloner::remapSubstitutionMapInDifferential(SubstitutionMap substMap) { - return substMap.subst(getDifferential().getForwardingSubstitutionMap()); -} + void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, + SILValue tangentBuffer) { + assert(originalBuffer->getType().isAddress()); + auto insertion = + bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); + assert(insertion.second && "Tangent buffer already exists"); + (void)insertion; + } -Type JVPCloner::remapTypeInDifferential(Type ty) { - if (ty->hasArchetype()) - return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); - return getDifferential().mapTypeIntoContext(ty); -} + SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { + assert(originalBuffer->getType().isAddress()); + assert(originalBuffer->getFunction() == original); + auto insertion = + bufferMap.try_emplace({origBB, originalBuffer}, SILValue()); + assert(!insertion.second && "Tangent buffer should already exist"); + return insertion.first->getSecond(); + } -SILType JVPCloner::remapSILTypeInDifferential(SILType ty) { - if (ty.hasArchetype()) - return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); - return getDifferential().mapTypeIntoContext(ty); -} + //--------------------------------------------------------------------------// + // Differential type calculations + //--------------------------------------------------------------------------// -Optional JVPCloner::getTangentSpace(CanType type) { - // Use witness generic signature to remap types. - if (auto witnessGenSig = witness->getDerivativeGenericSignature()) - type = witnessGenSig->getCanonicalTypeInContext(type); - return type->getAutoDiffTangentSpace( - LookUpConformanceInModule(getModule().getSwiftModule())); -} + /// Substitutes all replacement types of the given substitution map using the + /// tangent function's substitution map. + SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) { + return substMap.subst(getDifferential().getForwardingSubstitutionMap()); + } -SILType JVPCloner::getRemappedTangentType(SILType type) { - return SILType::getPrimitiveType( - getTangentSpace(remapSILTypeInDifferential(type).getASTType()) - ->getCanonicalType(), - type.getCategory()); -} + /// Remap any archetypes into the differential function's context. + Type remapTypeInDifferential(Type ty) { + if (ty->hasArchetype()) + return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); + return getDifferential().mapTypeIntoContext(ty); + } -//--------------------------------------------------------------------------// -// Tangent value mapping -//--------------------------------------------------------------------------// + /// Remap any archetypes into the differential function's context. + SILType remapSILTypeInDifferential(SILType ty) { + if (ty.hasArchetype()) + return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); + return getDifferential().mapTypeIntoContext(ty); + } -AdjointValue JVPCloner::getTangentValue(SILValue originalValue) { - assert(originalValue->getType().isObject()); - assert(originalValue->getFunction() == original); - auto insertion = tangentValueMap.try_emplace( - originalValue, - makeZeroTangentValue(getRemappedTangentType(originalValue->getType()))); - return insertion.first->getSecond(); -} + /// Find the tangent space of a given canonical type. + Optional getTangentSpace(CanType type) { + // Use witness generic signature to remap types. + if (auto witnessGenSig = witness->getDerivativeGenericSignature()) + type = witnessGenSig->getCanonicalTypeInContext(type); + return type->getAutoDiffTangentSpace( + LookUpConformanceInModule(getModule().getSwiftModule())); + } -void JVPCloner::setTangentValue(SILBasicBlock *origBB, SILValue originalValue, - AdjointValue newTangentValue) { -#ifndef NDEBUG - if (auto *defInst = originalValue->getDefiningInstruction()) { - bool isTupleTypedApplyResult = - isa(defInst) && originalValue->getType().is(); - assert(!isTupleTypedApplyResult && - "Should not set tangent value for tuple-typed result from `apply` " - "instruction; use `destructure_tuple` on `apply` result and set " - "tangent value for `destructure_tuple` results instead."); + /// Assuming the given type conforms to `Differentiable` after remapping, + /// returns the associated tangent space SIL type. + SILType getRemappedTangentType(SILType type) { + return SILType::getPrimitiveType( + getTangentSpace(remapSILTypeInDifferential(type).getASTType()) + ->getCanonicalType(), + type.getCategory()); + } + + /// Set up the differential function. This includes: + /// - Creating all differential blocks. + /// - Creating differential entry block arguments based on the function type. + /// - Creating tangent value mapping for original/differential parameters. + /// - Checking for unvaried result and emitting related warnings. + void prepareForDifferentialGeneration(); + +public: + explicit Implementation(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, + SILFunction *jvp, DifferentiationInvoker invoker); + + static SILFunction * + createEmptyDifferential(ADContext &context, + SILDifferentiabilityWitness *witness, + LinearMapInfo *linearMapInfo); + + /// Run JVP generation. Returns true on error. + bool run(); + + void postProcess(SILInstruction *orig, SILInstruction *cloned) { + if (errorOccurred) + return; + SILClonerWithScopes::postProcess(orig, cloned); + } + + /// Remap original basic blocks. + SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { + auto *jvpBB = BBMap[bb]; + return jvpBB; } + + /// General visitor for all instructions. If any error is emitted by previous + /// visits, bail out. + void visit(SILInstruction *inst) { + if (errorOccurred) + return; + if (differentialInfo.shouldDifferentiateInstruction(inst)) { + LLVM_DEBUG(getADDebugStream() << "JVPCloner visited:\n[ORIG]" << *inst); +#ifndef NDEBUG + auto diffBuilder = getDifferentialBuilder(); + auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); #endif - assert(originalValue->getType().isObject()); - assert(newTangentValue.getType().isObject()); - assert(originalValue->getFunction() == original); - LLVM_DEBUG(getADDebugStream() << "Adding tangent for " << originalValue); - // The tangent value must be in the tangent space. - assert(newTangentValue.getType() == - getRemappedTangentType(originalValue->getType())); - auto insertion = tangentValueMap.try_emplace(originalValue, newTangentValue); - (void)insertion; - assert(insertion.second && "The tangent value should not already exist."); -} + TypeSubstCloner::visit(inst); + LLVM_DEBUG({ + auto &s = llvm::dbgs() << "[TAN] Emitted in differential:\n"; + auto afterInsertion = diffBuilder.getInsertionPoint(); + for (auto it = ++beforeInsertion; it != afterInsertion; ++it) + s << *it; + }); + } else { + TypeSubstCloner::visit(inst); + } + } -//--------------------------------------------------------------------------// -// Tangent emission helpers -//--------------------------------------------------------------------------// + void visitSILInstruction(SILInstruction *inst) { + context.emitNondifferentiabilityError( + inst, invoker, diag::autodiff_expression_not_differentiable_note); + errorOccurred = true; + } + + void visitInstructionsInBlock(SILBasicBlock *bb) { + // Destructure the differential struct to get the elements. + auto &diffBuilder = getDifferentialBuilder(); + auto diffLoc = getDifferential().getLocation(); + auto *diffBB = diffBBMap.lookup(bb); + auto *mainDifferentialStruct = diffBB->getArguments().back(); + diffBuilder.setInsertionPoint(diffBB); + auto *dsi = + diffBuilder.createDestructureStruct(diffLoc, mainDifferentialStruct); + initializeDifferentialStructElements(bb, dsi->getResults()); + TypeSubstCloner::visitInstructionsInBlock(bb); + } + + // If an `apply` has active results or active inout parameters, replace it + // with an `apply` of its JVP. + void visitApplyInst(ApplyInst *ai) { + // If the function should not be differentiated or its the array literal + // initialization intrinsic, just do standard cloning. + if (!differentialInfo.shouldDifferentiateApplySite(ai) || + ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { + LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); + TypeSubstCloner::visitApplyInst(ai); + return; + } + + // Diagnose functions with active inout arguments. + // TODO(TF-129): Support `inout` argument differentiation. + for (auto inoutArg : ai->getInoutArguments()) { + if (activityInfo.isActive(inoutArg, getIndices())) { + context.emitNondifferentiabilityError( + ai, invoker, + diag::autodiff_cannot_differentiate_through_inout_arguments); + errorOccurred = true; + return; + } + } + + auto loc = ai->getLoc(); + auto &builder = getBuilder(); + auto origCallee = getOpValue(ai->getCallee()); + auto originalFnTy = origCallee->getType().castTo(); + + LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); + + // Get the minimal parameter and result indices required for differentiating + // this `apply`. + SmallVector allResults; + SmallVector activeParamIndices; + SmallVector activeResultIndices; + collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo, + allResults, activeParamIndices, + activeResultIndices); + assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); + assert(!activeResultIndices.empty() && "Result indices cannot be empty"); + LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; + llvm::interleave( + activeParamIndices.begin(), activeParamIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << "}, results={"; llvm::interleave( + activeResultIndices.begin(), activeResultIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << "}\n";); + + // Form expected indices. + auto numResults = + ai->getSubstCalleeType()->getNumResults() + + ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); + SILAutoDiffIndices indices( + IndexSubset::get(getASTContext(), + ai->getArgumentsWithoutIndirectResults().size(), + activeParamIndices), + IndexSubset::get(getASTContext(), numResults, activeResultIndices)); + + // Emit the JVP. + SILValue jvpValue; + // If functionSource is a `@differentiable` function, just extract it. + if (originalFnTy->isDifferentiable()) { + auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); + for (auto i : indices.parameters->getIndices()) { + if (!paramIndices->contains(i)) { + context.emitNondifferentiabilityError( + origCallee, invoker, + diag:: + autodiff_function_noderivative_parameter_not_differentiable); + errorOccurred = true; + return; + } + } + auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee); + jvpValue = builder.createDifferentiableFunctionExtract( + loc, NormalDifferentiableFunctionTypeComponent::JVP, + borrowedDiffFunc); + jvpValue = builder.emitCopyValueOperation(loc, jvpValue); + } + + // If JVP has not yet been found, emit an `differentiable_function` + // instruction on the remapped function operand and + // an `differentiable_function_extract` instruction to get the JVP. + // The `differentiable_function` instruction will be canonicalized during + // the transform main loop. + if (!jvpValue) { + // FIXME: Handle indirect differentiation invokers. This may require some + // redesign: currently, each original function + witness pair is mapped + // only to one invoker. + /* + DifferentiationInvoker indirect(ai, attr); + auto insertion = + context.getInvokers().try_emplace({original, attr}, indirect); + auto &invoker = insertion.first->getSecond(); + invoker = indirect; + */ + + // If the original `apply` instruction has a substitution map, then the + // applied function is specialized. + // In the JVP, specialization is also necessary for parity. The original + // function operand is specialized with a remapped version of same + // substitution map using an argument-less `partial_apply`. + if (ai->getSubstitutionMap().empty()) { + origCallee = builder.emitCopyValueOperation(loc, origCallee); + } else { + auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); + auto jvpPartialApply = getBuilder().createPartialApply( + ai->getLoc(), origCallee, substMap, {}, + ParameterConvention::Direct_Guaranteed); + origCallee = jvpPartialApply; + } + + // Check and diagnose non-differentiable original function type. + auto diagnoseNondifferentiableOriginalFunctionType = + [&](CanSILFunctionType origFnTy) { + // Check and diagnose non-differentiable arguments. + for (auto paramIndex : indices.parameters->getIndices()) { + if (!originalFnTy->getParameters()[paramIndex] + .getSILStorageInterfaceType() + .isDifferentiable(getModule())) { + context.emitNondifferentiabilityError( + ai->getArgumentsWithoutIndirectResults()[paramIndex], + invoker, diag::autodiff_nondifferentiable_argument); + errorOccurred = true; + return true; + } + } + // Check and diagnose non-differentiable results. + for (auto resultIndex : indices.results->getIndices()) { + SILType remappedResultType; + if (resultIndex >= originalFnTy->getNumResults()) { + auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); + auto inoutArg = + *std::next(ai->getInoutArguments().begin(), inoutArgIdx); + remappedResultType = inoutArg->getType(); + } else { + remappedResultType = originalFnTy->getResults()[resultIndex] + .getSILStorageInterfaceType(); + } + if (!remappedResultType.isDifferentiable(getModule())) { + context.emitNondifferentiabilityError( + origCallee, invoker, + diag::autodiff_nondifferentiable_result); + errorOccurred = true; + return true; + } + } + return false; + }; + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) + return; + + auto *diffFuncInst = context.createDifferentiableFunction( + builder, loc, indices.parameters, indices.results, origCallee); + + // Record the `differentiable_function` instruction. + context.addDifferentiableFunctionInstToWorklist(diffFuncInst); + + auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst); + auto extractedJVP = builder.createDifferentiableFunctionExtract( + loc, NormalDifferentiableFunctionTypeComponent::JVP, borrowedADFunc); + jvpValue = builder.emitCopyValueOperation(loc, extractedJVP); + builder.emitEndBorrowOperation(loc, borrowedADFunc); + builder.emitDestroyValueOperation(loc, diffFuncInst); + } + + // Call the JVP using the original parameters. + SmallVector jvpArgs; + auto jvpFnTy = getOpType(jvpValue->getType()).castTo(); + auto numJVPArgs = + jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults(); + jvpArgs.reserve(numJVPArgs); + // Collect substituted arguments. + for (auto origArg : ai->getArguments()) + jvpArgs.push_back(getOpValue(origArg)); + assert(jvpArgs.size() == numJVPArgs); + // Apply the JVP. + // The JVP should be specialized, so no substitution map is necessary. + auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(), + jvpArgs, ai->isNonThrowing()); + LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall); + + // Release the differentiable function. + builder.emitDestroyValueOperation(loc, jvpValue); + + // Get the JVP results (original results and differential). + SmallVector jvpDirectResults; + extractAllElements(jvpCall, builder, jvpDirectResults); + auto originalDirectResults = + ArrayRef(jvpDirectResults).drop_back(1); + auto originalDirectResult = + joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc()); + + mapValue(ai, originalDirectResult); + + // Some instructions that produce the callee may have been cloned. + // If the original callee did not have any users beyond this `apply`, + // recursively kill the cloned callee. + if (auto *origCallee = cast_or_null( + ai->getCallee()->getDefiningInstruction())) + if (origCallee->hasOneUse()) + recursivelyDeleteTriviallyDeadInstructions( + getOpValue(origCallee)->getDefiningInstruction()); + + // Add the differential function for when we create the struct we partially + // apply to the differential we are generating. + auto differential = jvpDirectResults.back(); + auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai); + auto originalDifferentialType = + getOpType(differential->getType()).getAs(); + auto loweredDifferentialType = + getOpType(getLoweredType(differentialDecl->getInterfaceType())) + .castTo(); + // If actual differential type does not match lowered differential type, + // reabstract the differential using a thunk. + if (!loweredDifferentialType->isEqual(originalDifferentialType)) { + SILOptFunctionBuilder fb(context.getTransform()); + differential = reabstractFunction( + builder, fb, loc, differential, loweredDifferentialType, + [this](SubstitutionMap subs) -> SubstitutionMap { + return this->getOpSubstitutionMap(subs); + }); + } + differentialValues[ai->getParent()].push_back(differential); + + // Differential emission. + emitTangentForApplyInst(ai, indices, originalDifferentialType); + } + + void visitReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); + auto *origExit = ri->getParent(); + auto &builder = getBuilder(); + auto *diffStructVal = buildDifferentialValueStructValue(ri); + + // Get the JVP value corresponding to the original functions's return value. + auto *origRetInst = cast(origExit->getTerminator()); + auto origResult = getOpValue(origRetInst->getOperand()); + SmallVector origResults; + extractAllElements(origResult, builder, origResults); + + // Get and partially apply the differential. + auto jvpGenericEnv = jvp->getGenericEnvironment(); + auto jvpSubstMap = jvpGenericEnv + ? jvpGenericEnv->getForwardingSubstitutionMap() + : jvp->getForwardingSubstitutionMap(); + auto *differentialRef = builder.createFunctionRef(loc, &getDifferential()); + auto *differentialPartialApply = builder.createPartialApply( + loc, differentialRef, jvpSubstMap, {diffStructVal}, + ParameterConvention::Direct_Guaranteed); + + auto differentialType = jvp->getLoweredFunctionType() + ->getResults() + .back() + .getSILStorageInterfaceType(); + differentialType = differentialType.substGenericArgs( + getModule(), jvpSubstMap, TypeExpansionContext::minimal()); + differentialType = differentialType.subst(getModule(), jvpSubstMap); + auto differentialFnType = differentialType.castTo(); + + auto differentialSubstType = + differentialPartialApply->getType().castTo(); + SILValue differentialValue; + if (differentialSubstType == differentialFnType) { + differentialValue = differentialPartialApply; + } else if (differentialSubstType + ->isABICompatibleWith(differentialFnType, *jvp) + .isCompatible()) { + differentialValue = builder.createConvertFunction( + loc, differentialPartialApply, differentialType, + /*withoutActuallyEscaping*/ false); + } else { + // When `diag::autodiff_loadable_value_addressonly_tangent_unsupported` + // applies, the return type may be ABI-incomaptible with the type of the + // partially applied differential. In these cases, produce an undef and + // rely on other code to emit a diagnostic. + differentialValue = SILUndef::get(differentialType, *jvp); + } + + // Return a tuple of the original result and differential. + SmallVector directResults; + directResults.append(origResults.begin(), origResults.end()); + directResults.push_back(differentialValue); + builder.createReturn(ri->getLoc(), + joinElements(directResults, builder, loc)); + } + + void visitBranchInst(BranchInst *bi) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitCondBranchInst(CondBranchInst *cbi) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitSwitchEnumInst(SwitchEnumInst *sei) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { + // Clone `differentiable_function` from original to JVP, then add the cloned + // instruction to the `differentiable_function` worklist. + TypeSubstCloner::visitDifferentiableFunctionInst(dfi); + auto *newDFI = cast(getOpValue(dfi)); + context.addDifferentiableFunctionInstToWorklist(newDFI); + } + + //--------------------------------------------------------------------------// + // Tangent emission helpers + //--------------------------------------------------------------------------// #define CLONE_AND_EMIT_TANGENT(INST, ID) \ - void JVPCloner::visit##INST##Inst(INST##Inst *inst) { \ + void visit##INST##Inst(INST##Inst *inst) { \ TypeSubstCloner::visit##INST##Inst(inst); \ if (differentialInfo.shouldDifferentiateInstruction(inst)) \ emitTangentFor##INST##Inst(inst); \ } \ - void JVPCloner::emitTangentFor##INST##Inst(INST##Inst *(ID)) - -CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = bbi->getLoc(); - auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); - auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); - setTangentValue(bbi->getParent(), bbi, - makeConcreteTangentValue(tanValBorrow)); -} - -CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = ebi->getLoc(); - auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); - diffBuilder.emitEndBorrowOperation(loc, tanVal); -} + void emitTangentFor##INST##Inst(INST##Inst *(ID)) + + CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = bbi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); + auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); + setTangentValue(bbi->getParent(), bbi, + makeConcreteTangentValue(tanValBorrow)); + } -CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = dvi->getLoc(); - auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); - diffBuilder.emitDestroyValue(loc, tanVal); -} + CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = ebi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); + diffBuilder.emitEndBorrowOperation(loc, tanVal); + } -CLONE_AND_EMIT_TANGENT(CopyValue, cvi) { - auto &diffBuilder = getDifferentialBuilder(); - auto tan = getTangentValue(cvi->getOperand()); - auto tanVal = materializeTangent(tan, cvi->getLoc()); - auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal); - setTangentValue(cvi->getParent(), cvi, makeConcreteTangentValue(tanValCopy)); -} + CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = dvi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); + diffBuilder.emitDestroyValue(loc, tanVal); + } -/// Handle `load` instruction. -/// Original: y = load x -/// Tangent: tan[y] = load tan[x] -CLONE_AND_EMIT_TANGENT(Load, li) { - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = li->getParent(); - auto loc = li->getLoc(); - auto tanBuf = getTangentBuffer(bb, li->getOperand()); - auto tanVal = diffBuilder.emitLoadValueOperation(loc, tanBuf, - li->getOwnershipQualifier()); - setTangentValue(bb, li, makeConcreteTangentValue(tanVal)); -} + CLONE_AND_EMIT_TANGENT(CopyValue, cvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto tan = getTangentValue(cvi->getOperand()); + auto tanVal = materializeTangent(tan, cvi->getLoc()); + auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal); + setTangentValue(cvi->getParent(), cvi, + makeConcreteTangentValue(tanValCopy)); + } -/// Handle `load_borrow` instruction. -/// Original: y = load_borrow x -/// Tangent: tan[y] = load_borrow tan[x] -CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) { - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = lbi->getParent(); - auto loc = lbi->getLoc(); - auto tanBuf = getTangentBuffer(bb, lbi->getOperand()); - auto tanVal = diffBuilder.emitLoadBorrowOperation(loc, tanBuf); - setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal)); -} + /// Handle `load` instruction. + /// Original: y = load x + /// Tangent: tan[y] = load tan[x] + CLONE_AND_EMIT_TANGENT(Load, li) { + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = li->getParent(); + auto loc = li->getLoc(); + auto tanBuf = getTangentBuffer(bb, li->getOperand()); + auto tanVal = diffBuilder.emitLoadValueOperation( + loc, tanBuf, li->getOwnershipQualifier()); + setTangentValue(bb, li, makeConcreteTangentValue(tanVal)); + } -/// Handle `store` instruction in the differential. -/// Original: store x to y -/// Tangent: store tan[x] to tan[y] -CLONE_AND_EMIT_TANGENT(Store, si) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = si->getLoc(); - auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); - auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); - diffBuilder.emitStoreValueOperation(loc, tanValSrc, tanValDest, - si->getOwnershipQualifier()); -} + /// Handle `load_borrow` instruction. + /// Original: y = load_borrow x + /// Tangent: tan[y] = load_borrow tan[x] + CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) { + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = lbi->getParent(); + auto loc = lbi->getLoc(); + auto tanBuf = getTangentBuffer(bb, lbi->getOperand()); + auto tanVal = diffBuilder.emitLoadBorrowOperation(loc, tanBuf); + setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal)); + } -/// Handle `store_borrow` instruction in the differential. -/// Original: store_borrow x to y -/// Tangent: store_borrow tan[x] to tan[y] -CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = sbi->getLoc(); - auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); - auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); - diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); -} + /// Handle `store` instruction in the differential. + /// Original: store x to y + /// Tangent: store tan[x] to tan[y] + CLONE_AND_EMIT_TANGENT(Store, si) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = si->getLoc(); + auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); + auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); + diffBuilder.emitStoreValueOperation(loc, tanValSrc, tanValDest, + si->getOwnershipQualifier()); + } -/// Handle `copy_addr` instruction. -/// Original: copy_addr x to y -/// Tangent: copy_addr tan[x] to tan[y] -CLONE_AND_EMIT_TANGENT(CopyAddr, cai) { - auto diffBuilder = getDifferentialBuilder(); - auto loc = cai->getLoc(); - auto *bb = cai->getParent(); - auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); - auto tanDest = getTangentBuffer(bb, cai->getDest()); - - diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), - cai->isInitializationOfDest()); -} + /// Handle `store_borrow` instruction in the differential. + /// Original: store_borrow x to y + /// Tangent: store_borrow tan[x] to tan[y] + CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = sbi->getLoc(); + auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); + auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); + diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); + } -/// Handle `unconditional_checked_cast_addr` instruction. -/// Original: unconditional_checked_cast_addr $X in x to $Y in y -/// Tangent: unconditional_checked_cast_addr $X.Tan in tan[x] -/// to $Y.Tan in tan[y] -CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai) { - auto diffBuilder = getDifferentialBuilder(); - auto loc = uccai->getLoc(); - auto *bb = uccai->getParent(); - auto &tanSrc = getTangentBuffer(bb, uccai->getSrc()); - auto tanDest = getTangentBuffer(bb, uccai->getDest()); - - diffBuilder.createUnconditionalCheckedCastAddr( - loc, tanSrc, tanSrc->getType().getASTType(), tanDest, - tanDest->getType().getASTType()); -} + /// Handle `copy_addr` instruction. + /// Original: copy_addr x to y + /// Tangent: copy_addr tan[x] to tan[y] + CLONE_AND_EMIT_TANGENT(CopyAddr, cai) { + auto diffBuilder = getDifferentialBuilder(); + auto loc = cai->getLoc(); + auto *bb = cai->getParent(); + auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); + auto tanDest = getTangentBuffer(bb, cai->getDest()); -/// Handle `begin_access` instruction (and do differentiability checks). -/// Original: y = begin_access x -/// Tangent: tan[y] = begin_access tan[x] -CLONE_AND_EMIT_TANGENT(BeginAccess, bai) { - // Check for non-differentiable writes. - if (bai->getAccessKind() == SILAccessKind::Modify) { - if (auto *gai = dyn_cast(bai->getSource())) { - context.emitNondifferentiabilityError( - bai, invoker, - diag::autodiff_cannot_differentiate_writes_to_global_variables); - errorOccurred = true; - return; - } - if (auto *pbi = dyn_cast(bai->getSource())) { - context.emitNondifferentiabilityError( - bai, invoker, - diag::autodiff_cannot_differentiate_writes_to_mutable_captures); - errorOccurred = true; - return; - } + diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), + cai->isInitializationOfDest()); } - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = bai->getParent(); + /// Handle `unconditional_checked_cast_addr` instruction. + /// Original: unconditional_checked_cast_addr $X in x to $Y in y + /// Tangent: unconditional_checked_cast_addr $X.Tan in tan[x] + /// to $Y.Tan in tan[y] + CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai) { + auto diffBuilder = getDifferentialBuilder(); + auto loc = uccai->getLoc(); + auto *bb = uccai->getParent(); + auto &tanSrc = getTangentBuffer(bb, uccai->getSrc()); + auto tanDest = getTangentBuffer(bb, uccai->getDest()); + + diffBuilder.createUnconditionalCheckedCastAddr( + loc, tanSrc, tanSrc->getType().getASTType(), tanDest, + tanDest->getType().getASTType()); + } - auto tanSrc = getTangentBuffer(bb, bai->getSource()); - auto *tanDest = diffBuilder.createBeginAccess( - bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(), - bai->hasNoNestedConflict(), bai->isFromBuiltin()); - setTangentBuffer(bb, bai, tanDest); -} + /// Handle `begin_access` instruction (and do differentiability checks). + /// Original: y = begin_access x + /// Tangent: tan[y] = begin_access tan[x] + CLONE_AND_EMIT_TANGENT(BeginAccess, bai) { + // Check for non-differentiable writes. + if (bai->getAccessKind() == SILAccessKind::Modify) { + if (auto *gai = dyn_cast(bai->getSource())) { + context.emitNondifferentiabilityError( + bai, invoker, + diag::autodiff_cannot_differentiate_writes_to_global_variables); + errorOccurred = true; + return; + } + if (auto *pbi = dyn_cast(bai->getSource())) { + context.emitNondifferentiabilityError( + bai, invoker, + diag::autodiff_cannot_differentiate_writes_to_mutable_captures); + errorOccurred = true; + return; + } + } -/// Handle `end_access` instruction. -/// Original: begin_access x -/// Tangent: end_access tan[x] -CLONE_AND_EMIT_TANGENT(EndAccess, eai) { - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = eai->getParent(); - auto loc = eai->getLoc(); - auto tanSrc = getTangentBuffer(bb, eai->getOperand()); - diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting()); -} + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = bai->getParent(); -/// Handle `alloc_stack` instruction. -/// Original: y = alloc_stack $T -/// Tangent: tan[y] = alloc_stack $T.Tangent -CLONE_AND_EMIT_TANGENT(AllocStack, asi) { - auto &diffBuilder = getDifferentialBuilder(); - auto *mappedAllocStackInst = diffBuilder.createAllocStack( - asi->getLoc(), getRemappedTangentType(asi->getElementType()), - asi->getVarInfo()); - bufferMap.try_emplace({asi->getParent(), asi}, mappedAllocStackInst); -} + auto tanSrc = getTangentBuffer(bb, bai->getSource()); + auto *tanDest = diffBuilder.createBeginAccess( + bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(), + bai->hasNoNestedConflict(), bai->isFromBuiltin()); + setTangentBuffer(bb, bai, tanDest); + } -/// Handle `dealloc_stack` instruction. -/// Original: dealloc_stack x -/// Tangent: dealloc_stack tan[x] -CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) { - auto &diffBuilder = getDifferentialBuilder(); - auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand()); - diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf); -} + /// Handle `end_access` instruction. + /// Original: begin_access x + /// Tangent: end_access tan[x] + CLONE_AND_EMIT_TANGENT(EndAccess, eai) { + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = eai->getParent(); + auto loc = eai->getLoc(); + auto tanSrc = getTangentBuffer(bb, eai->getOperand()); + diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting()); + } -/// Handle `destroy_addr` instruction. -/// Original: destroy_addr x -/// Tangent: destroy_addr tan[x] -CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) { - auto &diffBuilder = getDifferentialBuilder(); - auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand()); - diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf); -} + /// Handle `alloc_stack` instruction. + /// Original: y = alloc_stack $T + /// Tangent: tan[y] = alloc_stack $T.Tangent + CLONE_AND_EMIT_TANGENT(AllocStack, asi) { + auto &diffBuilder = getDifferentialBuilder(); + auto *mappedAllocStackInst = diffBuilder.createAllocStack( + asi->getLoc(), getRemappedTangentType(asi->getElementType()), + asi->getVarInfo()); + bufferMap.try_emplace({asi->getParent(), asi}, mappedAllocStackInst); + } -/// Handle `struct` instruction. -/// Original: y = struct $T (x0, x1, x2, ...) -/// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) -CLONE_AND_EMIT_TANGENT(Struct, si) { - auto &diffBuilder = getDifferentialBuilder(); - SmallVector tangentElements; - for (auto elem : si->getElements()) - tangentElements.push_back(getTangentValue(elem).getConcreteValue()); - auto tanExtract = diffBuilder.createStruct( - si->getLoc(), getRemappedTangentType(si->getType()), tangentElements); - setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract)); -} + /// Handle `dealloc_stack` instruction. + /// Original: dealloc_stack x + /// Tangent: dealloc_stack tan[x] + CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) { + auto &diffBuilder = getDifferentialBuilder(); + auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand()); + diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf); + } -/// Handle `struct_extract` instruction. -/// Original: y = struct_extract x, #field -/// Tangent: tan[y] = struct_extract tan[x], #field' -/// ^~~~~~~ -/// field in tangent space corresponding to #field -CLONE_AND_EMIT_TANGENT(StructExtract, sei) { - assert(!sei->getField()->getAttrs().hasAttribute() && - "`struct_extract` with `@noDerivative` field should not be " - "differentiated; activity analysis should not marked as varied."); - auto diffBuilder = getDifferentialBuilder(); - auto loc = getValidLocation(sei); - // Find the corresponding field in the tangent space. - auto structType = - remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType(); - auto *tanField = getTangentStoredProperty(context, sei, structType, invoker); - if (!tanField) { - errorOccurred = true; - return; - } - // Emit tangent `struct_extract`. - auto tanStruct = materializeTangent(getTangentValue(sei->getOperand()), loc); - auto tangentInst = diffBuilder.createStructExtract(loc, tanStruct, tanField); - // Update tangent value mapping for `struct_extract` result. - auto tangentResult = makeConcreteTangentValue(tangentInst); - setTangentValue(sei->getParent(), sei, tangentResult); -} + /// Handle `destroy_addr` instruction. + /// Original: destroy_addr x + /// Tangent: destroy_addr tan[x] + CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) { + auto &diffBuilder = getDifferentialBuilder(); + auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand()); + diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf); + } -/// Handle `struct_element_addr` instruction. -/// Original: y = struct_element_addr x, #field -/// Tangent: tan[y] = struct_element_addr tan[x], #field' -/// ^~~~~~~ -/// field in tangent space corresponding to #field -CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { - assert(!seai->getField()->getAttrs().hasAttribute() && - "`struct_element_addr` with `@noDerivative` field should not be " - "differentiated; activity analysis should not marked as varied."); - auto diffBuilder = getDifferentialBuilder(); - auto *bb = seai->getParent(); - auto loc = getValidLocation(seai); - // Find the corresponding field in the tangent space. - auto structType = - remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType(); - auto *tanField = getTangentStoredProperty(context, seai, structType, invoker); - if (!tanField) { - errorOccurred = true; - return; - } - // Emit tangent `struct_element_addr`. - auto tanOperand = getTangentBuffer(bb, seai->getOperand()); - auto tangentInst = - diffBuilder.createStructElementAddr(loc, tanOperand, tanField); - // Update tangent buffer map for `struct_element_addr`. - setTangentBuffer(bb, seai, tangentInst); -} + /// Handle `struct` instruction. + /// Original: y = struct $T (x0, x1, x2, ...) + /// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) + CLONE_AND_EMIT_TANGENT(Struct, si) { + auto &diffBuilder = getDifferentialBuilder(); + SmallVector tangentElements; + for (auto elem : si->getElements()) + tangentElements.push_back(getTangentValue(elem).getConcreteValue()); + auto tanExtract = diffBuilder.createStruct( + si->getLoc(), getRemappedTangentType(si->getType()), tangentElements); + setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract)); + } -/// Handle `tuple` instruction. -/// Original: y = tuple (x0, x1, x2, ...) -/// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) -/// ^~~ -/// excluding non-differentiable elements -CLONE_AND_EMIT_TANGENT(Tuple, ti) { - auto diffBuilder = getDifferentialBuilder(); - - // Get the tangents of all the tuple elements. - SmallVector tangentTupleElements; - for (auto elem : ti->getElements()) { - if (!getTangentSpace(elem->getType().getASTType())) - continue; - tangentTupleElements.push_back( - materializeTangent(getTangentValue(elem), ti->getLoc())); + /// Handle `struct_extract` instruction. + /// Original: y = struct_extract x, #field + /// Tangent: tan[y] = struct_extract tan[x], #field' + /// ^~~~~~~ + /// field in tangent space corresponding to #field + CLONE_AND_EMIT_TANGENT(StructExtract, sei) { + assert(!sei->getField()->getAttrs().hasAttribute() && + "`struct_extract` with `@noDerivative` field should not be " + "differentiated; activity analysis should not marked as varied."); + auto diffBuilder = getDifferentialBuilder(); + auto loc = getValidLocation(sei); + // Find the corresponding field in the tangent space. + auto structType = + remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(context, sei, structType, invoker); + if (!tanField) { + errorOccurred = true; + return; + } + // Emit tangent `struct_extract`. + auto tanStruct = + materializeTangent(getTangentValue(sei->getOperand()), loc); + auto tangentInst = + diffBuilder.createStructExtract(loc, tanStruct, tanField); + // Update tangent value mapping for `struct_extract` result. + auto tangentResult = makeConcreteTangentValue(tangentInst); + setTangentValue(sei->getParent(), sei, tangentResult); } - // Emit the instruction and add the tangent mapping. - auto tanTuple = joinElements(tangentTupleElements, diffBuilder, ti->getLoc()); - setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple)); -} + /// Handle `struct_element_addr` instruction. + /// Original: y = struct_element_addr x, #field + /// Tangent: tan[y] = struct_element_addr tan[x], #field' + /// ^~~~~~~ + /// field in tangent space corresponding to #field + CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { + assert(!seai->getField()->getAttrs().hasAttribute() && + "`struct_element_addr` with `@noDerivative` field should not be " + "differentiated; activity analysis should not marked as varied."); + auto diffBuilder = getDifferentialBuilder(); + auto *bb = seai->getParent(); + auto loc = getValidLocation(seai); + // Find the corresponding field in the tangent space. + auto structType = + remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType(); + auto *tanField = + getTangentStoredProperty(context, seai, structType, invoker); + if (!tanField) { + errorOccurred = true; + return; + } + // Emit tangent `struct_element_addr`. + auto tanOperand = getTangentBuffer(bb, seai->getOperand()); + auto tangentInst = + diffBuilder.createStructElementAddr(loc, tanOperand, tanField); + // Update tangent buffer map for `struct_element_addr`. + setTangentBuffer(bb, seai, tangentInst); + } -/// Handle `tuple_extract` instruction. -/// Original: y = tuple_extract x, -/// Tangent: tan[y] = tuple_extract tan[x], -/// ^~~~ -/// tuple tangent space index corresponding to n -CLONE_AND_EMIT_TANGENT(TupleExtract, tei) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = tei->getLoc(); - auto origTupleTy = tei->getOperand()->getType().castTo(); - unsigned tanIndex = 0; - for (unsigned i : range(tei->getFieldNo())) { - if (getTangentSpace( - origTupleTy->getElement(i).getType()->getCanonicalType())) - ++tanIndex; - } - auto tanType = getRemappedTangentType(tei->getType()); - auto tanSource = materializeTangent(getTangentValue(tei->getOperand()), loc); - SILValue tanBuf; - // If the tangent buffer of the source does not have a tuple type, then - // it must represent a "single element tuple type". Use it directly. - if (!tanSource->getType().is()) { - setTangentValue(tei->getParent(), tei, makeConcreteTangentValue(tanSource)); - } else { - tanBuf = diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType); - bufferMap.try_emplace({tei->getParent(), tei}, tanBuf); + /// Handle `tuple` instruction. + /// Original: y = tuple (x0, x1, x2, ...) + /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) + /// ^~~ + /// excluding non-differentiable elements + CLONE_AND_EMIT_TANGENT(Tuple, ti) { + auto diffBuilder = getDifferentialBuilder(); + // Get the tangents of all the tuple elements. + SmallVector tangentTupleElements; + for (auto elem : ti->getElements()) { + if (!getTangentSpace(elem->getType().getASTType())) + continue; + tangentTupleElements.push_back( + materializeTangent(getTangentValue(elem), ti->getLoc())); + } + // Emit the instruction and add the tangent mapping. + auto tanTuple = + joinElements(tangentTupleElements, diffBuilder, ti->getLoc()); + setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple)); } -} -/// Handle `tuple_element_addr` instruction. -/// Original: y = tuple_element_addr x, -/// Tangent: tan[y] = tuple_element_addr tan[x], -/// ^~~~ -/// tuple tangent space index corresponding to n -CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) { - auto &diffBuilder = getDifferentialBuilder(); - auto origTupleTy = teai->getOperand()->getType().castTo(); - unsigned tanIndex = 0; - for (unsigned i : range(teai->getFieldNo())) { - if (getTangentSpace( - origTupleTy->getElement(i).getType()->getCanonicalType())) - ++tanIndex; - } - auto tanType = getRemappedTangentType(teai->getType()); - auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand()); - SILValue tanBuf; - // If the tangent buffer of the source does not have a tuple type, then - // it must represent a "single element tuple type". Use it directly. - if (!tanSource->getType().is()) { - tanBuf = tanSource; - } else { - tanBuf = diffBuilder.createTupleElementAddr(teai->getLoc(), tanSource, - tanIndex, tanType); + /// Handle `tuple_extract` instruction. + /// Original: y = tuple_extract x, + /// Tangent: tan[y] = tuple_extract tan[x], + /// ^~~~ + /// tuple tangent space index corresponding to n + CLONE_AND_EMIT_TANGENT(TupleExtract, tei) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = tei->getLoc(); + auto origTupleTy = tei->getOperand()->getType().castTo(); + unsigned tanIndex = 0; + for (unsigned i : range(tei->getFieldNo())) { + if (getTangentSpace( + origTupleTy->getElement(i).getType()->getCanonicalType())) + ++tanIndex; + } + auto tanType = getRemappedTangentType(tei->getType()); + auto tanSource = + materializeTangent(getTangentValue(tei->getOperand()), loc); + SILValue tanBuf; + // If the tangent buffer of the source does not have a tuple type, then + // it must represent a "single element tuple type". Use it directly. + if (!tanSource->getType().is()) { + setTangentValue(tei->getParent(), tei, + makeConcreteTangentValue(tanSource)); + } else { + tanBuf = + diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType); + bufferMap.try_emplace({tei->getParent(), tei}, tanBuf); + } } - bufferMap.try_emplace({teai->getParent(), teai}, tanBuf); -} -/// Handle `destructure_tuple` instruction. -/// Original: (y0, y1, ...) = destructure_tuple x, -/// Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], -/// ^~~~ -/// tuple tangent space index corresponding to n -CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { - assert(llvm::any_of(dti->getResults(), - [&](SILValue elt) { - return activityInfo.isActive(elt, getIndices()); - }) && - "`destructure_tuple` should have at least one active result"); - - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = dti->getParent(); - auto loc = dti->getLoc(); - - auto tanTuple = materializeTangent(getTangentValue(dti->getOperand()), loc); - SmallVector tanElts; - if (tanTuple->getType().is()) { - auto *tanDti = diffBuilder.createDestructureTuple(loc, tanTuple); - tanElts.append(tanDti->getResults().begin(), tanDti->getResults().end()); - } else { - tanElts.push_back(tanTuple); + /// Handle `tuple_element_addr` instruction. + /// Original: y = tuple_element_addr x, + /// Tangent: tan[y] = tuple_element_addr tan[x], + /// ^~~~ + /// tuple tangent space index corresponding to n + CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) { + auto &diffBuilder = getDifferentialBuilder(); + auto origTupleTy = teai->getOperand()->getType().castTo(); + unsigned tanIndex = 0; + for (unsigned i : range(teai->getFieldNo())) { + if (getTangentSpace( + origTupleTy->getElement(i).getType()->getCanonicalType())) + ++tanIndex; + } + auto tanType = getRemappedTangentType(teai->getType()); + auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand()); + SILValue tanBuf; + // If the tangent buffer of the source does not have a tuple type, then + // it must represent a "single element tuple type". Use it directly. + if (!tanSource->getType().is()) { + tanBuf = tanSource; + } else { + tanBuf = diffBuilder.createTupleElementAddr(teai->getLoc(), tanSource, + tanIndex, tanType); + } + bufferMap.try_emplace({teai->getParent(), teai}, tanBuf); } - unsigned tanIdx = 0; - for (auto i : range(dti->getNumResults())) { - auto origElt = dti->getResult(i); - if (!getTangentSpace(origElt->getType().getASTType())) - continue; - setTangentValue(bb, origElt, makeConcreteTangentValue(tanElts[tanIdx++])); + + /// Handle `destructure_tuple` instruction. + /// Original: (y0, y1, ...) = destructure_tuple x, + /// Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], + /// ^~~~ + /// tuple tangent space index corresponding to n + CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { + assert(llvm::any_of(dti->getResults(), + [&](SILValue elt) { + return activityInfo.isActive(elt, getIndices()); + }) && + "`destructure_tuple` should have at least one active result"); + + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = dti->getParent(); + auto loc = dti->getLoc(); + + auto tanTuple = materializeTangent(getTangentValue(dti->getOperand()), loc); + SmallVector tanElts; + if (tanTuple->getType().is()) { + auto *tanDti = diffBuilder.createDestructureTuple(loc, tanTuple); + tanElts.append(tanDti->getResults().begin(), tanDti->getResults().end()); + } else { + tanElts.push_back(tanTuple); + } + unsigned tanIdx = 0; + for (auto i : range(dti->getNumResults())) { + auto origElt = dti->getResult(i); + if (!getTangentSpace(origElt->getType().getASTType())) + continue; + setTangentValue(bb, origElt, makeConcreteTangentValue(tanElts[tanIdx++])); + } } -} #undef CLONE_AND_EMIT_TANGENT -/// Handle `apply` instruction. -/// Original: y = apply f(x0, x1, ...) -/// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...) -void JVPCloner::emitTangentForApplyInst( - ApplyInst *ai, SILAutoDiffIndices applyIndices, - CanSILFunctionType originalDifferentialType) { - assert(differentialInfo.shouldDifferentiateApplySite(ai)); - auto *bb = ai->getParent(); - auto loc = ai->getLoc(); - auto &diffBuilder = getDifferentialBuilder(); - - // Get the differential value. - auto *field = differentialInfo.lookUpLinearMapDecl(ai); - assert(field); - SILValue differential = getDifferentialStructElement(bb, field); - auto differentialType = remapSILTypeInDifferential(differential->getType()) - .castTo(); - - // Get the differential arguments. - SmallVector diffArgs; - - for (auto indRes : ai->getIndirectSILResults()) - diffArgs.push_back(getTangentBuffer(bb, indRes)); - - auto paramArgs = ai->getArgumentsWithoutIndirectResults(); - // Get the tangent value of the original arguments. - for (auto i : indices(paramArgs)) { - auto origArg = paramArgs[i]; - // If the argument is not active: - // - Skip the element, if it is not differentiable. - // - Otherwise, add a zero value to that location. - if (!activityInfo.isActive(origArg, getIndices())) { - auto origCalleeType = ai->getSubstCalleeType(); - if (!origCalleeType->isDifferentiable()) - continue; - auto actualOrigCalleeIndices = - origCalleeType->getDifferentiabilityParameterIndices(); - if (actualOrigCalleeIndices->contains(i)) { + /// Handle `apply` instruction, given: + /// - The minimal indices for differentiating the `apply`. + /// - The original non-reabstracted differential type. + /// + /// Original: y = apply f(x0, x1, ...) + /// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...) + void emitTangentForApplyInst(ApplyInst *ai, SILAutoDiffIndices applyIndices, + CanSILFunctionType originalDifferentialType) { + assert(differentialInfo.shouldDifferentiateApplySite(ai)); + auto *bb = ai->getParent(); + auto loc = ai->getLoc(); + auto &diffBuilder = getDifferentialBuilder(); + + // Get the differential value. + auto *field = differentialInfo.lookUpLinearMapDecl(ai); + assert(field); + SILValue differential = getDifferentialStructElement(bb, field); + auto differentialType = remapSILTypeInDifferential(differential->getType()) + .castTo(); + + // Get the differential arguments. + SmallVector diffArgs; + + for (auto indRes : ai->getIndirectSILResults()) + diffArgs.push_back(getTangentBuffer(bb, indRes)); + + auto paramArgs = ai->getArgumentsWithoutIndirectResults(); + // Get the tangent value of the original arguments. + for (auto i : indices(paramArgs)) { + auto origArg = paramArgs[i]; + // If the argument is not active: + // - Skip the element, if it is not differentiable. + // - Otherwise, add a zero value to that location. + if (!activityInfo.isActive(origArg, getIndices())) { + auto origCalleeType = ai->getSubstCalleeType(); + if (!origCalleeType->isDifferentiable()) + continue; + auto actualOrigCalleeIndices = + origCalleeType->getDifferentiabilityParameterIndices(); + if (actualOrigCalleeIndices->contains(i)) { + SILValue tanParam; + if (origArg->getType().isObject()) { + tanParam = emitZeroDirect( + getRemappedTangentType(origArg->getType()).getASTType(), loc); + diffArgs.push_back(tanParam); + } else { + tanParam = diffBuilder.createAllocStack( + loc, getRemappedTangentType(origArg->getType())); + emitZeroIndirect( + getRemappedTangentType(origArg->getType()).getASTType(), + tanParam, loc); + } + } + } + // Otherwise, if the argument is active, handle the argument normally by + // getting its tangent value. + else { SILValue tanParam; if (origArg->getType().isObject()) { - tanParam = emitZeroDirect( - getRemappedTangentType(origArg->getType()).getASTType(), loc); - diffArgs.push_back(tanParam); + tanParam = materializeTangent(getTangentValue(origArg), loc); } else { - tanParam = diffBuilder.createAllocStack( - loc, getRemappedTangentType(origArg->getType())); - emitZeroIndirect( - getRemappedTangentType(origArg->getType()).getASTType(), tanParam, - loc); + tanParam = getTangentBuffer(ai->getParent(), origArg); } + diffArgs.push_back(tanParam); + if (errorOccurred) + return; } } - // Otherwise, if the argument is active, handle the argument normally by - // getting its tangent value. - else { - SILValue tanParam; - if (origArg->getType().isObject()) { - tanParam = materializeTangent(getTangentValue(origArg), loc); - } else { - tanParam = getTangentBuffer(ai->getParent(), origArg); - } - diffArgs.push_back(tanParam); - if (errorOccurred) - return; + + // If callee differential was reabstracted in JVP, reabstract the callee + // differential. + if (!differentialType->isEqual(originalDifferentialType)) { + SILOptFunctionBuilder fb(context.getTransform()); + differential = reabstractFunction( + diffBuilder, fb, loc, differential, originalDifferentialType, + [this](SubstitutionMap subs) -> SubstitutionMap { + return this->getOpSubstitutionMap(subs); + }); } - } - // If callee differential was reabstracted in JVP, reabstract the callee - // differential. - if (!differentialType->isEqual(originalDifferentialType)) { - SILOptFunctionBuilder fb(context.getTransform()); - differential = reabstractFunction( - diffBuilder, fb, loc, differential, originalDifferentialType, - [this](SubstitutionMap subs) -> SubstitutionMap { - return this->getOpSubstitutionMap(subs); - }); - } - - // Call the differential. - auto *differentialCall = - diffBuilder.createApply(loc, differential, SubstitutionMap(), diffArgs, - /*isNonThrowing*/ false); - diffBuilder.emitDestroyValueOperation(loc, differential); - - // Get the original `apply` results. - SmallVector origDirectResults; - forEachApplyDirectResult(ai, [&](SILValue directResult) { - origDirectResults.push_back(directResult); - }); - SmallVector origAllResults; - collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); - - // Get the callee differential `apply` results. - SmallVector differentialDirectResults; - extractAllElements(differentialCall, getDifferentialBuilder(), - differentialDirectResults); - SmallVector differentialAllResults; - collectAllActualResultsInTypeOrder( - differentialCall, differentialDirectResults, differentialAllResults); - assert(applyIndices.results->getNumIndices() == - differentialAllResults.size()); - - // Set tangent values for original `apply` results. - unsigned differentialResultIndex = 0; - for (auto resultIndex : applyIndices.results->getIndices()) { - auto origResult = origAllResults[resultIndex]; - auto differentialResult = differentialAllResults[differentialResultIndex++]; - if (origResult->getType().isObject()) { - if (!origResult->getType().is()) { - setTangentValue(bb, origResult, - makeConcreteTangentValue(differentialResult)); - } else if (auto *dti = getSingleDestructureTupleUser(ai)) { - bool notSetValue = true; - for (auto result : dti->getResults()) { - if (activityInfo.isActive(result, getIndices())) { - assert(notSetValue && - "This was incorrectly set, should only have one active " - "result from the tuple."); - notSetValue = false; - setTangentValue(bb, result, - makeConcreteTangentValue(differentialResult)); + // Call the differential. + auto *differentialCall = + diffBuilder.createApply(loc, differential, SubstitutionMap(), diffArgs, + /*isNonThrowing*/ false); + diffBuilder.emitDestroyValueOperation(loc, differential); + + // Get the original `apply` results. + SmallVector origDirectResults; + forEachApplyDirectResult(ai, [&](SILValue directResult) { + origDirectResults.push_back(directResult); + }); + SmallVector origAllResults; + collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); + + // Get the callee differential `apply` results. + SmallVector differentialDirectResults; + extractAllElements(differentialCall, getDifferentialBuilder(), + differentialDirectResults); + SmallVector differentialAllResults; + collectAllActualResultsInTypeOrder( + differentialCall, differentialDirectResults, differentialAllResults); + assert(applyIndices.results->getNumIndices() == + differentialAllResults.size()); + + // Set tangent values for original `apply` results. + unsigned differentialResultIndex = 0; + for (auto resultIndex : applyIndices.results->getIndices()) { + auto origResult = origAllResults[resultIndex]; + auto differentialResult = + differentialAllResults[differentialResultIndex++]; + if (origResult->getType().isObject()) { + if (!origResult->getType().is()) { + setTangentValue(bb, origResult, + makeConcreteTangentValue(differentialResult)); + } else if (auto *dti = getSingleDestructureTupleUser(ai)) { + bool notSetValue = true; + for (auto result : dti->getResults()) { + if (activityInfo.isActive(result, getIndices())) { + assert(notSetValue && + "This was incorrectly set, should only have one active " + "result from the tuple."); + notSetValue = false; + setTangentValue(bb, result, + makeConcreteTangentValue(differentialResult)); + } } } } } } + + /// Generate a `return` instruction in the current differential basic block. + void emitReturnInstForDifferential() { + auto &differential = getDifferential(); + auto diffLoc = differential.getLocation(); + auto &diffBuilder = getDifferentialBuilder(); + + // Collect original results. + SmallVector originalResults; + collectAllDirectResultsInTypeOrder(*original, originalResults); + // Collect differential return elements. + SmallVector retElts; + // for (auto origResult : originalResults) { + for (auto i : range(originalResults.size())) { + auto origResult = originalResults[i]; + if (!getIndices().results->contains(i)) + continue; + auto tanVal = materializeTangent(getTangentValue(origResult), diffLoc); + retElts.push_back(tanVal); + } + + diffBuilder.createReturn(diffLoc, + joinElements(retElts, diffBuilder, diffLoc)); + } +}; + +//--------------------------------------------------------------------------// +// Initialization +//--------------------------------------------------------------------------// + +/// Initialization helper function. +/// +/// Returns the substitution map used for type remapping. +static SubstitutionMap getSubstitutionMap(SILFunction *original, + SILFunction *jvp) { + auto substMap = original->getForwardingSubstitutionMap(); + if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { + auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); + substMap = SubstitutionMap::get( + jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, + LookUpConformanceInSubstitutionMap(jvpSubstMap)); + } + return substMap; } -/// Generate a `return` instruction in the current differential basic block. -void JVPCloner::emitReturnInstForDifferential() { - auto &differential = getDifferential(); - auto diffLoc = differential.getLocation(); - auto &diffBuilder = getDifferentialBuilder(); - - // Collect original results. - SmallVector originalResults; - collectAllDirectResultsInTypeOrder(*original, originalResults); - // Collect differential return elements. - SmallVector retElts; - // for (auto origResult : originalResults) { - for (auto i : range(originalResults.size())) { - auto origResult = originalResults[i]; - if (!getIndices().results->contains(i)) - continue; - auto tanVal = materializeTangent(getTangentValue(origResult), diffLoc); - retElts.push_back(tanVal); +/// Initialization helper function. +/// +/// Returns the activity info for the given original function, autodiff indices, +/// and JVP generic signature. +static const DifferentiableActivityInfo & +getActivityInfo(ADContext &context, SILFunction *original, + SILAutoDiffIndices indices, SILFunction *jvp) { + // Get activity info of the original function. + auto &passManager = context.getPassManager(); + auto *activityAnalysis = + passManager.getAnalysis(); + auto &activityCollection = *activityAnalysis->get(original); + auto &activityInfo = activityCollection.getActivityInfo( + jvp->getLoweredFunctionType()->getSubstGenericSignature(), + AutoDiffDerivativeFunctionKind::JVP); + LLVM_DEBUG(activityInfo.dump(indices, getADDebugStream())); + return activityInfo; +} + +JVPCloner::Implementation::Implementation(ADContext &context, + SILFunction *original, + SILDifferentiabilityWitness *witness, + SILFunction *jvp, + DifferentiationInvoker invoker) + : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)), + context(context), original(original), witness(witness), jvp(jvp), + invoker(invoker), + activityInfo(getActivityInfo(context, original, + witness->getSILAutoDiffIndices(), jvp)), + differentialInfo(context, AutoDiffLinearMapKind::Differential, original, + jvp, witness->getSILAutoDiffIndices(), activityInfo), + differentialBuilder(SILBuilder( + *createEmptyDifferential(context, witness, &differentialInfo))), + diffLocalAllocBuilder(getDifferential()) { + // Create empty differential function. + context.recordGeneratedFunction(&getDifferential()); +} + +JVPCloner::JVPCloner(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, SILFunction *jvp, + DifferentiationInvoker invoker) + : impl(*new Implementation(context, original, witness, jvp, invoker)) {} + +JVPCloner::~JVPCloner() { delete &impl; } + +//--------------------------------------------------------------------------// +// Differential struct mapping +//--------------------------------------------------------------------------// + +void JVPCloner::Implementation::initializeDifferentialStructElements( + SILBasicBlock *origBB, SILInstructionResultArray values) { + auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB); + assert(diffStructDecl->getStoredProperties().size() == values.size() && + "The number of differential struct fields must equal the number of " + "differential struct element values"); + for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) { + assert(std::get<1>(pair).getOwnershipKind() != + ValueOwnershipKind::Guaranteed && + "Differential struct elements must be @owned"); + auto insertion = differentialStructElements.insert( + {std::get<0>(pair), std::get<1>(pair)}); + (void)insertion; + assert(insertion.second && + "A differential struct element mapping already exists!"); } +} - diffBuilder.createReturn(diffLoc, - joinElements(retElts, diffBuilder, diffLoc)); +SILValue +JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB, + VarDecl *field) { + assert(differentialInfo.getLinearMapStruct(origBB) == + cast(field->getDeclContext())); + assert(differentialStructElements.count(field) && + "Differential struct element for this field does not exist!"); + return differentialStructElements.lookup(field); } -void JVPCloner::prepareForDifferentialGeneration() { +//--------------------------------------------------------------------------// +// Tangent emission helpers +//--------------------------------------------------------------------------// + +void JVPCloner::Implementation::prepareForDifferentialGeneration() { // Create differential blocks and arguments. auto &differential = getDifferential(); auto *origEntry = original->getEntryBlock(); @@ -957,10 +1487,9 @@ void JVPCloner::prepareForDifferentialGeneration() { setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]); } -/*static*/ SILFunction * -JVPCloner::createEmptyDifferential(ADContext &context, - SILDifferentiabilityWitness *witness, - LinearMapInfo *linearMapInfo) { +/*static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential( + ADContext &context, SILDifferentiabilityWitness *witness, + LinearMapInfo *linearMapInfo) { auto &module = context.getModule(); auto *original = witness->getOriginalFunction(); auto *jvp = witness->getJVP(); @@ -1067,8 +1596,7 @@ JVPCloner::createEmptyDifferential(ADContext &context, return differential; } -/// Run JVP generation. Returns true on error. -bool JVPCloner::run() { +bool JVPCloner::Implementation::run() { PrettyStackTraceSILFunction trace("generating JVP and differential for", original); LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() @@ -1094,363 +1622,7 @@ bool JVPCloner::run() { return errorOccurred; } -void JVPCloner::postProcess(SILInstruction *orig, SILInstruction *cloned) { - if (errorOccurred) - return; - SILClonerWithScopes::postProcess(orig, cloned); -} - -/// Remap original basic blocks. -SILBasicBlock *JVPCloner::remapBasicBlock(SILBasicBlock *bb) { - auto *jvpBB = BBMap[bb]; - return jvpBB; -} - -void JVPCloner::visit(SILInstruction *inst) { - if (errorOccurred) - return; - if (differentialInfo.shouldDifferentiateInstruction(inst)) { - LLVM_DEBUG(getADDebugStream() << "JVPCloner visited:\n[ORIG]" << *inst); -#ifndef NDEBUG - auto diffBuilder = getDifferentialBuilder(); - auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); -#endif - TypeSubstCloner::visit(inst); - LLVM_DEBUG({ - auto &s = llvm::dbgs() << "[TAN] Emitted in differential:\n"; - auto afterInsertion = diffBuilder.getInsertionPoint(); - for (auto it = ++beforeInsertion; it != afterInsertion; ++it) - s << *it; - }); - } else { - TypeSubstCloner::visit(inst); - } -} - -void JVPCloner::visitSILInstruction(SILInstruction *inst) { - context.emitNondifferentiabilityError( - inst, invoker, diag::autodiff_expression_not_differentiable_note); - errorOccurred = true; -} - -void JVPCloner::visitInstructionsInBlock(SILBasicBlock *bb) { - // Destructure the differential struct to get the elements. - auto &diffBuilder = getDifferentialBuilder(); - auto diffLoc = getDifferential().getLocation(); - auto *diffBB = diffBBMap.lookup(bb); - auto *mainDifferentialStruct = diffBB->getArguments().back(); - diffBuilder.setInsertionPoint(diffBB); - auto *dsi = - diffBuilder.createDestructureStruct(diffLoc, mainDifferentialStruct); - initializeDifferentialStructElements(bb, dsi->getResults()); - TypeSubstCloner::visitInstructionsInBlock(bb); -} - -// If an `apply` has active results or active inout parameters, replace it -// with an `apply` of its JVP. -void JVPCloner::visitApplyInst(ApplyInst *ai) { - // If the function should not be differentiated or its the array literal - // initialization intrinsic, just do standard cloning. - if (!differentialInfo.shouldDifferentiateApplySite(ai) || - ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { - LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); - TypeSubstCloner::visitApplyInst(ai); - return; - } - - // Diagnose functions with active inout arguments. - // TODO(TF-129): Support `inout` argument differentiation. - for (auto inoutArg : ai->getInoutArguments()) { - if (activityInfo.isActive(inoutArg, getIndices())) { - context.emitNondifferentiabilityError( - ai, invoker, - diag::autodiff_cannot_differentiate_through_inout_arguments); - errorOccurred = true; - return; - } - } - - auto loc = ai->getLoc(); - auto &builder = getBuilder(); - auto origCallee = getOpValue(ai->getCallee()); - auto originalFnTy = origCallee->getType().castTo(); - - LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); - - // Get the minimal parameter and result indices required for differentiating - // this `apply`. - SmallVector allResults; - SmallVector activeParamIndices; - SmallVector activeResultIndices; - collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo, - allResults, activeParamIndices, - activeResultIndices); - assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); - assert(!activeResultIndices.empty() && "Result indices cannot be empty"); - LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; - llvm::interleave( - activeParamIndices.begin(), activeParamIndices.end(), - [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}, results={"; llvm::interleave( - activeResultIndices.begin(), activeResultIndices.end(), - [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}\n";); - - // Form expected indices. - auto numResults = - ai->getSubstCalleeType()->getNumResults() + - ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); - SILAutoDiffIndices indices( - IndexSubset::get(getASTContext(), - ai->getArgumentsWithoutIndirectResults().size(), - activeParamIndices), - IndexSubset::get(getASTContext(), numResults, activeResultIndices)); - - // Emit the JVP. - SILValue jvpValue; - // If functionSource is a `@differentiable` function, just extract it. - if (originalFnTy->isDifferentiable()) { - auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); - for (auto i : indices.parameters->getIndices()) { - if (!paramIndices->contains(i)) { - context.emitNondifferentiabilityError( - origCallee, invoker, - diag::autodiff_function_noderivative_parameter_not_differentiable); - errorOccurred = true; - return; - } - } - auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee); - jvpValue = builder.createDifferentiableFunctionExtract( - loc, NormalDifferentiableFunctionTypeComponent::JVP, borrowedDiffFunc); - jvpValue = builder.emitCopyValueOperation(loc, jvpValue); - } - - // If JVP has not yet been found, emit an `differentiable_function` - // instruction on the remapped function operand and - // an `differentiable_function_extract` instruction to get the JVP. - // The `differentiable_function` instruction will be canonicalized during - // the transform main loop. - if (!jvpValue) { - // FIXME: Handle indirect differentiation invokers. This may require some - // redesign: currently, each original function + witness pair is mapped - // only to one invoker. - /* - DifferentiationInvoker indirect(ai, attr); - auto insertion = - context.getInvokers().try_emplace({original, attr}, indirect); - auto &invoker = insertion.first->getSecond(); - invoker = indirect; - */ - - // If the original `apply` instruction has a substitution map, then the - // applied function is specialized. - // In the JVP, specialization is also necessary for parity. The original - // function operand is specialized with a remapped version of same - // substitution map using an argument-less `partial_apply`. - if (ai->getSubstitutionMap().empty()) { - origCallee = builder.emitCopyValueOperation(loc, origCallee); - } else { - auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); - auto jvpPartialApply = getBuilder().createPartialApply( - ai->getLoc(), origCallee, substMap, {}, - ParameterConvention::Direct_Guaranteed); - origCallee = jvpPartialApply; - } - - // Check and diagnose non-differentiable original function type. - auto diagnoseNondifferentiableOriginalFunctionType = - [&](CanSILFunctionType origFnTy) { - // Check and diagnose non-differentiable arguments. - for (auto paramIndex : indices.parameters->getIndices()) { - if (!originalFnTy->getParameters()[paramIndex] - .getSILStorageInterfaceType() - .isDifferentiable(getModule())) { - context.emitNondifferentiabilityError( - ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker, - diag::autodiff_nondifferentiable_argument); - errorOccurred = true; - return true; - } - } - // Check and diagnose non-differentiable results. - for (auto resultIndex : indices.results->getIndices()) { - SILType remappedResultType; - if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); - } else { - remappedResultType = originalFnTy->getResults()[resultIndex] - .getSILStorageInterfaceType(); - } - if (!remappedResultType.isDifferentiable(getModule())) { - context.emitNondifferentiabilityError( - origCallee, invoker, diag::autodiff_nondifferentiable_result); - errorOccurred = true; - return true; - } - } - return false; - }; - if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) - return; - - auto *diffFuncInst = context.createDifferentiableFunction( - builder, loc, indices.parameters, indices.results, origCallee); - - // Record the `differentiable_function` instruction. - context.addDifferentiableFunctionInstToWorklist(diffFuncInst); - - auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst); - auto extractedJVP = builder.createDifferentiableFunctionExtract( - loc, NormalDifferentiableFunctionTypeComponent::JVP, borrowedADFunc); - jvpValue = builder.emitCopyValueOperation(loc, extractedJVP); - builder.emitEndBorrowOperation(loc, borrowedADFunc); - builder.emitDestroyValueOperation(loc, diffFuncInst); - } - - // Call the JVP using the original parameters. - SmallVector jvpArgs; - auto jvpFnTy = getOpType(jvpValue->getType()).castTo(); - auto numJVPArgs = - jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults(); - jvpArgs.reserve(numJVPArgs); - // Collect substituted arguments. - for (auto origArg : ai->getArguments()) - jvpArgs.push_back(getOpValue(origArg)); - assert(jvpArgs.size() == numJVPArgs); - // Apply the JVP. - // The JVP should be specialized, so no substitution map is necessary. - auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(), - jvpArgs, ai->isNonThrowing()); - LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall); - - // Release the differentiable function. - builder.emitDestroyValueOperation(loc, jvpValue); - - // Get the JVP results (original results and differential). - SmallVector jvpDirectResults; - extractAllElements(jvpCall, builder, jvpDirectResults); - auto originalDirectResults = - ArrayRef(jvpDirectResults).drop_back(1); - auto originalDirectResult = - joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc()); - - mapValue(ai, originalDirectResult); - - // Some instructions that produce the callee may have been cloned. - // If the original callee did not have any users beyond this `apply`, - // recursively kill the cloned callee. - if (auto *origCallee = cast_or_null( - ai->getCallee()->getDefiningInstruction())) - if (origCallee->hasOneUse()) - recursivelyDeleteTriviallyDeadInstructions( - getOpValue(origCallee)->getDefiningInstruction()); - - // Add the differential function for when we create the struct we partially - // apply to the differential we are generating. - auto differential = jvpDirectResults.back(); - auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai); - auto originalDifferentialType = - getOpType(differential->getType()).getAs(); - auto loweredDifferentialType = - getOpType(getLoweredType(differentialDecl->getInterfaceType())) - .castTo(); - // If actual differential type does not match lowered differential type, - // reabstract the differential using a thunk. - if (!loweredDifferentialType->isEqual(originalDifferentialType)) { - SILOptFunctionBuilder fb(context.getTransform()); - differential = reabstractFunction( - builder, fb, loc, differential, loweredDifferentialType, - [this](SubstitutionMap subs) -> SubstitutionMap { - return this->getOpSubstitutionMap(subs); - }); - } - differentialValues[ai->getParent()].push_back(differential); - - // Differential emission. - emitTangentForApplyInst(ai, indices, originalDifferentialType); -} - -void JVPCloner::visitReturnInst(ReturnInst *ri) { - auto loc = ri->getOperand().getLoc(); - auto *origExit = ri->getParent(); - auto &builder = getBuilder(); - auto *diffStructVal = buildDifferentialValueStructValue(ri); - - // Get the JVP value corresponding to the original functions's return value. - auto *origRetInst = cast(origExit->getTerminator()); - auto origResult = getOpValue(origRetInst->getOperand()); - SmallVector origResults; - extractAllElements(origResult, builder, origResults); - - // Get and partially apply the differential. - auto jvpGenericEnv = jvp->getGenericEnvironment(); - auto jvpSubstMap = jvpGenericEnv - ? jvpGenericEnv->getForwardingSubstitutionMap() - : jvp->getForwardingSubstitutionMap(); - auto *differentialRef = builder.createFunctionRef(loc, &getDifferential()); - auto *differentialPartialApply = builder.createPartialApply( - loc, differentialRef, jvpSubstMap, {diffStructVal}, - ParameterConvention::Direct_Guaranteed); - - auto differentialType = jvp->getLoweredFunctionType() - ->getResults() - .back() - .getSILStorageInterfaceType(); - differentialType = differentialType.substGenericArgs( - getModule(), jvpSubstMap, TypeExpansionContext::minimal()); - differentialType = differentialType.subst(getModule(), jvpSubstMap); - auto differentialFnType = differentialType.castTo(); - - auto differentialSubstType = - differentialPartialApply->getType().castTo(); - SILValue differentialValue; - if (differentialSubstType == differentialFnType) { - differentialValue = differentialPartialApply; - } else if (differentialSubstType - ->isABICompatibleWith(differentialFnType, *jvp) - .isCompatible()) { - differentialValue = builder.createConvertFunction( - loc, differentialPartialApply, differentialType, - /*withoutActuallyEscaping*/ false); - } else { - // When `diag::autodiff_loadable_value_addressonly_tangent_unsupported` - // applies, the return type may be ABI-incomaptible with the type of the - // partially applied differential. In these cases, produce an undef and rely - // on other code to emit a diagnostic. - differentialValue = SILUndef::get(differentialType, *jvp); - } - - // Return a tuple of the original result and differential. - SmallVector directResults; - directResults.append(origResults.begin(), origResults.end()); - directResults.push_back(differentialValue); - builder.createReturn(ri->getLoc(), joinElements(directResults, builder, loc)); -} - -void JVPCloner::visitBranchInst(BranchInst *bi) { - llvm_unreachable("Unsupported SIL instruction."); -} - -void JVPCloner::visitCondBranchInst(CondBranchInst *cbi) { - llvm_unreachable("Unsupported SIL instruction."); -} - -void JVPCloner::visitSwitchEnumInst(SwitchEnumInst *sei) { - llvm_unreachable("Unsupported SIL instruction."); -} - -void JVPCloner::visitDifferentiableFunctionInst( - DifferentiableFunctionInst *dfi) { - // Clone `differentiable_function` from original to JVP, then add the cloned - // instruction to the `differentiable_function` worklist. - TypeSubstCloner::visitDifferentiableFunctionInst(dfi); - auto *newDFI = cast(getOpValue(dfi)); - context.addDifferentiableFunctionInstToWorklist(newDFI); -} +bool JVPCloner::run() { return impl.run(); } } // end namespace autodiff } // end namespace swift diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 73f5efba6666b..cc2f80b8c44ae 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -18,7 +18,11 @@ #define DEBUG_TYPE "differentiation" #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" +#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" +#include "swift/SILOptimizer/Differentiation/AdjointValue.h" +#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" +#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" #include "swift/SILOptimizer/Differentiation/VJPCloner.h" @@ -27,8 +31,10 @@ #include "swift/AST/TypeCheckRequests.h" #include "swift/SIL/InstructionUtils.h" #include "swift/SIL/Projection.h" +#include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" +#include "llvm/ADT/DenseMap.h" namespace swift { @@ -49,8 +55,12 @@ class VJPCloner; /// instructions per basic block in reverse order. This visitation order is /// necessary for generating pullback functions, whose control flow graph is /// ~a transposed version of the original function's control flow graph. -struct PullbackCloner::Implementation final +class PullbackCloner::Implementation final : public SILInstructionVisitor { + +public: + explicit Implementation(VJPCloner &vjpCloner); + private: /// The parent VJP cloner. VJPCloner &vjpCloner; @@ -124,23 +134,21 @@ struct PullbackCloner::Implementation final bool errorOccurred = false; - ADContext &getContext() const { return vjpCloner.context; } + ADContext &getContext() const { return vjpCloner.getContext(); } SILModule &getModule() const { return getContext().getModule(); } ASTContext &getASTContext() const { return getPullback().getASTContext(); } - SILFunction &getOriginal() const { return *vjpCloner.original; } - SILFunction &getPullback() const { return *vjpCloner.pullback; } - SILDifferentiabilityWitness *getWitness() const { return vjpCloner.witness; } - DifferentiationInvoker getInvoker() const { return vjpCloner.invoker; } - LinearMapInfo &getPullbackInfo() { return vjpCloner.pullbackInfo; } + SILFunction &getOriginal() const { return vjpCloner.getOriginal(); } + SILFunction &getPullback() const { return vjpCloner.getPullback(); } + SILDifferentiabilityWitness *getWitness() const { + return vjpCloner.getWitness(); + } + DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); } + LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); } const SILAutoDiffIndices getIndices() const { return vjpCloner.getIndices(); } const DifferentiableActivityInfo &getActivityInfo() const { - return vjpCloner.activityInfo; + return vjpCloner.getActivityInfo(); } -public: - explicit Implementation(VJPCloner &vjpCloner); - -private: //--------------------------------------------------------------------------// // Pullback struct mapping //--------------------------------------------------------------------------// @@ -1520,9 +1528,10 @@ PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner) auto *domAnalysis = passManager.getAnalysis(); auto *postDomAnalysis = passManager.getAnalysis(); auto *postOrderAnalysis = passManager.getAnalysis(); - domInfo = domAnalysis->get(vjpCloner.original); - postDomInfo = postDomAnalysis->get(vjpCloner.original); - postOrderInfo = postOrderAnalysis->get(vjpCloner.original); + auto *original = &vjpCloner.getOriginal(); + domInfo = domAnalysis->get(original); + postDomInfo = postDomAnalysis->get(original); + postOrderInfo = postOrderAnalysis->get(original); } PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index cdaa189bf382f..933b706baff57 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -18,20 +18,620 @@ #define DEBUG_TYPE "differentiation" #include "swift/SILOptimizer/Differentiation/VJPCloner.h" +#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" +#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" +#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" +#include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" #include "swift/SILOptimizer/Utils/CFGOptUtils.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" +#include "llvm/ADT/DenseMap.h" namespace swift { namespace autodiff { -/*static*/ -SubstitutionMap VJPCloner::getSubstitutionMap(SILFunction *original, - SILFunction *vjp) { +class VJPCloner::Implementation final + : public TypeSubstCloner { + friend class VJPCloner; + friend class PullbackCloner; + + /// The parent VJP cloner. + VJPCloner &cloner; + + /// The global context. + ADContext &context; + + /// The original function. + SILFunction *const original; + + /// The differentiability witness. + SILDifferentiabilityWitness *const witness; + + /// The VJP function. + SILFunction *const vjp; + + /// The pullback function. + SILFunction *pullback; + + /// The differentiation invoker. + DifferentiationInvoker invoker; + + /// Info from activity analysis on the original function. + const DifferentiableActivityInfo &activityInfo; + + /// The linear map info. + LinearMapInfo pullbackInfo; + + /// Caches basic blocks whose phi arguments have been remapped (adding a + /// predecessor enum argument). + SmallPtrSet remappedBasicBlocks; + + bool errorOccurred = false; + + /// Mapping from original blocks to pullback values. Used to build pullback + /// struct instances. + llvm::DenseMap> pullbackValues; + + ASTContext &getASTContext() const { return vjp->getASTContext(); } + SILModule &getModule() const { return vjp->getModule(); } + const SILAutoDiffIndices getIndices() const { + return witness->getSILAutoDiffIndices(); + } + + Implementation(VJPCloner &parent, ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, SILFunction *vjp, + DifferentiationInvoker invoker); + + /// Creates an empty pullback function, to be filled in by `PullbackCloner`. + SILFunction *createEmptyPullback(); + + /// Run VJP generation. Returns true on error. + bool run(); + + /// Get the lowered SIL type of the given AST type. + SILType getLoweredType(Type type) { + auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); + Lowering::AbstractionPattern pattern(vjpGenSig, + type->getCanonicalType(vjpGenSig)); + return vjp->getLoweredType(pattern, type); + } + + /// Get the lowered SIL type of the given nominal type declaration. + SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { + auto nominalType = + getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); + return getLoweredType(nominalType); + } + + // Creates a trampoline block for given original terminator instruction, the + // pullback struct value for its parent block, and a successor basic block. + // + // The trampoline block has the same arguments as and branches to the remapped + // successor block, but drops the last predecessor enum argument. + // + // Used for cloning branching terminator instructions with specific + // requirements on successor block arguments, where an additional predecessor + // enum argument is not acceptable. + SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst, + StructInst *pbStructVal, + SILBasicBlock *succBB); + + /// Build a pullback struct value for the given original terminator + /// instruction. + StructInst *buildPullbackValueStructValue(TermInst *termInst); + + /// Build a predecessor enum instance using the given builder for the given + /// original predecessor/successor blocks and pullback struct value. + EnumInst *buildPredecessorEnumValue(SILBuilder &builder, + SILBasicBlock *predBB, + SILBasicBlock *succBB, + SILValue pbStructVal); + +public: + /// Remap original basic blocks, adding predecessor enum arguments. + SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { + auto *vjpBB = BBMap[bb]; + // If error has occurred, or if block has already been remapped, return + // remapped, return remapped block. + if (errorOccurred || remappedBasicBlocks.count(bb)) + return vjpBB; + // Add predecessor enum argument to the remapped block. + auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb); + auto enumTy = + getOpASTType(predEnum->getDeclaredInterfaceType()->getCanonicalType()); + auto enumLoweredTy = context.getTypeConverter().getLoweredType( + enumTy, TypeExpansionContext::minimal()); + vjpBB->createPhiArgument(enumLoweredTy, ValueOwnershipKind::Owned); + remappedBasicBlocks.insert(bb); + return vjpBB; + } + + /// General visitor for all instructions. If any error is emitted by previous + /// visits, bail out. + void visit(SILInstruction *inst) { + if (errorOccurred) + return; + TypeSubstCloner::visit(inst); + } + + void visitSILInstruction(SILInstruction *inst) { + context.emitNondifferentiabilityError( + inst, invoker, diag::autodiff_expression_not_differentiable_note); + errorOccurred = true; + } + + void postProcess(SILInstruction *orig, SILInstruction *cloned) { + if (errorOccurred) + return; + SILClonerWithScopes::postProcess(orig, cloned); + } + + void visitReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); + auto &builder = getBuilder(); + + // Build pullback struct value for original block. + auto *origExit = ri->getParent(); + auto *pbStructVal = buildPullbackValueStructValue(ri); + + // Get the value in the VJP corresponding to the original result. + auto *origRetInst = cast(origExit->getTerminator()); + auto origResult = getOpValue(origRetInst->getOperand()); + SmallVector origResults; + extractAllElements(origResult, builder, origResults); + + // Get and partially apply the pullback. + auto vjpGenericEnv = vjp->getGenericEnvironment(); + auto vjpSubstMap = vjpGenericEnv + ? vjpGenericEnv->getForwardingSubstitutionMap() + : vjp->getForwardingSubstitutionMap(); + auto *pullbackRef = builder.createFunctionRef(loc, pullback); + auto *pullbackPartialApply = + builder.createPartialApply(loc, pullbackRef, vjpSubstMap, {pbStructVal}, + ParameterConvention::Direct_Guaranteed); + auto pullbackType = vjp->getLoweredFunctionType() + ->getResults() + .back() + .getSILStorageInterfaceType(); + pullbackType = pullbackType.substGenericArgs( + getModule(), vjpSubstMap, TypeExpansionContext::minimal()); + pullbackType = pullbackType.subst(getModule(), vjpSubstMap); + auto pullbackFnType = pullbackType.castTo(); + + auto pullbackSubstType = + pullbackPartialApply->getType().castTo(); + SILValue pullbackValue; + if (pullbackSubstType == pullbackFnType) { + pullbackValue = pullbackPartialApply; + } else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp) + .isCompatible()) { + pullbackValue = + builder.createConvertFunction(loc, pullbackPartialApply, pullbackType, + /*withoutActuallyEscaping*/ false); + } else { + // When `diag::autodiff_loadable_value_addressonly_tangent_unsupported` + // applies, the return type may be ABI-incomaptible with the type of the + // partially applied pullback. In these cases, produce an undef and rely + // on other code to emit a diagnostic. + pullbackValue = SILUndef::get(pullbackType, *vjp); + } + + // Return a tuple of the original result and pullback. + SmallVector directResults; + directResults.append(origResults.begin(), origResults.end()); + directResults.push_back(pullbackValue); + builder.createReturn(ri->getLoc(), + joinElements(directResults, builder, loc)); + } + + void visitBranchInst(BranchInst *bi) { + // Build pullback struct value for original block. + // Build predecessor enum value for destination block. + auto *origBB = bi->getParent(); + auto *pbStructVal = buildPullbackValueStructValue(bi); + auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB, + bi->getDestBB(), pbStructVal); + + // Remap arguments, appending the new enum values. + SmallVector args; + for (auto origArg : bi->getArgs()) + args.push_back(getOpValue(origArg)); + args.push_back(enumVal); + + // Create a new `br` instruction. + getBuilder().createBranch(bi->getLoc(), getOpBasicBlock(bi->getDestBB()), + args); + } + + void visitCondBranchInst(CondBranchInst *cbi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(cbi); + // Create a new `cond_br` instruction. + getBuilder().createCondBranch( + cbi->getLoc(), getOpValue(cbi->getCondition()), + createTrampolineBasicBlock(cbi, pbStructVal, cbi->getTrueBB()), + createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB())); + } + + void visitSwitchEnumInstBase(SwitchEnumInstBase *inst) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(inst); + + // Create trampoline successor basic blocks. + SmallVector, 4> caseBBs; + for (unsigned i : range(inst->getNumCases())) { + auto caseBB = inst->getCase(i); + auto *trampolineBB = + createTrampolineBasicBlock(inst, pbStructVal, caseBB.second); + caseBBs.push_back({caseBB.first, trampolineBB}); + } + // Create trampoline default basic block. + SILBasicBlock *newDefaultBB = nullptr; + if (auto *defaultBB = inst->getDefaultBBOrNull().getPtrOrNull()) + newDefaultBB = createTrampolineBasicBlock(inst, pbStructVal, defaultBB); + + // Create a new `switch_enum` instruction. + switch (inst->getKind()) { + case SILInstructionKind::SwitchEnumInst: + getBuilder().createSwitchEnum(inst->getLoc(), + getOpValue(inst->getOperand()), + newDefaultBB, caseBBs); + break; + case SILInstructionKind::SwitchEnumAddrInst: + getBuilder().createSwitchEnumAddr(inst->getLoc(), + getOpValue(inst->getOperand()), + newDefaultBB, caseBBs); + break; + default: + llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`"); + } + } + + void visitSwitchEnumInst(SwitchEnumInst *sei) { + visitSwitchEnumInstBase(sei); + } + + void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { + visitSwitchEnumInstBase(seai); + } + + void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccbi); + // Create a new `checked_cast_branch` instruction. + getBuilder().createCheckedCastBranch( + ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()), + getOpType(ccbi->getTargetLoweredType()), + getOpASTType(ccbi->getTargetFormalType()), + createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getSuccessBB()), + createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getFailureBB()), + ccbi->getTrueBBCount(), ccbi->getFalseBBCount()); + } + + void visitCheckedCastValueBranchInst(CheckedCastValueBranchInst *ccvbi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccvbi); + // Create a new `checked_cast_value_branch` instruction. + getBuilder().createCheckedCastValueBranch( + ccvbi->getLoc(), getOpValue(ccvbi->getOperand()), + getOpASTType(ccvbi->getSourceFormalType()), + getOpType(ccvbi->getTargetLoweredType()), + getOpASTType(ccvbi->getTargetFormalType()), + createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getSuccessBB()), + createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB())); + } + + void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi) { + // Build pullback struct value for original block. + auto *pbStructVal = buildPullbackValueStructValue(ccabi); + // Create a new `checked_cast_addr_branch` instruction. + getBuilder().createCheckedCastAddrBranch( + ccabi->getLoc(), ccabi->getConsumptionKind(), + getOpValue(ccabi->getSrc()), getOpASTType(ccabi->getSourceFormalType()), + getOpValue(ccabi->getDest()), + getOpASTType(ccabi->getTargetFormalType()), + createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()), + createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()), + ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); + } + + // If an `apply` has active results or active inout arguments, replace it + // with an `apply` of its VJP. + void visitApplyInst(ApplyInst *ai) { + // If callee should not be differentiated, do standard cloning. + if (!pullbackInfo.shouldDifferentiateApplySite(ai)) { + LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); + TypeSubstCloner::visitApplyInst(ai); + return; + } + // If callee is `array.uninitialized_intrinsic`, do standard cloning. + // `array.unininitialized_intrinsic` differentiation is handled separately. + if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { + LLVM_DEBUG(getADDebugStream() + << "Cloning `array.unininitialized_intrinsic` `apply`:\n" + << *ai << '\n'); + TypeSubstCloner::visitApplyInst(ai); + return; + } + // If callee is `array.finalize_intrinsic`, do standard cloning. + // `array.finalize_intrinsic` has special-case pullback generation. + if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { + LLVM_DEBUG(getADDebugStream() + << "Cloning `array.finalize_intrinsic` `apply`:\n" + << *ai << '\n'); + TypeSubstCloner::visitApplyInst(ai); + return; + } + // If the original function is a semantic member accessor, do standard + // cloning. Semantic member accessors have special pullback generation + // logic, so all `apply` instructions can be directly cloned to the VJP. + if (isSemanticMemberAccessor(original)) { + LLVM_DEBUG(getADDebugStream() + << "Cloning `apply` in semantic member accessor:\n" + << *ai << '\n'); + TypeSubstCloner::visitApplyInst(ai); + return; + } + + auto loc = ai->getLoc(); + auto &builder = getBuilder(); + auto origCallee = getOpValue(ai->getCallee()); + auto originalFnTy = origCallee->getType().castTo(); + + LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n'); + + // Get the minimal parameter and result indices required for differentiating + // this `apply`. + SmallVector allResults; + SmallVector activeParamIndices; + SmallVector activeResultIndices; + collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo, + allResults, activeParamIndices, + activeResultIndices); + assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); + assert(!activeResultIndices.empty() && "Result indices cannot be empty"); + LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; + llvm::interleave( + activeParamIndices.begin(), activeParamIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << "}, results={"; llvm::interleave( + activeResultIndices.begin(), activeResultIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << "}\n";); + + // Form expected indices. + auto numSemanticResults = + ai->getSubstCalleeType()->getNumResults() + + ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); + SILAutoDiffIndices indices( + IndexSubset::get(getASTContext(), + ai->getArgumentsWithoutIndirectResults().size(), + activeParamIndices), + IndexSubset::get(getASTContext(), numSemanticResults, + activeResultIndices)); + + // Emit the VJP. + SILValue vjpValue; + // If functionSource is a `@differentiable` function, just extract it. + if (originalFnTy->isDifferentiable()) { + auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); + for (auto i : indices.parameters->getIndices()) { + if (!paramIndices->contains(i)) { + context.emitNondifferentiabilityError( + origCallee, invoker, + diag:: + autodiff_function_noderivative_parameter_not_differentiable); + errorOccurred = true; + return; + } + } + auto origFnType = origCallee->getType().castTo(); + auto origFnUnsubstType = origFnType->getUnsubstitutedType(getModule()); + if (origFnType != origFnUnsubstType) { + origCallee = builder.createConvertFunction( + loc, origCallee, SILType::getPrimitiveObjectType(origFnUnsubstType), + /*withoutActuallyEscaping*/ false); + } + auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee); + vjpValue = builder.createDifferentiableFunctionExtract( + loc, NormalDifferentiableFunctionTypeComponent::VJP, + borrowedDiffFunc); + vjpValue = builder.emitCopyValueOperation(loc, vjpValue); + auto vjpFnType = vjpValue->getType().castTo(); + auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule()); + if (vjpFnType != vjpFnUnsubstType) { + vjpValue = builder.createConvertFunction( + loc, vjpValue, SILType::getPrimitiveObjectType(vjpFnUnsubstType), + /*withoutActuallyEscaping*/ false); + } + } + + // Check and diagnose non-differentiable original function type. + auto diagnoseNondifferentiableOriginalFunctionType = + [&](CanSILFunctionType origFnTy) { + // Check and diagnose non-differentiable arguments. + for (auto paramIndex : indices.parameters->getIndices()) { + if (!originalFnTy->getParameters()[paramIndex] + .getSILStorageInterfaceType() + .isDifferentiable(getModule())) { + context.emitNondifferentiabilityError( + ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker, + diag::autodiff_nondifferentiable_argument); + errorOccurred = true; + return true; + } + } + // Check and diagnose non-differentiable results. + for (auto resultIndex : indices.results->getIndices()) { + SILType remappedResultType; + if (resultIndex >= originalFnTy->getNumResults()) { + auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); + auto inoutArg = + *std::next(ai->getInoutArguments().begin(), inoutArgIdx); + remappedResultType = inoutArg->getType(); + } else { + remappedResultType = originalFnTy->getResults()[resultIndex] + .getSILStorageInterfaceType(); + } + if (!remappedResultType.isDifferentiable(getModule())) { + context.emitNondifferentiabilityError( + origCallee, invoker, diag::autodiff_nondifferentiable_result); + errorOccurred = true; + return true; + } + } + return false; + }; + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) + return; + + // If VJP has not yet been found, emit an `differentiable_function` + // instruction on the remapped original function operand and + // an `differentiable_function_extract` instruction to get the VJP. + // The `differentiable_function` instruction will be canonicalized during + // the transform main loop. + if (!vjpValue) { + // FIXME: Handle indirect differentiation invokers. This may require some + // redesign: currently, each original function + witness pair is mapped + // only to one invoker. + /* + DifferentiationInvoker indirect(ai, attr); + auto insertion = + context.getInvokers().try_emplace({original, attr}, indirect); + auto &invoker = insertion.first->getSecond(); + invoker = indirect; + */ + + // If the original `apply` instruction has a substitution map, then the + // applied function is specialized. + // In the VJP, specialization is also necessary for parity. The original + // function operand is specialized with a remapped version of same + // substitution map using an argument-less `partial_apply`. + if (ai->getSubstitutionMap().empty()) { + origCallee = builder.emitCopyValueOperation(loc, origCallee); + } else { + auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); + auto vjpPartialApply = getBuilder().createPartialApply( + ai->getLoc(), origCallee, substMap, {}, + ParameterConvention::Direct_Guaranteed); + origCallee = vjpPartialApply; + originalFnTy = origCallee->getType().castTo(); + // Diagnose if new original function type is non-differentiable. + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) + return; + } + + auto *diffFuncInst = context.createDifferentiableFunction( + getBuilder(), loc, indices.parameters, indices.results, origCallee); + + // Record the `differentiable_function` instruction. + context.addDifferentiableFunctionInstToWorklist(diffFuncInst); + + auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst); + auto extractedVJP = getBuilder().createDifferentiableFunctionExtract( + loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc); + vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); + builder.emitEndBorrowOperation(loc, borrowedADFunc); + builder.emitDestroyValueOperation(loc, diffFuncInst); + } + + // Record desired/actual VJP indices. + // Temporarily set original pullback type to `None`. + NestedApplyInfo info{indices, /*originalPullbackType*/ None}; + auto insertion = context.getNestedApplyInfo().try_emplace(ai, info); + auto &nestedApplyInfo = insertion.first->getSecond(); + nestedApplyInfo = info; + + // Call the VJP using the original parameters. + SmallVector vjpArgs; + auto vjpFnTy = getOpType(vjpValue->getType()).castTo(); + auto numVJPArgs = + vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults(); + vjpArgs.reserve(numVJPArgs); + // Collect substituted arguments. + for (auto origArg : ai->getArguments()) + vjpArgs.push_back(getOpValue(origArg)); + assert(vjpArgs.size() == numVJPArgs); + // Apply the VJP. + // The VJP should be specialized, so no substitution map is necessary. + auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(), + vjpArgs, ai->isNonThrowing()); + LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); + builder.emitDestroyValueOperation(loc, vjpValue); + + // Get the VJP results (original results and pullback). + SmallVector vjpDirectResults; + extractAllElements(vjpCall, getBuilder(), vjpDirectResults); + ArrayRef originalDirectResults = + ArrayRef(vjpDirectResults).drop_back(1); + SILValue originalDirectResult = + joinElements(originalDirectResults, getBuilder(), vjpCall->getLoc()); + SILValue pullback = vjpDirectResults.back(); + { + auto pullbackFnType = pullback->getType().castTo(); + auto pullbackUnsubstFnType = + pullbackFnType->getUnsubstitutedType(getModule()); + if (pullbackFnType != pullbackUnsubstFnType) { + pullback = builder.createConvertFunction( + loc, pullback, + SILType::getPrimitiveObjectType(pullbackUnsubstFnType), + /*withoutActuallyEscaping*/ false); + } + } + + // Store the original result to the value map. + mapValue(ai, originalDirectResult); + + // Checkpoint the pullback. + auto *pullbackDecl = pullbackInfo.lookUpLinearMapDecl(ai); + + // If actual pullback type does not match lowered pullback type, reabstract + // the pullback using a thunk. + auto actualPullbackType = + getOpType(pullback->getType()).getAs(); + auto loweredPullbackType = + getOpType(getLoweredType(pullbackDecl->getInterfaceType())) + .castTo(); + if (!loweredPullbackType->isEqual(actualPullbackType)) { + // Set non-reabstracted original pullback type in nested apply info. + nestedApplyInfo.originalPullbackType = actualPullbackType; + SILOptFunctionBuilder fb(context.getTransform()); + pullback = reabstractFunction( + getBuilder(), fb, ai->getLoc(), pullback, loweredPullbackType, + [this](SubstitutionMap subs) -> SubstitutionMap { + return this->getOpSubstitutionMap(subs); + }); + } + pullbackValues[ai->getParent()].push_back(pullback); + + // Some instructions that produce the callee may have been cloned. + // If the original callee did not have any users beyond this `apply`, + // recursively kill the cloned callee. + if (auto *origCallee = cast_or_null( + ai->getCallee()->getDefiningInstruction())) + if (origCallee->hasOneUse()) + recursivelyDeleteTriviallyDeadInstructions( + getOpValue(origCallee)->getDefiningInstruction()); + } + + void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { + // Clone `differentiable_function` from original to VJP, then add the cloned + // instruction to the `differentiable_function` worklist. + TypeSubstCloner::visitDifferentiableFunctionInst(dfi); + auto *newDFI = cast(getOpValue(dfi)); + context.addDifferentiableFunctionInstToWorklist(newDFI); + } +}; + +/// Initialization helper function. +/// +/// Returns the substitution map used for type remapping. +static SubstitutionMap getSubstitutionMap(SILFunction *original, + SILFunction *vjp) { auto substMap = original->getForwardingSubstitutionMap(); if (auto *vjpGenEnv = vjp->getGenericEnvironment()) { auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap(); @@ -42,10 +642,13 @@ SubstitutionMap VJPCloner::getSubstitutionMap(SILFunction *original, return substMap; } -/*static*/ -const DifferentiableActivityInfo & -VJPCloner::getActivityInfo(ADContext &context, SILFunction *original, - SILAutoDiffIndices indices, SILFunction *vjp) { +/// Initialization helper function. +/// +/// Returns the activity info for the given original function, autodiff indices, +/// and VJP generic signature. +static const DifferentiableActivityInfo & +getActivityInfoHelper(ADContext &context, SILFunction *original, + SILAutoDiffIndices indices, SILFunction *vjp) { // Get activity info of the original function. auto &passManager = context.getPassManager(); auto *activityAnalysis = @@ -58,14 +661,16 @@ VJPCloner::getActivityInfo(ADContext &context, SILFunction *original, return activityInfo; } -VJPCloner::VJPCloner(ADContext &context, SILFunction *original, - SILDifferentiabilityWitness *witness, SILFunction *vjp, - DifferentiationInvoker invoker) +VJPCloner::Implementation::Implementation(VJPCloner &cloner, ADContext &context, + SILFunction *original, + SILDifferentiabilityWitness *witness, + SILFunction *vjp, + DifferentiationInvoker invoker) : TypeSubstCloner(*vjp, *original, getSubstitutionMap(original, vjp)), - context(context), original(original), witness(witness), vjp(vjp), - invoker(invoker), - activityInfo(getActivityInfo(context, original, - witness->getSILAutoDiffIndices(), vjp)), + cloner(cloner), context(context), original(original), witness(witness), + vjp(vjp), invoker(invoker), + activityInfo(getActivityInfoHelper( + context, original, witness->getSILAutoDiffIndices(), vjp)), pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp, witness->getSILAutoDiffIndices(), activityInfo) { // Create empty pullback function. @@ -73,7 +678,32 @@ VJPCloner::VJPCloner(ADContext &context, SILFunction *original, context.recordGeneratedFunction(pullback); } -SILFunction *VJPCloner::createEmptyPullback() { +VJPCloner::VJPCloner(ADContext &context, SILFunction *original, + SILDifferentiabilityWitness *witness, SILFunction *vjp, + DifferentiationInvoker invoker) + : impl(*new Implementation(*this, context, original, witness, vjp, + invoker)) {} + +VJPCloner::~VJPCloner() { delete &impl; } + +ADContext &VJPCloner::getContext() const { return impl.context; } +SILModule &VJPCloner::getModule() const { return impl.getModule(); } +SILFunction &VJPCloner::getOriginal() const { return *impl.original; } +SILFunction &VJPCloner::getVJP() const { return *impl.vjp; } +SILFunction &VJPCloner::getPullback() const { return *impl.pullback; } +SILDifferentiabilityWitness *VJPCloner::getWitness() const { + return impl.witness; +} +const SILAutoDiffIndices VJPCloner::getIndices() const { + return impl.getIndices(); +} +DifferentiationInvoker VJPCloner::getInvoker() const { return impl.invoker; } +LinearMapInfo &VJPCloner::getPullbackInfo() const { return impl.pullbackInfo; } +const DifferentiableActivityInfo &VJPCloner::getActivityInfo() const { + return impl.activityInfo; +} + +SILFunction *VJPCloner::Implementation::createEmptyPullback() { auto &module = context.getModule(); auto origTy = original->getLoweredFunctionType(); // Get witness generic signature for remapping types. @@ -259,32 +889,8 @@ SILFunction *VJPCloner::createEmptyPullback() { return pullback; } -void VJPCloner::postProcess(SILInstruction *orig, SILInstruction *cloned) { - if (errorOccurred) - return; - SILClonerWithScopes::postProcess(orig, cloned); -} - -SILBasicBlock *VJPCloner::remapBasicBlock(SILBasicBlock *bb) { - auto *vjpBB = BBMap[bb]; - // If error has occurred, or if block has already been remapped, return - // remapped, return remapped block. - if (errorOccurred || remappedBasicBlocks.count(bb)) - return vjpBB; - // Add predecessor enum argument to the remapped block. - auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb); - auto enumTy = - getOpASTType(predEnum->getDeclaredInterfaceType()->getCanonicalType()); - auto enumLoweredTy = context.getTypeConverter().getLoweredType( - enumTy, TypeExpansionContext::minimal()); - vjpBB->createPhiArgument(enumLoweredTy, ValueOwnershipKind::Owned); - remappedBasicBlocks.insert(bb); - return vjpBB; -} - -SILBasicBlock *VJPCloner::createTrampolineBasicBlock(TermInst *termInst, - StructInst *pbStructVal, - SILBasicBlock *succBB) { +SILBasicBlock *VJPCloner::Implementation::createTrampolineBasicBlock( + TermInst *termInst, StructInst *pbStructVal, SILBasicBlock *succBB) { assert(llvm::find(termInst->getSuccessorBlocks(), succBB) != termInst->getSuccessorBlocks().end() && "Basic block is not a successor of terminator instruction"); @@ -307,32 +913,8 @@ SILBasicBlock *VJPCloner::createTrampolineBasicBlock(TermInst *termInst, return trampolineBB; } -void VJPCloner::visit(SILInstruction *inst) { - if (errorOccurred) - return; - TypeSubstCloner::visit(inst); -} - -void VJPCloner::visitSILInstruction(SILInstruction *inst) { - context.emitNondifferentiabilityError( - inst, invoker, diag::autodiff_expression_not_differentiable_note); - errorOccurred = true; -} - -SILType VJPCloner::getLoweredType(Type type) { - auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); - Lowering::AbstractionPattern pattern(vjpGenSig, - type->getCanonicalType(vjpGenSig)); - return vjp->getLoweredType(pattern, type); -} - -SILType VJPCloner::getNominalDeclLoweredType(NominalTypeDecl *nominal) { - auto nominalType = - getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); - return getLoweredType(nominalType); -} - -StructInst *VJPCloner::buildPullbackValueStructValue(TermInst *termInst) { +StructInst * +VJPCloner::Implementation::buildPullbackValueStructValue(TermInst *termInst) { assert(termInst->getFunction() == original); auto loc = RegularLocation::getAutoGeneratedLocation(); auto origBB = termInst->getParent(); @@ -348,10 +930,9 @@ StructInst *VJPCloner::buildPullbackValueStructValue(TermInst *termInst) { return getBuilder().createStruct(loc, structLoweredTy, bbPullbackValues); } -EnumInst *VJPCloner::buildPredecessorEnumValue(SILBuilder &builder, - SILBasicBlock *predBB, - SILBasicBlock *succBB, - SILValue pbStructVal) { +EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue( + SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB, + SILValue pbStructVal) { auto loc = RegularLocation::getAutoGeneratedLocation(); auto *succEnum = pullbackInfo.getBranchingTraceDecl(succBB); auto enumLoweredTy = getNominalDeclLoweredType(succEnum); @@ -374,457 +955,7 @@ EnumInst *VJPCloner::buildPredecessorEnumValue(SILBuilder &builder, return builder.createEnum(loc, newBox, enumEltDecl, enumLoweredTy); } -void VJPCloner::visitReturnInst(ReturnInst *ri) { - auto loc = ri->getOperand().getLoc(); - auto &builder = getBuilder(); - - // Build pullback struct value for original block. - auto *origExit = ri->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(ri); - - // Get the value in the VJP corresponding to the original result. - auto *origRetInst = cast(origExit->getTerminator()); - auto origResult = getOpValue(origRetInst->getOperand()); - SmallVector origResults; - extractAllElements(origResult, builder, origResults); - - // Get and partially apply the pullback. - auto vjpGenericEnv = vjp->getGenericEnvironment(); - auto vjpSubstMap = vjpGenericEnv - ? vjpGenericEnv->getForwardingSubstitutionMap() - : vjp->getForwardingSubstitutionMap(); - auto *pullbackRef = builder.createFunctionRef(loc, pullback); - auto *pullbackPartialApply = - builder.createPartialApply(loc, pullbackRef, vjpSubstMap, {pbStructVal}, - ParameterConvention::Direct_Guaranteed); - auto pullbackType = vjp->getLoweredFunctionType() - ->getResults() - .back() - .getSILStorageInterfaceType(); - pullbackType = pullbackType.substGenericArgs(getModule(), vjpSubstMap, - TypeExpansionContext::minimal()); - pullbackType = pullbackType.subst(getModule(), vjpSubstMap); - auto pullbackFnType = pullbackType.castTo(); - - auto pullbackSubstType = - pullbackPartialApply->getType().castTo(); - SILValue pullbackValue; - if (pullbackSubstType == pullbackFnType) { - pullbackValue = pullbackPartialApply; - } else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp) - .isCompatible()) { - pullbackValue = - builder.createConvertFunction(loc, pullbackPartialApply, pullbackType, - /*withoutActuallyEscaping*/ false); - } else { - // When `diag::autodiff_loadable_value_addressonly_tangent_unsupported` - // applies, the return type may be ABI-incomaptible with the type of the - // partially applied pullback. In these cases, produce an undef and rely on - // other code to emit a diagnostic. - pullbackValue = SILUndef::get(pullbackType, *vjp); - } - - // Return a tuple of the original result and pullback. - SmallVector directResults; - directResults.append(origResults.begin(), origResults.end()); - directResults.push_back(pullbackValue); - builder.createReturn(ri->getLoc(), joinElements(directResults, builder, loc)); -} - -void VJPCloner::visitBranchInst(BranchInst *bi) { - // Build pullback struct value for original block. - // Build predecessor enum value for destination block. - auto *origBB = bi->getParent(); - auto *pbStructVal = buildPullbackValueStructValue(bi); - auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB, - bi->getDestBB(), pbStructVal); - - // Remap arguments, appending the new enum values. - SmallVector args; - for (auto origArg : bi->getArgs()) - args.push_back(getOpValue(origArg)); - args.push_back(enumVal); - - // Create a new `br` instruction. - getBuilder().createBranch(bi->getLoc(), getOpBasicBlock(bi->getDestBB()), - args); -} - -void VJPCloner::visitCondBranchInst(CondBranchInst *cbi) { - // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(cbi); - // Create a new `cond_br` instruction. - getBuilder().createCondBranch( - cbi->getLoc(), getOpValue(cbi->getCondition()), - createTrampolineBasicBlock(cbi, pbStructVal, cbi->getTrueBB()), - createTrampolineBasicBlock(cbi, pbStructVal, cbi->getFalseBB())); -} - -void VJPCloner::visitSwitchEnumInstBase(SwitchEnumInstBase *sei) { - // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(sei); - - // Create trampoline successor basic blocks. - SmallVector, 4> caseBBs; - for (unsigned i : range(sei->getNumCases())) { - auto caseBB = sei->getCase(i); - auto *trampolineBB = - createTrampolineBasicBlock(sei, pbStructVal, caseBB.second); - caseBBs.push_back({caseBB.first, trampolineBB}); - } - // Create trampoline default basic block. - SILBasicBlock *newDefaultBB = nullptr; - if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull()) - newDefaultBB = createTrampolineBasicBlock(sei, pbStructVal, defaultBB); - - // Create a new `switch_enum` instruction. - switch (sei->getKind()) { - case SILInstructionKind::SwitchEnumInst: - getBuilder().createSwitchEnum(sei->getLoc(), getOpValue(sei->getOperand()), - newDefaultBB, caseBBs); - break; - case SILInstructionKind::SwitchEnumAddrInst: - getBuilder().createSwitchEnumAddr( - sei->getLoc(), getOpValue(sei->getOperand()), newDefaultBB, caseBBs); - break; - default: - llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`"); - } -} - -void VJPCloner::visitSwitchEnumInst(SwitchEnumInst *sei) { - visitSwitchEnumInstBase(sei); -} - -void VJPCloner::visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { - visitSwitchEnumInstBase(seai); -} - -void VJPCloner::visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { - // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(ccbi); - // Create a new `checked_cast_branch` instruction. - getBuilder().createCheckedCastBranch( - ccbi->getLoc(), ccbi->isExact(), getOpValue(ccbi->getOperand()), - getOpType(ccbi->getTargetLoweredType()), - getOpASTType(ccbi->getTargetFormalType()), - createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getSuccessBB()), - createTrampolineBasicBlock(ccbi, pbStructVal, ccbi->getFailureBB()), - ccbi->getTrueBBCount(), ccbi->getFalseBBCount()); -} - -void VJPCloner::visitCheckedCastValueBranchInst( - CheckedCastValueBranchInst *ccvbi) { - // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(ccvbi); - // Create a new `checked_cast_value_branch` instruction. - getBuilder().createCheckedCastValueBranch( - ccvbi->getLoc(), getOpValue(ccvbi->getOperand()), - getOpASTType(ccvbi->getSourceFormalType()), - getOpType(ccvbi->getTargetLoweredType()), - getOpASTType(ccvbi->getTargetFormalType()), - createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getSuccessBB()), - createTrampolineBasicBlock(ccvbi, pbStructVal, ccvbi->getFailureBB())); -} - -void VJPCloner::visitCheckedCastAddrBranchInst( - CheckedCastAddrBranchInst *ccabi) { - // Build pullback struct value for original block. - auto *pbStructVal = buildPullbackValueStructValue(ccabi); - // Create a new `checked_cast_addr_branch` instruction. - getBuilder().createCheckedCastAddrBranch( - ccabi->getLoc(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()), - getOpASTType(ccabi->getSourceFormalType()), getOpValue(ccabi->getDest()), - getOpASTType(ccabi->getTargetFormalType()), - createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()), - createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()), - ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); -} - -void VJPCloner::visitApplyInst(ApplyInst *ai) { - // If callee should not be differentiated, do standard cloning. - if (!pullbackInfo.shouldDifferentiateApplySite(ai)) { - LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); - TypeSubstCloner::visitApplyInst(ai); - return; - } - // If callee is `array.uninitialized_intrinsic`, do standard cloning. - // `array.unininitialized_intrinsic` differentiation is handled separately. - if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { - LLVM_DEBUG(getADDebugStream() - << "Cloning `array.unininitialized_intrinsic` `apply`:\n" - << *ai << '\n'); - TypeSubstCloner::visitApplyInst(ai); - return; - } - // If callee is `array.finalize_intrinsic`, do standard cloning. - // `array.finalize_intrinsic` has special-case pullback generation. - if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { - LLVM_DEBUG(getADDebugStream() - << "Cloning `array.finalize_intrinsic` `apply`:\n" - << *ai << '\n'); - TypeSubstCloner::visitApplyInst(ai); - return; - } - // If the original function is a semantic member accessor, do standard - // cloning. Semantic member accessors have special pullback generation logic, - // so all `apply` instructions can be directly cloned to the VJP. - if (isSemanticMemberAccessor(original)) { - LLVM_DEBUG(getADDebugStream() - << "Cloning `apply` in semantic member accessor:\n" - << *ai << '\n'); - TypeSubstCloner::visitApplyInst(ai); - return; - } - - auto loc = ai->getLoc(); - auto &builder = getBuilder(); - auto origCallee = getOpValue(ai->getCallee()); - auto originalFnTy = origCallee->getType().castTo(); - - LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n'); - - // Get the minimal parameter and result indices required for differentiating - // this `apply`. - SmallVector allResults; - SmallVector activeParamIndices; - SmallVector activeResultIndices; - collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo, - allResults, activeParamIndices, - activeResultIndices); - assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); - assert(!activeResultIndices.empty() && "Result indices cannot be empty"); - LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; - llvm::interleave( - activeParamIndices.begin(), activeParamIndices.end(), - [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}, results={"; llvm::interleave( - activeResultIndices.begin(), activeResultIndices.end(), - [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}\n";); - - // Form expected indices. - auto numSemanticResults = - ai->getSubstCalleeType()->getNumResults() + - ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); - SILAutoDiffIndices indices( - IndexSubset::get(getASTContext(), - ai->getArgumentsWithoutIndirectResults().size(), - activeParamIndices), - IndexSubset::get(getASTContext(), numSemanticResults, - activeResultIndices)); - - // Emit the VJP. - SILValue vjpValue; - // If functionSource is a `@differentiable` function, just extract it. - if (originalFnTy->isDifferentiable()) { - auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); - for (auto i : indices.parameters->getIndices()) { - if (!paramIndices->contains(i)) { - context.emitNondifferentiabilityError( - origCallee, invoker, - diag::autodiff_function_noderivative_parameter_not_differentiable); - errorOccurred = true; - return; - } - } - auto origFnType = origCallee->getType().castTo(); - auto origFnUnsubstType = origFnType->getUnsubstitutedType(getModule()); - if (origFnType != origFnUnsubstType) { - origCallee = builder.createConvertFunction( - loc, origCallee, SILType::getPrimitiveObjectType(origFnUnsubstType), - /*withoutActuallyEscaping*/ false); - } - auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, origCallee); - vjpValue = builder.createDifferentiableFunctionExtract( - loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedDiffFunc); - vjpValue = builder.emitCopyValueOperation(loc, vjpValue); - auto vjpFnType = vjpValue->getType().castTo(); - auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule()); - if (vjpFnType != vjpFnUnsubstType) { - vjpValue = builder.createConvertFunction( - loc, vjpValue, SILType::getPrimitiveObjectType(vjpFnUnsubstType), - /*withoutActuallyEscaping*/ false); - } - } - - // Check and diagnose non-differentiable original function type. - auto diagnoseNondifferentiableOriginalFunctionType = - [&](CanSILFunctionType origFnTy) { - // Check and diagnose non-differentiable arguments. - for (auto paramIndex : indices.parameters->getIndices()) { - if (!originalFnTy->getParameters()[paramIndex] - .getSILStorageInterfaceType() - .isDifferentiable(getModule())) { - context.emitNondifferentiabilityError( - ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker, - diag::autodiff_nondifferentiable_argument); - errorOccurred = true; - return true; - } - } - // Check and diagnose non-differentiable results. - for (auto resultIndex : indices.results->getIndices()) { - SILType remappedResultType; - if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); - } else { - remappedResultType = originalFnTy->getResults()[resultIndex] - .getSILStorageInterfaceType(); - } - if (!remappedResultType.isDifferentiable(getModule())) { - context.emitNondifferentiabilityError( - origCallee, invoker, diag::autodiff_nondifferentiable_result); - errorOccurred = true; - return true; - } - } - return false; - }; - if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) - return; - - // If VJP has not yet been found, emit an `differentiable_function` - // instruction on the remapped original function operand and - // an `differentiable_function_extract` instruction to get the VJP. - // The `differentiable_function` instruction will be canonicalized during - // the transform main loop. - if (!vjpValue) { - // FIXME: Handle indirect differentiation invokers. This may require some - // redesign: currently, each original function + witness pair is mapped - // only to one invoker. - /* - DifferentiationInvoker indirect(ai, attr); - auto insertion = - context.getInvokers().try_emplace({original, attr}, indirect); - auto &invoker = insertion.first->getSecond(); - invoker = indirect; - */ - - // If the original `apply` instruction has a substitution map, then the - // applied function is specialized. - // In the VJP, specialization is also necessary for parity. The original - // function operand is specialized with a remapped version of same - // substitution map using an argument-less `partial_apply`. - if (ai->getSubstitutionMap().empty()) { - origCallee = builder.emitCopyValueOperation(loc, origCallee); - } else { - auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); - auto vjpPartialApply = getBuilder().createPartialApply( - ai->getLoc(), origCallee, substMap, {}, - ParameterConvention::Direct_Guaranteed); - origCallee = vjpPartialApply; - originalFnTy = origCallee->getType().castTo(); - // Diagnose if new original function type is non-differentiable. - if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) - return; - } - - auto *diffFuncInst = context.createDifferentiableFunction( - getBuilder(), loc, indices.parameters, indices.results, origCallee); - - // Record the `differentiable_function` instruction. - context.addDifferentiableFunctionInstToWorklist(diffFuncInst); - - auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst); - auto extractedVJP = getBuilder().createDifferentiableFunctionExtract( - loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc); - vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); - builder.emitEndBorrowOperation(loc, borrowedADFunc); - builder.emitDestroyValueOperation(loc, diffFuncInst); - } - - // Record desired/actual VJP indices. - // Temporarily set original pullback type to `None`. - NestedApplyInfo info{indices, /*originalPullbackType*/ None}; - auto insertion = context.getNestedApplyInfo().try_emplace(ai, info); - auto &nestedApplyInfo = insertion.first->getSecond(); - nestedApplyInfo = info; - - // Call the VJP using the original parameters. - SmallVector vjpArgs; - auto vjpFnTy = getOpType(vjpValue->getType()).castTo(); - auto numVJPArgs = - vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults(); - vjpArgs.reserve(numVJPArgs); - // Collect substituted arguments. - for (auto origArg : ai->getArguments()) - vjpArgs.push_back(getOpValue(origArg)); - assert(vjpArgs.size() == numVJPArgs); - // Apply the VJP. - // The VJP should be specialized, so no substitution map is necessary. - auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(), - vjpArgs, ai->isNonThrowing()); - LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); - builder.emitDestroyValueOperation(loc, vjpValue); - - // Get the VJP results (original results and pullback). - SmallVector vjpDirectResults; - extractAllElements(vjpCall, getBuilder(), vjpDirectResults); - ArrayRef originalDirectResults = - ArrayRef(vjpDirectResults).drop_back(1); - SILValue originalDirectResult = - joinElements(originalDirectResults, getBuilder(), vjpCall->getLoc()); - SILValue pullback = vjpDirectResults.back(); - { - auto pullbackFnType = pullback->getType().castTo(); - auto pullbackUnsubstFnType = - pullbackFnType->getUnsubstitutedType(getModule()); - if (pullbackFnType != pullbackUnsubstFnType) { - pullback = builder.createConvertFunction( - loc, pullback, SILType::getPrimitiveObjectType(pullbackUnsubstFnType), - /*withoutActuallyEscaping*/ false); - } - } - - // Store the original result to the value map. - mapValue(ai, originalDirectResult); - - // Checkpoint the pullback. - auto *pullbackDecl = pullbackInfo.lookUpLinearMapDecl(ai); - - // If actual pullback type does not match lowered pullback type, reabstract - // the pullback using a thunk. - auto actualPullbackType = - getOpType(pullback->getType()).getAs(); - auto loweredPullbackType = - getOpType(getLoweredType(pullbackDecl->getInterfaceType())) - .castTo(); - if (!loweredPullbackType->isEqual(actualPullbackType)) { - // Set non-reabstracted original pullback type in nested apply info. - nestedApplyInfo.originalPullbackType = actualPullbackType; - SILOptFunctionBuilder fb(context.getTransform()); - pullback = reabstractFunction( - getBuilder(), fb, ai->getLoc(), pullback, loweredPullbackType, - [this](SubstitutionMap subs) -> SubstitutionMap { - return this->getOpSubstitutionMap(subs); - }); - } - pullbackValues[ai->getParent()].push_back(pullback); - - // Some instructions that produce the callee may have been cloned. - // If the original callee did not have any users beyond this `apply`, - // recursively kill the cloned callee. - if (auto *origCallee = cast_or_null( - ai->getCallee()->getDefiningInstruction())) - if (origCallee->hasOneUse()) - recursivelyDeleteTriviallyDeadInstructions( - getOpValue(origCallee)->getDefiningInstruction()); -} - -void VJPCloner::visitDifferentiableFunctionInst( - DifferentiableFunctionInst *dfi) { - // Clone `differentiable_function` from original to VJP, then add the cloned - // instruction to the `differentiable_function` worklist. - TypeSubstCloner::visitDifferentiableFunctionInst(dfi); - auto *newDFI = cast(getOpValue(dfi)); - context.addDifferentiableFunctionInstToWorklist(newDFI); -} - -bool VJPCloner::run() { +bool VJPCloner::Implementation::run() { PrettyStackTraceSILFunction trace("generating VJP for", original); LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() << " to vjp @" << vjp->getName() << '\n'); @@ -853,7 +984,7 @@ bool VJPCloner::run() { << *vjp); // Generate pullback code. - PullbackCloner PullbackCloner(*this); + PullbackCloner PullbackCloner(cloner); if (PullbackCloner.run()) { errorOccurred = true; return true; @@ -861,5 +992,7 @@ bool VJPCloner::run() { return errorOccurred; } +bool VJPCloner::run() { return impl.run(); } + } // end namespace autodiff } // end namespace swift diff --git a/lib/SILOptimizer/IPO/EagerSpecializer.cpp b/lib/SILOptimizer/IPO/EagerSpecializer.cpp index 7f47b83c5fcd9..f647a161b1164 100644 --- a/lib/SILOptimizer/IPO/EagerSpecializer.cpp +++ b/lib/SILOptimizer/IPO/EagerSpecializer.cpp @@ -36,8 +36,9 @@ #include "swift/AST/Type.h" #include "swift/SIL/SILFunction.h" #include "swift/SILOptimizer/PassManager/Transforms.h" -#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" +#include "swift/SILOptimizer/Utils/CFGOptUtils.h" #include "swift/SILOptimizer/Utils/Generics.h" +#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "llvm/Support/Debug.h" using namespace swift; @@ -47,6 +48,16 @@ llvm::cl::opt EagerSpecializeFlag( "enable-eager-specializer", llvm::cl::init(true), llvm::cl::desc("Run the eager-specializer pass.")); +static void +cleanupCallArguments(SILBuilder &builder, SILLocation loc, + ArrayRef values, + ArrayRef valueIndicesThatNeedEndBorrow) { + for (int index : valueIndicesThatNeedEndBorrow) { + auto *lbi = cast(values[index]); + builder.createEndBorrow(loc, lbi); + } +} + /// Returns true if the given return or throw block can be used as a merge point /// for new return or error values. static bool isTrivialReturnBlock(SILBasicBlock *RetBB) { @@ -127,12 +138,16 @@ static void addReturnValueImpl(SILBasicBlock *RetBB, SILBasicBlock *NewRetBB, Builder.createBranch(Loc, MergedBB, {OldRetVal}); } } + // Create a CFG edge from NewRetBB to MergedBB. Builder.setInsertionPoint(NewRetBB); SmallVector BBArgs; if (!NewRetVal->getType().isVoid()) BBArgs.push_back(NewRetVal); Builder.createBranch(Loc, MergedBB, BBArgs); + + // Then split any critical edges we created to the merged block. + splitCriticalEdgesTo(MergedBB); } /// Adds a CFG edge from the unterminated NewRetBB to a merged "return" block. @@ -154,14 +169,10 @@ static void addThrowValue(SILBasicBlock *NewThrowBB, SILValue NewErrorVal) { /// /// TODO: Move this to Utils. static SILValue -emitApplyWithRethrow(SILBuilder &Builder, - SILLocation Loc, - SILValue FuncRef, - CanSILFunctionType CanSILFuncTy, - SubstitutionMap Subs, +emitApplyWithRethrow(SILBuilder &Builder, SILLocation Loc, SILValue FuncRef, + CanSILFunctionType CanSILFuncTy, SubstitutionMap Subs, ArrayRef CallArgs, - void (*EmitCleanup)(SILBuilder&, SILLocation)) { - + ArrayRef CallArgIndicesThatNeedEndBorrow) { auto &F = Builder.getFunction(); SILFunctionConventions fnConv(CanSILFuncTy, Builder.getModule()); @@ -176,30 +187,31 @@ emitApplyWithRethrow(SILBuilder &Builder, SILValue Error = ErrorBB->createPhiArgument( fnConv.getSILErrorType(F.getTypeExpansionContext()), ValueOwnershipKind::Owned); - - EmitCleanup(Builder, Loc); + cleanupCallArguments(Builder, Loc, CallArgs, + CallArgIndicesThatNeedEndBorrow); addThrowValue(ErrorBB, Error); } + // Advance Builder to the fall-thru path and return a SILArgument holding the // result value. Builder.clearInsertionPoint(); Builder.emitBlock(NormalBB); - return Builder.getInsertionBB()->createPhiArgument( + SILValue finalArgument = Builder.getInsertionBB()->createPhiArgument( fnConv.getSILResultType(F.getTypeExpansionContext()), ValueOwnershipKind::Owned); + cleanupCallArguments(Builder, Loc, CallArgs, CallArgIndicesThatNeedEndBorrow); + return finalArgument; } /// Emits code to invoke the specified specialized CalleeFunc using the /// provided SILBuilder. /// /// TODO: Move this to Utils. -static SILValue -emitInvocation(SILBuilder &Builder, - const ReabstractionInfo &ReInfo, - SILLocation Loc, - SILFunction *CalleeFunc, - ArrayRef CallArgs, - void (*EmitCleanup)(SILBuilder&, SILLocation)) { +static SILValue emitInvocation(SILBuilder &Builder, + const ReabstractionInfo &ReInfo, SILLocation Loc, + SILFunction *CalleeFunc, + ArrayRef CallArgs, + ArrayRef ArgsNeedEndBorrow) { auto *FuncRefInst = Builder.createFunctionRef(Loc, CalleeFunc); auto CanSILFuncTy = CalleeFunc->getLoweredFunctionType(); @@ -256,14 +268,15 @@ emitInvocation(SILBuilder &Builder, // or de-facto? if (!CanSILFuncTy->hasErrorResult() || CalleeFunc->findThrowBB() == CalleeFunc->end()) { - return Builder.createApply(CalleeFunc->getLocation(), FuncRefInst, - Subs, CallArgs, isNonThrowing); + auto *AI = Builder.createApply(CalleeFunc->getLocation(), FuncRefInst, Subs, + CallArgs, isNonThrowing); + cleanupCallArguments(Builder, Loc, CallArgs, ArgsNeedEndBorrow); + return AI; } - return emitApplyWithRethrow(Builder, CalleeFunc->getLocation(), - FuncRefInst, CalleeSubstFnTy, Subs, - CallArgs, - EmitCleanup); + return emitApplyWithRethrow(Builder, CalleeFunc->getLocation(), FuncRefInst, + CalleeSubstFnTy, Subs, CallArgs, + ArgsNeedEndBorrow); } /// Returns the thick metatype for the given SILType. @@ -323,8 +336,11 @@ class EagerDispatch { SILValue emitArgumentCast(CanSILFunctionType CalleeSubstFnTy, SILFunctionArgument *OrigArg, unsigned Idx); - SILValue emitArgumentConversion(SmallVectorImpl &CallArgs); + SILValue + emitArgumentConversion(SmallVectorImpl &CallArgs, + SmallVectorImpl &ArgAtIndexNeedsEndBorrow); }; + } // end anonymous namespace /// Inserts type checks in the original generic function for dispatching to the @@ -335,6 +351,7 @@ void EagerDispatch::emitDispatchTo(SILFunction *NewFunc) { auto ReturnBB = GenericFunc->findReturnBB(); if (ReturnBB != GenericFunc->end()) OldReturnBB = &*ReturnBB; + // 1. Emit a cascading sequence of type checks blocks. // First split the entry BB, moving all instructions to the FailedTypeCheckBB. @@ -384,23 +401,24 @@ void EagerDispatch::emitDispatchTo(SILFunction *NewFunc) { // 2. Convert call arguments, casting and adjusting for calling convention. SmallVector CallArgs; - SILValue StoreResultTo = emitArgumentConversion(CallArgs); + SmallVector ArgAtIndexNeedsEndBorrow; + SILValue StoreResultTo = + emitArgumentConversion(CallArgs, ArgAtIndexNeedsEndBorrow); // 3. Emit an invocation of the specialized function. // Emit any rethrow with no cleanup since all args have been forwarded and // nothing has been locally allocated or copied. - auto NoCleanup = [](SILBuilder&, SILLocation){}; - SILValue Result = - emitInvocation(Builder, ReInfo, Loc, NewFunc, CallArgs, NoCleanup); + SILValue Result = emitInvocation(Builder, ReInfo, Loc, NewFunc, CallArgs, + ArgAtIndexNeedsEndBorrow); // 4. Handle the return value. auto VoidTy = Builder.getModule().Types.getEmptyTupleType(); if (StoreResultTo) { // Store the direct result to the original result address. - Builder.createStore(Loc, Result, StoreResultTo, - StoreOwnershipQualifier::Unqualified); + Builder.emitStoreValueOperation(Loc, Result, StoreResultTo, + StoreOwnershipQualifier::Init); // And return Void. Result = Builder.createTuple(Loc, VoidTy, { }); } @@ -416,7 +434,10 @@ void EagerDispatch::emitDispatchTo(SILFunction *NewFunc) { auto resultTy = GenericFunc->getConventions().getSILResultType( Builder.getTypeExpansionContext()); auto GenResultTy = GenericFunc->mapTypeIntoContext(resultTy); - auto CastResult = Builder.createUncheckedBitCast(Loc, Result, GenResultTy); + + SILValue CastResult = + Builder.createUncheckedBitCast(Loc, Result, GenResultTy); + addReturnValue(Builder.getInsertionBB(), OldReturnBB, CastResult); } } @@ -620,8 +641,9 @@ SILValue EagerDispatch::emitArgumentCast(CanSILFunctionType CalleeSubstFnTy, /// /// Returns the SILValue to store the result into if the specialized function /// has a direct result. -SILValue EagerDispatch:: -emitArgumentConversion(SmallVectorImpl &CallArgs) { +SILValue EagerDispatch::emitArgumentConversion( + SmallVectorImpl &CallArgs, + SmallVectorImpl &ArgAtIndexNeedsEndBorrow) { auto OrigArgs = GenericFunc->begin()->getSILFunctionArguments(); assert(OrigArgs.size() == substConv.getNumSILArguments() && "signature mismatch"); @@ -661,6 +683,7 @@ emitArgumentConversion(SmallVectorImpl &CallArgs) { CallArgs.push_back(CastArg); continue; } + if (ArgIdx < substConv.getSILArgIndexOfFirstParam()) { // Handle result arguments. unsigned formalIdx = @@ -672,29 +695,43 @@ emitArgumentConversion(SmallVectorImpl &CallArgs) { StoreResultTo = CastArg; continue; } + CallArgs.push_back(CastArg); + continue; + } + + // Handle arguments for formal parameters. + unsigned paramIdx = ArgIdx - substConv.getSILArgIndexOfFirstParam(); + if (!ReInfo.isParamConverted(paramIdx)) { + CallArgs.push_back(CastArg); + continue; + } + + // An argument is converted from indirect to direct. Instead of the + // address we pass the loaded value. + // + // FIXME: If type of CastArg is an archetype, but it is loadable because + // of a layout constraint on the caller side, we have a problem here + // We need to load the value on the caller side, but this archetype is + // not statically known to be loadable on the caller side (though we + // have proven dynamically that it has a fixed size). + // + // We can try to load it as an int value of width N, but then it is not + // clear how to convert it into a value of the archetype type, which is + // expected. May be we should pass it as @in parameter and make it + // loadable on the caller's side? + auto argConv = substConv.getSILArgumentConvention(ArgIdx); + SILValue Val; + if (!argConv.isGuaranteedConvention()) { + Val = Builder.emitLoadValueOperation(Loc, CastArg, + LoadOwnershipQualifier::Take); } else { - // Handle arguments for formal parameters. - unsigned paramIdx = ArgIdx - substConv.getSILArgIndexOfFirstParam(); - if (ReInfo.isParamConverted(paramIdx)) { - // An argument is converted from indirect to direct. Instead of the - // address we pass the loaded value. - // FIXME: If type of CastArg is an archetype, but it is loadable because - // of a layout constraint on the caller side, we have a problem here - // We need to load the value on the caller side, but this archetype is - // not statically known to be loadable on the caller side (though we - // have proven dynamically that it has a fixed size). - // We can try to load it as an int value of width N, but then it is not - // clear how to convert it into a value of the archetype type, which is - // expected. May be we should pass it as @in parameter and make it - // loadable on the caller's side? - SILValue Val = Builder.createLoad(Loc, CastArg, - LoadOwnershipQualifier::Unqualified); - CallArgs.push_back(Val); - continue; - } + Val = Builder.emitLoadBorrowOperation(Loc, CastArg); + if (Val.getOwnershipKind() == ValueOwnershipKind::Guaranteed) + ArgAtIndexNeedsEndBorrow.push_back(CallArgs.size()); } - CallArgs.push_back(CastArg); + CallArgs.push_back(Val); } + return StoreResultTo; } @@ -745,11 +782,9 @@ void EagerSpecializerTransform::run() { // Process functions in any order. for (auto &F : *getModule()) { - // TODO: we should support ownership here but first we'll have to support - // ownership in GenericFuncSpecializer. - if (!F.shouldOptimize() || F.hasOwnership()) { + if (!F.shouldOptimize()) { LLVM_DEBUG(llvm::dbgs() << " Cannot specialize function " << F.getName() - << " because it has ownership or is marked to be " + << " because it is marked to be " "excluded from optimizations.\n"); continue; } @@ -787,6 +822,7 @@ void EagerSpecializerTransform::run() { [&](const SILSpecializeAttr *SA, SILFunction *NewFunc, const ReabstractionInfo &ReInfo) { if (NewFunc) { + NewFunc->verify(); Changed = true; EagerDispatch(&F, ReInfo).emitDispatchTo(NewFunc); } @@ -800,6 +836,7 @@ void EagerSpecializerTransform::run() { // As specializations are created, the attributes should be removed. F.clearSpecializeAttrs(); + F.verify(); } } diff --git a/lib/SILOptimizer/Mandatory/OwnershipModelEliminator.cpp b/lib/SILOptimizer/Mandatory/OwnershipModelEliminator.cpp index 33d574faddb73..769ee6c2cef75 100644 --- a/lib/SILOptimizer/Mandatory/OwnershipModelEliminator.cpp +++ b/lib/SILOptimizer/Mandatory/OwnershipModelEliminator.cpp @@ -93,6 +93,16 @@ struct OwnershipModelEliminatorVisitor bool visitSwitchEnumInst(SwitchEnumInst *SWI); bool visitDestructureStructInst(DestructureStructInst *DSI); bool visitDestructureTupleInst(DestructureTupleInst *DTI); + + // We lower this to unchecked_bitwise_cast losing our assumption of layout + // compatibility. + bool visitUncheckedValueCastInst(UncheckedValueCastInst *UVCI) { + auto *NewVal = B.createUncheckedBitwiseCast( + UVCI->getLoc(), UVCI->getOperand(), UVCI->getType()); + UVCI->replaceAllUsesWith(NewVal); + UVCI->eraseFromParent(); + return true; + } }; } // end anonymous namespace diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp index e71bba15b5f6c..9aa6581052f3f 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp @@ -190,7 +190,8 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI, auto UAC = Builder.createUncheckedAddrCast(AI.getLoc(), Op, NewOpType); Args.push_back(UAC); } else if (OldOpType.getASTType() != NewOpType.getASTType()) { - auto URC = Builder.createUncheckedBitCast(AI.getLoc(), Op, NewOpType); + auto URC = + Builder.createUncheckedReinterpretCast(AI.getLoc(), Op, NewOpType); Args.push_back(URC); } else { Args.push_back(Op); @@ -381,8 +382,8 @@ bool SILCombiner::tryOptimizeKeypathOffsetOf(ApplyInst *AI, return false; // Convert the projected address back to an optional integer. - SILValue offset = Builder.createUncheckedBitCast(loc, offsetPtr, - SILType::getPrimitiveObjectType(builtinIntTy)); + SILValue offset = Builder.createUncheckedReinterpretCast( + loc, offsetPtr, SILType::getPrimitiveObjectType(builtinIntTy)); SILValue offsetInt = Builder.createStruct(loc, intType, { offset }); result = Builder.createOptionalSome(loc, offsetInt, AI->getType()); } else { diff --git a/lib/SILOptimizer/Transforms/SpeculativeDevirtualizer.cpp b/lib/SILOptimizer/Transforms/SpeculativeDevirtualizer.cpp index ecd0e7db862ff..120cab28900ff 100644 --- a/lib/SILOptimizer/Transforms/SpeculativeDevirtualizer.cpp +++ b/lib/SILOptimizer/Transforms/SpeculativeDevirtualizer.cpp @@ -562,9 +562,9 @@ static bool tryToSpeculateTarget(FullApplySite AI, ClassHierarchyAnalysis *CHA, if (LastCCBI && SubTypeValue == LastCCBI->getOperand()) { // Remove last checked_cast_br, because it will always succeed. SILBuilderWithScope B(LastCCBI); - auto CastedValue = B.createUncheckedBitCast(LastCCBI->getLoc(), - LastCCBI->getOperand(), - LastCCBI->getTargetLoweredType()); + auto CastedValue = B.createUncheckedReinterpretCast( + LastCCBI->getLoc(), LastCCBI->getOperand(), + LastCCBI->getTargetLoweredType()); B.createBranch(LastCCBI->getLoc(), LastCCBI->getSuccessBB(), {CastedValue}); LastCCBI->eraseFromParent(); ORE.emit(RB); diff --git a/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp b/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp index 1d17283837cf9..88b831e20f2ba 100644 --- a/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp +++ b/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp @@ -177,6 +177,7 @@ static bool hasOpaqueArchetype(TypeExpansionContext context, case SILInstructionKind::UncheckedAddrCastInst: case SILInstructionKind::UncheckedTrivialBitCastInst: case SILInstructionKind::UncheckedBitwiseCastInst: + case SILInstructionKind::UncheckedValueCastInst: case SILInstructionKind::RefToRawPointerInst: case SILInstructionKind::RawPointerToRefInst: #define LOADABLE_REF_STORAGE(Name, ...) \ diff --git a/lib/SILOptimizer/Utils/CFGOptUtils.cpp b/lib/SILOptimizer/Utils/CFGOptUtils.cpp index d40cb378b8a45..213ebe5f15cf9 100644 --- a/lib/SILOptimizer/Utils/CFGOptUtils.cpp +++ b/lib/SILOptimizer/Utils/CFGOptUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "swift/SILOptimizer/Utils/CFGOptUtils.h" +#include "swift/Basic/STLExtras.h" #include "swift/Demangling/ManglingMacros.h" #include "swift/SIL/BasicBlockUtils.h" #include "swift/SIL/Dominance.h" @@ -551,6 +552,20 @@ bool swift::splitCriticalEdgesFrom(SILBasicBlock *fromBB, return changed; } +bool swift::splitCriticalEdgesTo(SILBasicBlock *toBB, DominanceInfo *domInfo, + SILLoopInfo *loopInfo) { + bool changed = false; + unsigned numPreds = std::distance(toBB->pred_begin(), toBB->pred_end()); + + for (unsigned idx = 0; idx != numPreds; ++idx) { + SILBasicBlock *fromBB = *std::next(toBB->pred_begin(), idx); + auto *newBB = splitIfCriticalEdge(fromBB, toBB); + changed |= (newBB != nullptr); + } + + return changed; +} + bool swift::hasCriticalEdges(SILFunction &f, bool onlyNonCondBr) { for (SILBasicBlock &bb : f) { // Only consider critical edges for terminators that don't support block diff --git a/lib/SILOptimizer/Utils/InstOptUtils.cpp b/lib/SILOptimizer/Utils/InstOptUtils.cpp index 334cb559e142d..11a5bf7e79af4 100644 --- a/lib/SILOptimizer/Utils/InstOptUtils.cpp +++ b/lib/SILOptimizer/Utils/InstOptUtils.cpp @@ -853,7 +853,8 @@ swift::castValueToABICompatibleType(SILBuilder *builder, SILLocation loc, } // Cast between two metatypes and that's it. - return {builder->createUncheckedBitCast(loc, value, destTy), false}; + return {builder->createUncheckedReinterpretCast(loc, value, destTy), + false}; } } } diff --git a/lib/SILOptimizer/Utils/SILInliner.cpp b/lib/SILOptimizer/Utils/SILInliner.cpp index c8fc5e583298e..b08f9832e0237 100644 --- a/lib/SILOptimizer/Utils/SILInliner.cpp +++ b/lib/SILOptimizer/Utils/SILInliner.cpp @@ -714,6 +714,7 @@ InlineCost swift::instructionInlineCost(SILInstruction &I) { case SILInstructionKind::UncheckedAddrCastInst: case SILInstructionKind::UncheckedTrivialBitCastInst: case SILInstructionKind::UncheckedBitwiseCastInst: + case SILInstructionKind::UncheckedValueCastInst: case SILInstructionKind::RawPointerToRefInst: case SILInstructionKind::RefToRawPointerInst: diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index cd3fbe7dfb329..aff6561c3df55 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -6268,7 +6268,7 @@ performMemberLookup(ConstraintKind constraintKind, DeclNameRef memberName, // If this is an attempt to access read-only member via // writable key path, let's fail this choice early. auto &ctx = getASTContext(); - if (isReadOnlyKeyPathComponent(storage) && + if (isReadOnlyKeyPathComponent(storage, SourceLoc()) && (keyPath == ctx.getWritableKeyPathDecl() || keyPath == ctx.getReferenceWritableKeyPathDecl())) { result.addUnviable( @@ -8020,7 +8020,7 @@ ConstraintSystem::simplifyKeyPathConstraint( if (!storage) return SolutionKind::Error; - if (isReadOnlyKeyPathComponent(storage)) { + if (isReadOnlyKeyPathComponent(storage, component.getLoc())) { capability = ReadOnly; continue; } diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h index 0b641c7f111c0..1001b42d10fc6 100644 --- a/lib/Sema/ConstraintSystem.h +++ b/lib/Sema/ConstraintSystem.h @@ -4671,7 +4671,8 @@ class ConstraintSystem { bool restoreOnFail, llvm::function_ref pred); - bool isReadOnlyKeyPathComponent(const AbstractStorageDecl *storage) { + bool isReadOnlyKeyPathComponent(const AbstractStorageDecl *storage, + SourceLoc referenceLoc) { // See whether key paths can store to this component. (Key paths don't // get any special power from being formed in certain contexts, such // as the ability to assign to `let`s in initialization contexts, so @@ -4692,6 +4693,17 @@ class ConstraintSystem { // a reference-writable component shows up later. return true; } + + // If the setter is unavailable, then the keypath ought to be read-only + // in this context. + if (auto setter = storage->getOpaqueAccessor(AccessorKind::Set)) { + auto maybeUnavail = TypeChecker::checkDeclarationAvailability(setter, + referenceLoc, + DC); + if (maybeUnavail.hasValue()) { + return true; + } + } return false; } diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index f2af26aa9d614..0da74141a1a7c 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -2523,6 +2523,14 @@ EmittedMembersRequest::evaluate(Evaluator &evaluator, forceConformance(Context.getProtocol(KnownProtocolKind::Hashable)); forceConformance(Context.getProtocol(KnownProtocolKind::Differentiable)); + // If the class conforms to Encodable or Decodable, even via an extension, + // the CodingKeys enum is synthesized as a member of the type itself. + // Force it into existence. + (void) evaluateOrDefault(Context.evaluator, + ResolveImplicitMemberRequest{CD, + ImplicitMemberAction::ResolveCodingKeys}, + {}); + // If the class has a @main attribute, we need to force synthesis of the // $main function. (void) evaluateOrDefault(Context.evaluator, diff --git a/lib/Sema/TypeCheckStorage.cpp b/lib/Sema/TypeCheckStorage.cpp index 64448c4505e1a..47817266887a4 100644 --- a/lib/Sema/TypeCheckStorage.cpp +++ b/lib/Sema/TypeCheckStorage.cpp @@ -1659,7 +1659,7 @@ synthesizeSetterBody(AccessorDecl *setter, ASTContext &ctx) { case WriteImplKind::Modify: return synthesizeModifyCoroutineSetterBody(setter, ctx); } - llvm_unreachable("bad ReadImplKind"); + llvm_unreachable("bad WriteImplKind"); } static std::pair @@ -1940,7 +1940,75 @@ static AccessorDecl *createSetterPrototype(AbstractStorageDecl *storage, // All mutable storage requires a setter. assert(storage->requiresOpaqueAccessor(AccessorKind::Set)); + + // Copy availability from the accessor we'll synthesize the setter from. + SmallVector asAvailableAs; + + // That could be a property wrapper... + if (auto var = dyn_cast(storage)) { + if (var->hasAttachedPropertyWrapper()) { + // The property wrapper info may not actually link back to a wrapper + // implementation, if there was a semantic error checking the wrapper. + auto info = var->getAttachedPropertyWrapperTypeInfo(0); + if (info.valueVar) { + if (auto setter = info.valueVar->getOpaqueAccessor(AccessorKind::Set)) { + asAvailableAs.push_back(setter); + } + } + } else if (auto wrapperSynthesizedKind + = var->getPropertyWrapperSynthesizedPropertyKind()) { + switch (*wrapperSynthesizedKind) { + case PropertyWrapperSynthesizedPropertyKind::Backing: + break; + + case PropertyWrapperSynthesizedPropertyKind::StorageWrapper: { + if (auto origVar = var->getOriginalWrappedProperty(wrapperSynthesizedKind)) { + // The property wrapper info may not actually link back to a wrapper + // implementation, if there was a semantic error checking the wrapper. + auto info = origVar->getAttachedPropertyWrapperTypeInfo(0); + if (info.projectedValueVar) { + if (auto setter + = info.projectedValueVar->getOpaqueAccessor(AccessorKind::Set)){ + asAvailableAs.push_back(setter); + } + } + } + break; + } + } + } + } + + // ...or another accessor. + switch (storage->getWriteImpl()) { + case WriteImplKind::Immutable: + llvm_unreachable("synthesizing setter from immutable storage"); + case WriteImplKind::Stored: + case WriteImplKind::StoredWithObservers: + case WriteImplKind::InheritedWithObservers: + case WriteImplKind::Set: + // Setter's availability shouldn't be externally influenced in these + // cases. + break; + + case WriteImplKind::MutableAddress: + if (auto addr = storage->getOpaqueAccessor(AccessorKind::MutableAddress)) { + asAvailableAs.push_back(addr); + } + break; + case WriteImplKind::Modify: + if (auto mod = storage->getOpaqueAccessor(AccessorKind::Modify)) { + asAvailableAs.push_back(mod); + } + break; + } + + if (!asAvailableAs.empty()) { + AvailabilityInference::applyInferredAvailableAttrs( + setter, asAvailableAs, ctx); + } + finishImplicitAccessor(setter, ctx); return setter; diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 2ed57c9f0794b..86565d2ee9314 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -1263,6 +1263,7 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, ONEOPERAND_ONETYPE_INST(UncheckedAddrCast) ONEOPERAND_ONETYPE_INST(UncheckedTrivialBitCast) ONEOPERAND_ONETYPE_INST(UncheckedBitwiseCast) + ONEOPERAND_ONETYPE_INST(UncheckedValueCast) ONEOPERAND_ONETYPE_INST(BridgeObjectToRef) ONEOPERAND_ONETYPE_INST(BridgeObjectToWord) ONEOPERAND_ONETYPE_INST(Upcast) diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 88c11081e8905..9bf7c6b8e4647 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 562; // base_addr_for_offset instruction +const uint16_t SWIFTMODULE_VERSION_MINOR = 563; // unchecked_value_cast /// A standard hash seed used for all string hashes in a serialized module. /// diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index af20a44e2c83a..285db801310ff 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -1557,6 +1557,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { case SILInstructionKind::UncheckedAddrCastInst: case SILInstructionKind::UncheckedTrivialBitCastInst: case SILInstructionKind::UncheckedBitwiseCastInst: + case SILInstructionKind::UncheckedValueCastInst: case SILInstructionKind::BridgeObjectToRefInst: case SILInstructionKind::BridgeObjectToWordInst: case SILInstructionKind::UpcastInst: diff --git a/stdlib/public/SwiftShims/LibcShims.h b/stdlib/public/SwiftShims/LibcShims.h index 9c30303ef2094..6abe6880d8ff8 100644 --- a/stdlib/public/SwiftShims/LibcShims.h +++ b/stdlib/public/SwiftShims/LibcShims.h @@ -32,26 +32,6 @@ extern "C" { #endif -// This declaration might not be universally correct. -// We verify its correctness for the current platform in the runtime code. -#if defined(__linux__) -# if defined(__ANDROID__) && !(defined(__aarch64__) || defined(__x86_64__)) -typedef __swift_uint16_t __swift_mode_t; -# else -typedef __swift_uint32_t __swift_mode_t; -# endif -#elif defined(__APPLE__) -typedef __swift_uint16_t __swift_mode_t; -#elif defined(_WIN32) -typedef __swift_int32_t __swift_mode_t; -#elif defined(__wasi__) -typedef __swift_uint32_t __swift_mode_t; -#elif defined(__OpenBSD__) -typedef __swift_uint32_t __swift_mode_t; -#else // just guessing -typedef __swift_uint16_t __swift_mode_t; -#endif - // Input/output SWIFT_RUNTIME_STDLIB_INTERNAL diff --git a/stdlib/public/stubs/LibcShims.cpp b/stdlib/public/stubs/LibcShims.cpp index e47222910b5dd..092581f8cedc6 100644 --- a/stdlib/public/stubs/LibcShims.cpp +++ b/stdlib/public/stubs/LibcShims.cpp @@ -31,11 +31,6 @@ #include "../SwiftShims/LibcShims.h" -#if !defined(_WIN32) || defined(__CYGWIN__) -static_assert(std::is_same::value, - "__swift_mode_t must be defined as equivalent to mode_t in LibcShims.h"); -#endif - SWIFT_RUNTIME_STDLIB_INTERNAL int _swift_stdlib_putchar_unlocked(int c) { #if defined(_WIN32) diff --git a/test/DebugInfo/local-vars.swift.gyb b/test/DebugInfo/local-vars.swift.gyb index 611ef2270dc12..4a3be6bb1c72d 100644 --- a/test/DebugInfo/local-vars.swift.gyb +++ b/test/DebugInfo/local-vars.swift.gyb @@ -82,10 +82,11 @@ type_initializer_map = [ ] import re -def derive_name((type, val)): +def derive_name(type, val): return (re.sub(r'[<> ,\.\?\(\)\[\]-]', '_', type), type, val) -for name, type, val in map(derive_name, type_initializer_map): +for (type, val) in type_initializer_map: + (name, type, val) = derive_name(type, val) generic = (type in ['T', 'U']) function = "->" in type }% diff --git a/test/Driver/linker.swift b/test/Driver/linker.swift index 339a3b964540d..e5d36c3c6bb59 100644 --- a/test/Driver/linker.swift +++ b/test/Driver/linker.swift @@ -1,5 +1,6 @@ // Must be able to run xcrun-return-self.sh // REQUIRES: shell +// REQUIRES: rdar65281056 // RUN: %swiftc_driver -sdk "" -driver-print-jobs -target x86_64-apple-macosx10.9 %s 2>&1 > %t.simple.txt // RUN: %FileCheck %s < %t.simple.txt // RUN: %FileCheck -check-prefix SIMPLE %s < %t.simple.txt diff --git a/test/Driver/loaded_module_trace_spi.swift b/test/Driver/loaded_module_trace_spi.swift new file mode 100644 index 0000000000000..1cf183e6f648f --- /dev/null +++ b/test/Driver/loaded_module_trace_spi.swift @@ -0,0 +1,20 @@ +// RUN: %empty-directory(%t) +// RUN: %target-build-swift %s -emit-module -module-name Module0 -DSPI_DECLARING_MODULE -o %t/Module0.swiftmodule +// RUN: %target-build-swift %s -module-name Module1 -DSPI_IMPORTING_MODULE -emit-loaded-module-trace -o %t/spi_import -I %t +// RUN: %FileCheck %s < %t/Module1.trace.json + +#if SPI_DECLARING_MODULE + +@_spi(ForModule1) +public func f() {} + +#endif + +#if SPI_IMPORTING_MODULE + +@_spi(ForModule1) +import Module0 + +#endif + +// CHECK: "name":"Module0","path":"{{[^"]*}}","isImportedDirectly":true diff --git a/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-1distinct_use.swift b/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-1distinct_use.swift index 5eef993d54324..75f5f86c789c5 100644 --- a/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-1distinct_use.swift +++ b/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-1distinct_use.swift @@ -75,7 +75,25 @@ doit() // CHECK: [[ARGUMENT_BUFFER:%[0-9]+]] = bitcast i8* [[ERASED_TABLE]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR:%[0-9]+]] = load i8*, i8** [[ARGUMENT_BUFFER]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR]], %swift.protocol_conformance_descriptor* @"$sSi4main1PAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1PAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS:%[0-9]+]] = and i1 [[EQUAL_TYPES]], [[EQUAL_DESCRIPTORS]] // CHECK: br i1 [[EQUAL_ARGUMENTS]], label %[[EXIT_PRESPECIALIZED:[0-9]+]], label %[[EXIT_NORMAL:[0-9]+]] // CHECK: [[EXIT_PRESPECIALIZED]]: diff --git a/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-public-1distinct_use.swift b/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-public-1distinct_use.swift index f41f665832d95..06214ca5a0df9 100644 --- a/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-public-1distinct_use.swift +++ b/test/IRGen/prespecialized-metadata/enum-inmodule-1argument-1conformance-public-1distinct_use.swift @@ -75,7 +75,25 @@ doit() // CHECK: [[ARGUMENT_BUFFER:%[0-9]+]] = bitcast i8* [[ERASED_TABLE]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR:%[0-9]+]] = load i8*, i8** [[ARGUMENT_BUFFER]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR]], %swift.protocol_conformance_descriptor* @"$sSi4main1PAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1PAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS:%[0-9]+]] = and i1 [[EQUAL_TYPES]], [[EQUAL_DESCRIPTORS]] // CHECK: br i1 [[EQUAL_ARGUMENTS]], label %[[EXIT_PRESPECIALIZED:[0-9]+]], label %[[EXIT_NORMAL:[0-9]+]] // CHECK: [[EXIT_PRESPECIALIZED]]: diff --git a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-1conformance-1distinct_use.swift b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-1conformance-1distinct_use.swift index ef64635e40c72..9e811bba493b3 100644 --- a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-1conformance-1distinct_use.swift +++ b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-1conformance-1distinct_use.swift @@ -58,7 +58,25 @@ doit() // CHECK: [[ARGUMENT_BUFFER:%[0-9]+]] = bitcast i8* [[ERASED_TABLE]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR:%[0-9]+]] = load i8*, i8** [[ARGUMENT_BUFFER]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR]], %swift.protocol_conformance_descriptor* @"$sSi4main1PAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1PAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS:%[0-9]+]] = and i1 [[EQUAL_TYPES]], [[EQUAL_DESCRIPTORS]] // CHECK: br i1 [[EQUAL_ARGUMENTS]], label %[[EXIT_PRESPECIALIZED:[0-9]+]], label %[[EXIT_NORMAL:[0-9]+]] // CHECK: [[EXIT_PRESPECIALIZED]]: diff --git a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-2conformance-1distinct_use.swift b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-2conformance-1distinct_use.swift index 719f33d6371d3..18faf8c91f44a 100644 --- a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-2conformance-1distinct_use.swift +++ b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-2conformance-1distinct_use.swift @@ -63,12 +63,48 @@ doit() // CHECK: [[UNERASED_TABLE_1:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_1]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_1]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_1]], %swift.protocol_conformance_descriptor* @"$sSi4main1PAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_1:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_1]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_1:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_1]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_1]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1PAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_1:%[0-9]+]] = and i1 [[EQUAL_TYPES]], [[EQUAL_DESCRIPTORS_1]] // CHECK: [[UNERASED_TABLE_2:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_2]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_2]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_2]], %swift.protocol_conformance_descriptor* @"$sSi4main1QAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_2:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_2]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_2:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_2]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_2]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1QAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_2:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_1]], [[EQUAL_DESCRIPTORS_2]] // CHECK: br i1 [[EQUAL_ARGUMENTS_2]], label %[[EXIT_PRESPECIALIZED:[0-9]+]], label %[[EXIT_NORMAL:[0-9]+]] // CHECK: [[EXIT_PRESPECIALIZED]]: diff --git a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-3conformance-1distinct_use.swift b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-3conformance-1distinct_use.swift index 0e1974e1df5c3..43342719fb133 100644 --- a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-3conformance-1distinct_use.swift +++ b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-3conformance-1distinct_use.swift @@ -69,21 +69,75 @@ doit() // CHECK: [[UNERASED_TABLE_1:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_1]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_1]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_1]], %swift.protocol_conformance_descriptor* @"$sSi4main1PAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_1:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_1]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_1:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_1]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_1]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1PAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_1:%[0-9]+]] = and i1 [[EQUAL_TYPES]], [[EQUAL_DESCRIPTORS_1]] // CHECK: [[POINTER_TO_ERASED_TABLE_2:%[0-9]+]] = getelementptr i8*, i8** %1, i64 2 // CHECK: [[ERASED_TABLE_2:%"load argument at index 2 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_2]], align 1 // CHECK: [[UNERASED_TABLE_2:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_2]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_2]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_2]], %swift.protocol_conformance_descriptor* @"$sSi4main1QAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_2:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_2]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_2:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_2]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_2]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1QAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_2:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_1]], [[EQUAL_DESCRIPTORS_2]] // CHECK: [[POINTER_TO_ERASED_TABLE_3:%[0-9]+]] = getelementptr i8*, i8** %1, i64 3 // CHECK: [[ERASED_TABLE_3:%"load argument at index 3 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_3]], align 1 // CHECK: [[UNERASED_TABLE_3:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_3]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_3:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_3]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_3:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_3]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_3:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_3]], %swift.protocol_conformance_descriptor* @"$sSi4main1RAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_3:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_3]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_3:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_3]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_3:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_3]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_3:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_3]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_3:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_3]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_3:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_3]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_3:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_3]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_3:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_3]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1RAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_3:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_2]], [[EQUAL_DESCRIPTORS_3]] // CHECK: br i1 [[EQUAL_ARGUMENTS_3]], label %[[EXIT_PRESPECIALIZED:[0-9]+]], label %[[EXIT_NORMAL:[0-9]+]] // CHECK: [[EXIT_PRESPECIALIZED]]: diff --git a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-4conformance-1distinct_use.swift b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-4conformance-1distinct_use.swift index d59b500546590..2a680000ea99e 100644 --- a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-4conformance-1distinct_use.swift +++ b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-4conformance-1distinct_use.swift @@ -73,28 +73,100 @@ doit() // CHECK: [[UNERASED_TABLE_1:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_1]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_1]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_1]], %swift.protocol_conformance_descriptor* @"$sSi4main1PAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_1:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_1]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_1:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_1]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_1]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1PAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_1:%[0-9]+]] = and i1 [[EQUAL_TYPES]], [[EQUAL_DESCRIPTORS_1]] // CHECK: [[POINTER_TO_ERASED_TABLE_2:%[0-9]+]] = getelementptr i8*, i8** %1, i64 2 // CHECK: [[ERASED_TABLE_2:%"load argument at index 2 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_2]], align 1 // CHECK: [[UNERASED_TABLE_2:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_2]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_2]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_2]], %swift.protocol_conformance_descriptor* @"$sSi4main1QAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_2:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_2]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_2:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_2]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_2]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1QAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_2:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_1]], [[EQUAL_DESCRIPTORS_2]] // CHECK: [[POINTER_TO_ERASED_TABLE_3:%[0-9]+]] = getelementptr i8*, i8** %1, i64 3 // CHECK: [[ERASED_TABLE_3:%"load argument at index 3 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_3]], align 1 // CHECK: [[UNERASED_TABLE_3:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_3]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_3:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_3]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_3:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_3]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_3:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_3]], %swift.protocol_conformance_descriptor* @"$sSi4main1RAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_3:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_3]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_3:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_3]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_3:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_3]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_3:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_3]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_3:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_3]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_3:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_3]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_3:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_3]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_3:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_3]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1RAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_3:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_2]], [[EQUAL_DESCRIPTORS_3]] // CHECK: [[POINTER_TO_ERASED_TABLE_4:%[0-9]+]] = getelementptr i8*, i8** %1, i64 4 // CHECK: [[ERASED_TABLE_4:%"load argument at index 4 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_4]], align 1 // CHECK: [[UNERASED_TABLE_4:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_4]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_4:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_4]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_4:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_4]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_4:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_4]], %swift.protocol_conformance_descriptor* @"$sSi4main1SAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_4:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_4]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_4:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_4]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_4:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_4]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_4:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_4]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_4:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_4]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_4:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_4]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_4:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_4]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_4:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_4]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1SAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_4:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_3]], [[EQUAL_DESCRIPTORS_4]] // CHECK: br i1 [[EQUAL_ARGUMENTS_4]], label %[[EXIT_PRESPECIALIZED:[0-9]+]], label %[[EXIT_NORMAL:[0-9]+]] // CHECK: [[EXIT_PRESPECIALIZED]]: diff --git a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-5conformance-1distinct_use.swift b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-5conformance-1distinct_use.swift index 0f883e478e94f..3bfa8c5834da9 100644 --- a/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-5conformance-1distinct_use.swift +++ b/test/IRGen/prespecialized-metadata/struct-inmodule-1argument-5conformance-1distinct_use.swift @@ -77,35 +77,125 @@ doit() // CHECK: [[UNERASED_TABLE_1:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_1]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_1]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_1:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_1]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_1]], %swift.protocol_conformance_descriptor* @"$sSi4main1PAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_1:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_1]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_1:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_1]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_1]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_1]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_1]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_1]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_1]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_1:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_1]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1PAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_1:%[0-9]+]] = and i1 [[EQUAL_TYPES]], [[EQUAL_DESCRIPTORS_1]] // CHECK: [[POINTER_TO_ERASED_TABLE_2:%[0-9]+]] = getelementptr i8*, i8** %1, i64 2 // CHECK: [[ERASED_TABLE_2:%"load argument at index 2 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_2]], align 1 // CHECK: [[UNERASED_TABLE_2:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_2]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_2]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_2:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_2]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_2]], %swift.protocol_conformance_descriptor* @"$sSi4main1QAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_2:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_2]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_2:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_2]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_2]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_2]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_2]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_2]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_2]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_2:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_2]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1QAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_2:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_1]], [[EQUAL_DESCRIPTORS_2]] // CHECK: [[POINTER_TO_ERASED_TABLE_3:%[0-9]+]] = getelementptr i8*, i8** %1, i64 3 // CHECK: [[ERASED_TABLE_3:%"load argument at index 3 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_3]], align 1 // CHECK: [[UNERASED_TABLE_3:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_3]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_3:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_3]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_3:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_3]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_3:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_3]], %swift.protocol_conformance_descriptor* @"$sSi4main1RAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_3:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_3]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_3:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_3]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_3:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_3]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_3:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_3]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_3:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_3]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_3:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_3]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_3:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_3]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_3:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_3]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1RAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_3:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_2]], [[EQUAL_DESCRIPTORS_3]] // CHECK: [[POINTER_TO_ERASED_TABLE_4:%[0-9]+]] = getelementptr i8*, i8** %1, i64 4 // CHECK: [[ERASED_TABLE_4:%"load argument at index 4 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_4]], align 1 // CHECK: [[UNERASED_TABLE_4:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_4]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_4:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_4]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_4:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_4]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_4:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_4]], %swift.protocol_conformance_descriptor* @"$sSi4main1SAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_4:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_4]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_4:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_4]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_4:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_4]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_4:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_4]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_4:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_4]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_4:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_4]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_4:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_4]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_4:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_4]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1SAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_4:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_3]], [[EQUAL_DESCRIPTORS_4]] // CHECK: [[POINTER_TO_ERASED_TABLE_5:%[0-9]+]] = getelementptr i8*, i8** %1, i64 5 // CHECK: [[ERASED_TABLE_5:%"load argument at index 5 from buffer"]] = load i8*, i8** [[POINTER_TO_ERASED_TABLE_5]], align 1 // CHECK: [[UNERASED_TABLE_5:%[0-9]+]] = bitcast i8* [[ERASED_TABLE_5]] to i8** // CHECK: [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_5:%[0-9]+]] = load i8*, i8** [[UNERASED_TABLE_5]], align 1 // CHECK: [[PROVIDED_PROTOCOL_DESCRIPTOR_5:%[0-9]+]] = bitcast i8* [[UNCAST_PROVIDED_PROTOCOL_DESCRIPTOR_5]] to %swift.protocol_conformance_descriptor* -// CHECK: [[EQUAL_DESCRIPTORS_5:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors(%swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_5]], %swift.protocol_conformance_descriptor* @"$sSi4main1TAAMc") +// CHECK-arm64e: [[ERASED_TABLE_INT_5:%[0-9]+]] = ptrtoint i8* [[ERASED_TABLE_5]] to i64 +// CHECK-arm64e: [[TABLE_SIGNATURE_5:%[0-9]+]] = call i64 @llvm.ptrauth.blend.i64(i64 [[ERASED_TABLE_INT_5]], i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_5:%[0-9]+]] = call i64 @llvm.ptrauth.auth.i64(i64 %13, i32 2, i64 [[TABLE_SIGNATURE_5]]) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_5:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_5]] to %swift.protocol_conformance_descriptor* +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_5:%[0-9]+]] = ptrtoint %swift.protocol_conformance_descriptor* [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_5]] to i64 +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_5:%[0-9]+]] = call i64 @llvm.ptrauth.sign.i64(i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_AUTHED_PTR_INT_5]], i32 2, i64 50923) +// CHECK-arm64e: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_5:%[0-9]+]] = inttoptr i64 [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_INT_5]] to %swift.protocol_conformance_descriptor* +// CHECK: [[EQUAL_DESCRIPTORS_5:%[0-9]+]] = call swiftcc i1 @swift_compareProtocolConformanceDescriptors( +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-arm64e-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR_SIGNED_5]], +// CHECK-i386-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-x86_64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7s-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-armv7k-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-arm64-SAME: [[PROVIDED_PROTOCOL_DESCRIPTOR]], +// CHECK-SAME: %swift.protocol_conformance_descriptor* +// CHECK-SAME: $sSi4main1TAAMc +// CHECK-SAME: ) // CHECK: [[EQUAL_ARGUMENTS_5:%[0-9]+]] = and i1 [[EQUAL_ARGUMENTS_4]], [[EQUAL_DESCRIPTORS_5]] // CHECK: br i1 [[EQUAL_ARGUMENTS_5]], label %[[EXIT_PRESPECIALIZED:[0-9]+]], label %[[EXIT_NORMAL:[0-9]+]] // CHECK: [[EXIT_PRESPECIALIZED]]: diff --git a/test/Interop/Cxx/class/Inputs/module.modulemap b/test/Interop/Cxx/class/Inputs/module.modulemap index d4460da89db64..24cde705fcde9 100644 --- a/test/Interop/Cxx/class/Inputs/module.modulemap +++ b/test/Interop/Cxx/class/Inputs/module.modulemap @@ -2,8 +2,8 @@ module AccessSpecifiers { header "access-specifiers.h" } -module LoadableTypes { - header "loadable-types.h" +module TypeClassification { + header "type-classification.h" } module MemberwiseInitializer { diff --git a/test/Interop/Cxx/class/Inputs/loadable-types.h b/test/Interop/Cxx/class/Inputs/type-classification.h similarity index 99% rename from test/Interop/Cxx/class/Inputs/loadable-types.h rename to test/Interop/Cxx/class/Inputs/type-classification.h index 5f92b6ef7d089..785c5ef278a73 100644 --- a/test/Interop/Cxx/class/Inputs/loadable-types.h +++ b/test/Interop/Cxx/class/Inputs/type-classification.h @@ -80,7 +80,7 @@ struct StructWithSubobjectMoveAssignment { }; struct StructWithDestructor { - ~StructWithDestructor(){} + ~StructWithDestructor() {} }; struct StructWithInheritedDestructor : StructWithDestructor {}; diff --git a/test/Interop/Cxx/class/loadable-types-silgen.swift b/test/Interop/Cxx/class/type-classification-loadable-silgen.swift similarity index 99% rename from test/Interop/Cxx/class/loadable-types-silgen.swift rename to test/Interop/Cxx/class/type-classification-loadable-silgen.swift index 7611f5d73d747..9681aa3d9b333 100644 --- a/test/Interop/Cxx/class/loadable-types-silgen.swift +++ b/test/Interop/Cxx/class/type-classification-loadable-silgen.swift @@ -3,7 +3,7 @@ // This test checks that we classify C++ types as loadable and address-only // correctly. -import LoadableTypes +import TypeClassification // Tests for individual special members diff --git a/test/Interop/Cxx/class/type-classification-non-trivial-silgen.swift b/test/Interop/Cxx/class/type-classification-non-trivial-silgen.swift new file mode 100644 index 0000000000000..a4ba0bdaccbe0 --- /dev/null +++ b/test/Interop/Cxx/class/type-classification-non-trivial-silgen.swift @@ -0,0 +1,25 @@ +// RUN: %target-swift-frontend -I %S/Inputs -enable-cxx-interop -emit-sil %s | %FileCheck %s + +import TypeClassification + +// Make sure that "StructWithDestructor" is marked as non-trivial by checking for a +// "destroy_addr". +// CHECK-LABEL: @$s4main24testStructWithDestructoryyF +// CHECK: [[AS:%.*]] = alloc_stack $StructWithDestructor +// CHECK: destroy_addr [[AS]] +// CHECK-LABEL: end sil function '$s4main24testStructWithDestructoryyF' +public func testStructWithDestructor() { + let d = StructWithDestructor() +} + +// Make sure that "HasMemberWithDestructor" is marked as non-trivial by checking +// for a "destroy_addr". +// CHECK-LABEL: @$s4main33testStructWithSubobjectDestructoryyF +// CHECK: [[AS:%.*]] = alloc_stack $StructWithSubobjectDestructor +// CHECK: destroy_addr [[AS]] +// CHECK-LABEL: end sil function '$s4main33testStructWithSubobjectDestructoryyF' +public func testStructWithSubobjectDestructor() { + let d = StructWithSubobjectDestructor() +} + + diff --git a/test/ModuleInterface/swift_build_sdk_interfaces/compiler-crash-windows.test-sh b/test/ModuleInterface/swift_build_sdk_interfaces/compiler-crash-windows.test-sh index 1c0177dadb14d..022ffea70e1a5 100644 --- a/test/ModuleInterface/swift_build_sdk_interfaces/compiler-crash-windows.test-sh +++ b/test/ModuleInterface/swift_build_sdk_interfaces/compiler-crash-windows.test-sh @@ -4,7 +4,9 @@ # See rdar://59397376 REQUIRES: OS=windows-msvc -RUN: env SWIFT_EXEC=%swiftc_driver_plain not --crash %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/mock-sdk/ -o %t/output -debug-crash-compiler 2>&1 | %FileCheck %s +RUN: %empty-directory(%t) +RUN: env SWIFT_EXEC=%swiftc_driver_plain not --crash %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/mock-sdk/ -o %t/output -debug-crash-compiler -log-path %t +RUN: %FileCheck %s < %t/Flat-Flat-err.txt CHECK: Program arguments: CHECK-SAME: -debug-crash-immediately diff --git a/test/ModuleInterface/swift_build_sdk_interfaces/xfails.test-sh b/test/ModuleInterface/swift_build_sdk_interfaces/xfails.test-sh index d31cdcec7e789..98f8458bab8f8 100644 --- a/test/ModuleInterface/swift_build_sdk_interfaces/xfails.test-sh +++ b/test/ModuleInterface/swift_build_sdk_interfaces/xfails.test-sh @@ -1,8 +1,8 @@ RUN: %empty-directory(%t) -RUN: env SWIFT_EXEC=%swiftc_driver_plain not %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/xfails-sdk/ -o %t/output 2> %t/stderr.txt | %FileCheck %s -RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/stderr.txt -RUN: env SWIFT_EXEC=%swiftc_driver_plain not %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/xfails-sdk/ -v -o %t/output 2> %t/stderr.txt | %FileCheck -check-prefix CHECK-VERBOSE %s -RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/stderr.txt +RUN: env SWIFT_EXEC=%swiftc_driver_plain not %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/xfails-sdk/ -o %t/output -log-path %t | %FileCheck %s +RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/Bad-Bad-err.txt +RUN: env SWIFT_EXEC=%swiftc_driver_plain not %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/xfails-sdk/ -v -o %t/output -log-path %t | %FileCheck -check-prefix CHECK-VERBOSE %s +RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/Bad-Bad-err.txt CHECK: # (FAIL) {{.+}}{{\\|/}}Bad.swiftinterface CHECK-VERBOSE-DAG: # (FAIL) {{.+}}{{\\|/}}Bad.swiftinterface @@ -13,10 +13,10 @@ HIDES-ERROR-NOT: cannot find 'garbage' in scope RUN: %empty-directory(%t) RUN: echo '["Good"]' > %t/xfails-good.json -RUN: env SWIFT_EXEC=%swiftc_driver_plain not %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/xfails-sdk/ -o %t/output -xfails %t/xfails-good.json 2> %t/stderr.txt | %FileCheck -check-prefix=CHECK-XFAIL-GOOD %s -RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/stderr.txt +RUN: env SWIFT_EXEC=%swiftc_driver_plain not %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/xfails-sdk/ -o %t/output -xfails %t/xfails-good.json -log-path %t | %FileCheck -check-prefix=CHECK-XFAIL-GOOD %s +RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/Bad-Bad-err.txt RUN: env SWIFT_EXEC=%swiftc_driver_plain not %{python} %utils/swift_build_sdk_interfaces.py %mcp_opt -sdk %S/Inputs/xfails-sdk/ -v -o %t/output -xfails %t/xfails-good.json 2> %t/stderr.txt | %FileCheck -check-prefix=CHECK-XFAIL-GOOD %s -RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/stderr.txt +RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/Bad-Bad-err.txt CHECK-XFAIL-GOOD-DAG: # (FAIL) {{.+}}{{\\|/}}Bad.swiftinterface CHECK-XFAIL-GOOD-DAG: # (UPASS) {{.+}}{{\\|/}}Good.swiftinterface @@ -25,8 +25,8 @@ RUN: %empty-directory(%t) RUN: echo '["Good", "Bad"]' > %t/xfails-good-and-bad.json RUN: %swift_build_sdk_interfaces -sdk %S/Inputs/xfails-sdk/ -o %t/output -xfails %t/xfails-good-and-bad.json 2> %t/stderr.txt | %FileCheck -check-prefix=CHECK-XFAIL-GOOD-AND-BAD %s RUN: %FileCheck -check-prefix HIDES-ERROR -allow-empty %s < %t/stderr.txt -RUN: %swift_build_sdk_interfaces -sdk %S/Inputs/xfails-sdk/ -v -o %t/output -xfails %t/xfails-good-and-bad.json 2> %t/stderr.txt | %FileCheck -check-prefix=CHECK-XFAIL-GOOD-AND-BAD %s -RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/stderr.txt +RUN: %swift_build_sdk_interfaces -sdk %S/Inputs/xfails-sdk/ -v -o %t/output -xfails %t/xfails-good-and-bad.json -log-path %t | %FileCheck -check-prefix=CHECK-XFAIL-GOOD-AND-BAD %s +RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/Bad-Bad-err.txt CHECK-XFAIL-GOOD-AND-BAD-DAG: # (XFAIL) {{.+}}{{\\|/}}Bad.swiftinterface CHECK-XFAIL-GOOD-AND-BAD-DAG: # (UPASS) {{.+}}{{\\|/}}Good.swiftinterface @@ -35,8 +35,8 @@ RUN: %empty-directory(%t) RUN: echo '["Bad"]' > %t/xfails-bad.json RUN: %swift_build_sdk_interfaces -sdk %S/Inputs/xfails-sdk/ -o %t/output -xfails %t/xfails-bad.json 2> %t/stderr.txt | %FileCheck -check-prefix=CHECK-XFAIL-BAD %s RUN: %FileCheck -check-prefix HIDES-ERROR -allow-empty %s < %t/stderr.txt -RUN: %swift_build_sdk_interfaces -sdk %S/Inputs/xfails-sdk/ -v -o %t/output -xfails %t/xfails-bad.json 2> %t/stderr.txt | %FileCheck -check-prefix=CHECK-XFAIL-BAD-VERBOSE %s -RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/stderr.txt +RUN: %swift_build_sdk_interfaces -sdk %S/Inputs/xfails-sdk/ -v -o %t/output -xfails %t/xfails-bad.json -log-path %t | %FileCheck -check-prefix=CHECK-XFAIL-BAD-VERBOSE %s +RUN: %FileCheck -check-prefix PRINTS-ERROR %s < %t/Bad-Bad-err.txt CHECK-XFAIL-BAD: # (XFAIL) {{.+}}{{\\|/}}Bad.swiftinterface CHECK-XFAIL-BAD-VERBOSE-DAG: # (XFAIL) {{.+}}{{\\|/}}Bad.swiftinterface diff --git a/test/SIL/ownership-verifier/use_verifier.sil b/test/SIL/ownership-verifier/use_verifier.sil index 25e9865360868..e0ea189cea663 100644 --- a/test/SIL/ownership-verifier/use_verifier.sil +++ b/test/SIL/ownership-verifier/use_verifier.sil @@ -1307,3 +1307,24 @@ bb0(%0 : @guaranteed $ClassProtConformingRef, %1 : @owned $ClassProtConformingRe %5 = tuple(%3 : $ClassProt, %4 : $ClassProt) return %5 : $(ClassProt, ClassProt) } + +// Make sure that a generic struct that has an AnyObject constraint can be cased +// successfully. + +struct GenericS { + var t1: T + var t2: T +} + +sil [ossa] @unchecked_value_cast_forwarding : $@convention(thin) (@owned GenericS) -> @owned GenericS { +bb0(%0 : @owned $GenericS): + %1 = unchecked_value_cast %0 : $GenericS to $GenericS + return %1 : $GenericS +} + +sil [ossa] @unchecked_value_cast_forwarding_2 : $@convention(thin) (@guaranteed GenericS) -> @owned GenericS { +bb0(%0 : @guaranteed $GenericS): + %1 = unchecked_value_cast %0 : $GenericS to $GenericS + %2 = copy_value %1 : $GenericS + return %2 : $GenericS +} diff --git a/test/SILGen/specialize_attr.swift b/test/SILGen/specialize_attr.swift index b1816042ed18a..33185d824ba9c 100644 --- a/test/SILGen/specialize_attr.swift +++ b/test/SILGen/specialize_attr.swift @@ -1,10 +1,19 @@ - // RUN: %target-swift-emit-silgen -module-name specialize_attr -emit-verbose-sil %s | %FileCheck %s +// RUN: %target-swift-emit-sil -sil-verify-all -O -module-name specialize_attr -emit-verbose-sil %s | %FileCheck -check-prefix=CHECK-OPT %s -// CHECK-LABEL: @_specialize(exported: false, kind: full, where T == Int, U == Float) +// CHECK: @_specialize(exported: false, kind: full, where T == Int, U == Float) +// CHECK-NEXT: @_specialize(exported: false, kind: full, where T == Klass1, U == FakeString) // CHECK-NEXT: func specializeThis(_ t: T, u: U) @_specialize(where T == Int, U == Float) -func specializeThis(_ t: T, u: U) {} +@_specialize(where T == Klass1, U == FakeString) +public func specializeThis(_ t: T, u: U) {} + +// CHECK: @_specialize(exported: false, kind: full, where T == Klass1, U == FakeString) +// CHECK-NEXT: func specializeThis2(_ t: __owned T, u: __owned U) -> (T, U) +@_specialize(where T == Klass1, U == FakeString) +public func specializeThis2(_ t: __owned T, u: __owned U) -> (T, U) { + (t, u) +} public protocol PP { associatedtype PElt @@ -20,22 +29,88 @@ public struct SS : QQ { public typealias QElt = Int } -public struct GG {} +public class Klass1 {} +public class FakeStringData {} +public struct FakeString { + var f: FakeStringData? = nil +} + +public struct RRNonTrivial : PP { + public typealias PElt = Klass1 + var elt: Klass1? = nil +} +public struct SSNonTrivial : QQ { + public typealias QElt = FakeString + var elt: FakeString? = nil +} + +public struct GG { + var t: T +} // CHECK-LABEL: public class CC where T : PP { -// CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RR, U == SS) -// CHECK-NEXT: @inline(never) public func foo(_ u: U, g: GG) -> (U, GG) where U : QQ public class CC { + // CHECK-NEXT: var k: Klass1? + // To make non-trivial + var k: Klass1? = nil + + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RR, U == SS) + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RRNonTrivial, U == SSNonTrivial) + // CHECK-NEXT: @inline(never) public func foo(_ u: U, g: GG) -> (U, GG) where U : QQ @inline(never) @_specialize(where T == RR, U == SS) + @_specialize(where T == RRNonTrivial, U == SSNonTrivial) public func foo(_ u: U, g: GG) -> (U, GG) { return (u, g) } + + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RR, U == SS) + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RRNonTrivial, U == SSNonTrivial) + // CHECK-NEXT: @inline(never) public func foo2(_ u: __owned U, g: __owned GG) -> (U, GG) where U : QQ + @inline(never) + @_specialize(where T == RR, U == SS) + @_specialize(where T == RRNonTrivial, U == SSNonTrivial) + public func foo2(_ u: __owned U, g: __owned GG) -> (U, GG) { + return (u, g) + } } -// CHECK-LABEL: sil hidden [_specialize exported: false, kind: full, where T == Int, U == Float] [ossa] @$s15specialize_attr0A4This_1uyx_q_tr0_lF : $@convention(thin) (@in_guaranteed T, @in_guaranteed U) -> () { +// CHECK-LABEL: public struct CC2 where T : PP { +public struct CC2 { + // CHECK-NEXT: var k: Klass1? + // To make non-trivial + var k: Klass1? = nil -// CHECK-LABEL: sil [noinline] [_specialize exported: false, kind: full, where T == RR, U == SS] [ossa] @$s15specialize_attr2CCC3foo_1gqd___AA2GGVyxGtqd___AHtAA2QQRd__lF : $@convention(method) (@in_guaranteed U, GG, @guaranteed CC) -> (@out U, GG) { + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RR, U == SS) + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RRNonTrivial, U == SSNonTrivial) + // CHECK-NEXT: @inline(never) public mutating func foo(_ u: U, g: GG) -> (U, GG) where U : QQ + @inline(never) + @_specialize(where T == RR, U == SS) + @_specialize(where T == RRNonTrivial, U == SSNonTrivial) + mutating public func foo(_ u: U, g: GG) -> (U, GG) { + return (u, g) + } + + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RR, U == SS) + // CHECK-NEXT: @_specialize(exported: false, kind: full, where T == RRNonTrivial, U == SSNonTrivial) + // CHECK-NEXT: @inline(never) public mutating func foo2(_ u: __owned U, g: __owned GG) -> (U, GG) where U : QQ + @inline(never) + @_specialize(where T == RR, U == SS) + @_specialize(where T == RRNonTrivial, U == SSNonTrivial) + mutating public func foo2(_ u: __owned U, g: __owned GG) -> (U, GG) { + return (u, g) + } +} + +// CHECK-LABEL: sil [_specialize exported: false, kind: full, where T == Klass1, U == FakeString] [_specialize exported: false, kind: full, where T == Int, U == Float] [ossa] @$s15specialize_attr0A4This_1uyx_q_tr0_lF : $@convention(thin) (@in_guaranteed T, @in_guaranteed U) -> () { + +// CHECK-OPT-LABEL: sil shared [noinline] @$s15specialize_attr2CCC3foo_1gqd___AA2GGVyxGtqd___AHtAA2QQRd__lFAA12RRNonTrivialV_AA05SSNonH0VTg5 : $@convention(method) (@guaranteed SSNonTrivial, @guaranteed GG, @guaranteed CC) -> (@owned SSNonTrivial, @out GG) { + +// CHECK-OPT-LABEL: sil shared [noinline] @$s15specialize_attr2CCC4foo2_1gqd___AA2GGVyxGtqd__n_AHntAA2QQRd__lFAA2RRV_AA2SSVTg5 : $@convention(method) (SS, GG, @guaranteed CC) -> (SS, @out GG) { + +// CHECK-OPT-LABEL: sil [noinline] @$s15specialize_attr2CCC4foo2_1gqd___AA2GGVyxGtqd__n_AHntAA2QQRd__lF : $@convention(method) (@in U, @in GG, @guaranteed CC) -> (@out U, @out GG) { + +// CHECK-LABEL: sil [noinline] [_specialize exported: false, kind: full, where T == RRNonTrivial, U == SSNonTrivial] [_specialize exported: false, kind: full, where T == RR, U == SS] [ossa] @$s15specialize_attr2CCC3foo_1gqd___AA2GGVyxGtqd___AHtAA2QQRd__lF : $@convention(method) (@in_guaranteed U, @in_guaranteed GG, @guaranteed CC) -> (@out U, @out GG) { // ----------------------------------------------------------------------------- // Test user-specialized subscript accessors. @@ -54,10 +129,13 @@ public class ASubscriptable : TestSubscriptable { public subscript(i: Int) -> Element { @_specialize(where Element == Int) + @_specialize(where Element == Klass1) get { return storage[i] } + @_specialize(where Element == Int) + @_specialize(where Element == Klass1) set(rhs) { storage[i] = rhs } @@ -65,10 +143,10 @@ public class ASubscriptable : TestSubscriptable { } // ASubscriptable.subscript.getter with _specialize -// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr14ASubscriptableCyxSicig : $@convention(method) (Int, @guaranteed ASubscriptable) -> @out Element { +// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Klass1] [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr14ASubscriptableCyxSicig : $@convention(method) (Int, @guaranteed ASubscriptable) -> @out Element { // ASubscriptable.subscript.setter with _specialize -// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr14ASubscriptableCyxSicis : $@convention(method) (@in Element, Int, @guaranteed ASubscriptable) -> () { +// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Klass1] [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr14ASubscriptableCyxSicis : $@convention(method) (@in Element, Int, @guaranteed ASubscriptable) -> () { // ASubscriptable.subscript.modify with no attribute // CHECK-LABEL: sil [transparent] [serialized] [ossa] @$s15specialize_attr14ASubscriptableCyxSiciM : $@yield_once @convention(method) (Int, @guaranteed ASubscriptable) -> @yields @inout Element { @@ -82,10 +160,13 @@ public class Addressable : TestSubscriptable { public subscript(i: Int) -> Element { @_specialize(where Element == Int) + @_specialize(where Element == Klass1) unsafeAddress { return UnsafePointer(storage + i) } + @_specialize(where Element == Int) + @_specialize(where Element == Klass1) unsafeMutableAddress { return UnsafeMutablePointer(storage + i) } @@ -93,10 +174,10 @@ public class Addressable : TestSubscriptable { } // Addressable.subscript.unsafeAddressor with _specialize -// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr11AddressableCyxSicilu : $@convention(method) (Int, @guaranteed Addressable) -> UnsafePointer { +// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Klass1] [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr11AddressableCyxSicilu : $@convention(method) (Int, @guaranteed Addressable) -> UnsafePointer { // Addressable.subscript.unsafeMutableAddressor with _specialize -// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr11AddressableCyxSiciau : $@convention(method) (Int, @guaranteed Addressable) -> UnsafeMutablePointer { +// CHECK-LABEL: sil [_specialize exported: false, kind: full, where Element == Klass1] [_specialize exported: false, kind: full, where Element == Int] [ossa] @$s15specialize_attr11AddressableCyxSiciau : $@convention(method) (Int, @guaranteed Addressable) -> UnsafeMutablePointer { // Addressable.subscript.getter with no attribute // CHECK-LABEL: sil [transparent] [serialized] [ossa] @$s15specialize_attr11AddressableCyxSicig : $@convention(method) (Int, @guaranteed Addressable) -> @out Element { diff --git a/test/SILGen/synthesized_conformance_class.swift b/test/SILGen/synthesized_conformance_class.swift index 362e18b9cc172..1b03b38a803fd 100644 --- a/test/SILGen/synthesized_conformance_class.swift +++ b/test/SILGen/synthesized_conformance_class.swift @@ -4,6 +4,7 @@ final class Final { var x: T init(x: T) { self.x = x } } + // CHECK-LABEL: final class Final { // CHECK: @_hasStorage final var x: T { get set } // CHECK: init(x: T) @@ -51,6 +52,16 @@ class Nonfinal { // CHECK: func encode(to encoder: Encoder) throws // CHECK: } +// Make sure that CodingKeys members are actually emitted. + +// CHECK-LABEL: sil private [ossa] @$s29synthesized_conformance_class5FinalC10CodingKeys{{.*}}21__derived_enum_equalsySbAFyx_G_AHtFZ : $@convention(method) (Final.CodingKeys, Final.CodingKeys, @thin Final.CodingKeys.Type) -> Bool { +// CHECK-LABEL: sil private [ossa] @$s29synthesized_conformance_class5FinalC10CodingKeys{{.*}}9hashValueSivg : $@convention(method) (Final.CodingKeys) -> Int { +// CHECK-LABEL: sil private [ossa] @$s29synthesized_conformance_class5FinalC10CodingKeys{{.*}}4hash4intoys6HasherVz_tF : $@convention(method) (@inout Hasher, Final.CodingKeys) -> () { +// CHECK-LABEL: sil private [ossa] @$s29synthesized_conformance_class5FinalC10CodingKeys{{.*}}11stringValueSSvg : $@convention(method) (Final.CodingKeys) -> @owned String { +// CHECK-LABEL: sil private [ossa] @$s29synthesized_conformance_class5FinalC10CodingKeys{{.*}}11stringValueAFyx_GSgSS_tcfC : $@convention(method) (@owned String, @thin Final.CodingKeys.Type) -> Optional.CodingKeys> { +// CHECK-LABEL: sil private [ossa] @$s29synthesized_conformance_class5FinalC10CodingKeys{{.*}}8intValueSiSgvg : $@convention(method) (Final.CodingKeys) -> Optional { +// CHECK-LABEL: sil private [ossa] @$s29synthesized_conformance_class5FinalC10CodingKeys{{.*}}8intValueAFyx_GSgSi_tcfC : $@convention(method) (Int, @thin Final.CodingKeys.Type) -> Optional.CodingKeys> { + extension Final: Encodable where T: Encodable {} // CHECK-LABEL: // Final.encode(to:) // CHECK-NEXT: sil hidden [ossa] @$s29synthesized_conformance_class5FinalCAASERzlE6encode2toys7Encoder_p_tKF : $@convention(method) (@in_guaranteed Encoder, @guaranteed Final) -> @error Error { diff --git a/test/SILOptimizer/eager_specialize_ossa.sil b/test/SILOptimizer/eager_specialize_ossa.sil new file mode 100644 index 0000000000000..892531543a111 --- /dev/null +++ b/test/SILOptimizer/eager_specialize_ossa.sil @@ -0,0 +1,1064 @@ +// RUN: %target-sil-opt -enable-sil-verify-all -eager-specializer %s | %FileCheck %s +// RUN: %target-sil-opt -enable-sil-verify-all -eager-specializer %s -o %t.sil && %target-swift-frontend -module-name=eager_specialize -emit-ir %t.sil | %FileCheck --check-prefix=CHECK-IRGEN --check-prefix=CHECK-IRGEN-%target-cpu %s +// RUN: %target-sil-opt -enable-sil-verify-all -eager-specializer -sil-inline-generics=true -inline %s | %FileCheck --check-prefix=CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE %s +// rdar://problem/65373647 +// UNSUPPORTED: CPU=arm64e + +sil_stage canonical + +import Builtin +import Swift +import SwiftShims + +public protocol AnElt { +} + +public protocol HasElt { + associatedtype Elt + init(e: Elt) +} + +struct X : AnElt { + @_hasStorage var i: Int { get set } + init(i: Int) +} + +struct S : HasElt { + typealias Elt = X + @_hasStorage var e: X { get set } + init(e: Elt) +} + +class Klass {} + +class EltTrivialKlass : AnElt { + var i: Int { get set } + init(i: Int) {} +} + +class ContainerKlass : HasElt { + typealias Elt = EltTrivialKlass + var e: Elt { get set } + required init(e: Elt) +} + +public struct G { + var container: Container? = nil + public func getContainer(e: Container.Elt) -> Container + init() +} + +public struct GTrivial { + public func getContainer(e: Container.Elt) -> Container + init() +} + +// CHECK: @_specialize(exported: false, kind: full, where T == S) +// CHECK: public func getGenericContainer(g: G, e: T.Elt) -> T where T : HasElt, T.Elt : AnElt +@_specialize(where T == S) +public func getGenericContainer(g: G, e: T.Elt) -> T where T.Elt : AnElt + +// CHECK: @_specialize(exported: false, kind: full, where T == S) +// CHECK: @_specialize(exported: false, kind: full, where T == ContainerKlass) +// CHECK: public func getGenericContainerTrivial(g: GTrivial, e: T.Elt) -> T where T : HasElt, T.Elt : AnElt +@_specialize(where T == S) +@_specialize(where T == ContainerKlass) +public func getGenericContainerTrivial(g: GTrivial, e: T.Elt) -> T where T.Elt : AnElt + +// CHECK: @_specialize(exported: false, kind: full, where T == ContainerKlass) +// CHECK: public func getGenericContainer2_owned(g: G, e: T.Elt) -> T where T : HasElt, T.Elt : AnElt +@_specialize(where T == ContainerKlass) +public func getGenericContainer2_owned(g: G, e: T.Elt) -> T where T.Elt : AnElt + +@_specialize(where T == ContainerKlass) +public func getGenericContainer2_guaranteed(g: G, e: T.Elt) -> T where T.Elt : AnElt + +enum ArithmeticError : Error { + case DivByZero + func hash(into hasher: inout Hasher) + var _code: Int { get } +} + +// CHECK: @_specialize(exported: false, kind: full, where T == Int) +// CHECK: public func divideNum(num: T, den: T) throws -> T where T : SignedInteger, T : _ExpressibleByBuiltinIntegerLiteral +@_specialize(where T == Int) +public func divideNum(num: T, den: T) throws -> T + +@inline(never) @_optimize(none) func foo(t: T) -> Int64 + +// CHECK: @_specialize(exported: false, kind: full, where T == Int64) +// CHECK: @_specialize(exported: false, kind: full, where T == Float) +// CHECK: @_specialize(exported: false, kind: full, where T == Klass) +// CHECK: public func voidReturn(t: T) +@_specialize(where T == Int64) +@_specialize(where T == Float) +@_specialize(where T == Klass) +public func voidReturn(t: T) + +// CHECK: @_specialize(exported: false, kind: full, where T == Int64) +// CHECK: @_specialize(exported: false, kind: full, where T == Float) +// CHECK: @_specialize(exported: false, kind: full, where T == Klass) +// CHECK: public func nonvoidReturn(t: T) -> Int64 +@_specialize(where T == Int64) +@_specialize(where T == Float) +@_specialize(where T == Klass) +public func nonvoidReturn(t: T) -> Int64 + +// --- test: protocol conformance, substitution for dependent types +// non-layout dependent generic arguments, emitUncheckedBitCast (non +// address-type) + +// Helper +// +// G.getContainer(A.Elt) -> A +sil [ossa] @$s16eager_specialize1GV12getContaineryx3EltQzF : $@convention(method) (@in Container.Elt, @in_guaranteed G) -> @out Container { +bb0(%0 : $*Container, %1 : $*Container.Elt, %2 : $*G): + %4 = witness_method $Container, #HasElt.init!allocator : $@convention(witness_method: HasElt) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @thick τ_0_0.Type) -> @out τ_0_0 + %5 = metatype $@thick Container.Type + %6 = apply %4(%0, %1, %5) : $@convention(witness_method: HasElt) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @thick τ_0_0.Type) -> @out τ_0_0 + %7 = tuple () + return %7 : $() +} + +// getGenericContainer (G, e : A.Elt) -> A +sil [_specialize where T == S] [ossa] @$s16eager_specialize19getGenericContainer_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (@in_guaranteed G, @in T.Elt) -> @out T { +bb0(%0 : $*T, %1 : $*G, %2 : $*T.Elt): + // function_ref G.getContainer(A.Elt) -> A + %5 = function_ref @$s16eager_specialize1GV12getContaineryx3EltQzF : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 + %6 = apply %5(%0, %2, %1) : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 + %7 = tuple () + return %7 : $() +} + +// Specialization getGenericContainer +// +// CHECK-LABEL: sil shared [ossa] @$s16eager_specialize19getGenericContainer_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main1SV_Tg5 : $@convention(thin) (G, X) -> S { +// CHECK: bb0(%0 : $G, %1 : $X): +// CHECK: return %{{.*}} : $S +// CHECK: } // end sil function '$s16eager_specialize19getGenericContainer_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main1SV_Tg5' + +// Generic with specialized dispatch. No more [specialize] attribute. +// +// CHECK-LABEL: sil [ossa] @$s16eager_specialize19getGenericContainer_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (@in_guaranteed G, @in T.Elt) -> @out T { +// CHECK: bb0([[ARG0:%.*]] : $*T, [[ARG1:%.*]] : $*G, [[ARG2:%.*]] : $*T.Elt): +// CHECK: [[META_LHS:%.*]] = metatype $@thick T.Type +// CHECK: [[META_RHS:%.*]] = metatype $@thick S.Type +// CHECK: [[META_LHS_WORD:%.*]] = unchecked_bitwise_cast [[META_LHS]] : $@thick T.Type to $Builtin.Word +// CHECK: [[META_RHS_WORD:%.*]] = unchecked_bitwise_cast [[META_RHS]] : $@thick S.Type to $Builtin.Word +// CHECK: [[CMP:%.*]] = builtin "cmp_eq_Word"([[META_LHS_WORD]] : $Builtin.Word, [[META_RHS_WORD]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[CMP]], bb3, bb1 + +// CHECK: bb1: +// CHECK: [[ORIGINAL_FN:%.*]] = function_ref @$s16eager_specialize1GV12getContaineryx3EltQzF : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 +// CHECK: %10 = apply [[ORIGINAL_FN]]([[ARG0]], [[ARG2]], [[ARG1]]) : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 +// CHECK: br bb2 + +// CHECK: bb2: +// CHECK: tuple () +// CHECK: return {{%.*}} : $() + +// CHECK: bb3: +// CHECK: [[ARG0_CAST:%.*]] = unchecked_addr_cast [[ARG0]] : $*T to $*S +// CHECK: [[ARG1_CAST:%.*]] = unchecked_addr_cast [[ARG1]] : $*G to $*G +// CHECK: [[ARG1_VAL:%.*]] = load [trivial] [[ARG1_CAST]] +// CHECK: [[ARG2_CAST:%.*]] = unchecked_addr_cast [[ARG2]] : $*T.Elt to $*X +// CHECK: [[ARG2_VAL:%.*]] = load [trivial] [[ARG2_CAST]] : $*X + // function_ref specialized getGenericContainer (G, e : A.Elt) -> A +// CHECK: [[SPEC_FUNC:%.*]] = function_ref @$s16eager_specialize19getGenericContainer_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main1SV_Tg5 : $@convention(thin) (G, X) -> S +// CHECK: [[RESULT:%.*]] = apply [[SPEC_FUNC]]([[ARG1_VAL]], [[ARG2_VAL]]) : $@convention(thin) (G, X) -> S +// CHECK: store [[RESULT]] to [trivial] [[ARG0_CAST]] : $*S +// CHECK: [[TUP:%.*]] = tuple () +// CHECK: unchecked_trivial_bit_cast [[TUP]] : $() to $() +// CHECK: br bb2 +// CHECK: } // end sil function '$s16eager_specialize19getGenericContainer_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF' + +// Helper +sil [ossa] @$s16eager_specialize1GV19getContainerTrivialyx3EltQzF : $@convention(method) (@in Container.Elt, GTrivial) -> @out Container { +bb0(%0 : $*Container, %1 : $*Container.Elt, %2 : $GTrivial): + %4 = witness_method $Container, #HasElt.init!allocator : $@convention(witness_method: HasElt) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @thick τ_0_0.Type) -> @out τ_0_0 + %5 = metatype $@thick Container.Type + %6 = apply %4(%0, %1, %5) : $@convention(witness_method: HasElt) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @thick τ_0_0.Type) -> @out τ_0_0 + %7 = tuple () + return %7 : $() +} + +sil [_specialize where T == ContainerKlass] [_specialize where T == S] [ossa] @$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (GTrivial, @in T.Elt) -> @out T { +bb0(%0 : $*T, %1 : $GTrivial, %2 : $*T.Elt): + // function_ref G.getContainer(A.Elt) -> A + %5 = function_ref @$s16eager_specialize1GV19getContainerTrivialyx3EltQzF : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, GTrivial<τ_0_0>) -> @out τ_0_0 + %6 = apply %5(%0, %2, %1) : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, GTrivial<τ_0_0>) -> @out τ_0_0 + %7 = tuple () + return %7 : $() +} + +// Specialization getGenericContainerTrivial +// +// CHECK-LABEL: sil shared [ossa] @$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main0E5KlassC_Tg5 : $@convention(thin) (GTrivial, @owned EltTrivialKlass) -> @owned ContainerKlass { +// CHECK: return {{%.*}} : $ContainerKlass +// CHECK: } // end sil function '$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main0E5KlassC_Tg5' +// +// CHECK-LABEL: sil shared [ossa] @$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main1SV_Tg5 : $@convention(thin) (GTrivial, X) -> S { +// CHECK: bb0(%0 : $GTrivial, %1 : $X): +// CHECK: return %{{.*}} : $S +// CHECK: } // end sil function '$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main1SV_Tg5' + +// Generic with specialized dispatch. No more [specialize] attribute. +// +// CHECK-LABEL: sil [ossa] @$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (GTrivial, @in T.Elt) -> @out T { +// CHECK: bb0([[ARG0:%.*]] : $*T, [[ARG1:%.*]] : $GTrivial, [[ARG2:%.*]] : $*T.Elt): +// CHECK: [[META_LHS:%.*]] = metatype $@thick T.Type +// CHECK: [[META_RHS:%.*]] = metatype $@thick S.Type +// CHECK: [[META_LHS_WORD:%.*]] = unchecked_bitwise_cast [[META_LHS]] : $@thick T.Type to $Builtin.Word +// CHECK: [[META_RHS_WORD:%.*]] = unchecked_bitwise_cast [[META_RHS]] : $@thick S.Type to $Builtin.Word +// CHECK: [[CMP:%.*]] = builtin "cmp_eq_Word"([[META_LHS_WORD]] : $Builtin.Word, [[META_RHS_WORD]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[CMP]], bb5, bb1 + +// CHECK: bb2: +// CHECK: [[ORIGINAL_FN:%.*]] = function_ref @$s16eager_specialize1GV19getContainerTrivialyx3EltQzF : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, GTrivial<τ_0_0>) -> @out τ_0_0 +// CHECK: apply [[ORIGINAL_FN]]([[ARG0]], [[ARG2]], [[ARG1]]) : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, GTrivial<τ_0_0>) -> @out τ_0_0 +// CHECK: br bb3 + +// CHECK: bb3: +// CHECK: tuple () +// CHECK: return {{%.*}} : $() + +// CHECK: bb4: +// CHECK: [[ARG0_CAST:%.*]] = unchecked_addr_cast [[ARG0]] : $*T to $*S +// CHECK: [[ARG1_CAST:%.*]] = unchecked_trivial_bit_cast [[ARG1]] : $GTrivial to $GTrivial +// CHECK: [[ARG2_CAST:%.*]] = unchecked_addr_cast [[ARG2]] : $*T.Elt to $*X +// CHECK: [[ARG2_VAL:%.*]] = load [trivial] [[ARG2_CAST]] : $*X + // function_ref specialized getGenericContainer (G, e : A.Elt) -> A +// CHECK: [[SPEC_FUNC:%.*]] = function_ref @$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF4main1SV_Tg5 : $@convention(thin) (GTrivial, X) -> S +// CHECK: [[RESULT:%.*]] = apply [[SPEC_FUNC]]([[ARG1_CAST]], [[ARG2_VAL]]) : $@convention(thin) (GTrivial, X) -> S +// CHECK: store [[RESULT]] to [trivial] [[ARG0_CAST]] : $*S +// CHECK: [[TUP:%.*]] = tuple () +// CHECK: unchecked_trivial_bit_cast [[TUP]] : $() to $() +// CHECK: br bb3 +// CHECK: } // end sil function '$s16eager_specialize26getGenericContainerTrivial_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF' + +// Now we check getGenericContainer_owned. + +sil [_specialize where T == ContainerKlass] [ossa] @$s16eager_specialize25getGenericContainer_owned_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (@in_guaranteed G, @in T.Elt) -> @out T { +bb0(%0 : $*T, %1 : $*G, %2 : $*T.Elt): + // function_ref G.getContainer(A.Elt) -> A + %5 = function_ref @$s16eager_specialize1GV12getContaineryx3EltQzF : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 + %6 = apply %5(%0, %2, %1) : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 + %7 = tuple () + return %7 : $() +} + +// CHECK-LABEL: sil shared [ossa] @$s16eager_specialize25getGenericContainer_owned_1exAA1GVyxG_3EltQztAA8HasownedRzAA7AnownedAHRQlF4main0E5KlassC_Tg5 : $@convention(thin) (@guaranteed G, @owned EltTrivialKlass) -> @owned ContainerKlass { +// CHECK: } // end sil function '$s16eager_specialize25getGenericContainer_owned_1exAA1GVyxG_3EltQztAA8HasownedRzAA7AnownedAHRQlF4main0E5KlassC_Tg5' + +// CHECK-LABEL: sil [ossa] @$s16eager_specialize25getGenericContainer_owned_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (@in_guaranteed G, @in T.Elt) -> @out T { +// CHECK: bb0([[ARG0:%.*]] : $*T, [[ARG1:%.*]] : $*G, [[ARG2:%.*]] : $*T.Elt): +// CHECK: bb3: +// CHECK: [[ARG0_CAST:%.*]] = unchecked_addr_cast [[ARG0]] +// CHECK: [[ARG1_CAST:%.*]] = unchecked_addr_cast [[ARG1]] +// CHECK: [[ARG1_LOAD_BORROW:%.*]] = load_borrow [[ARG1_CAST]] +// CHECK: [[ARG2_CAST:%.*]] = unchecked_addr_cast [[ARG2]] +// CHECK: [[ARG2_LOAD:%.*]] = load [take] [[ARG2_CAST]] +// CHECK: [[FUNC:%.*]] = function_ref @$s16eager_specialize25getGenericContainer_owned_1exAA1GVyxG_3EltQztAA8HasownedRzAA7AnownedAHRQlF4main0E5KlassC_Tg5 : $@convention(thin) (@guaranteed G, @owned EltTrivialKlass) -> @owned ContainerKlass +// CHECK: [[RESULT:%.*]] = apply [[FUNC]]([[ARG1_LOAD_BORROW]], [[ARG2_LOAD]]) : $@convention(thin) (@guaranteed G, @owned EltTrivialKlass) -> @owned ContainerKlass // user: %22 +// CHECK: store [[RESULT]] to [init] [[ARG0_CAST]] +// CHECK: } // end sil function '$s16eager_specialize25getGenericContainer_owned_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF' + +sil [_specialize where T == ContainerKlass] [ossa] @$s16eager_specialize30getGenericContainer_guaranteed_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (@in_guaranteed G, @in_guaranteed T.Elt) -> @out T { +bb0(%0 : $*T, %1 : $*G, %2 : $*T.Elt): + %a = alloc_stack $*T.Elt + copy_addr %2 to [initialization] %a : $*T.Elt + // function_ref G.getContainer(A.Elt) -> A + %5 = function_ref @$s16eager_specialize1GV12getContaineryx3EltQzF : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 + %6 = apply %5(%0, %a, %1) : $@convention(method) <τ_0_0 where τ_0_0 : HasElt> (@in τ_0_0.Elt, @in_guaranteed G<τ_0_0>) -> @out τ_0_0 + dealloc_stack %a : $*T.Elt + %7 = tuple () + return %7 : $() +} + +// CHECK-LABEL: sil [ossa] @$s16eager_specialize30getGenericContainer_guaranteed_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF : $@convention(thin) (@in_guaranteed G, @in_guaranteed T.Elt) -> @out T { +// CHECK: bb0([[ARG0:%.*]] : $*T, [[ARG1:%.*]] : $*G, [[ARG2:%.*]] : $*T.Elt): +// CHECK: bb3: +// CHECK: [[ARG0_CAST:%.*]] = unchecked_addr_cast [[ARG0]] +// CHECK: [[ARG1_CAST:%.*]] = unchecked_addr_cast [[ARG1]] +// CHECK: [[ARG1_LOAD_BORROW:%.*]] = load_borrow [[ARG1_CAST]] +// CHECK: [[ARG2_CAST:%.*]] = unchecked_addr_cast [[ARG2]] +// CHECK: [[ARG2_LOAD:%.*]] = load_borrow [[ARG2_CAST]] +// CHECK: [[FUNC:%.*]] = function_ref @$s16eager_specialize30getGenericContainer_guaranteed_1exAA1GVyxG_3EltQztAA13HasguaranteedRzAA12AnguaranteedAHRQlF4main0E5KlassC_Tg5 : $@convention(thin) (@guaranteed G, @guaranteed EltTrivialKlass) -> @owned ContainerKlass +// CHECK: [[RESULT:%.*]] = apply [[FUNC]]([[ARG1_LOAD_BORROW]], [[ARG2_LOAD]]) : $@convention(thin) (@guaranteed G, @guaranteed EltTrivialKlass) -> @owned ContainerKlass +// CHECK: store [[RESULT]] to [init] [[ARG0_CAST]] +// CHECK: } // end sil function '$s16eager_specialize30getGenericContainer_guaranteed_1exAA1GVyxG_3EltQztAA03HasF0RzAA02AnF0AHRQlF' + +// --- test: rethrow + +// Helper +// +// static != infix (A, A) -> Bool +sil public_external [serialized] [ossa] @$ss2neoiySbx_xts9EquatableRzlFZ : $@convention(thin) (@in T, @in T) -> Bool { +bb0(%0 : $*T, %1 : $*T): + %4 = witness_method $T, #Equatable."==" : $@convention(witness_method: Equatable) <τ_0_0 where τ_0_0 : Equatable> (@in τ_0_0, @in τ_0_0, @thick τ_0_0.Type) -> Bool + %5 = metatype $@thick T.Type + %6 = apply %4(%0, %1, %5) : $@convention(witness_method: Equatable) <τ_0_0 where τ_0_0 : Equatable> (@in τ_0_0, @in τ_0_0, @thick τ_0_0.Type) -> Bool + %7 = struct_extract %6 : $Bool, #Bool._value + %8 = integer_literal $Builtin.Int1, -1 + %9 = builtin "xor_Int1"(%7 : $Builtin.Int1, %8 : $Builtin.Int1) : $Builtin.Int1 + %10 = struct $Bool (%9 : $Builtin.Int1) + return %10 : $Bool +} + +// divideNum (A, den : A) throws -> A +sil [_specialize where T == Int] [ossa] @$s16eager_specialize9divideNum_3denxx_xtKs13SignedIntegerRzlF : $@convention(thin) (@in T, @in T) -> (@out T, @error Error) { +bb0(%0 : $*T, %1 : $*T, %2 : $*T): + // function_ref static != infix (A, A) -> Bool + %5 = function_ref @$ss2neoiySbx_xts9EquatableRzlFZ : $@convention(thin) <τ_0_0 where τ_0_0 : Equatable> (@in τ_0_0, @in τ_0_0) -> Bool + %6 = alloc_stack $T + copy_addr %2 to [initialization] %6 : $*T + %8 = witness_method $T, #_ExpressibleByBuiltinIntegerLiteral.init!allocator : $@convention(witness_method: _ExpressibleByBuiltinIntegerLiteral) <τ_0_0 where τ_0_0 : _ExpressibleByBuiltinIntegerLiteral> (Builtin.IntLiteral, @thick τ_0_0.Type) -> @out τ_0_0 + %9 = metatype $@thick T.Type + %10 = integer_literal $Builtin.IntLiteral, 0 + %11 = alloc_stack $T + %12 = apply %8(%11, %10, %9) : $@convention(witness_method: _ExpressibleByBuiltinIntegerLiteral) <τ_0_0 where τ_0_0 : _ExpressibleByBuiltinIntegerLiteral> (Builtin.IntLiteral, @thick τ_0_0.Type) -> @out τ_0_0 + %13 = apply %5(%6, %11) : $@convention(thin) <τ_0_0 where τ_0_0 : Equatable> (@in τ_0_0, @in τ_0_0) -> Bool + %14 = struct_extract %13 : $Bool, #Bool._value + dealloc_stack %11 : $*T + dealloc_stack %6 : $*T + cond_br %14, bb2, bb1 + +bb1: + destroy_addr %2 : $*T + destroy_addr %1 : $*T + %24 = alloc_existential_box $Error, $ArithmeticError + %25 = project_existential_box $ArithmeticError in %24 : $Error + %26 = enum $ArithmeticError, #ArithmeticError.DivByZero!enumelt + store %26 to [trivial] %25 : $*ArithmeticError + throw %24 : $Error + +bb2: + %18 = witness_method $T, #BinaryInteger."/" : $@convention(witness_method: BinaryInteger) <τ_0_0 where τ_0_0 : BinaryInteger> (@in τ_0_0, @in τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 + %19 = apply %18(%0, %1, %2, %9) : $@convention(witness_method: BinaryInteger) <τ_0_0 where τ_0_0 : BinaryInteger> (@in τ_0_0, @in τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 + %20 = tuple () + return %20 : $() +} + +// specialized divideNum (A, den : A) throws -> A +// CHECK-LABEL: sil shared [ossa] @$s16eager_specialize9divideNum_3denxx_xtKSZRzlFSi_Tg5 : $@convention(thin) (Int, Int) -> (Int, @error Error) { +// CHECK: bb0(%0 : $Int, %1 : $Int): +// CHECK: return %{{.*}} +// CHECK: throw %{{.*}} + +// Generic with specialized dispatch. No more [specialize] attribute. +// +// CHECK-LABEL: sil [ossa] @$s16eager_specialize9divideNum_3denxx_xtKs13SignedIntegerRzlF : $@convention(thin) (@in T, @in T) -> (@out T, @error Error) { +// CHECK: bb0(%0 : $*T, %1 : $*T, %2 : $*T): +// CHECK: %3 = metatype $@thick T.Type +// CHECK: %4 = metatype $@thick Int.Type +// CHECK: %5 = unchecked_bitwise_cast %3 : $@thick T.Type to $Builtin.Word +// CHECK: %6 = unchecked_bitwise_cast %4 : $@thick Int.Type to $Builtin.Word +// CHECK: %7 = builtin "cmp_eq_Word"(%5 : $Builtin.Word, %6 : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br %7, bb6, bb1 + +// CHECK: bb1: +// CHECK: // function_ref static != infix(_:_:) +// CHECK: cond_br %{{.*}}, bb4, bb2 + +// CHECK: bb2: +// CHECK: br bb3(%{{.*}} : $Error) + +// CHECK: bb3(%{{.*}} : @owned $Error): +// CHECK: throw %{{.*}} : $Error + +// CHECK: bb4: +// CHECK: %{{.*}} = witness_method $T, #BinaryInteger."/" : {{.*}} : $@convention(witness_method: BinaryInteger) <τ_0_0 where τ_0_0 : BinaryInteger> (@in τ_0_0, @in τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: apply %{{.*}}({{.*}}) : $@convention(witness_method: BinaryInteger) <τ_0_0 where τ_0_0 : BinaryInteger> (@in τ_0_0, @in τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: br bb5 + +// CHECK: bb5: +// CHECK: %{{.*}} = tuple () +// CHECK: return %{{.*}} : $() + +// CHECK: bb6: +// CHECK: %{{.*}} = unchecked_addr_cast %0 : $*T to $*Int +// CHECK: %{{.*}} = unchecked_addr_cast %1 : $*T to $*Int +// CHECK: %{{.*}} = load [trivial] %{{.*}} : $*Int +// CHECK: %{{.*}} = unchecked_addr_cast %2 : $*T to $*Int +// CHECK: %{{.*}} = load [trivial] %{{.*}} : $*Int +// CHECK: // function_ref specialized divideNum(_:den:) +// CHECK: %{{.*}} = function_ref @$s16eager_specialize9divideNum_3denxx_xtKSZRzlFSi_Tg5 : $@convention(thin) (Int, Int) -> (Int, @error Error) +// CHECK: try_apply %{{.*}}(%{{.*}}, %{{.*}}) : $@convention(thin) (Int, Int) -> (Int, @error Error), normal bb8, error bb7 + +// CHECK: bb7(%{{.*}} : @owned $Error): +// CHECK: br bb3(%{{.*}} : $Error) + +// CHECK: bb8(%{{.*}} : $Int): +// CHECK: store %{{.*}} to [trivial] %{{.*}} : $*Int +// CHECK: %{{.*}} = tuple () +// CHECK: %{{.*}} = unchecked_trivial_bit_cast %{{.*}} : $() to $() +// CHECK: br bb5 + +// --- test: multiple void and non-void return values + +// foo (A) -> Int64 +sil hidden [noinline] [Onone] [ossa] @$s16eager_specialize3fooys5Int64VxlF : $@convention(thin) (@in T) -> Int64 { +// %0 // users: %1, %4 +bb0(%0 : $*T): + %2 = integer_literal $Builtin.Int64, 3 + %3 = struct $Int64 (%2 : $Builtin.Int64) + destroy_addr %0 : $*T + return %3 : $Int64 +} + +// voidReturn (A) -> () +sil [_specialize where T == Float] [_specialize where T == Int64] [ossa] @$s16eager_specialize10voidReturnyyxlF : $@convention(thin) (@in T) -> () { +bb0(%0 : $*T): + // function_ref foo (A) -> Int64 + %2 = function_ref @$s16eager_specialize3fooys5Int64VxlF : $@convention(thin) <τ_0_0> (@in τ_0_0) -> Int64 + %3 = apply %2(%0) : $@convention(thin) <τ_0_0> (@in τ_0_0) -> Int64 + %4 = tuple () + return %4 : $() +} + +// CHECK-LABEL: // specialized voidReturn(_:) +// CHECK: sil shared [ossa] @$s16eager_specialize10voidReturnyyxlFSf_Tg5 : $@convention(thin) (Float) -> () { +// %0 // user: %2 +// CHECK: bb0(%0 : $Float): +// CHECK: return %5 : $() + +// CHECK-LABEL: // specialized voidReturn(_:) +// CHECK: sil shared [ossa] @$s16eager_specialize10voidReturnyyxlFs5Int64V_Tg5 : $@convention(thin) (Int64) -> () { +// CHECK: bb0(%0 : $Int64): +// CHECK: return %5 : $() + +// Generic with specialized dispatch. No more [specialize] attribute. +// +// CHECK-LABEL: // voidReturn(_:) +// CHECK: sil [ossa] @$s16eager_specialize10voidReturnyyxlF : $@convention(thin) (@in T) -> () { +// CHECK: bb0(%0 : $*T): +// CHECK: builtin "cmp_eq_Word" +// CHECK: cond_br %5, bb5, bb1 + +// CHECK: bb1: +// CHECK: builtin "cmp_eq_Word" +// CHECK: cond_br %11, bb4, bb2 + +// CHECK: bb2: +// CHECK: function_ref @$s16eager_specialize3fooys5Int64VxlF : $@convention(thin) <τ_0_0> (@in τ_0_0) -> Int64 +// CHECK: apply %13(%0) : $@convention(thin) <τ_0_0> (@in τ_0_0) -> Int64 +// CHECK: br bb3 + +// CHECK: bb3: +// CHECK: tuple () +// CHECK: return + +// CHECK: bb4: +// CHECK: function_ref @$s16eager_specialize10voidReturnyyxlFSf_Tg5 : $@convention(thin) (Float) -> () +// CHECK: br bb3 + +// CHECK: bb5: +// CHECK: br bb3 + +// nonvoidReturn(A) -> Int64 +sil [_specialize where T == Float] [_specialize where T == Int64] [ossa] @$s16eager_specialize13nonvoidReturnys5Int64VxlF : $@convention(thin) (@in T) -> Int64 { +// %0 // users: %1, %3 +bb0(%0 : $*T): + // function_ref foo(A) -> Int64 + %2 = function_ref @$s16eager_specialize3fooys5Int64VxlF : $@convention(thin) <τ_0_0> (@in τ_0_0) -> Int64 + %3 = apply %2(%0) : $@convention(thin) <τ_0_0> (@in τ_0_0) -> Int64 + return %3 : $Int64 +} + +// CHECK-LABEL: // specialized nonvoidReturn(_:) +// CHECK: sil shared [ossa] @$s16eager_specialize13nonvoidReturnys5Int64VxlFSf_Tg5 : $@convention(thin) (Float) -> Int64 { +// CHECK: bb0(%0 : $Float): +// CHECK: return %4 : $Int64 +// CHECK: } // end sil function '$s16eager_specialize13nonvoidReturnys5Int64VxlFSf_Tg5' + +// CHECK-LABEL: // specialized nonvoidReturn(_:) +// CHECK: sil shared [ossa] @$s16eager_specialize13nonvoidReturnys5Int64VxlFAD_Tg5 : $@convention(thin) (Int64) -> Int64 { +// CHECK: bb0(%0 : $Int64): +// CHECK: return %4 : $Int64 +// CHECK: } // end sil function '$s16eager_specialize13nonvoidReturnys5Int64VxlFAD_Tg5' + +// CHECK-LABEL: // nonvoidReturn(_:) +// CHECK: sil [ossa] @$s16eager_specialize13nonvoidReturnys5Int64VxlF : $@convention(thin) (@in T) -> Int64 { +// CHECK: bb0(%0 : $*T): +// CHECK: builtin "cmp_eq_Word" +// CHECK: cond_br %{{.*}}, bb5, bb1 + +// CHECK: bb1: +// CHECK: builtin "cmp_eq_Word" +// CHECK: cond_br %{{.*}}, bb4, bb2 + +// CHECK: bb2: +// CHECK: // function_ref foo(_:) +// CHECK: function_ref @$s16eager_specialize3fooys5Int64VxlF : $@convention(thin) <τ_0_0> (@in τ_0_0) -> Int64 +// CHECK: apply %13 +// CHECK: br bb3(%{{.*}} : $Int64) + +// CHECK: bb3(%{{.*}} : $Int64): +// CHECK: return %{{.*}} : $Int64 + +// CHECK: bb4: +// CHECK: br bb3(%{{.*}} : $Int64) + +// CHECK: bb5: +// CHECK: br bb3(%{{.*}} : $Int64) + +//////////////////////////////////////////////////////////////////// +// Check the ability to specialize for _Trivial(64) and _Trivial(32) +//////////////////////////////////////////////////////////////////// + +// copyValueAndReturn (A, s : inout A) -> A +sil [noinline] [_specialize where S : _Trivial(32)] [_specialize where S : _Trivial(64)] [ossa] @$s16eager_specialize18copyValueAndReturn_1sxx_xztlF : $@convention(thin) (@in S, @inout S) -> @out S { +bb0(%0 : $*S, %1 : $*S, %2 : $*S): + copy_addr %2 to [initialization] %0 : $*S + destroy_addr %1 : $*S + %7 = tuple () + return %7 : $() +} // end sil function '$s16eager_specialize18copyValueAndReturn_1sxx_xztlF' + + +// Check specialized for 32 bits +// specialized copyValueAndReturn(A, s : inout A) -> A +// CHECK-LABEL: sil shared [noinline] [ossa] @$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze31_lIetilr_Tp5 : $@convention(thin) (@in S, @inout S) -> @out S +// CHECK: bb0(%0 : $*S, %1 : $*S, %2 : $*S): +// CHECK: copy_addr %2 to [initialization] %0 : $*S +// CHECK: destroy_addr %1 : $*S +// CHECK: %5 = tuple () +// CHECK: return %5 : $() +// CHECK: } // end sil function '$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze31_lIetilr_Tp5' + +// Check specialized for 64 bits +// specialized copyValueAndReturn(A, s : inout A) -> A +// CHECK-LABEL: sil shared [noinline] [ossa] @$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze63_lIetilr_Tp5 : $@convention(thin) (@in S, @inout S) -> @out S +// CHECK: bb0(%0 : $*S, %1 : $*S, %2 : $*S): +// CHECK: copy_addr %2 to [initialization] %0 : $*S +// CHECK: destroy_addr %1 : $*S +// CHECK: %5 = tuple () +// CHECK: return %5 : $() +// CHECK: } // end sil function '$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze63_lIetilr_Tp5' + +// Generic with specialized dispatch. No more [specialize] attribute. +// +// CHECK-LABEL: sil [noinline] [ossa] @$s16eager_specialize18copyValueAndReturn_1sxx_xztlF : $@convention(thin) (@in S, @inout S) -> @out S +// Check if size == 8 bytes, i.e. 64 444its +// CHECK: [[META_S:%.*]] = metatype $@thick S.Type +// CHECK: [[SIZE_S:%.*]] = builtin "sizeof"([[META_S]] : $@thick S.Type) : $Builtin.Word +// CHECK: [[SIZE_CHECK:%.*]] = integer_literal $Builtin.Word, 8 +// CHECK: [[CMP:%.*]] = builtin "cmp_eq_Word"([[SIZE_S]] : $Builtin.Word, [[SIZE_CHECK]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[CMP]], bb6, bb1 + +// Check if size == 4 bytes, i.32 2 2 2 bits +// CHECK: bb1: +// CHECK: [[META:%.*]] = metatype $@thick S.Type +// CHECK: [[SIZE_S:%.*]] = builtin "sizeof"([[META]] : $@thick S.Type) : $Builtin.Word +// CHECK: [[EXPECTED_SIZE:%.*]] = integer_literal $Builtin.Word, 4 +// CHECK: [[CMP:%.*]] = builtin "cmp_eq_Word"([[SIZE_S]] : $Builtin.Word, [[EXPECTED_SIZE]] : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br [[CMP]], bb4, bb2 + +// None of the constraint checks was successful, perform a generic copy. +// CHECK: bb2: +// CHECK: copy_addr %2 to [initialization] %0 : $*S +// CHECK: destroy_addr %1 : $*S +// CHECK: br bb3 + +// CHECK: bb3: +// CHECK: %16 = tuple () +// CHECK: return %16 : $() + +// Check if it is a trivial type +// CHECK: bb4: +// CHECK: %18 = builtin "ispod"(%8 : $@thick S.Type) : $Builtin.Int1 +// CHECK: cond_br %18, bb5, bb2 + +// Invoke the specialized function for 32 bits +// CHECK: bb5: +// CHECK: %20 = unchecked_addr_cast %0 : $*S to $*S +// CHECK: %21 = unchecked_addr_cast %1 : $*S to $*S +// CHECK: %22 = unchecked_addr_cast %2 : $*S to $*S +// function_ref specialized copyValueAndReturn (A, s : inout A) -> A +// CHECK: %23 = function_ref @$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze31_lIetilr_Tp5 : $@convention(thin) <τ_0_0 where τ_0_0 : _Trivial(32)> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: apply %23(%20, %21, %22) : $@convention(thin) <τ_0_0 where τ_0_0 : _Trivial(32)> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: %25 = tuple () +// CHECK: unchecked_trivial_bit_cast %25 : $() to $() +// CHECK: br bb3 + +// Check if it is a trivial type +// CHECK: bb6: +// CHECK: %28 = builtin "ispod"(%3 : $@thick S.Type) : $Builtin.Int1 +// CHECK: cond_br %28, bb7, bb1 + +// Invoke the specialized function for 64 bits +// CHECK: bb7: +// CHECK: %30 = unchecked_addr_cast %0 : $*S to $*S +// CHECK: %31 = unchecked_addr_cast %1 : $*S to $*S +// CHECK: %32 = unchecked_addr_cast %2 : $*S to $*S +// function_ref specialized copyValueAndReturn (A, s : inout A) -> A +// CHECK: %33 = function_ref @$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze63_lIetilr_Tp5 : $@convention(thin) <τ_0_0 where τ_0_0 : _Trivial(64)> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: apply %33(%30, %31, %32) : $@convention(thin) <τ_0_0 where τ_0_0 : _Trivial(64)> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: %35 = tuple () +// CHECK: unchecked_trivial_bit_cast %35 : $() to $() +// CHECK: br bb3 +// CHECK: } // end sil function '$s16eager_specialize18copyValueAndReturn_1sxx_xztlF' + +//////////////////////////////////////////////////////////////////// +// Check the ability to specialize for _Trivial +//////////////////////////////////////////////////////////////////// + +// copyValueAndReturn2 (A, s : inout A) -> A +sil [noinline] [_specialize where S : _Trivial] [ossa] @$s16eager_specialize19copyValueAndReturn2_1sxx_xztlF : $@convention(thin) (@in S, @inout S) -> @out S { +bb0(%0 : $*S, %1 : $*S, %2 : $*S): + copy_addr %2 to [initialization] %0 : $*S + destroy_addr %1 : $*S + %7 = tuple () + return %7 : $() +} // end sil function '$s16eager_specialize19copyValueAndReturn2_1sxx_xztlF' + +// Check the specialization for _Trivial +// specialized copyValueAndReturn2 (A, s : inout A) -> A +// CHECK-LABEL: sil shared [noinline] [ossa] @$s16eager_specialize19copyValueAndReturn2_1sxx_xztlFxxxRlzTlIetilr_Tp5 : $@convention(thin) (@in S, @inout S) -> @out S +// CHECK: bb0(%0 : $*S, %1 : $*S, %2 : $*S): +// CHECK: copy_addr %2 to [initialization] %0 : $*S +// CHECK: destroy_addr %1 : $*S +// CHECK: %5 = tuple () +// CHECK: return %5 : $() +// CHECK: } // end sil function '$s16eager_specialize19copyValueAndReturn2_1sxx_xztlFxxxRlzTlIetilr_Tp5' + +// Generic with specialized dispatch. No more [specialize] attribute. +// copyValueAndReturn2 (A, s : inout A) -> A +// CHECK-LABEL: sil [noinline] [ossa] @$s16eager_specialize19copyValueAndReturn2_1sxx_xztlF : $@convention(thin) (@in S, @inout S) -> @out S +// CHECK: bb0(%0 : $*S, %1 : $*S, %2 : $*S): +// CHECK: %3 = metatype $@thick S.Type +// CHECK: %4 = builtin "ispod"(%3 : $@thick S.Type) : $Builtin.Int1 +// CHECK: cond_br %4, bb3, bb1 + +// None of the constraint checks was successful, perform a generic copy. +// CHECK: bb1: +// CHECK: copy_addr %2 to [initialization] %0 : $*S +// CHECK: destroy_addr %1 : $*S +// CHECK: br bb2 + +// CHECK: bb2: +// CHECK: %9 = tuple () +// CHECK: return %9 : $() + +// Invoke the specialized function for trivial types +// CHECK: bb3: +// CHECK: %11 = unchecked_addr_cast %0 : $*S to $*S +// CHECK: %12 = unchecked_addr_cast %1 : $*S to $*S +// CHECK: %13 = unchecked_addr_cast %2 : $*S to $*S + // function_ref specialized copyValueAndReturn2 (A, s : inout A) -> A +// CHECK: %14 = function_ref @$s16eager_specialize19copyValueAndReturn2_1sxx_xztlFxxxRlzTlIetilr_Tp5 : $@convention(thin) <τ_0_0 where τ_0_0 : _Trivial> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: %15 = apply %14(%11, %12, %13) : $@convention(thin) <τ_0_0 where τ_0_0 : _Trivial> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: %16 = tuple () +// CHECK: %17 = unchecked_trivial_bit_cast %16 : $() to $() +// CHECK: br bb2 +// CHECK: } // end sil function '$s16eager_specialize19copyValueAndReturn2_1sxx_xztlF' + +//////////////////////////////////////////////////////////////////// +// Check the ability to specialize for _RefCountedObject +//////////////////////////////////////////////////////////////////// + +// copyValueAndReturn3 (A, s : inout A) -> A +sil [noinline] [_specialize where S : _RefCountedObject] [ossa] @$s16eager_specialize19copyValueAndReturn3_1sxx_xztlF : $@convention(thin) (@in S, @inout S) -> @out S { +bb0(%0 : $*S, %1 : $*S, %2 : $*S): + copy_addr %2 to [initialization] %0 : $*S + destroy_addr %1 : $*S + %7 = tuple () + return %7 : $() +} // end sil function '$s16eager_specialize19copyValueAndReturn3_1sxx_xztlF' + + +// Check for specialized function for _RefCountedObject +// specialized copyValueAndReturn3 (A, s : inout A) -> A +// CHECK-LABEL: sil shared [noinline] [ossa] @$s16eager_specialize19copyValueAndReturn3_1sxx_xztlFxxxRlzRlIetilr_Tp5 : $@convention(thin) (@in S, @inout S) -> @out S +// CHECK: bb0(%0 : $*S, %1 : $*S, %2 : $*S): +// CHECK: copy_addr %2 to [initialization] %0 : $*S +// CHECK: destroy_addr %1 : $*S +// CHECK: %5 = tuple () +// CHECK: return %5 : $() +// CHECK: } // end sil function '$s16eager_specialize19copyValueAndReturn3_1sxx_xztlFxxxRlzRlIetilr_Tp5' + + +// Generic with specialized dispatch. No more [specialize] attribute. +// copyValueAndReturn3 (A, s : inout A) -> A +// CHECK-LABEL: sil [noinline] [ossa] @$s16eager_specialize19copyValueAndReturn3_1sxx_xztlF : $@convention(thin) (@in S, @inout S) -> @out S { +// Check if can be a class +// CHECK: bb0(%0 : $*S, %1 : $*S, %2 : $*S): +// CHECK: %3 = metatype $@thick S.Type +// CHECK: %4 = builtin "canBeClass"(%3 : $@thick S.Type) : $Builtin.Int8 +// CHECK: %5 = integer_literal $Builtin.Int8, 1 +// CHECK: %6 = builtin "cmp_eq_Int8"(%4 : $Builtin.Int8, %5 : $Builtin.Int8) : $Builtin.Int1 +// True if it is a Swift class +// CHECK: cond_br %6, bb3, bb4 + +// CHECK: bb1: +// CHECK: copy_addr %2 to [initialization] %0 : $*S +// CHECK: destroy_addr %1 : $*S +// CHECK: br bb2 + +// CHECK: bb2: +// CHECK: %11 = tuple () +// CHECK: return %11 : $() + +// Invoke the specialized function for ref-conted objects +// CHECK: bb3: +// CHECK: %13 = unchecked_addr_cast %0 : $*S to $*S +// CHECK: %14 = unchecked_addr_cast %1 : $*S to $*S +// CHECK: %15 = unchecked_addr_cast %2 : $*S to $*S + // function_ref specialized copyValueAndReturn3 (A, s : inout A) -> A +// CHECK: %16 = function_ref @$s16eager_specialize19copyValueAndReturn3_1sxx_xztlFxxxRlzRlIetilr_Tp5 : $@convention(thin) <τ_0_0 where τ_0_0 : _RefCountedObject> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: %17 = apply %16(%13, %14, %15) : $@convention(thin) <τ_0_0 where τ_0_0 : _RefCountedObject> (@in τ_0_0, @inout τ_0_0) -> @out τ_0_0 +// CHECK: %18 = tuple () +// CHECK: %19 = unchecked_trivial_bit_cast %18 : $() to $() +// CHECK: br bb2 + +// Check if the object could be of a class or objc existential type +// CHECK: bb4: +// CHECK: %21 = integer_literal $Builtin.Int8, 2 +// CHECK: %22 = builtin "cmp_eq_Int8"(%4 : $Builtin.Int8, %21 : $Builtin.Int8) : $Builtin.Int1 +// CHECK: cond_br %22, bb5, bb1 + +// CHECK: bb5: +// CHECK: %24 = function_ref @_swift_isClassOrObjCExistentialType : $@convention(thin) <τ_0_0> (@thick τ_0_0.Type) -> Bool +// CHECK: %25 = apply %24(%3) : $@convention(thin) <τ_0_0> (@thick τ_0_0.Type) -> Bool +// CHECK: %26 = struct_extract %25 : $Bool, #Bool._value +// CHECK: cond_br %26, bb3, bb1 +// CHECK: } // end sil function '$s16eager_specialize19copyValueAndReturn3_1sxx_xztlF' + +//////////////////////////////////////////////////////////////////// +// Check the ability to produce exported specializations, which can +// be referenced from other object files. +//////////////////////////////////////////////////////////////////// + +// exportSpecializations (A) -> () +sil [_specialize exported: true, where T == Int64] [ossa] @$s16eager_specialize21exportSpecializationsyyxlF : $@convention(thin) (@in T) -> () { +bb0(%0 : $*T): + destroy_addr %0 : $*T + %3 = tuple () + return %3 : $() +} // end sil function '$s16eager_specialize21exportSpecializationsyyxlF' + +//////////////////////////////////////////////////////////////////// +// Check the ability to produce explicit partial specializations. +//////////////////////////////////////////////////////////////////// + +// checkExplicitPartialSpecialization (A, B) -> () +sil [_specialize kind: partial, where T == Int64] [ossa] @$s16eager_specialize34checkExplicitPartialSpecializationyyx_q_tr0_lF : $@convention(thin) (@in T, @in S) -> () { +bb0(%0 : $*T, %1 : $*S): + destroy_addr %1 : $*S + destroy_addr %0 : $*T + %6 = tuple () + return %6 : $() +} // end sil function '$s16eager_specialize34checkExplicitPartialSpecializationyyx_q_tr0_lF' + + +// Check for specialized function for τ_0_0 == Int64 +// specialized checkExplicitPartialSpecialization (A, B) -> () +// CHECK-LABEL: sil shared [ossa] @$s16eager_specialize34checkExplicitPartialSpecializationyyx_q_tr0_lFs5Int64Vq_ADRszr0_lIetyi_Tp5 : $@convention(thin) (Int64, @in S) -> () +// CHECK: bb0(%0 : $Int64, %1 : $*S): +// CHECK: %2 = alloc_stack $Int64 +// CHECK: store %0 to [trivial] %2 : $*Int64 +// CHECK: destroy_addr %1 : $*S +// CHECK: destroy_addr %2 : $*Int64 +// CHECK: %6 = tuple () +// CHECK: dealloc_stack %2 : $*Int64 +// CHECK: return %6 : $() +// CHECK: } // end sil function '$s16eager_specialize34checkExplicitPartialSpecializationyyx_q_tr0_lFs5Int64Vq_ADRszr0_lIetyi_Tp5' + +// Generic with specialized dispatch. No more [specialize] attribute. +// checkExplicitPartialSpecialization (A, B) -> () +// CHECK-LABEL: sil [ossa] @$s16eager_specialize34checkExplicitPartialSpecializationyyx_q_tr0_lF : $@convention(thin) (@in T, @in S) -> () +// CHECK: bb0(%0 : $*T, %1 : $*S): +// CHECK: %2 = metatype $@thick T.Type +// CHECK: %3 = metatype $@thick Int64.Type +// CHECK: %4 = unchecked_bitwise_cast %2 : $@thick T.Type to $Builtin.Word +// CHECK: %5 = unchecked_bitwise_cast %3 : $@thick Int64.Type to $Builtin.Word +// CHECK: %6 = builtin "cmp_eq_Word"(%4 : $Builtin.Word, %5 : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br %6, bb3, bb1 + +// Type dispatch was not successful. +// CHECK: bb1: +// CHECK: destroy_addr %1 : $*S +// CHECK: destroy_addr %0 : $*T +// CHECK: br bb2 + +// CHECK: bb2: +// CHECK: %11 = tuple () +// CHECK: return %11 : $() + +// Invoke a partially specialized function. +// CHECK: bb3: +// CHECK: %13 = unchecked_addr_cast %0 : $*T to $*Int64 +// CHECK: %14 = load [trivial] %13 : $*Int64 +// CHECK: %15 = unchecked_addr_cast %1 : $*S to $*S +// function_ref specialized checkExplicitPartialSpecialization (A, B) -> () +// CHECK: %16 = function_ref @$s16eager_specialize34checkExplicitPartialSpecializationyyx_q_tr0_lFs5Int64Vq_ADRszr0_lIetyi_Tp5 : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 == Int64> (Int64, @in τ_0_1) -> () +// CHECK: %17 = apply %16(%14, %15) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 == Int64> (Int64, @in τ_0_1) -> () +// CHECK: %18 = tuple () +// CHECK: %19 = unchecked_trivial_bit_cast %18 : $() to $() +// CHECK: br bb2 +// CHECK: } // end sil function '$s16eager_specialize34checkExplicitPartialSpecializationyyx_q_tr0_lF' + +///////////////////////////////////////////////////////////////////////// +// Check that functions with unreachable instructions can be specialized. +///////////////////////////////////////////////////////////////////////// + +protocol P { +} + +struct T : P { + init() +} + +extension P { + public static func f(_ x: Self) -> Self +} + +sil @error : $@convention(thin) () -> Never + +// CHECK-LABEL: sil [ossa] @$s16eager_specialize1PPAAE1fyxxFZ : $@convention(method) (@in Self, @thick Self.Type) -> @out Self +// CHECK: %3 = metatype $@thick Self.Type +// CHECK: %4 = metatype $@thick T.Type +// CHECK: %5 = unchecked_bitwise_cast %3 : $@thick Self.Type to $Builtin.Word +// CHECK: %6 = unchecked_bitwise_cast %4 : $@thick T.Type to $Builtin.Word +// CHECK: %7 = builtin "cmp_eq_Word"(%5 : $Builtin.Word, %6 : $Builtin.Word) : $Builtin.Int1 +// CHECK: cond_br %7, bb2, bb1 + +// CHECK: bb1: +// CHECK: %9 = function_ref @error : $@convention(thin) () -> Never +// CHECK: %10 = apply %9() : $@convention(thin) () -> Never +// CHECK: unreachable + +// CHECK: bb2: +// CHECK: %12 = unchecked_addr_cast %0 : $*Self to $*T +// CHECK: %13 = unchecked_addr_cast %1 : $*Self to $*T +// CHECK: %14 = load [trivial] %13 : $*T +// CHECK: %15 = unchecked_trivial_bit_cast %2 : $@thick Self.Type to $@thick T.Type +// CHECK: %16 = function_ref @$s16eager_specialize1PPAAE1fyxxFZ4main1TV_Tg5 : $@convention(method) (T, @thick T.Type) -> T +// CHECK: %17 = apply %16(%14, %15) : $@convention(method) (T, @thick T.Type) -> T +// CHECK: store %17 to [trivial] %12 : $*T +// CHECK: %19 = tuple () +// CHECK: unreachable +// CHECK: } // end sil function '$s16eager_specialize1PPAAE1fyxxFZ' + +sil [_specialize exported: false, kind: full, where Self == T] [ossa] @$s16eager_specialize1PPAAE1fyxxFZ : $@convention(method) (@in Self, @thick Self.Type) -> @out Self { +bb0(%0 : $*Self, %1 : $*Self, %2 : $@thick Self.Type): + // function_ref error + %5 = function_ref @error : $@convention(thin) () -> Never + %6 = apply %5() : $@convention(thin) () -> Never + unreachable +} // end sil function '$s16eager_specialize1PPAAE1fyxxFZ' + + +//////////////////////////////////////////////////////////////////// +// Check that IRGen generates efficient code for fixed-size Trivial +// constraints. +//////////////////////////////////////////////////////////////////// + +// Check that a specialization for _Trivial(32) uses direct loads and stores +// instead of value witness functions to load and store the value of a generic type. +// CHECK-IRGEN-LABEL: define linkonce_odr hidden swiftcc void @"$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze31_lIetilr_Tp5"(i32* noalias nocapture sret %0, i32* noalias nocapture dereferenceable(4) %1, i32* nocapture dereferenceable(4) %2, %swift.type* %S +// CHECK-IRGEN: entry: +// CHECK-IRGEN: %3 = load i32, i32* %2 +// CHECK-IRGEN-NEXT: store i32 %3, i32* %0 +// CHECK-IRGEN-NEXT: ret void +// CHECK-IRGEN-NEXT:} + +// Check that a specialization for _Trivial(64) uses direct loads and stores +// instead of value witness functions to load and store the value of a generic type. +// CHECK-IRGEN-LABEL: define linkonce_odr hidden swiftcc void @"$s16eager_specialize18copyValueAndReturn_1sxx_xztlFxxxRlze63_lIetilr_Tp5"(i64* noalias nocapture sret %0, i64* noalias nocapture dereferenceable(8) %1, i64* nocapture dereferenceable(8) %2, %swift.type* %S +// CHECK-IRGEN: entry: +// CHECK-IRGEN: %3 = load i64, i64* %2 +// CHECK-IRGEN-NEXT: store i64 %3, i64* %0 +// CHECK-IRGEN-NEXT: ret void +// CHECK-IRGEN-NEXT: } + +// Check that a specialization for _Trivial does not call the 'destroy' value witness, +// because it is known that the object is Trivial, i.e. contains no references. +// CHECK-IRGEN-LABEL: define linkonce_odr hidden swiftcc void @"$s16eager_specialize19copyValueAndReturn2_1sxx_xztlFxxxRlzTlIetilr_Tp5"(%swift.opaque* noalias nocapture sret %0, %swift.opaque* noalias nocapture %1, %swift.opaque* nocapture %2, %swift.type* %S +// CHECK-IRGEN-NEXT: entry: +// CHECK-IRGEN: %3 = bitcast %swift.type* %S to i8*** +// CHECK-IRGEN-NEXT: %4 = getelementptr inbounds i8**, i8*** %3, i{{.*}} -1 +// CHECK-IRGEN-NEXT: %S.valueWitnesses = load i8**, i8*** %4 +// CHECK-IRGEN-NEXT: %5 = getelementptr inbounds i8*, i8** %S.valueWitnesses +// CHECK-IRGEN-NEXT: %6 = load i8*, i8** %5 +// CHECK-IRGEN-NEXT: %initializeWithCopy = {{.*}} +// CHECK-IRGEN-arm64e-NEXT: ptrtoint i8** %5 to i64 +// CHECK-IRGEN-arm64e-NEXT: call i64 @llvm.ptrauth.blend.i64 +// CHECK-IRGEN-NEXT: call {{.*}} %initializeWithCopy +// CHECK-IRGEN-NEXT: ret void +// CHECK-IRGEN-NEXT: } + +// Check that a specialization for _RefCountedObject just copies the fixed-size reference, +// and call retain/release directly, instead of calling the value witness functions. +// The matching patterns in this test are rather non-precise to cover both objc and non-objc platforms. +// CHECK-IRGEN-LABEL: define{{.*}}@"$s16eager_specialize19copyValueAndReturn3_1sxx_xztlFxxxRlzRlIetilr_Tp5" +// CHECK-IRGEN: entry: +// CHECK-IRGEN-NOT: ret void +// CHECK-IRGEN: call {{.*}}etain +// CHECK-IRGEN-NOT: ret void +// CHECK-IRGEN: call {{.*}}elease +// CHECK-IRGEN: ret void + +//////////////////////////////////////////////////////////////////// +// Check that try_apply instructions are handled correctly by the +// eager specializer. +//////////////////////////////////////////////////////////////////// + +protocol ThrowingP { + func action() throws -> Int64 +} + +extension Int64 : ThrowingP { + public func action() throws -> Int64 +} + +extension Klass : ThrowingP { + public func action() throws -> Int64 +} + +class ClassUsingThrowingP { + required init() + + @_specialize(exported: false, kind: full, where T == Klass) + @_specialize(exported: false, kind: full, where T == Int64) + public static func f(_: T) throws -> Self where T : ThrowingP + + @_specialize(exported: false, kind: full, where T == Klass) + @_specialize(exported: false, kind: full, where T == Int64) + public static func g(_ t: T) throws -> Int64 where T : ThrowingP + + @_specialize(exported: false, kind: full, where T == Klass) + @_specialize(exported: false, kind: full, where T == Int64) + public static func h(_: T) throws -> Self where T : ThrowingP + + deinit +} + +// Int64.action() +sil @$ss5Int64V34eager_specialize_throwing_functionE6actionAByKF : $@convention(method) (Int64) -> (Int64, @error Error) + +// protocol witness for ThrowingP.action() in conformance Int64 +sil @$ss5Int64V34eager_specialize_throwing_function9ThrowingPA2cDP6actionAByKFTW : $@convention(witness_method: ThrowingP) (@in_guaranteed Int64) -> (Int64, @error Error) + +sil @$s34eager_specialize_throwing_function19ClassUsingThrowingPCACycfc : $@convention(method) (@owned ClassUsingThrowingP) -> @owned ClassUsingThrowingP + +sil @$s34eager_specialize_throwing_function19ClassUsingThrowingPCfd : $@convention(method) (@guaranteed ClassUsingThrowingP) -> @owned Builtin.NativeObject + +// ClassUsingThrowingP.__allocating_init() +sil @$s34eager_specialize_throwing_function19ClassUsingThrowingPCACycfC : $@convention(method) (@thick ClassUsingThrowingP.Type) -> @owned ClassUsingThrowingP + +// ClassUsingThrowingP.__deallocating_deinit +sil @$s34eager_specialize_throwing_function19ClassUsingThrowingPCfD : $@convention(method) (@owned ClassUsingThrowingP) -> () + +// f is a function that may throw according to its type, but does not actually do it. +// Check that this function is properly specialized by the eager specializer. +// It should dispatch to its specialized version, but use apply [nothrow] to invoke +// the specialized version. + +// CHECK-LABEL: sil [ossa] @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1fyACXDxKAA0G1PRzlFZ : $@convention(method) (@in T, @thick ClassUsingThrowingP.Type) -> (@owned ClassUsingThrowingP, @error Error) +// CHECK: [[SPECIALIZED:%.*]] = function_ref @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1fyACXDxKAA0G1PRzlFZs5Int64V_Tg5 : $@convention(method) (Int64, @thick ClassUsingThrowingP.Type) -> (@owned ClassUsingThrowingP, @error Error) +// CHECK: apply [nothrow] [[SPECIALIZED]] +// CHECK: // end sil function '$s34eager_specialize_throwing_function19ClassUsingThrowingPC1fyACXDxKAA0G1PRzlFZ' +// static ClassUsingThrowingP.f(_:) +sil [_specialize exported: false, kind: full, where T == Int64] [_specialize exported: false, kind: full, where T == Klass] [ossa] @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1fyACXDxKAA0G1PRzlFZ : $@convention(method) (@in T, @thick ClassUsingThrowingP.Type) -> (@owned ClassUsingThrowingP, @error Error) { +bb0(%0 : $*T, %1 : $@thick ClassUsingThrowingP.Type): + destroy_addr %0 : $*T + %4 = unchecked_trivial_bit_cast %1 : $@thick ClassUsingThrowingP.Type to $@thick @dynamic_self ClassUsingThrowingP.Type + // function_ref ClassUsingThrowingP.__allocating_init() + %7 = function_ref @$s34eager_specialize_throwing_function19ClassUsingThrowingPCACycfC : $@convention(method) (@thick ClassUsingThrowingP.Type) -> @owned ClassUsingThrowingP + %8 = upcast %4 : $@thick @dynamic_self ClassUsingThrowingP.Type to $@thick ClassUsingThrowingP.Type + %9 = apply %7(%8) : $@convention(method) (@thick ClassUsingThrowingP.Type) -> @owned ClassUsingThrowingP + %10 = unchecked_ref_cast %9 : $ClassUsingThrowingP to $ClassUsingThrowingP + return %10 : $ClassUsingThrowingP +} // end sil function '$s34eager_specialize_throwing_function19ClassUsingThrowingPC1fyACXDxKAA0G1PRzlFZ' + +// g is a function that may throw according to its type and has a try_apply inisde +// its body. +// Check that this function is properly specialized by the eager specializer. +// It should dispatch to its specialized version and use try_apply to invoke +// the specialized version. + +// CHECK-LABEL: sil [ossa] @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1gys5Int64VxKAA0G1PRzlFZ : $@convention(method) (@in T, @thick ClassUsingThrowingP.Type) -> (Int64, @error Error) +// CHECK: [[SPECIALIZED:%.*]] = function_ref @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1gys5Int64VxKAA0G1PRzlFZAF_Tg5 : $@convention(method) (Int64, @thick ClassUsingThrowingP.Type) -> (Int64, @error Error) +// CHECK: try_apply [[SPECIALIZED]] +// CHECK: // end sil function '$s34eager_specialize_throwing_function19ClassUsingThrowingPC1gys5Int64VxKAA0G1PRzlFZ' + +// static ClassUsingThrowingP.g(_:) +sil [_specialize exported: false, kind: full, where T == Int64] [_specialize exported: false, kind: full, where T == Klass] [ossa] @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1gys5Int64VxKAA0G1PRzlFZ : $@convention(method) (@in T, @thick ClassUsingThrowingP.Type) -> (Int64, @error Error) { +bb0(%0 : $*T, %1 : $@thick ClassUsingThrowingP.Type): + %5 = witness_method $T, #ThrowingP.action : (Self) -> () throws -> Int64 : $@convention(witness_method: ThrowingP) <τ_0_0 where τ_0_0 : ThrowingP> (@in_guaranteed τ_0_0) -> (Int64, @error Error) + try_apply %5(%0) : $@convention(witness_method: ThrowingP) <τ_0_0 where τ_0_0 : ThrowingP> (@in_guaranteed τ_0_0) -> (Int64, @error Error), normal bb1, error bb2 + +bb1(%7 : $Int64): // Preds: bb0 + destroy_addr %0 : $*T + return %7 : $Int64 + +bb2(%10 : @owned $Error): // Preds: bb0 + destroy_addr %0 : $*T + throw %10 : $Error +} // end sil function '$s34eager_specialize_throwing_function19ClassUsingThrowingPC1gys5Int64VxKAA0G1PRzlFZ' + +// Make sure we specialize this appropriately. This will cause us to need to use +// load_borrows when reabstracting. Make sure that we put the end_borrow in the +// same place. +// +// CHECK-LABEL: sil [ossa] @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1hys5Int64VxKAA0G1PRzlFZ : $@convention(method) (@in_guaranteed T, @thick ClassUsingThrowingP.Type) -> (Int64, @error Error) { +// CHECK: [[SPECIALIZED_1:%.*]] = function_ref @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1hys5Int64VxKAA0G1PRzlFZ4main5KlassC_Tg5 : $@convention(method) (@guaranteed Klass, @thick ClassUsingThrowingP.Type) -> (Int64, @error Error) +// CHECK: try_apply [[SPECIALIZED_1]] +// CHECK: } // end sil function '$s34eager_specialize_throwing_function19ClassUsingThrowingPC1hys5Int64VxKAA0G1PRzlFZ' + +sil [_specialize exported: false, kind: full, where T == Int64] [_specialize exported: false, kind: full, where T == Klass] [ossa] @$s34eager_specialize_throwing_function19ClassUsingThrowingPC1hys5Int64VxKAA0G1PRzlFZ : $@convention(method) (@in_guaranteed T, @thick ClassUsingThrowingP.Type) -> (Int64, @error Error) { +bb0(%0 : $*T, %1 : $@thick ClassUsingThrowingP.Type): + %5 = witness_method $T, #ThrowingP.action : (Self) -> () throws -> Int64 : $@convention(witness_method: ThrowingP) <τ_0_0 where τ_0_0 : ThrowingP> (@in_guaranteed τ_0_0) -> (Int64, @error Error) + try_apply %5(%0) : $@convention(witness_method: ThrowingP) <τ_0_0 where τ_0_0 : ThrowingP> (@in_guaranteed τ_0_0) -> (Int64, @error Error), normal bb1, error bb2 + +bb1(%7 : $Int64): // Preds: bb0 + return %7 : $Int64 + +bb2(%10 : @owned $Error): // Preds: bb0 + throw %10 : $Error +} // end sil function '$s34eager_specialize_throwing_function19ClassUsingThrowingPC1gys5Int64VxKAA0G1PRzlFZ' + +// Check that a specialization was produced and it is not inlined. +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE-LABEL: sil{{.*}}@{{.*}}testSimpleGeneric{{.*}}where T : _Trivial(64, 64) +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE-LABEL: sil{{.*}}@testSimpleGeneric : +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE: [[METATYPE:%.*]] = metatype $@thick T.Type +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE: [[SIZEOF:%.*]] = builtin "sizeof"([[METATYPE]] : $@thick T.Type) +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE: [[SIZE:%.*]] = integer_literal $Builtin.Word, 8 +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE: builtin "cmp_eq_Word"([[SIZEOF]] : $Builtin.Word, [[SIZE]] : $Builtin.Word) +// Invoke the specialization, but do not inline it! +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE: function_ref @{{.*}}testSimpleGeneric{{.*}} +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE: apply +// CHECK-EAGER-SPECIALIZE-AND-GENERICS-INLINE: // end sil function 'testSimpleGeneric' + +sil [_specialize exported: false, kind: full, where T: _Trivial(64, 64)] [ossa] @testSimpleGeneric : $@convention(thin) (@in T) -> Builtin.Int64 { +bb0(%0 : $*T): + %1 = metatype $@thick T.Type + %2 = builtin "sizeof"(%1 : $@thick T.Type) : $Builtin.Word + %8 = builtin "zextOrBitCast_Word_Int64"(%2 : $Builtin.Word) : $Builtin.Int64 + destroy_addr %0 : $*T + return %8 : $Builtin.Int64 +} + +sil_vtable ClassUsingThrowingP { + #ClassUsingThrowingP.init!allocator: (ClassUsingThrowingP.Type) -> () -> ClassUsingThrowingP : @$s34eager_specialize_throwing_function19ClassUsingThrowingPCACycfC // ClassUsingThrowingP.__allocating_init() + #ClassUsingThrowingP.init!initializer: (ClassUsingThrowingP.Type) -> () -> ClassUsingThrowingP : @$s34eager_specialize_throwing_function19ClassUsingThrowingPCACycfc // ClassUsingThrowingP.init() + #ClassUsingThrowingP.deinit!deallocator: @$s34eager_specialize_throwing_function19ClassUsingThrowingPCfD // ClassUsingThrowingP.__deallocating_deinit +} + +sil_vtable Klass { +} + +sil_vtable EltTrivialKlass { +} + +sil_vtable ContainerKlass { +} + +sil_witness_table hidden Int64: ThrowingP module eager_specialize_throwing_function { + method #ThrowingP.action: (Self) -> () throws -> Int64 : @$ss5Int64V34eager_specialize_throwing_function9ThrowingPA2cDP6actionAByKFTW // protocol witness for ThrowingP.action() in conformance Int64 +} + +sil_default_witness_table hidden ThrowingP { + no_default +} diff --git a/test/SILOptimizer/ownership_model_eliminator.sil b/test/SILOptimizer/ownership_model_eliminator.sil index 2de217a619903..5a119d1d5278e 100644 --- a/test/SILOptimizer/ownership_model_eliminator.sil +++ b/test/SILOptimizer/ownership_model_eliminator.sil @@ -342,3 +342,14 @@ bb0(%0a : $Builtin.Int32, %0b : $Builtin.Int32): %9999 = tuple() return %9999 : $() } + +// Just make sure that we do not crash on this function. +// +// CHECK-LABEL: sil @lower_unchecked_value_cast_to_unchecked_bitwise_cast : $@convention(thin) (Builtin.Int32) -> Builtin.Int32 { +// CHECK: unchecked_bitwise_cast +// CHECK: } // end sil function 'lower_unchecked_value_cast_to_unchecked_bitwise_cast' +sil [ossa] @lower_unchecked_value_cast_to_unchecked_bitwise_cast : $@convention(thin) (Builtin.Int32) -> Builtin.Int32 { +bb0(%0a : $Builtin.Int32): + %0b = unchecked_value_cast %0a : $Builtin.Int32 to $Builtin.Int32 + return %0b : $Builtin.Int32 +} diff --git a/test/Sema/generalized_accessors_availability.swift b/test/Sema/generalized_accessors_availability.swift new file mode 100644 index 0000000000000..62be2691ff26d --- /dev/null +++ b/test/Sema/generalized_accessors_availability.swift @@ -0,0 +1,69 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend -target x86_64-apple-macosx10.9 -typecheck -verify %s + +// REQUIRES: OS=macosx + +@propertyWrapper +struct SetterConditionallyAvailable { + var wrappedValue: T { + get { fatalError() } + + @available(macOS 10.10, *) + set { fatalError() } + } + + var projectedValue: T { + get { fatalError() } + + @available(macOS 10.10, *) + set { fatalError() } + } +} + +@propertyWrapper +struct ModifyConditionallyAvailable { + var wrappedValue: T { + get { fatalError() } + + @available(macOS 10.10, *) + _modify { fatalError() } + } + + var projectedValue: T { + get { fatalError() } + + @available(macOS 10.10, *) + _modify { fatalError() } + } +} + +struct Butt { + var modify_conditionally_available: Int { + get { fatalError() } + + @available(macOS 10.10, *) + _modify { fatalError() } + } + + @SetterConditionallyAvailable + var wrapped_setter_conditionally_available: Int + + @ModifyConditionallyAvailable + var wrapped_modify_conditionally_available: Int +} + +func butt(x: inout Butt) { // expected-note*{{}} + x.modify_conditionally_available = 0 // expected-error{{only available in macOS 10.10 or newer}} expected-note{{}} + x.wrapped_setter_conditionally_available = 0 // expected-error{{only available in macOS 10.10 or newer}} expected-note{{}} + x.wrapped_modify_conditionally_available = 0 // expected-error{{only available in macOS 10.10 or newer}} expected-note{{}} + x.$wrapped_setter_conditionally_available = 0 // expected-error{{only available in macOS 10.10 or newer}} expected-note{{}} + x.$wrapped_modify_conditionally_available = 0 // expected-error{{only available in macOS 10.10 or newer}} expected-note{{}} + + if #available(macOS 10.10, *) { + x.modify_conditionally_available = 0 + x.wrapped_setter_conditionally_available = 0 + x.wrapped_modify_conditionally_available = 0 + x.$wrapped_setter_conditionally_available = 0 + x.$wrapped_modify_conditionally_available = 0 + } +} diff --git a/test/expr/unary/keypath/keypath-availability.swift b/test/expr/unary/keypath/keypath-availability.swift new file mode 100644 index 0000000000000..8787647d1fce0 --- /dev/null +++ b/test/expr/unary/keypath/keypath-availability.swift @@ -0,0 +1,35 @@ +// RUN: %target-swift-frontend -target x86_64-apple-macosx10.9 -typecheck -verify %s +// REQUIRES: OS=macosx + +struct Butt { + var setter_conditionally_available: Int { + get { fatalError() } + + @available(macOS 10.10, *) + set { fatalError() } + } +} + +func assertExactType(of _: inout T, is _: T.Type) {} + +@available(macOS 10.9, *) +public func unavailableSetterContext() { + var kp = \Butt.setter_conditionally_available + assertExactType(of: &kp, is: KeyPath.self) +} +@available(macOS 10.10, *) +public func availableSetterContext() { + var kp = \Butt.setter_conditionally_available + assertExactType(of: &kp, is: WritableKeyPath.self) +} +@available(macOS 10.9, *) +public func conditionalAvailableSetterContext() { + if #available(macOS 10.10, *) { + var kp = \Butt.setter_conditionally_available + assertExactType(of: &kp, is: WritableKeyPath.self) + } else { + var kp = \Butt.setter_conditionally_available + assertExactType(of: &kp, is: KeyPath.self) + } +} + diff --git a/utils/remote-run b/utils/remote-run index effa8ef0d347e..b0efe7f95c080 100755 --- a/utils/remote-run +++ b/utils/remote-run @@ -89,7 +89,7 @@ class CommandRunner(object): print(concatenated_commands, file=sys.stderr) if self.dry_run: return - _, _ = sftp_proc.communicate(concatenated_commands) + _, _ = sftp_proc.communicate(concatenated_commands.encode('utf-8')) if sftp_proc.returncode: sys.exit(sftp_proc.returncode) diff --git a/utils/sil-mode.el b/utils/sil-mode.el index b55cf6ccee8b0..fe29c99d9b58b 100644 --- a/utils/sil-mode.el +++ b/utils/sil-mode.el @@ -150,6 +150,7 @@ "unchecked_ref_cast" "unchecked_trivial_bit_cast" "unchecked_bitwise_cast" + "unchecked_value_cast" "ref_to_raw_pointer" "raw_pointer_to_ref" "unowned_to_ref" "ref_to_unowned" "convert_function" "convert_escape_to_noescape"