From 0501f5ffa4d8325a3391445557a22be71fb6f443 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 9 Oct 2019 12:02:45 -0700 Subject: [PATCH 1/4] [NFC] Gardening. --- .../Mandatory/Differentiation.cpp | 90 ++++++++++--------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 1354d8ce1bdda..f4147ec43a72c 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -434,6 +434,9 @@ class LinearMapInfo { /// The original function. SILFunction *const original; + /// The derivative function. + SILFunction *const derivative; + /// Activity info of the original function. const DifferentiableActivityInfo &activityInfo; @@ -464,9 +467,9 @@ class LinearMapInfo { /// A type converter, used to compute struct/enum SIL types. Lowering::TypeConverter &typeConverter; - SILBuilder &builder; - private: + /// Adds a `VarDecl` member with the given name and type to the given nominal + /// declaration. VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type) { auto &astCtx = nominal->getASTContext(); auto id = astCtx.getIdentifier(name); @@ -485,9 +488,9 @@ class LinearMapInfo { /// Retrieves the file unit that contains implicit declarations in the /// current Swift module. If it does not exist, create one. /// - // FIXME: Currently it defaults to the file containing `origFn`, if it can be - // determined. Otherwise, it defaults to any file unit in the module. To - // handle this more properly, we should make a DerivedFileUnit class to + // FIXME: Currently it defaults to the file containing `original`, if it can + // be determined. Otherwise, it defaults to any file unit in the module. To + // handle this more properly, we could revive the DerivedFileUnit class to // contain all synthesized implicit type declarations. SourceFile &getDeclarationFileUnit() { if (original->hasLocation()) @@ -699,7 +702,7 @@ class LinearMapInfo { /// branching enum field. void generateDifferentiationDataStructures( ADContext &context, const SILAutoDiffIndices &indices, - SILFunction *assocFn); + SILFunction *derivative); public: bool shouldDifferentiateApplyInst(ApplyInst *ai); @@ -710,10 +713,9 @@ class LinearMapInfo { explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, - SILFunction *original, SILFunction *assocFn, + SILFunction *original, SILFunction *derivative, const SILAutoDiffIndices &indices, - const DifferentiableActivityInfo &activityInfo, - SILBuilder &builder); + const DifferentiableActivityInfo &activityInfo); /// Returns the linear map struct associated with the given original block. StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const { @@ -771,7 +773,9 @@ class LinearMapInfo { /// `struct_extract` in the original function. VarDecl *lookUpLinearMapDecl(SILInstruction *inst) { auto lookup = linearMapValueMap.find(inst); - return lookup == linearMapValueMap.end() ? nullptr : lookup->getSecond(); + assert(lookup != linearMapValueMap.end() && + "No linear map declaration corresponding to the given instruction"); + return lookup->getSecond(); } }; @@ -1513,14 +1517,13 @@ static void collectMinimalIndicesForFunctionCall( LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, - SILFunction *original, SILFunction *assocFn, + SILFunction *original, SILFunction *derivative, const SILAutoDiffIndices &indices, - const DifferentiableActivityInfo &activityInfo, - SILBuilder &builder) - : kind(kind), original(original), activityInfo(activityInfo), - indices(indices), typeConverter(context.getTypeConverter()), - builder(builder) { - generateDifferentiationDataStructures(context, indices, assocFn); + const DifferentiableActivityInfo &activityInfo) + : kind(kind), original(original), derivative(derivative), + activityInfo(activityInfo), indices(indices), + typeConverter(context.getTypeConverter()) { + generateDifferentiationDataStructures(context, indices, derivative); } /// Returns a flag that indicates whether the `apply` instruction should be @@ -1615,7 +1618,7 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) { } /// Takes an `apply` instruction and adds its linear map function to the -/// linear map struct if it's active. +/// linear map struct if it is active. void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, const SILAutoDiffIndices &indices) { SmallVector allResults; @@ -1627,8 +1630,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, // Check if there are any active results or arguments. If not, skip // this instruction. - auto hasActiveResults = llvm::any_of( - allResults, [&](SILValue res) { + auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) { return activityInfo.isActive(res, indices); }); auto hasActiveArguments = llvm::any_of( @@ -1660,18 +1662,18 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, // Check for non-differentiable original function type. auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType origFnTy) { - // Check and diagnose non-differentiable arguments. + // Check non-differentiable arguments. for (unsigned paramIndex : range(origFnTy->getNumParameters())) { if (applyIndices.isWrtParameter(paramIndex) && !origFnTy->getParameters()[paramIndex] .getSILStorageType() - .isDifferentiable(builder.getModule())) + .isDifferentiable(derivative->getModule())) return true; } // Check non-differentiable results. if (!origFnTy->getResults()[applyIndices.source] .getSILStorageType() - .isDifferentiable(builder.getModule())) + .isDifferentiable(derivative->getModule())) return true; return false; }; @@ -1682,7 +1684,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType( parameters, source, /*differentiationOrder*/ 1, assocFnKind, context.getTypeConverter(), - LookUpConformanceInModule(builder.getModule().getSwiftModule())); + LookUpConformanceInModule(derivative->getModule().getSwiftModule())); auto assocFnResultTypes = assocFnType->getAllResultsType().castTo(); @@ -1745,8 +1747,6 @@ void LinearMapInfo::generateDifferentiationDataStructures( for (auto &origBB : *original) { for (auto &inst : origBB) { if (auto *ai = dyn_cast(&inst)) { - LLVM_DEBUG(getADDebugStream() - << "Adding linear map struct field for " << *ai); // Check for active 'inout' arguments. bool isInout = false; auto paramInfos = ai->getSubstCalleeConv().getParameters(); @@ -1761,13 +1761,17 @@ void LinearMapInfo::generateDifferentiationDataStructures( } } if (isInout) - break; + continue; + + // Add linear map field to struct for active `apply` instructions. + // Skip array literal intrinsic applications since array literal + // initialization is linear and handled separately. + if (!shouldDifferentiateApplyInst(ai) || isArrayLiteralIntrinsic(ai)) + continue; - // Add linear map to struct for active instructions. - // Do not add it for array functions since those are already linear - // and we don't need to add it to the struct. - if (shouldDifferentiateApplyInst(ai) && !isArrayLiteralIntrinsic(ai)) - addLinearMapToStruct(context, ai, indices); + LLVM_DEBUG(getADDebugStream() << "Adding linear map struct field for " + << *ai); + addLinearMapToStruct(context, ai, indices); } } } @@ -3327,8 +3331,8 @@ class VJPEmitter final context(context), original(original), attr(attr), vjp(vjp), invoker(invoker), activityInfo(getActivityInfo( context, original, attr->getIndices(), vjp)), - pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, - vjp, attr->getIndices(), activityInfo, getBuilder()) { + pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp, + attr->getIndices(), activityInfo) { // Create empty pullback function. pullback = createEmptyPullback(); context.getGeneratedFunctions().push_back(pullback); @@ -4156,7 +4160,7 @@ class JVPEmitter final //--------------------------------------------------------------------------// /// The builder for the differential function. - SILBuilder differentialAndBuilder; + SILBuilder differentialBuilder; /// Mapping from original basic blocks to corresponding differential basic /// blocks. @@ -4196,9 +4200,9 @@ class JVPEmitter final ASTContext &getASTContext() const { return jvp->getASTContext(); } SILModule &getModule() const { return jvp->getModule(); } const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); } - SILBuilder &getDifferentialBuilder() { return differentialAndBuilder; } + SILBuilder &getDifferentialBuilder() { return differentialBuilder; } SILFunction &getDifferential() { - return differentialAndBuilder.getFunction(); + return differentialBuilder.getFunction(); } SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) { #ifndef NDEBUG @@ -4243,11 +4247,11 @@ class JVPEmitter final } static SILBuilder - initializeDifferentialAndBuilder(ADContext &context, SILFunction *original, - SILDifferentiableAttr *attr, - LinearMapInfo *linearMapInfo) { + initializeDifferentialBuilder(ADContext &context, SILFunction *original, + SILDifferentiableAttr *attr, + LinearMapInfo *differentialInfo) { auto *differential = - createEmptyDifferential(context, original, attr, linearMapInfo); + createEmptyDifferential(context, original, attr, differentialInfo); return SILBuilder(*differential); } @@ -5226,8 +5230,8 @@ class JVPEmitter final invoker(invoker), activityInfo(getActivityInfo( context, original, attr->getIndices(), jvp)), differentialInfo(context, AutoDiffLinearMapKind::Differential, original, - jvp, attr->getIndices(), activityInfo, getBuilder()), - differentialAndBuilder(initializeDifferentialAndBuilder( + jvp, attr->getIndices(), activityInfo), + differentialBuilder(initializeDifferentialBuilder( context, original, attr, &differentialInfo)), diffLocalAllocBuilder(getDifferential()) { // Create empty differential function. From 9fffe8960b92c1630927bd3103858ab387068a67 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 9 Oct 2019 12:05:07 -0700 Subject: [PATCH 2/4] Remap `apply` callee type in derivative context. This is significant when the derivative has a more constrained generic signature. Resolves TF-817. --- .../Mandatory/Differentiation.cpp | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index f4147ec43a72c..b9e46370a719f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -468,6 +468,13 @@ class LinearMapInfo { Lowering::TypeConverter &typeConverter; private: + /// Remaps the given type into the derivative function's context. + SILType remapTypeInDerivative(SILType ty) { + if (ty.hasArchetype()) + return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext()); + return derivative->mapTypeIntoContext(ty); + } + /// Adds a `VarDecl` member with the given name and type to the given nominal /// declaration. VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type) { @@ -1647,9 +1654,12 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, // parameters from the function type. // - Otherwise, use the active parameters. AutoDiffIndexSubset *parameters; - auto originalFnSubstTy = ai->getSubstCalleeType(); - if (originalFnSubstTy->isDifferentiable()) { - parameters = originalFnSubstTy->getDifferentiationParameterIndices(); + auto origFnSubstTy = ai->getSubstCalleeType(); + auto remappedOrigFnSubstTy = + remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy)) + .castTo(); + if (remappedOrigFnSubstTy->isDifferentiable()) { + parameters = remappedOrigFnSubstTy->getDifferentiationParameterIndices(); } else { parameters = AutoDiffIndexSubset::get( original->getASTContext(), @@ -1664,24 +1674,24 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, [&](CanSILFunctionType origFnTy) { // Check non-differentiable arguments. for (unsigned paramIndex : range(origFnTy->getNumParameters())) { + auto remappedParamType = + origFnTy->getParameters()[paramIndex].getSILStorageType(); if (applyIndices.isWrtParameter(paramIndex) && - !origFnTy->getParameters()[paramIndex] - .getSILStorageType() - .isDifferentiable(derivative->getModule())) + !remappedParamType.isDifferentiable(derivative->getModule())) return true; } // Check non-differentiable results. - if (!origFnTy->getResults()[applyIndices.source] - .getSILStorageType() - .isDifferentiable(derivative->getModule())) + auto remappedResultType = + origFnTy->getResults()[applyIndices.source].getSILStorageType(); + if (!remappedResultType.isDifferentiable(derivative->getModule())) return true; return false; }; - if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy)) + if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy)) return; AutoDiffAssociatedFunctionKind assocFnKind(kind); - auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType( + auto assocFnType = remappedOrigFnSubstTy->getAutoDiffAssociatedFunctionType( parameters, source, /*differentiationOrder*/ 1, assocFnKind, context.getTypeConverter(), LookUpConformanceInModule(derivative->getModule().getSwiftModule())); From 39034209f28d7bb8345eba78600eda88bcd412af Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 9 Oct 2019 12:31:17 -0700 Subject: [PATCH 3/4] Add test (independent reproducer). --- test/AutoDiff/generics.swift | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index 8780a92df8b57..ff640ce45a71f 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -274,6 +274,25 @@ extension TF_697_Sequential: TF_697_Layer where Layer1: TF_697_Layer { } } +// TF-817: Test remapping `apply` callee types in derivative function context. +struct TF_817 { + func foo(_ index: Int) -> T { + fatalError() + } +} +extension TF_817: Differentiable where T: Differentiable { + @differentiating(foo) + func vjpFoo(index: Int) -> (value: T, pullback: (T.TangentVector) -> (TangentVector)) { + fatalError() + } +} +extension TF_817 { + @differentiable(wrt: self where T: Differentiable) + public func test(index: Int) -> T { + return self.foo(0) // crash happened here + } +} + // Test layout requirements. // The layout requirement is "contextual": the requirement is not on `T`, the From b2481dee4ded16998f0d03dfdf216f66557813c1 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 9 Oct 2019 12:31:30 -0700 Subject: [PATCH 4/4] Address review feedback. Inline code. --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index b9e46370a719f..1263cb1916c31 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -4256,15 +4256,6 @@ class JVPEmitter final return activityInfo; } - static SILBuilder - initializeDifferentialBuilder(ADContext &context, SILFunction *original, - SILDifferentiableAttr *attr, - LinearMapInfo *differentialInfo) { - auto *differential = - createEmptyDifferential(context, original, attr, differentialInfo); - return SILBuilder(*differential); - } - //--------------------------------------------------------------------------// // Differential struct mapping //--------------------------------------------------------------------------// @@ -5241,8 +5232,8 @@ class JVPEmitter final context, original, attr->getIndices(), jvp)), differentialInfo(context, AutoDiffLinearMapKind::Differential, original, jvp, attr->getIndices(), activityInfo), - differentialBuilder(initializeDifferentialBuilder( - context, original, attr, &differentialInfo)), + differentialBuilder(SILBuilder(*createEmptyDifferential( + context, original, attr, &differentialInfo))), diffLocalAllocBuilder(getDifferential()) { // Create empty differential function. context.getGeneratedFunctions().push_back(&getDifferential());