diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index c1b210cd2ad62..1d321d3d1bd9b 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -394,6 +394,8 @@ SIMPLE_DECL_ATTR(TensorFlowGraph, TensorFlowGraph, OnFunc, 82) SIMPLE_DECL_ATTR(TFParameter, TFParameter, OnVar, 83) +SIMPLE_DECL_ATTR(_fieldwiseProductSpace, FieldwiseProductSpace, + OnTypeAlias | UserInaccessible, 84) #undef TYPE_ATTR #undef DECL_ATTR_ALIAS diff --git a/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp b/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp index b11b0fe91cb54..16ee035e50de0 100644 --- a/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp +++ b/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp @@ -669,6 +669,24 @@ struct NestedApplyActivity { SILAutoDiffIndices indices; }; +/// Specifies how we should differentiate a `struct_extract` instruction. +enum class StructExtractDifferentiationStrategy { + // The `struct_extract` is not active, so do not differentiate it. + Inactive, + + // The `struct_extract` is extracting a field from a Differentiable struct + // with @_fieldwiseProductSpace cotangent space. Therefore, differentiate the + // `struct_extract` by setting the adjoint to a vector in the cotangent space + // that is zero except along the direction of the corresponding field. + // + // Fields correspond by matching name. + FieldwiseProductSpace, + + // Differentiate the `struct_extract` by looking up the corresponding getter + // and using its VJP. + Getter +}; + /// A differentiation task, specifying the original function and the /// `[differentiable]` attribute on the function. PrimalGen and AdjointGen /// will synthesize the primal and the adjoint for this task, filling the primal @@ -714,6 +732,10 @@ class DifferentiationTask { /// Note: This is only used when `DifferentiationUseVJP`. DenseMap nestedApplyActivities; + /// Mapping from original `struct_extract` instructions to their strategies. + DenseMap + structExtractDifferentiationStrategies; + /// Cache for associated functions. SILFunction *primal = nullptr; SILFunction *adjoint = nullptr; @@ -810,6 +832,11 @@ class DifferentiationTask { return nestedApplyActivities; } + DenseMap & + getStructExtractDifferentiationStrategies() { + return structExtractDifferentiationStrategies; + } + bool isEqual(const DifferentiationTask &other) const { return original == other.original && attr == other.attr; } @@ -2228,16 +2255,42 @@ class PrimalGenCloner final : public SILClonerWithScopes { } void visitStructExtractInst(StructExtractInst *sei) { + auto &astCtx = getContext().getASTContext(); + auto &structExtractDifferentiationStrategies = + getDifferentiationTask()->getStructExtractDifferentiationStrategies(); + // Special handling logic only applies when the `struct_extract` is active. // If not, just do standard cloning. if (!activityInfo.isActive(sei, synthesis.indices)) { LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n'); + structExtractDifferentiationStrategies.insert( + {sei, StructExtractDifferentiationStrategy::Inactive}); SILClonerWithScopes::visitStructExtractInst(sei); return; } - // This instruction is active. Replace it with a call to the corresponding - // getter's VJP. + // This instruction is active. Determine the appropriate differentiation + // strategy, and use it. + + // Use the FieldwiseProductSpace strategy, if appropriate. + auto *structDecl = sei->getStructDecl(); + auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector); + if (aliasLookup.size() >= 1) { + assert(aliasLookup.size() == 1); + assert(isa(aliasLookup[0])); + auto *aliasDecl = cast(aliasLookup[0]); + if (aliasDecl->getAttrs().hasAttribute()) { + structExtractDifferentiationStrategies.insert( + {sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace}); + SILClonerWithScopes::visitStructExtractInst(sei); + return; + } + } + + // The FieldwiseProductSpace strategy is not appropriate, so use the Getter + // strategy. + structExtractDifferentiationStrategies.insert( + {sei, StructExtractDifferentiationStrategy::Getter}); // Find the corresponding getter and its VJP. auto *getterDecl = sei->getField()->getGetter(); @@ -3596,17 +3649,103 @@ class AdjointEmitter final : public SILInstructionVisitor { } void visitStructExtractInst(StructExtractInst *sei) { - // Replace a `struct_extract` with a call to its pullback. auto loc = remapLocation(sei->getLoc()); + auto &astCtx = getContext().getASTContext(); - // Get the pullback. - auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei); - if (!pullbackField) { - // Inactive `struct_extract` instructions don't need to be cloned into the - // adjoint. + auto &differentiationStrategies = + getDifferentiationTask()->getStructExtractDifferentiationStrategies(); + auto differentiationStrategyLookUp = differentiationStrategies.find(sei); + assert(differentiationStrategyLookUp != differentiationStrategies.end()); + auto differentiationStrategy = differentiationStrategyLookUp->second; + + if (differentiationStrategy == + StructExtractDifferentiationStrategy::Inactive) { assert(!activityInfo.isActive(sei, synthesis.indices)); return; } + + if (differentiationStrategy == + StructExtractDifferentiationStrategy::FieldwiseProductSpace) { + // Compute adjoint as follows: + // y = struct_extract , x + // adj[x] = struct (0, ..., key': adj[y], ..., 0) + // where `key'` is the field in the cotangent space corresponding to + // `key`. + + // Find the decl of the cotangent space type. + auto *structDecl = sei->getStructDecl(); + auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector); + assert(aliasLookup.size() == 1); + assert(isa(aliasLookup[0])); + auto *aliasDecl = cast(aliasLookup[0]); + assert(aliasDecl->getAttrs().hasAttribute()); + auto cotangentVectorTy = + aliasDecl->getUnderlyingTypeLoc().getType()->getCanonicalType(); + assert(!getModule() + .Types.getTypeLowering(cotangentVectorTy) + .isAddressOnly()); + auto cotangentVectorSILTy = + SILType::getPrimitiveObjectType(cotangentVectorTy); + auto *cotangentVectorDecl = + cotangentVectorTy->getStructOrBoundGenericStruct(); + assert(cotangentVectorDecl); + + // Find the corresponding field in the cotangent space. + VarDecl *correspondingField = nullptr; + if (cotangentVectorDecl == structDecl) + correspondingField = sei->getField(); + else { + auto correspondingFieldLookup = + cotangentVectorDecl->lookupDirect(sei->getField()->getName()); + assert(correspondingFieldLookup.size() == 1); + assert(isa(correspondingFieldLookup[0])); + correspondingField = cast(correspondingFieldLookup[0]); + } + assert(correspondingField); + +#ifndef NDEBUG + unsigned numMatchingStoredProperties = 0; + for (auto *storedProperty : cotangentVectorDecl->getStoredProperties()) + if (storedProperty == correspondingField) + numMatchingStoredProperties += 1; + assert(numMatchingStoredProperties == 1); +#endif + + // Compute adjoint. + auto av = getAdjointValue(sei); + switch (av.getKind()) { + case AdjointValue::Kind::Zero: + addAdjointValue(sei->getOperand(), + AdjointValue::getZero(cotangentVectorSILTy)); + break; + case AdjointValue::Kind::Materialized: + case AdjointValue::Kind::Aggregate: { + SmallVector eltVals; + for (auto *field : cotangentVectorDecl->getStoredProperties()) { + if (field == correspondingField) + eltVals.push_back(av); + else + eltVals.push_back( + AdjointValue::getZero(SILType::getPrimitiveObjectType( + field->getType()->getCanonicalType()))); + } + addAdjointValue(sei->getOperand(), + AdjointValue::getAggregate(cotangentVectorSILTy, + eltVals, allocator)); + } + } + + return; + } + + // The only remaining strategy is the getter strategy. + // Replace the `struct_extract` with a call to its pullback. + assert(differentiationStrategy == + StructExtractDifferentiationStrategy::Getter); + + // Get the pullback. + auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei); + assert(pullbackField); SILValue pullback = builder.createStructExtract(loc, primalValueAggregateInAdj, pullbackField); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index dfce46cf1211f..9d9740c864952 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -126,6 +126,7 @@ class AttributeEarlyChecker : public AttributeVisitor { IGNORED_ATTR(CompilerEvaluable) IGNORED_ATTR(TensorFlowGraph) IGNORED_ATTR(TFParameter) + IGNORED_ATTR(FieldwiseProductSpace) #undef IGNORED_ATTR // @noreturn has been replaced with a 'Never' return type. @@ -884,6 +885,7 @@ class AttributeChecker : public AttributeVisitor { void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr); void visitTensorFlowGraphAttr(TensorFlowGraphAttr *attr); void visitTFParameterAttr(TFParameterAttr *attr); + void visitFieldwiseProductSpaceAttr(FieldwiseProductSpaceAttr *attr); }; } // end anonymous namespace @@ -2705,6 +2707,19 @@ void AttributeChecker::visitTFParameterAttr(TFParameterAttr *attr) { } } +void AttributeChecker::visitFieldwiseProductSpaceAttr( + FieldwiseProductSpaceAttr *attr) { + // If we make this attribute user-facing, we'll need to do various checks. + // - check that this attribute is on a Tangent/Cotangent type alias + // - check that we can access the raw fields of the Tangent/Cotangent from + // this module (e.g. the Tangent can't be a public resilient struct + // defined in a different module). + // - check that the stored properties of the Tangent/Cotangent match + // + // If we don't make this attribute user-facing, we can avoid doing checks + // here: the assertions in TFDifferentiation suffice. +} + void TypeChecker::checkDeclAttributes(Decl *D) { AttributeChecker Checker(*this, D); diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 3354082e4cd92..79e7642066c5e 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -1218,6 +1218,7 @@ namespace { UNINTERESTING_ATTR(CompilerEvaluable) UNINTERESTING_ATTR(TensorFlowGraph) UNINTERESTING_ATTR(TFParameter) + UNINTERESTING_ATTR(FieldwiseProductSpace) // These can't appear on overridable declarations. UNINTERESTING_ATTR(Prefix) diff --git a/test/AutoDiff/e2e_differentiable_property.swift b/test/AutoDiff/e2e_differentiable_property.swift index 4078ae67a65b6..4e29ee90d2951 100644 --- a/test/AutoDiff/e2e_differentiable_property.swift +++ b/test/AutoDiff/e2e_differentiable_property.swift @@ -8,26 +8,13 @@ import StdlibUnittest var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty") -struct TangentSpace { +struct TangentSpace : VectorNumeric { let dx, dy: Float } -extension TangentSpace : Differentiable, VectorNumeric { +extension TangentSpace : Differentiable { typealias TangentVector = TangentSpace typealias CotangentVector = TangentSpace - typealias Scalar = Float - static var zero: TangentSpace { - return TangentSpace(dx: 0, dy: 0) - } - static func + (lhs: TangentSpace, rhs: TangentSpace) -> TangentSpace { - return TangentSpace(dx: lhs.dx + rhs.dx, dy: lhs.dy + rhs.dy) - } - static func - (lhs: TangentSpace, rhs: TangentSpace) -> TangentSpace { - return TangentSpace(dx: lhs.dx - rhs.dx, dy: lhs.dy - rhs.dy) - } - static func * (lhs: Float, rhs: TangentSpace) -> TangentSpace { - return TangentSpace(dx: lhs * rhs.dx, dy: lhs * rhs.dy) - } } struct Space { @@ -83,4 +70,54 @@ E2EDifferentiablePropertyTests.test("stored property") { expectEqual(expectedGrad, actualGrad) } +struct ProductSpaceSelfTangent : VectorNumeric { + let x, y: Float +} + +extension ProductSpaceSelfTangent : Differentiable { + @_fieldwiseProductSpace + typealias TangentVector = ProductSpaceSelfTangent + @_fieldwiseProductSpace + typealias CotangentVector = ProductSpaceSelfTangent +} + +E2EDifferentiablePropertyTests.test("fieldwise product space, self tangent") { + let actualGrad = gradient(at: ProductSpaceSelfTangent(x: 0, y: 0)) { (point: ProductSpaceSelfTangent) -> Float in + return 5 * point.y + } + let expectedGrad = ProductSpaceSelfTangent(x: 0, y: 5) + expectEqual(expectedGrad, actualGrad) +} + +struct ProductSpaceOtherTangentTangentSpace : VectorNumeric { + let x, y: Float +} + +extension ProductSpaceOtherTangentTangentSpace : Differentiable { + typealias TangentVector = ProductSpaceOtherTangentTangentSpace + typealias CotangentVector = ProductSpaceOtherTangentTangentSpace +} + +struct ProductSpaceOtherTangent { + let x, y: Float +} + +extension ProductSpaceOtherTangent : Differentiable { + @_fieldwiseProductSpace + typealias TangentVector = ProductSpaceOtherTangentTangentSpace + @_fieldwiseProductSpace + typealias CotangentVector = ProductSpaceOtherTangentTangentSpace + func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent { + return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y) + } +} + +E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") { + let actualGrad = gradient(at: ProductSpaceOtherTangent(x: 0, y: 0)) { (point: ProductSpaceOtherTangent) -> Float in + return 7 * point.y + } + let expectedGrad = ProductSpaceOtherTangentTangentSpace(x: 0, y: 7) + expectEqual(expectedGrad, actualGrad) +} + runAllTests()