diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index b9d5de842e2d1..3c57643d4fd1c 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1145,9 +1145,11 @@ class ADContext { /// purposes. void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source); - /// Get or create a derivative function index subset thunk from - /// `actualIndices` to `desiredIndices` for the given derivative function - /// value and original function operand. + /// Get or create a derivative function parameter index subset thunk from + /// `actualIndices` to `desiredIndices` for the given associated function + /// value and original function operand. Returns a pair of the parameter + /// index subset thunk and its interface substitution map (used to partially + /// apply the thunk). /// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear /// map returned by the derivative function. std::pair @@ -1156,11 +1158,14 @@ class ADContext { AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); - /// Get or create a derivative function index subset thunk from - /// `actualIndices` to `desiredIndices` for the given derivative function - /// value and original function operand. - SILFunction *getOrCreateSubsetParametersThunkForLinearMap( - SILFunction *derivativeFn, CanSILFunctionType linearMapType, + /// Get or create a derivative function parameter index subset thunk from + /// `actualIndices` to `desiredIndices` for the given associated function + /// value and original function operand. Returns a pair of the parameter + /// index subset thunk and its interface substitution map (used to partially + /// apply the thunk). + std::pair + getOrCreateSubsetParametersThunkForLinearMap( + SILFunction *assocFn, CanSILFunctionType linearMapType, CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); @@ -8098,7 +8103,7 @@ class Differentiation : public SILModuleTransform { }; } // end anonymous namespace -SILFunction * +std::pair ADContext::getOrCreateSubsetParametersThunkForLinearMap( SILFunction *parentThunk, CanSILFunctionType linearMapType, CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, @@ -8107,8 +8112,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( << "Getting a subset parameters thunk for " << linearMapType << " from " << actualIndices << " to " << desiredIndices << '\n'); - SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap(); - GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment(); + SubstitutionMap interfaceSubs; + GenericEnvironment *genericEnv = nullptr; auto thunkType = buildThunkType( parentThunk, linearMapType, targetType, genericEnv, interfaceSubs, /*withoutActuallyEscaping*/ true, @@ -8141,7 +8146,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( ProfileCounter(), IsThunk, IsNotDynamic); if (!thunk->empty()) - return thunk; + return {thunk, interfaceSubs}; thunk->setGenericEnvironment(genericEnv); thunk->setOwnershipEliminated(); @@ -8289,7 +8294,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( for (auto *alloc : reversed(localAllocations)) builder.createDeallocStack(loc, alloc); builder.createReturn(loc, ai); - return thunk; + return {thunk, interfaceSubs}; } // If pullback thunk, return only the desired results and clean up the @@ -8325,7 +8330,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap( builder.createReturn(loc, result); getGeneratedFunctions().push_back(thunk); - return thunk; + return {thunk, interfaceSubs}; } std::pair @@ -8465,22 +8470,25 @@ ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction( auto linearMapTargetType = targetType->getResults().back().getSILStorageType() .castTo(); - auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap( - thunk, linearMapType, linearMapTargetType, kind, - desiredIndices, actualIndices); + SILFunction *linearMapThunk; + SubstitutionMap linearMapSubs; + std::tie(linearMapThunk, linearMapSubs) = + getOrCreateSubsetParametersThunkForLinearMap( + thunk, linearMapType, linearMapTargetType, kind, + desiredIndices, actualIndices); - auto *innerThunkFRI = builder.createFunctionRef(loc, innerThunk); - auto *newDerivative = builder.createPartialApply( - loc, innerThunkFRI, thunk->getForwardingSubstitutionMap(), {linearMap}, + auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk); + auto *thunkedLinearMap = builder.createPartialApply( + loc, linearMapThunkFRI, linearMapSubs, {linearMap}, ParameterConvention::Direct_Guaranteed); assert(origFnType->getResults().size() == 1); if (origFnType->getResults().front().isFormalDirect()) { auto result = joinElements( - {originalDirectResult, newDerivative}, builder, loc); + {originalDirectResult, thunkedLinearMap}, builder, loc); builder.createReturn(loc, result); } else { - builder.createReturn(loc, newDerivative); + builder.createReturn(loc, thunkedLinearMap); } getGeneratedFunctions().push_back(thunk); diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index f982e17acf331..035cdd40b2894 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -293,6 +293,16 @@ extension TF_817 { } } +// TF-886: Test `partial_apply` of linear map subset parameters thunk. +@differentiable +func TF_886_foo(_: Float, _: T, _: U) -> Float { + return 0 +} +@differentiable +func TF_886_bar(x: Float, y: T) -> Float { + return TF_886_foo(x, y, 0) +} + // Test layout requirements. // The layout requirement is "contextual": the requirement is not on `T`, the