diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index d0781b281b3ca..91492656af88f 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -374,6 +374,8 @@ ERROR(autodiff_unsupported_type,none, "differentiating '%0' is not supported yet", (Type)) ERROR(autodiff_function_not_differentiable,none, "function is not differentiable", ()) +ERROR(autodiff_property_not_differentiable,none, + "property is not differentiable", ()) NOTE(autodiff_function_generic_functions_unsupported,none, "differentiating generic functions is not supported yet", ()) NOTE(autodiff_value_defined_here,none, diff --git a/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp b/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp index 2efea97422e78..b11b0fe91cb54 100644 --- a/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp +++ b/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp @@ -472,9 +472,9 @@ class PrimalInfo { /// corresponding tape of its type. DenseMap nestedStaticPrimalValueMap; - /// Mapping from `apply` instructions in the original function to the - /// corresponding pullback decl in the primal struct. - DenseMap pullbackValueMap; + /// Mapping from `apply` and `struct_extract` instructions in the original + /// function to the corresponding pullback decl in the primal struct. + DenseMap pullbackValueMap; /// Mapping from types of control-dependent nested primal values to district /// tapes. @@ -573,7 +573,7 @@ class PrimalInfo { } /// Add a pullback to the primal value struct. - VarDecl *addPullbackDecl(ApplyInst *inst, Type pullbackType) { + VarDecl *addPullbackDecl(SILInstruction *inst, Type pullbackType) { // Decls must have AST types (not `SILFunctionType`), so we convert the // `SILFunctionType` of the pullback to a `FunctionType` with the same // parameters and results. @@ -605,9 +605,9 @@ class PrimalInfo { : lookup->getSecond(); } - /// Finds the pullback decl in the primal value struct for an `apply` in the - /// original function. - VarDecl *lookUpPullbackDecl(ApplyInst *inst) { + /// Finds the pullback decl in the primal value struct for an `apply` or + /// `struct_extract` in the original function. + VarDecl *lookUpPullbackDecl(SILInstruction *inst) { auto lookup = pullbackValueMap.find(inst); return lookup == pullbackValueMap.end() ? nullptr : lookup->getSecond(); @@ -2227,6 +2227,79 @@ class PrimalGenCloner final : public SILClonerWithScopes { SILClonerWithScopes::visitReleaseValueInst(rvi); } + void visitStructExtractInst(StructExtractInst *sei) { + // Special handling logic only applies when the `struct_extract` is active. + // If not, just do standard cloning. + if (!activityInfo.isActive(sei, synthesis.indices)) { + LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n'); + SILClonerWithScopes::visitStructExtractInst(sei); + return; + } + + // This instruction is active. Replace it with a call to the corresponding + // getter's VJP. + + // Find the corresponding getter and its VJP. + auto *getterDecl = sei->getField()->getGetter(); + assert(getterDecl); + auto *getterFn = getContext().getModule().lookUpFunction( + SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); + if (!getterFn) { + getContext().emitNondifferentiabilityError( + sei, synthesis.task, diag::autodiff_property_not_differentiable); + errorOccurred = true; + return; + } + auto getterDiffAttrs = getterFn->getDifferentiableAttrs(); + if (getterDiffAttrs.size() < 1) { + getContext().emitNondifferentiabilityError( + sei, synthesis.task, diag::autodiff_property_not_differentiable); + errorOccurred = true; + return; + } + auto *getterDiffAttr = getterDiffAttrs[0]; + if (!getterDiffAttr->hasVJP()) { + getContext().emitNondifferentiabilityError( + sei, synthesis.task, diag::autodiff_property_not_differentiable); + errorOccurred = true; + return; + } + assert(getterDiffAttr->getIndices() == + SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0})); + auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(), + getContext().getModule()); + + // Reference and apply the VJP. + auto loc = sei->getLoc(); + auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP); + auto *getterVJPApply = getBuilder().createApply( + loc, getterVJPRef, /*substitutionMap*/ {}, + /*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false); + + // Get the VJP results (original results and pullback). + SmallVector vjpDirectResults; + extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults); + ArrayRef originalDirectResults = + ArrayRef(vjpDirectResults).drop_back(1); + SILValue originalDirectResult = joinElements(originalDirectResults, + getBuilder(), + getterVJPApply->getLoc()); + SILValue pullback = vjpDirectResults.back(); + + // Store the original result to the value map. + mapValue(sei, originalDirectResult); + + // Checkpoint the original results. + getPrimalInfo().addStaticPrimalValueDecl(sei); + getBuilder().createRetainValue(loc, originalDirectResult, + getBuilder().getDefaultAtomicity()); + staticPrimalValues.push_back(originalDirectResult); + + // Checkpoint the pullback. + getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType()); + staticPrimalValues.push_back(pullback); + } + void visitApplyInst(ApplyInst *ai) { if (DifferentiationUseVJP) visitApplyInstWithVJP(ai); @@ -3522,33 +3595,36 @@ class AdjointEmitter final : public SILInstructionVisitor { } } - /// Handle `struct_extract` instruction. - /// y = struct_extract , x - /// adj[x] = struct (0, ..., key: adj[y], ..., 0) void visitStructExtractInst(StructExtractInst *sei) { - auto *structDecl = sei->getStructDecl(); - auto av = getAdjointValue(sei); - switch (av.getKind()) { - case AdjointValue::Kind::Zero: - addAdjointValue(sei->getOperand(), - AdjointValue::getZero(sei->getOperand()->getType())); - break; - case AdjointValue::Kind::Materialized: - case AdjointValue::Kind::Aggregate: { - SmallVector eltVals; - for (auto *field : structDecl->getStoredProperties()) { - if (field == sei->getField()) - eltVals.push_back(av); - else - eltVals.push_back(AdjointValue::getZero( - SILType::getPrimitiveObjectType( - field->getType()->getCanonicalType()))); - } - addAdjointValue(sei->getOperand(), - AdjointValue::getAggregate(sei->getOperand()->getType(), - eltVals, allocator)); - } + // Replace a `struct_extract` with a call to its pullback. + auto loc = remapLocation(sei->getLoc()); + + // Get the pullback. + auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei); + if (!pullbackField) { + // Inactive `struct_extract` instructions don't need to be cloned into the + // adjoint. + assert(!activityInfo.isActive(sei, synthesis.indices)); + return; } + SILValue pullback = builder.createStructExtract(loc, + primalValueAggregateInAdj, + pullbackField); + + // Construct the pullback arguments. + SmallVector args; + auto seed = getAdjointValue(sei); + assert(seed.getType().isObject()); + args.push_back(materializeAdjointDirect(seed, loc)); + + // Call the pullback. + auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), + args, /*isNonThrowing*/ false); + assert(!pullbackCall->hasIndirectResults()); + + // Set adjoint for the `struct_extract` operand. + addAdjointValue(sei->getOperand(), + AdjointValue::getMaterialized(pullbackCall)); } /// Handle `tuple` instruction. diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index 71a94742cada7..75e7286d71eb9 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1582,11 +1582,25 @@ extension ${Self} { extension ${Self} { @_transparent + // SWIFT_ENABLE_TENSORFLOW + @differentiable(adjoint: _adjointNegate) public static prefix func - (x: ${Self}) -> ${Self} { return ${Self}(Builtin.fneg_FPIEEE${bits}(x._value)) } } +// SWIFT_ENABLE_TENSORFLOW +extension ${Self} { + @usableFromInline + @_transparent + // SWIFT_ENABLE_TENSORFLOW + static func _adjointNegate( + seed: ${Self}, originalValue: ${Self}, x: ${Self} + ) -> ${Self} { + return -seed + } +} + //===----------------------------------------------------------------------===// // Explicit conversions between types. //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index d00afdbbbbf68..8f7db39a0a45f 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -31,6 +31,28 @@ func generic(_ x: T) -> T { return x + 1 } +//===----------------------------------------------------------------------===// +// Non-differentiable stored properties +//===----------------------------------------------------------------------===// + +struct S { + let p: Float +} + +extension S : Differentiable, VectorNumeric { + static var zero: S { return S(p: 0) } + typealias Scalar = Float + static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) } + static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) } + static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) } + + typealias TangentVector = S + typealias CotangentVector = S +} + +// expected-error @+1 {{property is not differentiable}} +_ = gradient(at: S(p: 0)) { s in 2 * s.p } + //===----------------------------------------------------------------------===// // Function composition //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/e2e_differentiable_property.swift b/test/AutoDiff/e2e_differentiable_property.swift index c10bf69a33ecd..4078ae67a65b6 100644 --- a/test/AutoDiff/e2e_differentiable_property.swift +++ b/test/AutoDiff/e2e_differentiable_property.swift @@ -75,14 +75,12 @@ E2EDifferentiablePropertyTests.test("computed property") { expectEqual(expectedGrad, actualGrad) } -// FIXME: The AD pass cannot differentiate this because it sees -// `struct_extract`s instead of calls to getters. -// E2EDifferentiablePropertyTests.test("stored property") { -// let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in -// return 3 * point.y -// } -// let expectedGrad = TangentSpace(dx: 0, dy: 3) -// expectEqual(expectedGrad, actualGrad) -// } +E2EDifferentiablePropertyTests.test("stored property") { + let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in + return 3 * point.y + } + let expectedGrad = TangentSpace(dx: 0, dy: 3) + expectEqual(expectedGrad, actualGrad) +} runAllTests() diff --git a/test/AutoDiff/method.swift b/test/AutoDiff/method.swift index 6a85170ecf6b7..c950a8528024a 100644 --- a/test/AutoDiff/method.swift +++ b/test/AutoDiff/method.swift @@ -9,7 +9,12 @@ var MethodTests = TestSuite("Method") // ==== Tests with generated adjoint ==== struct Parameter : Equatable { + @differentiable(wrt: (self), vjp: vjpX) let x: Float + + func vjpX() -> (Float, (Float) -> Parameter) { + return (x, { dx in Parameter(x: dx) } ) + } } extension Parameter { @@ -132,7 +137,12 @@ MethodTests.test("static method with generated adjoint, wrt all params") { // ==== Tests with custom adjoint ==== struct CustomParameter : Equatable { + @differentiable(wrt: (self), vjp: vjpX) let x: Float + + func vjpX() -> (Float, (Float) -> CustomParameter) { + return (x, { dx in CustomParameter(x: dx) }) + } } extension CustomParameter : Differentiable, VectorNumeric { diff --git a/test/AutoDiff/protocol_requirement_autodiff.swift b/test/AutoDiff/protocol_requirement_autodiff.swift index ec60ccac368de..e02d49cff97c1 100644 --- a/test/AutoDiff/protocol_requirement_autodiff.swift +++ b/test/AutoDiff/protocol_requirement_autodiff.swift @@ -27,7 +27,24 @@ struct Quadratic : DiffReq, Equatable { typealias TangentVector = Quadratic typealias CotangentVector = Quadratic - let a, b, c: Float + @differentiable(wrt: (self), vjp: vjpA) + let a: Float + func vjpA() -> (Float, (Float) -> Quadratic) { + return (a, { da in Quadratic(da, 0, 0) } ) + } + + @differentiable(wrt: (self), vjp: vjpB) + let b: Float + func vjpB() -> (Float, (Float) -> Quadratic) { + return (b, { db in Quadratic(0, db, 0) } ) + } + + @differentiable(wrt: (self), vjp: vjpC) + let c: Float + func vjpC() -> (Float, (Float) -> Quadratic) { + return (c, { dc in Quadratic(0, 0, dc) } ) + } + init(_ a: Float, _ b: Float, _ c: Float) { self.a = a self.b = b diff --git a/test/AutoDiff/simple_model.swift b/test/AutoDiff/simple_model.swift index 451f8bde0104c..e2b7ec8e1c6cd 100644 --- a/test/AutoDiff/simple_model.swift +++ b/test/AutoDiff/simple_model.swift @@ -7,8 +7,17 @@ import StdlibUnittest var SimpleModelTests = TestSuite("SimpleModel") struct DenseLayer : Equatable { + @differentiable(wrt: (self), vjp: vjpW) let w: Float + func vjpW() -> (Float, (Float) -> DenseLayer) { + return (w, { dw in DenseLayer(w: dw, b: 0) } ) + } + + @differentiable(wrt: (self), vjp: vjpB) let b: Float + func vjpB() -> (Float, (Float) -> DenseLayer) { + return (b, { db in DenseLayer(w: 0, b: db) } ) + } } extension DenseLayer : Differentiable, VectorNumeric { @@ -39,9 +48,23 @@ extension DenseLayer { } struct Model : Equatable { + @differentiable(wrt: (self), vjp: vjpL1) let l1: DenseLayer + func vjpL1() -> (DenseLayer, (DenseLayer) -> Model) { + return (l1, { dl1 in Model(l1: dl1, l2: DenseLayer.zero, l3: DenseLayer.zero) } ) + } + + @differentiable(wrt: (self), vjp: vjpL2) let l2: DenseLayer + func vjpL2() -> (DenseLayer, (DenseLayer) -> Model) { + return (l2, { dl2 in Model(l1: DenseLayer.zero, l2: dl2, l3: DenseLayer.zero) } ) + } + + @differentiable(wrt: (self), vjp: vjpL3) let l3: DenseLayer + func vjpL3() -> (DenseLayer, (DenseLayer) -> Model) { + return (l3, { dl3 in Model(l1: DenseLayer.zero, l2: DenseLayer.zero, l3: dl3) } ) + } } extension Model : Differentiable, VectorNumeric { diff --git a/test/AutoDiff/witness_table_silgen.swift b/test/AutoDiff/witness_table_silgen.swift index 3ae6175af8df8..a823d0fba21d0 100644 --- a/test/AutoDiff/witness_table_silgen.swift +++ b/test/AutoDiff/witness_table_silgen.swift @@ -21,7 +21,11 @@ struct S : Proto, VectorNumeric { typealias TangentVector = S typealias CotangentVector = S + @differentiable(wrt: (self), vjp: vjpP) let p: Float + func vjpP() -> (Float, (Float) -> S) { + return (p, { dp in S(p: dp) }) + } func function1(_ x: Float, _ y: Float) -> Float { return x + y + p