From 1c7b1aa18efbef88fdcf60b2d54cd7173d8a8af4 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 17:57:12 +0000 Subject: [PATCH 1/2] [AutoDiff] Fix `partial_apply` substitution map for subset parameters thunk. Fix `partial_apply` substitution map for subset parameters linear map thunk. The correct substitution map is computed by `buildThunkType` in the helper `ADContext::getOrCreateSubsetParametersThunkForLinearMap` and is now returned by the helper. Resolves TF-886. --- .../Mandatory/Differentiation.cpp | 34 +++++++++++-------- test/AutoDiff/generics.swift | 10 ++++++ 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 177bc7863d703..7bda4ef49c5e5 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1160,7 +1160,8 @@ class ADContext { /// Get or create an associated function index subset thunk from /// `actualIndices` to `desiredIndices` for the given associated function /// value and original function operand. - SILFunction *getOrCreateSubsetParametersThunkForLinearMap( + std::pair + getOrCreateSubsetParametersThunkForLinearMap( SILFunction *assocFn, CanSILFunctionType linearMapType, CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); @@ -8105,7 +8106,7 @@ class Differentiation : public SILModuleTransform { }; } // end anonymous namespace -SILFunction * +std::pair ADContext::getOrCreateSubsetParametersThunkForLinearMap( SILFunction *parentThunk, CanSILFunctionType linearMapType, CanSILFunctionType targetType, AutoDiffAssociatedFunctionKind kind, @@ -8114,8 +8115,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( << "Getting a subset parameters thunk for " << linearMapType << " from " << actualIndices << " to " << desiredIndices << '\n'); - SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap(); - GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment(); + SubstitutionMap interfaceSubs; + GenericEnvironment *genericEnv = nullptr; auto thunkType = buildThunkType( parentThunk, linearMapType, targetType, genericEnv, interfaceSubs, /*withoutActuallyEscaping*/ true, @@ -8148,7 +8149,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( ProfileCounter(), IsThunk, IsNotDynamic); if (!thunk->empty()) - return thunk; + return {thunk, interfaceSubs}; thunk->setGenericEnvironment(genericEnv); thunk->setOwnershipEliminated(); @@ -8296,7 +8297,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( for (auto *alloc : reversed(localAllocations)) builder.createDeallocStack(loc, alloc); builder.createReturn(loc, ai); - return thunk; + return {thunk, interfaceSubs}; } // If pullback thunk, return only the desired results and clean up the @@ -8332,7 +8333,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( builder.createReturn(loc, result); getGeneratedFunctions().push_back(thunk); - return thunk; + return {thunk, interfaceSubs}; } std::pair @@ -8472,22 +8473,25 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction( auto linearMapTargetType = targetType->getResults().back().getSILStorageType() .castTo(); - auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap( - thunk, linearMapType, linearMapTargetType, kind, - desiredIndices, actualIndices); + SILFunction *linearMapThunk; + SubstitutionMap linearMapSubs; + std::tie(linearMapThunk, linearMapSubs) = + getOrCreateSubsetParametersThunkForLinearMap( + thunk, linearMapType, linearMapTargetType, kind, + desiredIndices, actualIndices); - auto *innerThunkFRI = builder.createFunctionRef(loc, innerThunk); - auto *newDerivative = builder.createPartialApply( - loc, innerThunkFRI, thunk->getForwardingSubstitutionMap(), {linearMap}, + auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk); + auto *thunkedLinearMap = builder.createPartialApply( + loc, linearMapThunkFRI, linearMapSubs, {linearMap}, ParameterConvention::Direct_Guaranteed); assert(origFnType->getResults().size() == 1); if (origFnType->getResults().front().isFormalDirect()) { auto result = joinElements( - {originalDirectResult, newDerivative}, builder, loc); + {originalDirectResult, thunkedLinearMap}, builder, loc); builder.createReturn(loc, result); } else { - builder.createReturn(loc, newDerivative); + builder.createReturn(loc, thunkedLinearMap); } getGeneratedFunctions().push_back(thunk); diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index ff640ce45a71f..b28093d36c90b 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -293,6 +293,16 @@ extension TF_817 { } } +// TF-886: Test `partial_apply` of linear map subset parameters thunk. +@differentiable +func TF_886_foo(_: Float, _: T, _: U) -> Float { + return 0 +} +@differentiable +func TF_886_bar(x: Float, y: T) -> Float { + return TF_886_foo(x, y, 0) +} + // Test layout requirements. // The layout requirement is "contextual": the requirement is not on `T`, the From 6f3156c3ad9cd6754ec5866f2d8a090c7b98e36f Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 19:03:58 +0000 Subject: [PATCH 2/2] Update doc comments. --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 7bda4ef49c5e5..61eb86080ccc4 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1146,9 +1146,11 @@ class ADContext { /// purposes. void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source); - /// Get or create an associated function index subset thunk from + /// Get or create an associated function parameter index subset thunk from /// `actualIndices` to `desiredIndices` for the given associated function - /// value and original function operand. + /// value and original function operand. Returns a pair of the parameter + /// index subset thunk and its interface substitution map (used to partially + /// apply the thunk). /// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear /// map returned by the associated function. std::pair @@ -1157,9 +1159,11 @@ class ADContext { AutoDiffAssociatedFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); - /// Get or create an associated function index subset thunk from + /// Get or create an associated function parameter index subset thunk from /// `actualIndices` to `desiredIndices` for the given associated function - /// value and original function operand. + /// value and original function operand. Returns a pair of the parameter + /// index subset thunk and its interface substitution map (used to partially + /// apply the thunk). std::pair getOrCreateSubsetParametersThunkForLinearMap( SILFunction *assocFn, CanSILFunctionType linearMapType,