From 0ace6140d255f20cf3ab5851e15d5feead3729a6 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 14:14:50 -0700 Subject: [PATCH 01/13] Fix bug and add test case. --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 3 ++- test/AutoDiff/generics.swift | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index fc31564e32fb0..834c05026b368 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -3235,7 +3235,8 @@ class VJPEmitter final // This instruction is active. Determine the appropriate differentiation // strategy, and use it. auto *structDecl = sei->getStructDecl(); - if (structDecl->getAttrs().hasAttribute()) { + if (structDecl->getEffectiveAccess() <= AccessLevel::Internal + || structDecl->getAttrs().hasAttribute()) { strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; SILClonerWithScopes::visitStructExtractInst(sei); return; diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index c1bf6e9c2b31d..ba65da7e54b4c 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -128,4 +128,16 @@ func TF_508_func(x: TF_508_Struct, y: TF_508_Struct) } let TF_508_bp = pullback(at: TF_508_inst, TF_508_inst, in: TF_508_func) +// TF-523 +struct A : Differentiable & AdditiveArithmetic { + var a: Float = 1 + typealias TangentVector = A + typealias AllDifferentiableVariables = A +} + +@differentiable +func f(_ x: A) -> Float { + return x.a * 2 +} + // TODO: add more tests. From 328ce4c86d430b4e7032520cf35c2fc1462f2374 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 14:32:56 -0700 Subject: [PATCH 02/13] Fix test struct and func naming. --- test/AutoDiff/generics.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index ba65da7e54b4c..a1fe13a13050f 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -129,14 +129,14 @@ func TF_508_func(x: TF_508_Struct, y: TF_508_Struct) let TF_508_bp = pullback(at: TF_508_inst, TF_508_inst, in: TF_508_func) // TF-523 -struct A : Differentiable & AdditiveArithmetic { +struct TF_523_Struct : Differentiable & AdditiveArithmetic { var a: Float = 1 - typealias TangentVector = A - typealias AllDifferentiableVariables = A + typealias TangentVector = TF_523_Struct + typealias AllDifferentiableVariables = TF_523_Struct } @differentiable -func f(_ x: A) -> Float { +func TF_523_f(_ x: TF_523_Struct) -> Float { return x.a * 2 } From 2373df2d4a8c47a65a593ae08c1ae5ff1022e34b Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 14:33:50 -0700 Subject: [PATCH 03/13] Move or to previous line. --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 834c05026b368..ce91248262cf6 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -3235,8 +3235,8 @@ class VJPEmitter final // This instruction is active. Determine the appropriate differentiation // strategy, and use it. auto *structDecl = sei->getStructDecl(); - if (structDecl->getEffectiveAccess() <= AccessLevel::Internal - || structDecl->getAttrs().hasAttribute()) { + if (structDecl->getEffectiveAccess() <= AccessLevel::Internal || + structDecl->getAttrs().hasAttribute()) { strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; SILClonerWithScopes::visitStructExtractInst(sei); return; From 1d269cae9af55d5cd675ea4367a3e05c646f3edf Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 14:37:01 -0700 Subject: [PATCH 04/13] Make same change in VJPEmitter::visitStructElementAddrInst. --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index ce91248262cf6..a2e53c5f1c31f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -3305,7 +3305,8 @@ class VJPEmitter final // This instruction is active. Determine the appropriate differentiation // strategy, and use it. auto *structDecl = seai->getStructDecl(); - if (structDecl->getAttrs().hasAttribute()) { + if (structDecl->getEffectiveAccess() <= AccessLevel::Internal || + structDecl->getAttrs().hasAttribute()) { strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise; SILClonerWithScopes::visitStructElementAddrInst(seai); return; From d3951e0f182acd5e40565ba1cec2e1489515c7c0 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 21:49:40 -0700 Subject: [PATCH 05/13] WIP --- include/swift/AST/DiagnosticsSema.def | 2 + .../Mandatory/Differentiation.cpp | 38 +++++++------------ lib/Sema/TypeCheckAttr.cpp | 3 ++ .../differentiating_attr_type_checking.swift | 20 ++++++++++ 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 2d4fd2696027f..02579db0396bd 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, "layout requirement are not supported by '@differentiable' attribute", ()) ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) +ERROR(differentiable_attr_stored_prop_unsupported,none, +"Stored properties cannot be marked with '@differentiable'", ()) NOTE(protocol_witness_missing_specific_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index a2e53c5f1c31f..c7dc427a0bc03 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -3234,8 +3234,13 @@ class VJPEmitter final } // This instruction is active. Determine the appropriate differentiation // strategy, and use it. + // Find the corresponding getter. + auto *getterDecl = sei->getField()->getGetter(); + assert(getterDecl); + auto *getterFn = getModule().lookUpFunction( + SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); auto *structDecl = sei->getStructDecl(); - if (structDecl->getEffectiveAccess() <= AccessLevel::Internal || + if (!getterFn || structDecl->getAttrs().hasAttribute()) { strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; SILClonerWithScopes::visitStructExtractInst(sei); @@ -3243,18 +3248,8 @@ class VJPEmitter final } // The FieldwiseProductSpace strategy is not appropriate, so use the Getter // strategy. + assert(getterFn); strategies[sei] = StructExtractDifferentiationStrategy::Getter; - // Find the corresponding getter and its VJP. - auto *getterDecl = sei->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - if (!getterFn) { - context.emitNondifferentiabilityError( - sei, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } SILAutoDiffIndices indices(/*source*/ 0, AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); @@ -3304,8 +3299,13 @@ class VJPEmitter final } // This instruction is active. Determine the appropriate differentiation // strategy, and use it. + // Find the corresponding getter. + auto *getterDecl = seai->getField()->getGetter(); + assert(getterDecl); + auto *getterFn = getModule().lookUpFunction( + SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); auto *structDecl = seai->getStructDecl(); - if (structDecl->getEffectiveAccess() <= AccessLevel::Internal || + if (!getterFn || structDecl->getAttrs().hasAttribute()) { strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise; SILClonerWithScopes::visitStructElementAddrInst(seai); @@ -3313,18 +3313,8 @@ class VJPEmitter final } // The FieldwiseProductSpace strategy is not appropriate, so use the Getter // strategy. + assert(getterFn); strategies[seai] = StructExtractDifferentiationStrategy::Getter; - // Find the corresponding getter and its VJP. - auto *getterDecl = seai->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - if (!getterFn) { - context.emitNondifferentiabilityError( - seai, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } SILAutoDiffIndices indices(/*source*/ 0, AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 60e522d32df97..c0a1de031338f 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2887,6 +2887,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AbstractFunctionDecl *original = dyn_cast(D); if (auto *asd = dyn_cast(D)) { + if (asd->getImplInfo().isSimpleStored()) { + diagnoseAndRemoveAttr(attr, diag::differentiable_attr_stored_prop_unsupported); + } // When used directly on a storage decl (stored/computed property or // subscript), the getter is currently inferred to be `@differentiable`. // TODO(TF-129): Infer setter to also be `@differentiable` after diff --git a/test/AutoDiff/differentiating_attr_type_checking.swift b/test/AutoDiff/differentiating_attr_type_checking.swift index 6e6b7c525fd0f..2205f2b22f2c5 100644 --- a/test/AutoDiff/differentiating_attr_type_checking.swift +++ b/test/AutoDiff/differentiating_attr_type_checking.swift @@ -295,3 +295,23 @@ func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float) func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { return (x, { $0 }) } + +// Test usage of `@differentiable` on a stored property +struct PropertyDiff : Differentiable & AdditiveArithmetic { + // expected-error @+1 {{Stored properties cannot be marked with '@differentiable'}} + @differentiable(vjp: vjpPropertyA) + var a: Float = 1 + typealias TangentVector = PropertyDiff + typealias AllDifferentiableVariables = PropertyDiff + func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) { + (.zero, { _ in .zero }) + } +} + +@differentiable +func f(_ x: PropertyDiff) -> Float { + return x.a +} + +let a = gradient(at: PropertyDiff(), in: f) +print(a) From 7897faca7eb60d675ee709c9fb3b3748ca10abfb Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 21:49:40 -0700 Subject: [PATCH 06/13] Don't allow marking stored props as differentiable --- include/swift/AST/DiagnosticsSema.def | 2 + .../Mandatory/Differentiation.cpp | 38 +++++++------------ lib/Sema/TypeCheckAttr.cpp | 3 ++ .../differentiating_attr_type_checking.swift | 20 ++++++++++ 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 2d4fd2696027f..02579db0396bd 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, "layout requirement are not supported by '@differentiable' attribute", ()) ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) +ERROR(differentiable_attr_stored_prop_unsupported,none, +"Stored properties cannot be marked with '@differentiable'", ()) NOTE(protocol_witness_missing_specific_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index a2e53c5f1c31f..c7dc427a0bc03 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -3234,8 +3234,13 @@ class VJPEmitter final } // This instruction is active. Determine the appropriate differentiation // strategy, and use it. + // Find the corresponding getter. + auto *getterDecl = sei->getField()->getGetter(); + assert(getterDecl); + auto *getterFn = getModule().lookUpFunction( + SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); auto *structDecl = sei->getStructDecl(); - if (structDecl->getEffectiveAccess() <= AccessLevel::Internal || + if (!getterFn || structDecl->getAttrs().hasAttribute()) { strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; SILClonerWithScopes::visitStructExtractInst(sei); @@ -3243,18 +3248,8 @@ class VJPEmitter final } // The FieldwiseProductSpace strategy is not appropriate, so use the Getter // strategy. + assert(getterFn); strategies[sei] = StructExtractDifferentiationStrategy::Getter; - // Find the corresponding getter and its VJP. - auto *getterDecl = sei->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - if (!getterFn) { - context.emitNondifferentiabilityError( - sei, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } SILAutoDiffIndices indices(/*source*/ 0, AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); @@ -3304,8 +3299,13 @@ class VJPEmitter final } // This instruction is active. Determine the appropriate differentiation // strategy, and use it. + // Find the corresponding getter. + auto *getterDecl = seai->getField()->getGetter(); + assert(getterDecl); + auto *getterFn = getModule().lookUpFunction( + SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); auto *structDecl = seai->getStructDecl(); - if (structDecl->getEffectiveAccess() <= AccessLevel::Internal || + if (!getterFn || structDecl->getAttrs().hasAttribute()) { strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise; SILClonerWithScopes::visitStructElementAddrInst(seai); @@ -3313,18 +3313,8 @@ class VJPEmitter final } // The FieldwiseProductSpace strategy is not appropriate, so use the Getter // strategy. + assert(getterFn); strategies[seai] = StructExtractDifferentiationStrategy::Getter; - // Find the corresponding getter and its VJP. - auto *getterDecl = seai->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - if (!getterFn) { - context.emitNondifferentiabilityError( - seai, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } SILAutoDiffIndices indices(/*source*/ 0, AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 60e522d32df97..c0a1de031338f 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2887,6 +2887,9 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AbstractFunctionDecl *original = dyn_cast(D); if (auto *asd = dyn_cast(D)) { + if (asd->getImplInfo().isSimpleStored()) { + diagnoseAndRemoveAttr(attr, diag::differentiable_attr_stored_prop_unsupported); + } // When used directly on a storage decl (stored/computed property or // subscript), the getter is currently inferred to be `@differentiable`. // TODO(TF-129): Infer setter to also be `@differentiable` after diff --git a/test/AutoDiff/differentiating_attr_type_checking.swift b/test/AutoDiff/differentiating_attr_type_checking.swift index 6e6b7c525fd0f..2205f2b22f2c5 100644 --- a/test/AutoDiff/differentiating_attr_type_checking.swift +++ b/test/AutoDiff/differentiating_attr_type_checking.swift @@ -295,3 +295,23 @@ func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float) func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { return (x, { $0 }) } + +// Test usage of `@differentiable` on a stored property +struct PropertyDiff : Differentiable & AdditiveArithmetic { + // expected-error @+1 {{Stored properties cannot be marked with '@differentiable'}} + @differentiable(vjp: vjpPropertyA) + var a: Float = 1 + typealias TangentVector = PropertyDiff + typealias AllDifferentiableVariables = PropertyDiff + func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) { + (.zero, { _ in .zero }) + } +} + +@differentiable +func f(_ x: PropertyDiff) -> Float { + return x.a +} + +let a = gradient(at: PropertyDiff(), in: f) +print(a) From ab8cd94d4814a534e9ed8f2f3acb609ba43f1b3a Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 22:35:56 -0700 Subject: [PATCH 07/13] PR feedback. --- include/swift/AST/DiagnosticsSema.def | 4 ++-- lib/Sema/TypeCheckAttr.cpp | 3 ++- test/AutoDiff/differentiating_attr_type_checking.swift | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 02579db0396bd..24256959d9158 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2733,8 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, "layout requirement are not supported by '@differentiable' attribute", ()) ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) -ERROR(differentiable_attr_stored_prop_unsupported,none, -"Stored properties cannot be marked with '@differentiable'", ()) +ERROR(differentiable_attr_stored_property_unsupported,none, + "Stored properties cannot be marked with '@differentiable'", ()) NOTE(protocol_witness_missing_specific_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index c0a1de031338f..437a3ec291d97 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2888,7 +2888,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AbstractFunctionDecl *original = dyn_cast(D); if (auto *asd = dyn_cast(D)) { if (asd->getImplInfo().isSimpleStored()) { - diagnoseAndRemoveAttr(attr, diag::differentiable_attr_stored_prop_unsupported); + diagnoseAndRemoveAttr(attr, diag::differentiable_attr_stored_property_unsupported); + return; } // When used directly on a storage decl (stored/computed property or // subscript), the getter is currently inferred to be `@differentiable`. diff --git a/test/AutoDiff/differentiating_attr_type_checking.swift b/test/AutoDiff/differentiating_attr_type_checking.swift index 2205f2b22f2c5..d3dd13be4f766 100644 --- a/test/AutoDiff/differentiating_attr_type_checking.swift +++ b/test/AutoDiff/differentiating_attr_type_checking.swift @@ -298,7 +298,7 @@ func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { // Test usage of `@differentiable` on a stored property struct PropertyDiff : Differentiable & AdditiveArithmetic { - // expected-error @+1 {{Stored properties cannot be marked with '@differentiable'}} + // expected-error @+1 {{Stored properties cannot be marked with '@differentiable'}} @differentiable(vjp: vjpPropertyA) var a: Float = 1 typealias TangentVector = PropertyDiff From e79379e67cc207cb81e4c60b43037b1bbec7a15a Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Thu, 30 May 2019 22:38:16 -0700 Subject: [PATCH 08/13] Upper case error -> lower case. --- include/swift/AST/DiagnosticsSema.def | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 24256959d9158..8608dc78edfd2 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2734,7 +2734,7 @@ ERROR(differentiable_attr_unsupported_req_kind,none, ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) ERROR(differentiable_attr_stored_property_unsupported,none, - "Stored properties cannot be marked with '@differentiable'", ()) + "stored properties cannot be marked with '@differentiable'", ()) NOTE(protocol_witness_missing_specific_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) From e9fb45fa0f88b7e7ef915615263b92792a78d9f1 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 31 May 2019 11:15:59 -0700 Subject: [PATCH 09/13] Make stored props/vars not have custom VJPs/JVPs. --- include/swift/AST/DiagnosticsSema.def | 4 ++-- lib/Sema/TypeCheckAttr.cpp | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 8608dc78edfd2..2a40f23a47ab6 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2733,8 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, "layout requirement are not supported by '@differentiable' attribute", ()) ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) -ERROR(differentiable_attr_stored_property_unsupported,none, - "stored properties cannot be marked with '@differentiable'", ()) +ERROR(differentiable_attr_stored_property_variable_unsupported,none, + "stored properties/variables cannot be marked with '@differentiable' with a custom VJP/JVP", ()) NOTE(protocol_witness_missing_specific_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 437a3ec291d97..37c89c670aed6 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2887,8 +2887,10 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AbstractFunctionDecl *original = dyn_cast(D); if (auto *asd = dyn_cast(D)) { - if (asd->getImplInfo().isSimpleStored()) { - diagnoseAndRemoveAttr(attr, diag::differentiable_attr_stored_property_unsupported); + if (asd->getImplInfo().isSimpleStored() && + (attr->getJVP() || attr->getVJP())) { + diagnoseAndRemoveAttr(attr, + diag::differentiable_attr_stored_property_variable_unsupported); return; } // When used directly on a storage decl (stored/computed property or From 19ecff5e5156579f9ffb1450e15fe0b7336a19a4 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 31 May 2019 15:53:44 -0700 Subject: [PATCH 10/13] Remove getter differentiation pass, fix tests. --- include/swift/AST/DiagnosticsSema.def | 2 +- .../Mandatory/Differentiation.cpp | 162 +----------------- lib/Sema/TypeCheckAttr.cpp | 2 +- test/AutoDiff/autodiff_diagnostics.swift | 2 - .../AutoDiff/differentiable_attr_silgen.swift | 37 ---- .../differentiable_attr_type_checking.swift | 46 ++--- .../differentiating_attr_type_checking.swift | 2 +- .../e2e_differentiable_property.swift | 9 +- test/AutoDiff/method.swift | 18 +- .../protocol_requirement_autodiff.swift | 15 +- test/AutoDiff/simple_model.swift | 25 +-- test/AutoDiff/witness_table_silgen.swift | 5 +- 12 files changed, 58 insertions(+), 267 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 2a40f23a47ab6..bf4dace00d444 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2734,7 +2734,7 @@ ERROR(differentiable_attr_unsupported_req_kind,none, ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) ERROR(differentiable_attr_stored_property_variable_unsupported,none, - "stored properties/variables cannot be marked with '@differentiable' with a custom VJP/JVP", ()) + "'jvp:' or 'vjp:' cannot be specified for stored properties", ()) NOTE(protocol_witness_missing_specific_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index c7dc427a0bc03..cb6631e0d1ff0 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -737,11 +737,7 @@ enum class StructExtractDifferentiationStrategy { // that is zero except along the direction of the corresponding field. // // Fields correspond by matching name. - Fieldwise, - - // Differentiate the `struct_extract` by looking up the corresponding getter - // and using its VJP. - Getter + Fieldwise }; static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, @@ -3232,59 +3228,10 @@ class VJPEmitter final SILClonerWithScopes::visitStructExtractInst(sei); return; } - // This instruction is active. Determine the appropriate differentiation - // strategy, and use it. - // Find the corresponding getter. - auto *getterDecl = sei->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - auto *structDecl = sei->getStructDecl(); - if (!getterFn || - structDecl->getAttrs().hasAttribute()) { - strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; - SILClonerWithScopes::visitStructExtractInst(sei); - return; - } - // The FieldwiseProductSpace strategy is not appropriate, so use the Getter - // strategy. - assert(getterFn); - strategies[sei] = StructExtractDifferentiationStrategy::Getter; - SILAutoDiffIndices indices(/*source*/ 0, - AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); - auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); - if (!attr) { - context.emitNondifferentiabilityError( - sei, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } - // Reference and apply the VJP. - auto loc = sei->getLoc(); - auto *getterVJP = getAssociatedFunction( - context, getterFn, attr, AutoDiffAssociatedFunctionKind::VJP, - attr->getVJPName()); - assert(getterVJP && "Expected to find getter VJP"); - auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP); - auto *getterVJPApply = getBuilder().createApply( - loc, getterVJPRef, - getOpSubstitutionMap(getterVJP->getForwardingSubstitutionMap()), - /*args*/ {getOpValue(sei->getOperand())}, /*isNonThrowing*/ false); - // Extract direct results from `getterVJPApply`. - SmallVector vjpDirectResults; - extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults); - // Map original result. - auto originalDirectResults = - ArrayRef(vjpDirectResults).drop_back(1); - auto originalDirectResult = joinElements(originalDirectResults, - getBuilder(), - getterVJPApply->getLoc()); - mapValue(sei, originalDirectResult); - // Checkpoint the pullback. - auto pullback = vjpDirectResults.back(); - // TODO: Check whether it's necessary to reabstract getter pullbacks. - pullbackInfo.addPullbackDecl(sei, getOpType(pullback->getType())); - pullbackValues[sei->getParent()].push_back(pullback); + // This instruction is active. Use the field wise differentiation strategy + // to differentiate the struct extract instruction. + strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; + SILClonerWithScopes::visitStructExtractInst(sei); } void visitStructElementAddrInst(StructElementAddrInst *seai) { @@ -3297,78 +3244,10 @@ class VJPEmitter final SILClonerWithScopes::visitStructElementAddrInst(seai); return; } - // This instruction is active. Determine the appropriate differentiation - // strategy, and use it. - // Find the corresponding getter. - auto *getterDecl = seai->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - auto *structDecl = seai->getStructDecl(); - if (!getterFn || - structDecl->getAttrs().hasAttribute()) { - strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise; - SILClonerWithScopes::visitStructElementAddrInst(seai); - return; - } - // The FieldwiseProductSpace strategy is not appropriate, so use the Getter - // strategy. - assert(getterFn); - strategies[seai] = StructExtractDifferentiationStrategy::Getter; - SILAutoDiffIndices indices(/*source*/ 0, - AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); - auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); - if (!attr) { - context.emitNondifferentiabilityError( - seai, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } - // Set generic context scope before getting VJP function type. - auto vjpGenSig = SubsMap.getGenericSignature() - ? SubsMap.getGenericSignature()->getCanonicalSignature() - : nullptr; - Lowering::GenericContextScope genericContextScope( - context.getTypeConverter(), vjpGenSig); - // Reference the getter VJP. - auto loc = seai->getLoc(); - auto *getterVJP = getModule().lookUpFunction(attr->getVJPName()); - assert(getterVJP && "Expected to find getter VJP"); - auto vjpFnTy = getterVJP->getLoweredFunctionType(); - auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP); - // Store getter VJP arguments and indirect result buffers. - SmallVector vjpArgs; - SmallVector vjpIndirectResults; - for (auto indRes : vjpFnTy->getIndirectFormalResults()) { - auto *alloc = getBuilder().createAllocStack( - loc, getOpType(indRes.getSILStorageType())); - vjpArgs.push_back(alloc); - vjpIndirectResults.push_back(alloc); - } - vjpArgs.push_back(getOpValue(seai->getOperand())); - // Apply the getter VJP. - auto *getterVJPApply = getBuilder().createApply( - loc, getterVJPRef, - getOpSubstitutionMap(getterVJP->getForwardingSubstitutionMap()), - vjpArgs, /*isNonThrowing*/ false); - // Collect all results from `getterVJPApply` in type-defined order. - SmallVector vjpDirectResults; - extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults); - SmallVector allResults; - collectAllActualResultsInTypeOrder( - getterVJPApply, vjpDirectResults, - getterVJPApply->getIndirectSILResults(), allResults); - // Deallocate VJP indirect results. - for (auto alloc : vjpIndirectResults) - getBuilder().createDeallocStack(loc, alloc); - auto originalDirectResult = allResults[indices.source]; - // Map original result. - mapValue(seai, originalDirectResult); - // Checkpoint the pullback. - SILValue pullback = vjpDirectResults.back(); - // TODO: Check whether it's necessary to reabstract getter pullbacks. - pullbackInfo.addPullbackDecl(seai, getOpType(pullback->getType())); - pullbackValues[seai->getParent()].push_back(pullback); + // This instruction is active. Use the field wise differentiation strategy + // to differentiate the struct extract instruction. + strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise; + SILClonerWithScopes::visitStructElementAddrInst(seai); } // If an `apply` has active results or active inout parameters, replace it @@ -4839,29 +4718,6 @@ class AdjointEmitter final : public SILInstructionVisitor { } return; } - case StructExtractDifferentiationStrategy::Getter: { - // Get the pullback. - auto *pullbackField = getPullbackInfo().lookUpPullbackDecl(sei); - assert(pullbackField); - auto pullback = builder.createStructExtract( - loc, getAdjointBlockPullbackStructArgument(sei->getParent()), - pullbackField); - - // Construct the pullback arguments. - auto av = takeAdjointValue(sei); - auto vector = materializeAdjointDirect(std::move(av), loc); - - // Call the pullback. - auto *pullbackCall = builder.createApply( - loc, pullback, SubstitutionMap(), {vector}, /*isNonThrowing*/ false); - assert(!pullbackCall->hasIndirectResults()); - - // Accumulate adjoint for the `struct_extract` operand. - addAdjointValue(sei->getOperand(), - makeConcreteAdjointValue( - ValueWithCleanup(pullbackCall, vector.getCleanup()))); - break; - } } } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 37c89c670aed6..871c99b3eab74 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2891,7 +2891,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { (attr->getJVP() || attr->getVJP())) { diagnoseAndRemoveAttr(attr, diag::differentiable_attr_stored_property_variable_unsupported); - return; + return; } // When used directly on a storage decl (stored/computed property or // subscript), the getter is currently inferred to be `@differentiable`. diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index 46e1ec3aa875e..a1bbe0c5e26fa 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -38,8 +38,6 @@ extension S : Differentiable, VectorNumeric { typealias TangentVector = S } -// expected-error @+2 {{function is not differentiable}} -// expected-note @+1 {{property is not differentiable}} _ = gradient(at: S(p: 0)) { s in 2 * s.p } struct NoDerivativeProperty : Differentiable { diff --git a/test/AutoDiff/differentiable_attr_silgen.swift b/test/AutoDiff/differentiable_attr_silgen.swift index a63a8f84ea114..6f3ce873f6073 100644 --- a/test/AutoDiff/differentiable_attr_silgen.swift +++ b/test/AutoDiff/differentiable_attr_silgen.swift @@ -76,43 +76,6 @@ public func dhasvjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float) // CHECK-LABEL: sil [ossa] @dhasvjp -//===----------------------------------------------------------------------===// -// Stored property -//===----------------------------------------------------------------------===// - -struct DiffStoredProp { - @differentiable(wrt: (self), jvp: storedPropJVP, vjp: storedPropVJP) - let storedProp: Float - - @_silgen_name("storedPropJVP") - func storedPropJVP() -> (Float, (DiffStoredProp) -> Float) { - fatalError("unimplemented") - } - - @_silgen_name("storedPropVJP") - func storedPropVJP() -> (Float, (Float) -> DiffStoredProp) { - fatalError("unimplemented") - } -} - -extension DiffStoredProp : VectorNumeric { - static var zero: DiffStoredProp { fatalError("unimplemented") } - static func + (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp { - fatalError("unimplemented") - } - static func - (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp { - fatalError("unimplemented") - } - typealias Scalar = Float - static func * (lhs: Float, rhs: DiffStoredProp) -> DiffStoredProp { - fatalError("unimplemented") - } -} - -extension DiffStoredProp : Differentiable { - typealias TangentVector = DiffStoredProp -} - //===----------------------------------------------------------------------===// // Computed property //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index b7f30d522bc86..d00f12e4b348f 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -1,7 +1,10 @@ // RUN: %target-swift-frontend -typecheck -verify %s @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} -let global: Float = 1 +let globalConst: Float = 1 + +@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} +var globalVar: Float = 1 func testLocalVariables() { // expected-error @+1 {{'_' has no parameters to differentiate with respect to}} @@ -225,25 +228,18 @@ class Foo { } struct JVPStruct { + @differentiable let p: Float - @differentiable(wrt: (self), jvp: storedPropJVP) - let storedImmutableOk: Float - - // expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} - @differentiable(wrt: (self), jvp: storedPropJVP) - let storedImmutableWrongType: Double - - @differentiable(wrt: (self), jvp: storedPropJVP) - var storedMutableOk: Float - - // expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} - @differentiable(wrt: (self), jvp: storedPropJVP) - var storedMutableWrongType: Double + // expected-error @+1 {{'funcJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} + @differentiable(wrt: (self), jvp: funcJVP) + func funcWrongType() -> Double { + fatalError("unimplemented") + } } extension JVPStruct { - func storedPropJVP() -> (Float, (JVPStruct) -> Float) { + func funcJVP() -> (Float, (JVPStruct) -> Float) { fatalError("unimplemented") } } @@ -383,23 +379,15 @@ func vjpNonDiffResult2(x: Float) -> (Float, Int) { struct VJPStruct { let p: Float - @differentiable(vjp: storedPropVJP) - let storedImmutableOk: Float - - // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} - @differentiable(vjp: storedPropVJP) - let storedImmutableWrongType: Double - - @differentiable(vjp: storedPropVJP) - var storedMutableOk: Float - - // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} - @differentiable(vjp: storedPropVJP) - var storedMutableWrongType: Double + // expected-error @+1 {{'funcVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} + @differentiable(vjp: funcVJP) + func funcWrongType() -> Double { + fatalError("unimplemented") + } } extension VJPStruct { - func storedPropVJP() -> (Float, (Float) -> VJPStruct) { + func funcVJP() -> (Float, (Float) -> VJPStruct) { fatalError("unimplemented") } } diff --git a/test/AutoDiff/differentiating_attr_type_checking.swift b/test/AutoDiff/differentiating_attr_type_checking.swift index 1d3e4c356b187..2170a666d96d9 100644 --- a/test/AutoDiff/differentiating_attr_type_checking.swift +++ b/test/AutoDiff/differentiating_attr_type_checking.swift @@ -298,7 +298,7 @@ func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { // Test usage of `@differentiable` on a stored property struct PropertyDiff : Differentiable & AdditiveArithmetic { - // expected-error @+1 {{stored properties cannot be marked with '@differentiable'}} + // expected-error @+1 {{'jvp:' or 'vjp:' cannot be specified for stored properties}} @differentiable(vjp: vjpPropertyA) var a: Float = 1 typealias TangentVector = PropertyDiff diff --git a/test/AutoDiff/e2e_differentiable_property.swift b/test/AutoDiff/e2e_differentiable_property.swift index a140ba08766f9..1583ad9689084 100644 --- a/test/AutoDiff/e2e_differentiable_property.swift +++ b/test/AutoDiff/e2e_differentiable_property.swift @@ -30,10 +30,9 @@ struct Space { } private let storedX: Float - - /// `y` is a stored property with a custom vjp for its getter. - @differentiable(vjp: vjpY) - let y: Float + + @differentiable + var y: Float func vjpY() -> (Float, (Float) -> TangentSpace) { return (y, { v in TangentSpace(dx: 0, dy: v) }) @@ -70,7 +69,7 @@ E2EDifferentiablePropertyTests.test("stored property") { struct GenericMemberWrapper : Differentiable { // Stored property. - @differentiable(vjp: vjpX) + @differentiable var x: T func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) { diff --git a/test/AutoDiff/method.swift b/test/AutoDiff/method.swift index f3c9e15ad725f..c755df508123a 100644 --- a/test/AutoDiff/method.swift +++ b/test/AutoDiff/method.swift @@ -8,8 +8,15 @@ var MethodTests = TestSuite("Method") // ==== Tests with generated adjoint ==== struct Parameter : Equatable { + private let storedX: Float @differentiable(wrt: (self), jvp: jvpX, vjp: vjpX) - let x: Float + var x: Float { + return storedX + } + + init(x: Float) { + storedX = x + } func vjpX() -> (Float, (Float) -> Parameter) { return (x, { dx in Parameter(x: dx) } ) @@ -155,8 +162,15 @@ struct DiffWrtSelf : Differentiable { } struct CustomParameter : Equatable { + let storedX: Float @differentiable(wrt: (self), vjp: vjpX) - let x: Float + var x: Float { + return storedX + } + + init(x: Float) { + storedX = x + } func vjpX() -> (Float, (Float) -> CustomParameter) { return (x, { dx in CustomParameter(x: dx) }) diff --git a/test/AutoDiff/protocol_requirement_autodiff.swift b/test/AutoDiff/protocol_requirement_autodiff.swift index 5aecc510577bd..fedd5f4397bbf 100644 --- a/test/AutoDiff/protocol_requirement_autodiff.swift +++ b/test/AutoDiff/protocol_requirement_autodiff.swift @@ -18,23 +18,14 @@ extension DiffReq where TangentVector : AdditiveArithmetic { struct Quadratic : DiffReq, Equatable { typealias TangentVector = Quadratic - @differentiable(wrt: (self), vjp: vjpA) + @differentiable let a: Float - func vjpA() -> (Float, (Float) -> Quadratic) { - return (a, { da in Quadratic(da, 0, 0) } ) - } - @differentiable(wrt: (self), vjp: vjpB) + @differentiable let b: Float - func vjpB() -> (Float, (Float) -> Quadratic) { - return (b, { db in Quadratic(0, db, 0) } ) - } - @differentiable(wrt: (self), vjp: vjpC) + @differentiable let c: Float - func vjpC() -> (Float, (Float) -> Quadratic) { - return (c, { dc in Quadratic(0, 0, dc) } ) - } init(_ a: Float, _ b: Float, _ c: Float) { self.a = a diff --git a/test/AutoDiff/simple_model.swift b/test/AutoDiff/simple_model.swift index f0548f461841b..623f1263cd22d 100644 --- a/test/AutoDiff/simple_model.swift +++ b/test/AutoDiff/simple_model.swift @@ -6,17 +6,11 @@ import StdlibUnittest var SimpleModelTests = TestSuite("SimpleModel") struct DenseLayer : Equatable { - @differentiable(wrt: self, vjp: vjpW) + @differentiable let w: Float - func vjpW() -> (Float, (Float) -> DenseLayer) { - return (w, { dw in DenseLayer(w: dw, b: 0) } ) - } - @differentiable(wrt: self, vjp: vjpB) + @differentiable let b: Float - func vjpB() -> (Float, (Float) -> DenseLayer) { - return (b, { db in DenseLayer(w: 0, b: db) } ) - } } extension DenseLayer : Differentiable, VectorNumeric { @@ -46,23 +40,14 @@ extension DenseLayer { } struct Model : Equatable { - @differentiable(wrt: self, vjp: vjpL1) + @differentiable let l1: DenseLayer - func vjpL1() -> (DenseLayer, (DenseLayer) -> Model) { - return (l1, { dl1 in Model(l1: dl1, l2: DenseLayer.zero, l3: DenseLayer.zero) } ) - } - @differentiable(wrt: self, vjp: vjpL2) + @differentiable let l2: DenseLayer - func vjpL2() -> (DenseLayer, (DenseLayer) -> Model) { - return (l2, { dl2 in Model(l1: DenseLayer.zero, l2: dl2, l3: DenseLayer.zero) } ) - } - @differentiable(wrt: self, vjp: vjpL3) + @differentiable let l3: DenseLayer - func vjpL3() -> (DenseLayer, (DenseLayer) -> Model) { - return (l3, { dl3 in Model(l1: DenseLayer.zero, l2: DenseLayer.zero, l3: dl3) } ) - } } extension Model : Differentiable, VectorNumeric { diff --git a/test/AutoDiff/witness_table_silgen.swift b/test/AutoDiff/witness_table_silgen.swift index 9a7073450a254..285d8a3091001 100644 --- a/test/AutoDiff/witness_table_silgen.swift +++ b/test/AutoDiff/witness_table_silgen.swift @@ -21,11 +21,8 @@ struct S : Proto, VectorNumeric { typealias TangentVector = S - @differentiable(wrt: (self), vjp: vjpP) + @differentiable let p: Float - func vjpP() -> (Float, (Float) -> S) { - return (p, { dp in S(p: dp) }) - } @differentiable(wrt: (x, y)) func function1(_ x: Float, _ y: Double) -> Float { From d0f77e73a8a54b556ac25f0d471afe12811a7dfd Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sat, 1 Jun 2019 04:02:16 -0700 Subject: [PATCH 11/13] Revamp struct extraction differentiation semantics. - Remove `@_fieldwiseDifferentiable`. - Remove `StructExtractDifferentiationStrategy`. - Require `TangentVector` to have a member of the same name. --- include/swift/AST/Attr.def | 4 +- include/swift/AST/DiagnosticsSIL.def | 3 + include/swift/AST/DiagnosticsSema.def | 5 - .../Mandatory/Differentiation.cpp | 337 +++++++----------- lib/Sema/DerivedConformanceDifferentiable.cpp | 8 - lib/Sema/TypeCheckAttr.cpp | 19 - lib/Sema/TypeCheckDeclOverride.cpp | 1 - test/AutoDiff/autodiff_diagnostics.swift | 10 +- .../derived_differentiable_properties.swift | 18 +- .../e2e_differentiable_property.swift | 18 +- test/AutoDiff/separate_cotangent_type.swift | 14 +- test/AutoDiff/witness_table_silgen.swift | 1 - 12 files changed, 167 insertions(+), 271 deletions(-) diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 8b8d5f5414ba9..5c1a434d340ce 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -421,10 +421,8 @@ DECL_ATTR(differentiating, Differentiating, SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable, OnAccessor | OnFunc | OnConstructor | OnSubscript, /* Not serialized */ 90) -SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable, - OnNominalType | UserInaccessible, 91) SIMPLE_DECL_ATTR(noDerivative, NoDerivative, - OnVar, 92) + OnVar, 91) #undef TYPE_ATTR #undef DECL_ATTR_ALIAS diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index fa47890fb971a..4afed5c78da76 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -447,6 +447,9 @@ NOTE(autodiff_opaque_function_not_differentiable,none, "opaque non-'@differentiable' function is not differentiable", ()) NOTE(autodiff_property_not_differentiable,none, "property is not differentiable", ()) +NOTE(autodiff_stored_property_no_corresponding_tangent,none, + "property cannot be differentiated because '%0.TangentVector' does not " + "have a member named '%1'", (StringRef, StringRef)) NOTE(autodiff_value_defined_here,none, "value defined here", ()) NOTE(autodiff_when_differentiating_function_call,none, diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index bf4dace00d444..083cc724fb6cf 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2808,11 +2808,6 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none, "'@noDerivative' is only allowed on stored properties in structure types " "that declare a conformance to 'Differentiable'", ()) -// @_fieldwiseDifferentiable attribute -ERROR(fieldwise_differentiable_only_on_differentiable_structs,none, - "'@_fieldwiseDifferentiable' is only allowed on structure types that " - "conform to 'Differentiable'", ()) - //------------------------------------------------------------------------------ // MARK: Type Check Expressions //------------------------------------------------------------------------------ diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index cb6631e0d1ff0..ef184f1bc437f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -726,20 +726,6 @@ struct NestedApplyInfo { Optional originalPullbackType; }; -/// Specifies how we should differentiate a `struct_extract` instruction. -enum class StructExtractDifferentiationStrategy { - // The `struct_extract` is not active, so do not differentiate it. - Inactive, - - // The `struct_extract` is extracting a field from a Differentiable struct - // with @_fieldwiseProductSpace tangent space. Therefore, differentiate the - // `struct_extract` by setting the adjoint to a vector in the tangent space - // that is zero except along the direction of the corresponding field. - // - // Fields correspond by matching name. - Fieldwise -}; - static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, DifferentiationInvoker invoker) { invoker.print(os); @@ -900,11 +886,6 @@ class ADContext { /// `NestedApplyInfo`s. DenseMap nestedApplyInfo; - /// Mapping from original `struct_extract` and `struct_element_addr` - /// instructions to their strategies. - DenseMap - structExtractDifferentiationStrategies; - /// List of generated functions (JVPs, VJPs, adjoints, and thunks). /// Saved for deletion during cleanup. SmallVector generatedFunctions; @@ -966,11 +947,6 @@ class ADContext { return nestedApplyInfo; } - DenseMap - &getStructExtractDifferentiationStrategies() { - return structExtractDifferentiationStrategies; - } - SmallVector &getGeneratedFunctions() { return generatedFunctions; } @@ -1188,9 +1164,9 @@ class ADContext { SILDifferentiableAttr *attr, StringRef name, AutoDiffAssociatedFunctionKind kind); - template + template InFlightDiagnostic diagnose(SourceLoc loc, Diag diag, - U &&... args) const { + U &&...args) const { return getASTContext().Diags.diagnose(loc, diag, std::forward(args)...); } @@ -1198,23 +1174,26 @@ class ADContext { /// parent function, emits a "not differentiable" error based on the task. If /// the task is indirect, emits notes all the way up to the outermost task, /// and emits an error at the outer task. Otherwise, emits an error directly. + template InFlightDiagnostic emitNondifferentiabilityError( SILInstruction *inst, DifferentiationInvoker invoker, - Optional> diag = None); + Diag diag, U &&...args); /// Given a value and a differentiation task associated with the parent /// function, emits a "not differentiable" error based on the task. If the /// task is indirect, emits notes all the way up to the outermost task, and /// emits an error at the outer task. Otherwise, emits an error directly. + template InFlightDiagnostic emitNondifferentiabilityError( SILValue value, DifferentiationInvoker invoker, - Optional> diag = None); + Diag diag, U &&...args); /// Emit a "not differentiable" error based on the given differentiation task /// and diagnostic. + template InFlightDiagnostic emitNondifferentiabilityError( SourceLoc loc, DifferentiationInvoker invoker, - Optional> diag = None); + Diag diag, U &&...args); }; } // end anonymous namespace @@ -1222,36 +1201,41 @@ ADContext::ADContext(SILModuleTransform &transform) : transform(transform), module(*transform.getModule()), passManager(*transform.getPassManager()) {} +template InFlightDiagnostic ADContext::emitNondifferentiabilityError(SILValue value, DifferentiationInvoker invoker, - Optional> diag) { + Diag diag, U &&...args) { LLVM_DEBUG({ getADDebugStream() << "Diagnosing non-differentiability.\n"; getADDebugStream() << "For value:\n" << value; getADDebugStream() << "With invoker:\n" << invoker << '\n'; }); auto valueLoc = value.getLoc().getSourceLoc(); - return emitNondifferentiabilityError(valueLoc, invoker, diag); + return emitNondifferentiabilityError(valueLoc, invoker, diag, + std::forward(args)...); } +template InFlightDiagnostic ADContext::emitNondifferentiabilityError(SILInstruction *inst, DifferentiationInvoker invoker, - Optional> diag) { + Diag diag, U &&...args) { LLVM_DEBUG({ getADDebugStream() << "Diagnosing non-differentiability.\n"; getADDebugStream() << "For instruction:\n" << *inst; getADDebugStream() << "With invoker:\n" << invoker << '\n'; }); auto instLoc = inst->getLoc().getSourceLoc(); - return emitNondifferentiabilityError(instLoc, invoker, diag); + return emitNondifferentiabilityError(instLoc, invoker, diag, + std::forward(args)...); } +template InFlightDiagnostic ADContext::emitNondifferentiabilityError(SourceLoc loc, DifferentiationInvoker invoker, - Optional> diag) { + Diag diag, U &&...args) { switch (invoker.getKind()) { // For `autodiff_function` instructions: if the `autodiff_function` // instruction comes from a differential operator, emit an error on the @@ -1262,12 +1246,11 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, if (auto *expr = findDifferentialOperator(inst)) { diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error) .highlight(expr->getSubExpr()->getSourceRange()); - return diagnose(loc, - diag.getValueOr(diag::autodiff_expression_not_differentiable_note)); + return diagnose(loc, diag, std::forward(args)...); } diagnose(loc, diag::autodiff_expression_not_differentiable_error); - return diagnose(loc, - diag.getValueOr(diag::autodiff_expression_not_differentiable_note)); + return diagnose(loc, diag, std::forward(args)...); + //diag.getValueOr(diag::autodiff_expression_not_differentiable_note), } // For `[differentiable]` attributes, try to find an AST function declaration @@ -1295,8 +1278,7 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, if (!foundAttr) diagnose(original->getLocation().getSourceLoc(), diag::autodiff_function_not_differentiable_error); - return diagnose(loc, - diag.getValueOr(diag::autodiff_expression_not_differentiable_note)); + return diagnose(loc, diag, std::forward(args)...); } // For indirect differentiation, emit a "not differentiable" note on the @@ -1309,9 +1291,9 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, std::tie(inst, attr) = invoker.getIndirectDifferentiation(); auto invokerLookup = invokers.find(attr); assert(invokerLookup != invokers.end() && "Expected parent invoker"); - emitNondifferentiabilityError(inst, invokerLookup->second, None); - return diagnose(loc, - diag.getValueOr(diag::autodiff_when_differentiating_function_call)); + emitNondifferentiabilityError(inst, invokerLookup->second, + diag::autodiff_expression_not_differentiable_note); + return diagnose(loc, diag::autodiff_when_differentiating_function_call); } } } @@ -1560,18 +1542,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, } // Handle `struct_extract` and `struct_element_addr` instructions. -// - If the field is marked `@noDerivative` and belongs to a -// `@_fieldwiseDifferentiable` struct, do not set the result as varied because -// it is not in the set of differentiable variables. +// - If the field is marked `@noDerivative`, do not set the result as varied +// because it is not in the set of differentiable variables. // - Otherwise, propagate variedness from operand to result as usual. #define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \ else if (auto *sei = dyn_cast(&inst)) { \ if (isVaried(sei->getOperand(), i)) { \ auto hasNoDeriv = sei->getField()->getAttrs() \ .hasAttribute(); \ - auto structIsFieldwiseDiffable = sei->getStructDecl()->getAttrs() \ - .hasAttribute(); \ - if (!(hasNoDeriv && structIsFieldwiseDiffable)) \ + if (!hasNoDeriv) \ setVaried(sei, i); \ } \ } @@ -2838,20 +2817,6 @@ static void collectMinimalIndicesForFunctionCall( } } -// Returns the associated function with name `assocFnName`. If the function -// cannot be found, returns a reference to an external asssociated function -// declaration. -static SILFunction *getAssociatedFunction( - ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, - AutoDiffAssociatedFunctionKind kind, StringRef assocFnName) { - auto &module = context.getModule(); - auto *assocFn = module.lookUpFunction(assocFnName); - if (!assocFn) - assocFn = context.declareExternalAssociatedFunction( - original, attr, assocFnName, kind); - return assocFn; -} - namespace { class VJPEmitter final : public TypeSubstCloner { @@ -3095,7 +3060,8 @@ class VJPEmitter final } void visitSILInstruction(SILInstruction *inst) { - context.emitNondifferentiabilityError(inst, invoker); + context.emitNondifferentiabilityError(inst, invoker, + diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } @@ -3217,39 +3183,6 @@ class VJPEmitter final ri->getLoc(), joinElements(directResults, builder, loc)); } - void visitStructExtractInst(StructExtractInst *sei) { - auto &strategies = context.getStructExtractDifferentiationStrategies(); - // Special handling logic only applies when the `struct_extract` is active. - // If not, just do standard cloning. - if (!activityInfo.isActive(sei, getIndices())) { - LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n'); - strategies.insert( - {sei, StructExtractDifferentiationStrategy::Inactive}); - SILClonerWithScopes::visitStructExtractInst(sei); - return; - } - // This instruction is active. Use the field wise differentiation strategy - // to differentiate the struct extract instruction. - strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; - SILClonerWithScopes::visitStructExtractInst(sei); - } - - void visitStructElementAddrInst(StructElementAddrInst *seai) { - auto &strategies = context.getStructExtractDifferentiationStrategies(); - // Special handling logic only applies when the `struct_element_addr` is - // active. If not, just do standard cloning. - if (!activityInfo.isActive(seai, getIndices())) { - LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *seai << '\n'); - strategies[seai] =StructExtractDifferentiationStrategy::Inactive; - SILClonerWithScopes::visitStructElementAddrInst(seai); - return; - } - // This instruction is active. Use the field wise differentiation strategy - // to differentiate the struct extract instruction. - strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise; - SILClonerWithScopes::visitStructElementAddrInst(seai); - } - // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai) { @@ -3309,7 +3242,8 @@ class VJPEmitter final s << "}\n";); // FIXME: We don't support multiple active results yet. if (activeResultIndices.size() > 1) { - context.emitNondifferentiabilityError(ai, invoker); + context.emitNondifferentiabilityError( + ai, invoker, diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return; } @@ -4429,7 +4363,8 @@ class AdjointEmitter final : public SILInstructionVisitor { void visitSILInstruction(SILInstruction *inst) { LLVM_DEBUG(getADDebugStream() << "Unhandled instruction in adjoint emitter: " << *inst); - getContext().emitNondifferentiabilityError(inst, getInvoker()); + getContext().emitNondifferentiabilityError(inst, getInvoker(), + diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } @@ -4590,48 +4525,48 @@ class AdjointEmitter final : public SILInstructionVisitor { break; case AdjointValueKind::Concrete: { auto adjStruct = materializeAdjointDirect(std::move(av), loc); - if (structDecl->getAttrs().hasAttribute()) { - // Find the struct `TangentVector` type. - auto structTy = remapType(si->getType()).getASTType(); - auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( - LookUpConformanceInModule(getModule().getSwiftModule())) - ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering( - tangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); - auto *tangentVectorDecl = - tangentVectorTy->getStructOrBoundGenericStruct(); - assert(tangentVectorDecl); - - // Accumulate adjoints for the fields of the `struct` operand. - for (auto *field : structDecl->getStoredProperties()) { - // There does not exist a corresponding tangent field for original - // fields with `@noDerivative` attribute. Emit an error. - if (field->getAttrs().hasAttribute()) - continue; - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - if (tangentVectorDecl == structDecl) - tanField = field; - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(field->getName()); - assert(tanFieldLookup.size() == 1); - tanField = cast(tanFieldLookup.front()); + // Find the struct `TangentVector` type. + auto structTy = remapType(si->getType()).getASTType(); + auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( + LookUpConformanceInModule(getModule().getSwiftModule())) + ->getType()->getCanonicalType(); + assert(!getModule().Types.getTypeLowering( + tangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); + + // Accumulate adjoints for the fields of the `struct` operand. + for (auto *field : structDecl->getStoredProperties()) { + // There does not exist a corresponding tangent field for original + // fields with `@noDerivative` attribute. Emit an error. + if (field->getAttrs().hasAttribute()) + continue; + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + if (tangentVectorDecl == structDecl) + tanField = field; + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(field->getName()); + if (tanFieldLookup.empty()) { + getContext().emitNondifferentiabilityError( + si, getInvoker(), + diag::autodiff_stored_property_no_corresponding_tangent, + tangentVectorDecl->getNameStr(), field->getNameStr()); + errorOccurred = true; + return; } - auto *adjStructElt = - builder.createStructExtract(loc, adjStruct, tanField); - addAdjointValue( - si->getFieldValue(field), - makeConcreteAdjointValue(ValueWithCleanup( - adjStructElt, makeCleanup(adjStructElt, emitCleanup)))); + tanField = cast(tanFieldLookup.front()); } - } else { - // FIXME(TF-21): If `TangentVector` is not marked - // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. - llvm_unreachable("Unhandled. Are you trying to differentiate a " - "memberwise initializer?"); + auto *adjStructElt = + builder.createStructExtract(loc, adjStruct, tanField); + addAdjointValue( + si->getFieldValue(field), + makeConcreteAdjointValue(ValueWithCleanup( + adjStructElt, makeCleanup(adjStructElt, emitCleanup)))); } break; } @@ -4650,73 +4585,69 @@ class AdjointEmitter final : public SILInstructionVisitor { assert(!sei->getField()->getAttrs().hasAttribute() && "`struct_extract` with `@noDerivative` field should not be " "differentiated; activity analysis should not marked as varied"); - auto loc = sei->getLoc(); - auto &differentiationStrategies = - getContext().getStructExtractDifferentiationStrategies(); - auto strategy = differentiationStrategies.lookup(sei); - switch (strategy) { - case StructExtractDifferentiationStrategy::Inactive: - assert(!getActivityInfo().isActive(sei, getIndices())); - return; - case StructExtractDifferentiationStrategy::Fieldwise: { - // Compute adjoint as follows: - // y = struct_extract x, #key - // adj[x] += struct (0, ..., #key': adj[y], ..., 0) - // where `#key'` is the field in the tangent space corresponding to - // `#key`. - auto structTy = remapType(sei->getOperand()->getType()).getASTType(); - auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( - LookUpConformanceInModule(getModule().getSwiftModule())) - ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering( - tangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); - auto tangentVectorSILTy = - SILType::getPrimitiveObjectType(tangentVectorTy); - auto *tangentVectorDecl = - tangentVectorTy->getStructOrBoundGenericStruct(); - assert(tangentVectorDecl); - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - // If the tangent space is the original struct, then field is the same. - if (tangentVectorDecl == sei->getStructDecl()) - tanField = sei->getField(); - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(sei->getField()->getName()); - assert(tanFieldLookup.size() == 1); - tanField = cast(tanFieldLookup.front()); + // Compute adjoint as follows: + // y = struct_extract x, #key + // adj[x] += struct (0, ..., #key': adj[y], ..., 0) + // where `#key'` is the field in the tangent space corresponding to + // `#key`. + auto structTy = remapType(sei->getOperand()->getType()).getASTType(); + auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( + LookUpConformanceInModule(getModule().getSwiftModule())) + ->getType()->getCanonicalType(); + assert(!getModule().Types.getTypeLowering( + tangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); + auto tangentVectorSILTy = + SILType::getPrimitiveObjectType(tangentVectorTy); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + // If the tangent space is the original struct, then field is the same. + if (tangentVectorDecl == sei->getStructDecl()) + tanField = sei->getField(); + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(sei->getField()->getName()); + if (tanFieldLookup.empty()) { + getContext().emitNondifferentiabilityError( + sei, getInvoker(), + diag::autodiff_stored_property_no_corresponding_tangent, + sei->getStructDecl()->getNameStr(), + sei->getField()->getNameStr()); + errorOccurred = true; + return; } - // Accumulate adjoint for the `struct_extract` operand. - auto av = takeAdjointValue(sei); - switch (av.getKind()) { - case AdjointValueKind::Zero: - addAdjointValue(sei->getOperand(), - makeZeroAdjointValue(tangentVectorSILTy)); - break; - case AdjointValueKind::Concrete: - case AdjointValueKind::Aggregate: { - SmallVector eltVals; - for (auto *field : tangentVectorDecl->getStoredProperties()) { - if (field == tanField) { - eltVals.push_back(av); - } else { - auto substMap = tangentVectorTy->getMemberSubstitutionMap( - field->getModuleContext(), field); - auto fieldTy = field->getType().subst(substMap); - auto fieldSILTy = - getContext().getTypeConverter().getLoweredType( - fieldTy, ResilienceExpansion::Minimal); - assert(fieldSILTy.isObject()); - eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); - } + tanField = cast(tanFieldLookup.front()); + } + // Accumulate adjoint for the `struct_extract` operand. + auto av = takeAdjointValue(sei); + switch (av.getKind()) { + case AdjointValueKind::Zero: + addAdjointValue(sei->getOperand(), + makeZeroAdjointValue(tangentVectorSILTy)); + break; + case AdjointValueKind::Concrete: + case AdjointValueKind::Aggregate: { + SmallVector eltVals; + for (auto *field : tangentVectorDecl->getStoredProperties()) { + if (field == tanField) { + eltVals.push_back(av); + } else { + auto substMap = tangentVectorTy->getMemberSubstitutionMap( + field->getModuleContext(), field); + auto fieldTy = field->getType().subst(substMap); + auto fieldSILTy = + getContext().getTypeConverter().getLoweredType( + fieldTy, ResilienceExpansion::Minimal); + assert(fieldSILTy.isObject()); + eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); } - addAdjointValue(sei->getOperand(), - makeAggregateAdjointValue(tangentVectorSILTy, eltVals)); - } } - return; + addAdjointValue(sei->getOperand(), + makeAggregateAdjointValue(tangentVectorSILTy, eltVals)); } } } diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 43fbdc18560bf..d3165a51d9e3e 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -690,8 +690,6 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(), /*Inherited*/ C.AllocateCopy(inherited), /*GenericParams*/ {}, parentDC); - structDecl->getAttrs().add( - new (C) FieldwiseDifferentiableAttr(/*implicit*/ true)); structDecl->setImplicit(); structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); @@ -960,12 +958,6 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, if (!getAssociatedType(member, parentDC, id)) return nullptr; - // Since associated types will be derived, we make this struct a fieldwise - // differentiable type. - if (!nominal->getAttrs().hasAttribute()) - nominal->getAttrs().add( - new (C) FieldwiseDifferentiableAttr(/*implicit*/ true)); - // Prevent re-synthesis during repeated calls. // FIXME: Investigate why this is necessary to prevent duplicate synthesis. auto lookup = nominal->lookupDirect(id); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 871c99b3eab74..d2c307a486d0e 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -135,7 +135,6 @@ class AttributeEarlyChecker : public AttributeVisitor { IGNORED_ATTR(Differentiable) IGNORED_ATTR(Differentiating) IGNORED_ATTR(CompilerEvaluable) - IGNORED_ATTR(FieldwiseDifferentiable) IGNORED_ATTR(NoDerivative) #undef IGNORED_ATTR @@ -872,7 +871,6 @@ class AttributeChecker : public AttributeVisitor { void visitDifferentiableAttr(DifferentiableAttr *attr); void visitDifferentiatingAttr(DifferentiatingAttr *attr); void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr); - void visitFieldwiseDifferentiableAttr(FieldwiseDifferentiableAttr *attr); void visitNoDerivativeAttr(NoDerivativeAttr *attr); }; } // end anonymous namespace @@ -3576,23 +3574,6 @@ void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) { // TypeChecker::checkFunctionBodyCompilerEvaluable(). } -// SWIFT_ENABLE_TENSORFLOW -void AttributeChecker::visitFieldwiseDifferentiableAttr( - FieldwiseDifferentiableAttr *attr) { - auto *structDecl = dyn_cast(D); - if (!structDecl) { - diagnoseAndRemoveAttr(attr, - diag::fieldwise_differentiable_only_on_differentiable_structs); - return; - } - if (!conformsToDifferentiableInModule( - structDecl->getDeclaredInterfaceType(), D->getModuleContext())) { - diagnoseAndRemoveAttr(attr, - diag::fieldwise_differentiable_only_on_differentiable_structs); - return; - } -} - // SWIFT_ENABLE_TENSORFLOW void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) { auto *vd = dyn_cast(D); diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 6e340460c83f8..ff25f8c4d4b15 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -1303,7 +1303,6 @@ namespace { UNINTERESTING_ATTR(Differentiable) UNINTERESTING_ATTR(Differentiating) UNINTERESTING_ATTR(CompilerEvaluable) - UNINTERESTING_ATTR(FieldwiseDifferentiable) UNINTERESTING_ATTR(NoDerivative) // These can't appear on overridable declarations. diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index a1bbe0c5e26fa..0e57d4c6dbe69 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -29,15 +29,23 @@ struct S { } extension S : Differentiable, VectorNumeric { + struct TangentVector: Differentiable, VectorNumeric { + var dp: Float + } + typealias AllDifferentiableVariables = S static var zero: S { return S(p: 0) } typealias Scalar = Float static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) } static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) } static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) } - typealias TangentVector = S + func moved(along direction: TangentVector) -> S { + return S(p: p + direction.dp) + } } +// expected-error @+2 {{function is not differentiable}} +// expected-note @+1 {{property cannot be differentiated because 'S.TangentVector' does not have a member named 'p'}} _ = gradient(at: S(p: 0)) { s in 2 * s.p } struct NoDerivativeProperty : Differentiable { diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index 35126fa38c58f..29e669fe17116 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -6,11 +6,11 @@ public struct Foo : Differentiable { public var a: Float } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable { +// CHECK-AST-LABEL: public struct Foo : Differentiable { // CHECK-AST: @differentiable // CHECK-AST: public var a: Float // CHECK-AST: internal init(a: Float) -// CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables +// CHECK-AST: public struct AllDifferentiableVariables // CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables // CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables // CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables @@ -25,7 +25,7 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in x.a + x.a } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable { +// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable { // CHECK-AST: internal var a: Float // CHECK-AST: internal init(a: Float) // CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf @@ -36,11 +36,11 @@ struct TestNoDerivative : Differentiable { @noDerivative var technicallyDifferentiable: Float } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestNoDerivative : Differentiable { +// CHECK-AST-LABEL: internal struct TestNoDerivative : Differentiable { // CHECK-AST: var w: Float // CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float // CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float) -// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric +// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric // CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables @@ -50,11 +50,11 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable { @noDerivative var technicallyDifferentiable: Float } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestKeyPathIterable : Differentiable, KeyPathIterable { +// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable { // CHECK-AST: var w: Float // CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float // CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float) -// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric +// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric // CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables @@ -66,7 +66,7 @@ struct GenericTanMember : Differentiable, AdditiveArithmetic // TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized. // `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized. -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct GenericTanMember : Differentiable, AdditiveArithmetic where T : Differentiable +// CHECK-AST-LABEL: internal struct GenericTanMember : Differentiable, AdditiveArithmetic where T : Differentiable // CHECK-AST: internal var x: T.TangentVector // CHECK-AST: internal init(x: T.TangentVector) // CHECK-AST: internal typealias TangentVector = GenericTanMember @@ -81,7 +81,7 @@ public struct ConditionallyDifferentiable { } extension ConditionallyDifferentiable : Differentiable where T : Differentiable {} -// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct ConditionallyDifferentiable { +// CHECK-AST-LABEL: public struct ConditionallyDifferentiable { // CHECK-AST: @differentiable(wrt: self where T : Differentiable) // CHECK-AST: public let x: T // CHECK-AST: internal init(x: T) diff --git a/test/AutoDiff/e2e_differentiable_property.swift b/test/AutoDiff/e2e_differentiable_property.swift index 1583ad9689084..2822697e0b8da 100644 --- a/test/AutoDiff/e2e_differentiable_property.swift +++ b/test/AutoDiff/e2e_differentiable_property.swift @@ -9,7 +9,7 @@ import StdlibUnittest var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty") struct TangentSpace : VectorNumeric { - let dx, dy: Float + let x, y: Float } extension TangentSpace : Differentiable { @@ -26,18 +26,14 @@ struct Space { } func vjpX() -> (Float, (Float) -> TangentSpace) { - return (x, { v in TangentSpace(dx: v, dy: 0) } ) + return (x, { v in TangentSpace(x: v, y: 0) } ) } private let storedX: Float - + @differentiable var y: Float - func vjpY() -> (Float, (Float) -> TangentSpace) { - return (y, { v in TangentSpace(dx: 0, dy: v) }) - } - init(x: Float, y: Float) { self.storedX = x self.y = y @@ -47,7 +43,7 @@ struct Space { extension Space : Differentiable { typealias TangentVector = TangentSpace func moved(along: TangentSpace) -> Space { - return Space(x: x + along.dx, y: y + along.dy) + return Space(x: x + along.x, y: y + along.y) } } @@ -55,7 +51,7 @@ E2EDifferentiablePropertyTests.test("computed property") { let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in return 2 * point.x } - let expectedGrad = TangentSpace(dx: 2, dy: 0) + let expectedGrad = TangentSpace(x: 2, y: 0) expectEqual(expectedGrad, actualGrad) } @@ -63,7 +59,7 @@ E2EDifferentiablePropertyTests.test("stored property") { let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in return 3 * point.y } - let expectedGrad = TangentSpace(dx: 0, dy: 3) + let expectedGrad = TangentSpace(x: 0, y: 3) expectEqual(expectedGrad, actualGrad) } @@ -85,7 +81,6 @@ E2EDifferentiablePropertyTests.test("generic stored property") { expectEqual(expectedGrad, actualGrad) } -@_fieldwiseDifferentiable struct ProductSpaceSelfTangent : VectorNumeric { let x, y: Float } @@ -110,7 +105,6 @@ extension ProductSpaceOtherTangentTangentSpace : Differentiable { typealias TangentVector = ProductSpaceOtherTangentTangentSpace } -@_fieldwiseDifferentiable struct ProductSpaceOtherTangent { let x, y: Float } diff --git a/test/AutoDiff/separate_cotangent_type.swift b/test/AutoDiff/separate_cotangent_type.swift index b37e869c82965..8fbe76ecc12fd 100644 --- a/test/AutoDiff/separate_cotangent_type.swift +++ b/test/AutoDiff/separate_cotangent_type.swift @@ -10,7 +10,6 @@ import Glibc var SeparateTangentTypeTests = TestSuite("SeparateTangentType") -@_fieldwiseDifferentiable struct DifferentiableSubset : Differentiable { @differentiable(wrt: self) var w: Float @@ -18,7 +17,6 @@ struct DifferentiableSubset : Differentiable { var b: Float @noDerivative var flag: Bool - @_fieldwiseDifferentiable struct TangentVector : Differentiable, VectorNumeric { typealias TangentVector = DifferentiableSubset.TangentVector var w: Float @@ -41,12 +39,10 @@ SeparateTangentTypeTests.test("Initialization") { expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) } -// FIXME(SR-9602): If `TangentVector` is not marked -// `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. -// SeparateTangentTypeTests.test("SomeArithmetics") { -// let x = DifferentiableSubset(w: 0, b: 1, flag: false) -// let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) } -// expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) -// } +SeparateTangentTypeTests.test("SomeArithmetics") { + let x = DifferentiableSubset(w: 0, b: 1, flag: false) + let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) } + expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) +} runAllTests() diff --git a/test/AutoDiff/witness_table_silgen.swift b/test/AutoDiff/witness_table_silgen.swift index 285d8a3091001..e52faf389813f 100644 --- a/test/AutoDiff/witness_table_silgen.swift +++ b/test/AutoDiff/witness_table_silgen.swift @@ -11,7 +11,6 @@ protocol Proto : Differentiable { func function3(_ x: Float, _ y: Double) -> Double } -@_fieldwiseDifferentiable struct S : Proto, VectorNumeric { static var zero: S { return S(p: 0) } typealias Scalar = Float From fa4fa16f6118f7e64751e7e0056f693dde534f3e Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Sat, 1 Jun 2019 16:52:04 -0700 Subject: [PATCH 12/13] Remove a commented-out line. --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index ef184f1bc437f..e74d7c59ebb1f 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1250,7 +1250,6 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, } diagnose(loc, diag::autodiff_expression_not_differentiable_error); return diagnose(loc, diag, std::forward(args)...); - //diag.getValueOr(diag::autodiff_expression_not_differentiable_note), } // For `[differentiable]` attributes, try to find an AST function declaration From dfdceb712466104b61a7b213d55ec7e7c9b47f4b Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Sat, 1 Jun 2019 18:39:36 -0700 Subject: [PATCH 13/13] Remove TF-21 comment in code. --- lib/SILOptimizer/Mandatory/Differentiation.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index e74d7c59ebb1f..d1087f152ddf2 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -4570,10 +4570,6 @@ class AdjointEmitter final : public SILInstructionVisitor { break; } case AdjointValueKind::Aggregate: { - // FIXME(TF-21): If `TangentVector` is not marked - // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. - // for (auto pair : llvm::zip(si->getElements(), av.getAggregateElements())) - // addAdjointValue(std::get<0>(pair), std::get<1>(pair)); llvm_unreachable("Unhandled. Are you trying to differentiate a " "memberwise initializer?"); }