Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,13 +949,22 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
if (autoDiffDerivativeFunctionIdentifier) {
auto overriddenAttrs =
overridden.getDecl()->getAttrs().getAttributes<DifferentiableAttr>();
if (llvm::none_of(overriddenAttrs, [&](const DifferentiableAttr *attr) {
return attr->getParameterIndices() ==
autoDiffDerivativeFunctionIdentifier->getParameterIndices();
})) {
return SILDeclRef();
for (const auto *attr : overriddenAttrs) {
if (attr->getParameterIndices() !=
autoDiffDerivativeFunctionIdentifier->getParameterIndices())
continue;

// TODO(TF-1056): Do we need to check generic signature requirements?

auto dfi = overridden.autoDiffDerivativeFunctionIdentifier;
overridden.autoDiffDerivativeFunctionIdentifier =
AutoDiffDerivativeFunctionIdentifier::get(
dfi->getKind(), dfi->getParameterIndices(),
attr->getDerivativeGenericSignature(),
getDecl()->getASTContext());
return overridden;
}
return overridden;
return SILDeclRef();
}
// SWIFT_ENABLE_TENSORFLOW END
return overridden;
Expand Down
3 changes: 3 additions & 0 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,9 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
if (auto *vjpDecl = diffAttr->getVJPFunction())
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
assert((!AFD->getGenericSignature() || diffAttr->getDerivativeGenericSignature()) &&
"type-checking should resolve derivative generic signatures for "
"all functions with generic signatures");
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature());
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
Expand Down
16 changes: 14 additions & 2 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,20 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
/// derivative generic signature (containing requirements), and substitution
/// map. Returns true if error is emitted.
static bool diagnoseUnsatisfiedRequirements(ADContext &context,
CanSILFunctionType origFnTy,
GenericSignature derivativeGenSig,
SubstitutionMap substMap,
DifferentiationInvoker invoker,
SourceLoc loc) {
// If the original function is polymorphic and its generic signature is the
// same as the derivative generic signature, then the requirements are
// satisfied. This check is necessary because the subsequent logic does not
// correctly handle polymorphic original functions.
// TODO(TF-1055): Can be removed after we have a robust solution for TF-1055.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what's the specific symptom of TF-1055?

Referencing the example:

sil_differentiability_witness @foo <T: A> { ... }

sil_differentiability_witness @foo <T: B> { ... }

sil @foo : <T> (T) -> T

sil @example
bb0:
  %1 = function_ref @foo : <T> (T) -> T
  %2 = differentiable_function %1
  %3 = differentiable_function_extract [vjp] %2
  ...

Currently, is the symptom that one differentiability witness is chosen at random during differentiable_function canonicalization?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd get an error that foo is not differentiable, even if the next instruction applies the vjp to something that conforms to A. However, I don't know if any swift code will actually generate SIL like that.

With this check here, I know of no swift code that causes any symptoms.

A symptom that you can get from real swift, if you remove the check here, is:

protocol P {
  @differentiable
  func f(_ x: Float) -> Float
}

extension P {
  @differentiable
  func f(_ x: Float) -> Float
}

struct S: P {}

=>

error: function is not differentiable because `Self: P` is not satisfied

(error from memory, don't have a compiler to test it on right now)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that error happens is that the differentiability witness table for P.f has a Self: P constraint, and the substitution map that gets passed in to diagnoseUnsatisfiedRequirements doesn't have anything satisfying those constraints.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the details! I wonder if this is a deficiency in the code around diagnoseUnsatisfiedRequirements (e.g. logic for computing the substitution map). I'll try to reproduce the issue after this patch lands.

if (origFnTy->getInvocationGenericSignature() && derivativeGenSig &&
origFnTy->getInvocationGenericSignature()->isEqual(derivativeGenSig))
return false;

// If there are no derivative requirements, return false.
if (!derivativeGenSig)
return false;
Expand Down Expand Up @@ -528,6 +538,7 @@ emitDerivativeFunctionReference(
peerThroughFunctionConversions<FunctionRefInst>(original)) {
auto loc = originalFRI->getLoc();
auto *originalFn = originalFRI->getReferencedFunctionOrNull();
assert(originalFn);
auto originalFnTy = originalFn->getLoweredFunctionType();
auto *desiredResultIndices =
IndexSubset::get(context.getASTContext(), originalFnTy->getNumResults(),
Expand Down Expand Up @@ -636,8 +647,9 @@ emitDerivativeFunctionReference(
substMap = ai->getSubstitutionMap();
}
if (diagnoseUnsatisfiedRequirements(
context, minimalWitness->getDerivativeGenericSignature(), substMap,
invoker, original.getLoc().getSourceLoc()))
context, original->getType().castTo<SILFunctionType>(),
minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
original.getLoc().getSourceLoc()))
return None;
DifferentiabilityWitnessFunctionKind witnessKind;
switch (kind) {
Expand Down
8 changes: 5 additions & 3 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,16 +614,18 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
if (member->getAttrs().hasAttribute<DifferentiableAttr>() ||
getter->getAttrs().hasAttribute<DifferentiableAttr>())
continue;
GenericSignature derivativeGenSig = GenericSignature();
GenericSignature derivativeGenericSignature =
getter->getGenericSignature();
// If the parent declaration context is an extension, the nominal type may
// conditionally conform to `Differentiable`. Use the extension generic
// requirements in getter `@differentiable` attributes.
if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
derivativeGenSig = extDecl->getGenericSignature();
if (auto extGenSig = extDecl->getGenericSignature())
derivativeGenericSignature = extGenSig;
auto *diffableAttr = DifferentiableAttr::create(
getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
/*linear*/ false, /*parameterIndices*/ IndexSubset::get(C, 1, {0}),
/*jvp*/ None, /*vjp*/ None, derivativeGenSig);
/*jvp*/ None, /*vjp*/ None, derivativeGenericSignature);
member->getAttrs().add(diffableAttr);
}
}
Expand Down
20 changes: 12 additions & 8 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3429,13 +3429,16 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(
// Start type-checking the arguments of the @differentiable attribute. This
// covers 'wrt:', 'jvp:', 'vjp:', and 'where', all of which are optional.

// Note: If there is a 'where' clause, then the generic signature from that
// overwrites this.
GenericSignature derivativeGenSig = original->getGenericSignature();

// Handle 'where' clause, if it exists.
// - Resolve attribute where clause requirements and store in the attribute
// for serialization.
// - Compute generic signature for autodiff derivative functions based on
// the original function's generate signature and the attribute's where
// clause requirements.
GenericSignature whereClauseGenSig = GenericSignature();
GenericEnvironment *whereClauseGenEnv = nullptr;
if (auto *whereClause = attr->getWhereClause()) {
// `@differentiable` attributes on protocol requirements do not support
Expand Down Expand Up @@ -3507,13 +3510,14 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(

// Compute generic signature and environment for autodiff associated
// functions.
whereClauseGenSig = std::move(builder).computeGenericSignature(
derivativeGenSig = std::move(builder).computeGenericSignature(
attr->getLocation(), /*allowConcreteGenericParams=*/true);
whereClauseGenEnv = whereClauseGenSig->getGenericEnvironment();
// Store the resolved derivative generic signature in the attribute.
attr->setDerivativeGenericSignature(whereClauseGenSig);
whereClauseGenEnv = derivativeGenSig->getGenericEnvironment();
}

// Store the resolved derivative generic signature in the attribute.
attr->setDerivativeGenericSignature(derivativeGenSig);

// Validate the 'wrt:' parameters.

// Get the parsed wrt param indices, which have not yet been checked.
Expand Down Expand Up @@ -3572,7 +3576,7 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(
originalFnTy->getAutoDiffDerivativeFunctionType(
checkedWrtParamIndices, /*resultIndex*/ 0,
AutoDiffDerivativeFunctionKind::JVP, lookupConformance,
whereClauseGenSig, /*makeSelfParamFirst*/ true);
derivativeGenSig, /*makeSelfParamFirst*/ true);

auto isValidJVP = [&](AbstractFunctionDecl *jvpCandidate) -> bool {
return checkFunctionSignature(
Expand All @@ -3596,7 +3600,7 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(
originalFnTy->getAutoDiffDerivativeFunctionType(
checkedWrtParamIndices, /*resultIndex*/ 0,
AutoDiffDerivativeFunctionKind::VJP, lookupConformance,
whereClauseGenSig, /*makeSelfParamFirst*/ true);
derivativeGenSig, /*makeSelfParamFirst*/ true);

auto isValidVJP = [&](AbstractFunctionDecl *vjpCandidate) -> bool {
return checkFunctionSignature(
Expand Down Expand Up @@ -3652,7 +3656,7 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(
// Register derivative function configuration.
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
original->addDerivativeFunctionConfiguration(
{checkedWrtParamIndices, resultIndices, whereClauseGenSig});
{checkedWrtParamIndices, resultIndices, derivativeGenSig});
return checkedWrtParamIndices;
}

Expand Down
50 changes: 34 additions & 16 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,37 +556,55 @@ swift::matchWitness(
(void)reqDiffAttr->getParameterIndices();
}
for (auto *reqDiffAttr : reqAttrs.getAttributes<DifferentiableAttr>()) {
bool foundExactAttr = false;
bool foundSupersetAttr = false;
bool foundExactConfig = false;
Optional<AutoDiffConfig> supersetConfig = None;
for (auto witnessConfig :
witnessAFD->getDerivativeFunctionConfigurations()) {
// We can't use witnesses that have generic signatures not satisfied by
// the requirement's generic signature.
if (witnessConfig.derivativeGenericSignature &&
!witnessConfig.derivativeGenericSignature
->requirementsNotSatisfiedBy(
reqDiffAttr->getDerivativeGenericSignature())
.empty())
continue;
// All the witness's derivative generic requirements must be satisfied
// by the requirement's derivative generic requirements OR by the
// conditional conformance requirements.
if (witnessConfig.derivativeGenericSignature) {
bool genericRequirementsSatisfied = true;
auto reqDiffGenSig = reqDiffAttr->getDerivativeGenericSignature();
auto conformanceGenSig = dc->getGenericSignatureOfContext();
for (const auto &req :
witnessConfig.derivativeGenericSignature->getRequirements()) {
auto substReq = req.subst(result.WitnessSubstitutions);
bool reqDiffGenSigSatisfies =
reqDiffGenSig && substReq &&
reqDiffGenSig->isRequirementSatisfied(*substReq);
bool conformanceGenSigSatisfies =
conformanceGenSig &&
conformanceGenSig->isRequirementSatisfied(req);
if (!reqDiffGenSigSatisfies && !conformanceGenSigSatisfies) {
genericRequirementsSatisfied = false;
break;
}
}
if (!genericRequirementsSatisfied)
continue;
}

if (witnessConfig.parameterIndices ==
reqDiffAttr->getParameterIndices())
foundExactAttr = true;
reqDiffAttr->getParameterIndices()) {
foundExactConfig = true;
break;
}
if (witnessConfig.parameterIndices->isSupersetOf(
reqDiffAttr->getParameterIndices()))
foundSupersetAttr = true;
supersetConfig = witnessConfig;
}
if (!foundExactAttr) {
if (!foundExactConfig) {
bool success = false;
if (foundSupersetAttr) {
if (supersetConfig) {
// If the witness has a "superset" derivative configuration, create an
// implicit `@differentiable` attribute with the exact requirement
// `@differentiable` attribute parameter indices.
auto *newAttr = DifferentiableAttr::create(
witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc,
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
/*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature());
/*vjp*/ None, supersetConfig->derivativeGenericSignature);
auto insertion = ctx.DifferentiableAttrs.try_emplace(
{witnessAFD, newAttr->getParameterIndices()}, newAttr);
// Valid `@differentiable` attributes are uniqued by original function
Expand Down
5 changes: 1 addition & 4 deletions test/AutoDiff/derivative_registration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,7 @@ DerivativeRegistrationTests.testWithLeakChecking("DerivativeGenericSignature") {
let generic = Generic<Float>()
let x: Tracked<Float> = 3
let dx = gradient(at: x) { x in generic.instanceMethod(x) }
// NOTE(TF-1046): `gradient(at:in:)` calls the generated derivative for
// `Generic.instanceMethod` is used, not the registered derivative. This
// behavior is likely not expected by users; TF-1046 will fix this.
expectEqual(1, dx)
expectEqual(1000, dx)
}

runAllTests()
9 changes: 5 additions & 4 deletions test/AutoDiff/loadable_by_address_cross_module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
// Next, check that differentiability_witness_functions in the client get
// correctly modified by LBA.

// RUN: %target-swift-frontend -emit-sil -I%t %s
// RUN: %target-swift-frontend -emit-sil -I%t %s | %FileCheck %s -check-prefix=CHECK-CLIENT-PRE-LBA
// RUN: %target-swift-frontend -c -I%t %s -Xllvm -sil-print-after=loadable-address 2>&1 | %FileCheck %s -check-prefix=CHECK-CLIENT-POST-LBA

// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float

// CHECK-CLIENT-POST-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float)
// CHECK-CLIENT-POST-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float) -> (Float, LargeLoadableType<τ_0_0>))
// CHECK-CLIENT-POST-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float)
// CHECK-CLIENT-POST-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float) -> (Float, LargeLoadableType<τ_0_0>))

// Finally, execute the test.

Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/nonvaried_result.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ NonVariedResultTests.testWithLeakChecking("SingleBasicBlockGeneric") {
expectEqual((0, 0, 0), gradient(at: 3, 4, 5) { simpleGeneric($0, $1, $2) })
}

// CHECK-LABEL: sil private [ossa] @AD__${{.*}}simpleGeneric{{.*}}pullback_src_0_wrt_0_1_2 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU0_13simpleGenericL_yxx_x23DifferentiationUnittest7TrackedVySfGts14DifferentiableRz13TangentVectorsAGPQzRszlF_bb0__PB__src_0_wrt_0_1_2<τ_0_0>) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector, @owned Tracked<Float>) {
// CHECK-LABEL: sil private [ossa] @AD__${{.*}}simpleGeneric{{.*}}pullback_src_0_wrt_0_1_2{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU0_13simpleGenericL_yxx_x23DifferentiationUnittest7TrackedVySfGts14DifferentiableRz13TangentVectorsAGPQzRszlF_bb0__PB__src_0_wrt_0_1_2<τ_0_0>) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector, @owned Tracked<Float>) {
// CHECK: bb0([[DX:%.*]] : $*τ_0_0, [[DY:%.*]] : $*τ_0_0, [[SEED:%.*]] : $*τ_0_0, [[PB_STRUCT:%.*]] : [[PB_STRUCT_TYPE:.*]]):
// CHECK: [[ZERO_FN_X:%.*]] = witness_method $τ_0_0, #AdditiveArithmetic.zero!getter.1 : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: [[METATYPE_X:%.*]] = metatype $@thick τ_0_0.Type
Expand Down Expand Up @@ -150,7 +150,7 @@ NonVariedResultTests.testWithLeakChecking("ComplexGeneric") {
expectEqual(0, pullback(at: Tracked<Float>(3)) { complexGeneric(10, $0) }(1))
}

// CHECK-LABEL: sil private [ossa] @AD__${{.*}}complexGeneric{{.*}}pullback_src_0_wrt_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU4_14complexGenericL_yxx_xts14DifferentiableRzlF_bb9__PB__src_0_wrt_1<τ_0_0>) -> @out τ_0_0.TangentVector {
// CHECK-LABEL: sil private [ossa] @AD__${{.*}}complexGeneric{{.*}}pullback_src_0_wrt_1{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU4_14complexGenericL_yxx_xts14DifferentiableRzlF_bb9__PB__src_0_wrt_1<τ_0_0>) -> @out τ_0_0.TangentVector {
// CHECK: bb0([[DY:%.*]] : $*τ_0_0.TangentVector, [[SEED:%.*]] : $*τ_0_0.TangentVector, [[PB_STRUCT:%.*]] : @owned [[PB_STRUCT_TYPE:.*]]):
// CHECK: destroy_value [[PB_STRUCT]] : [[PB_STRUCT_TYPE]]
// CHECK: [[ZERO_FN:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter.1 : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/protocol_requirement_autodiff_diags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ struct AttemptsToSatisfyRequirement: HasRequirement {
// This does not satisfy the requirement because the differentiable attribute is more
// constrained than the requirement's differentiable attribute.
@differentiable(where T: P)
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}}
func requirement<T: Differentiable>(_ x: T, _ y: T) -> T { x }
}
Loading