From ae5d08e2c63ce98c3c46ec83b3d52149ea343a8e Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 17 Apr 2019 20:09:08 -0700 Subject: [PATCH] [AutoDiff] Add requirements to `@differentiable` attribute on getters. If a nominal type conditionally conforms to `Differentiable`, use the conditional conformance requirements in getter `@differentiable` attributes. Resolves TF-435. --- lib/Sema/DerivedConformanceDifferentiable.cpp | 8 +++++++- test/AutoDiff/derived_differentiable_properties.swift | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index d68a3b04c92c4..dfd9c51ea21f8 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -776,9 +776,15 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, // call to the getter. if (member->getEffectiveAccess() > AccessLevel::Internal && !member->getAttrs().hasAttribute()) { + ArrayRef requirements; + // If the parent declaration context is an extension, the nominal type may + // conditionally conform to `Differentiable`. Use the conditional + // conformance requirements in getter `@differentiable` attributes. + if (auto *extDecl = dyn_cast(parentDC->getAsDecl())) + requirements = extDecl->getGenericRequirements(); auto *diffableAttr = DifferentiableAttr::create( C, /*implicit*/ true, SourceLoc(), SourceLoc(), {}, None, - None, nullptr); + None, requirements); member->getAttrs().add(diffableAttr); // If getter does not exist, trigger synthesis and compute type. if (!member->getGetter()) diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index 975f6465fc72e..d8a95673a4e55 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -85,3 +85,14 @@ struct GenericCotanMember : Differentiable, AdditiveArithmet // CHECK-AST: internal typealias TangentVector = GenericCotanMember.CotangentVector // CHECK-AST: internal typealias CotangentVector = GenericCotanMember // CHECK-AST: internal typealias AllDifferentiableVariables = GenericCotanMember.CotangentVector + +public struct ConditionallyDifferentiable { + public let x: T +} +extension ConditionallyDifferentiable : Differentiable where T : Differentiable {} + +// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct ConditionallyDifferentiable { +// CHECK-AST: @differentiable(where T : Differentiable) +// CHECK-AST: public let x: T +// CHECK-AST: internal init(x: T) +// CHECK-AST: }