diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 8b8d5f5414ba9..5c1a434d340ce 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -421,10 +421,8 @@ DECL_ATTR(differentiating, Differentiating, SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable, OnAccessor | OnFunc | OnConstructor | OnSubscript, /* Not serialized */ 90) -SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable, - OnNominalType | UserInaccessible, 91) SIMPLE_DECL_ATTR(noDerivative, NoDerivative, - OnVar, 92) + OnVar, 91) #undef TYPE_ATTR #undef DECL_ATTR_ALIAS diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index fa47890fb971a..4afed5c78da76 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -447,6 +447,9 @@ NOTE(autodiff_opaque_function_not_differentiable,none, "opaque non-'@differentiable' function is not differentiable", ()) NOTE(autodiff_property_not_differentiable,none, "property is not differentiable", ()) +NOTE(autodiff_stored_property_no_corresponding_tangent,none, + "property cannot be differentiated because '%0.TangentVector' does not " + "have a member named '%1'", (StringRef, StringRef)) NOTE(autodiff_value_defined_here,none, "value defined here", ()) NOTE(autodiff_when_differentiating_function_call,none, diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 2d4fd2696027f..083cc724fb6cf 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, "layout requirement are not supported by '@differentiable' attribute", ()) ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) +ERROR(differentiable_attr_stored_property_variable_unsupported,none, + "'jvp:' or 'vjp:' cannot be specified for stored properties", ()) NOTE(protocol_witness_missing_specific_differentiable_attr,none, "candidate is missing attribute '%0'", (StringRef)) @@ -2806,11 +2808,6 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none, "'@noDerivative' is only allowed on stored properties in structure types " "that declare a conformance to 'Differentiable'", ()) -// @_fieldwiseDifferentiable attribute -ERROR(fieldwise_differentiable_only_on_differentiable_structs,none, - "'@_fieldwiseDifferentiable' is only allowed on structure types that " - "conform to 'Differentiable'", ()) - //------------------------------------------------------------------------------ // MARK: Type Check Expressions //------------------------------------------------------------------------------ diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index fc31564e32fb0..d1087f152ddf2 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -726,24 +726,6 @@ struct NestedApplyInfo { Optional originalPullbackType; }; -/// Specifies how we should differentiate a `struct_extract` instruction. -enum class StructExtractDifferentiationStrategy { - // The `struct_extract` is not active, so do not differentiate it. - Inactive, - - // The `struct_extract` is extracting a field from a Differentiable struct - // with @_fieldwiseProductSpace tangent space. Therefore, differentiate the - // `struct_extract` by setting the adjoint to a vector in the tangent space - // that is zero except along the direction of the corresponding field. - // - // Fields correspond by matching name. - Fieldwise, - - // Differentiate the `struct_extract` by looking up the corresponding getter - // and using its VJP. - Getter -}; - static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, DifferentiationInvoker invoker) { invoker.print(os); @@ -904,11 +886,6 @@ class ADContext { /// `NestedApplyInfo`s. DenseMap nestedApplyInfo; - /// Mapping from original `struct_extract` and `struct_element_addr` - /// instructions to their strategies. - DenseMap - structExtractDifferentiationStrategies; - /// List of generated functions (JVPs, VJPs, adjoints, and thunks). /// Saved for deletion during cleanup. SmallVector generatedFunctions; @@ -970,11 +947,6 @@ class ADContext { return nestedApplyInfo; } - DenseMap - &getStructExtractDifferentiationStrategies() { - return structExtractDifferentiationStrategies; - } - SmallVector &getGeneratedFunctions() { return generatedFunctions; } @@ -1192,9 +1164,9 @@ class ADContext { SILDifferentiableAttr *attr, StringRef name, AutoDiffAssociatedFunctionKind kind); - template + template InFlightDiagnostic diagnose(SourceLoc loc, Diag diag, - U &&... args) const { + U &&...args) const { return getASTContext().Diags.diagnose(loc, diag, std::forward(args)...); } @@ -1202,23 +1174,26 @@ class ADContext { /// parent function, emits a "not differentiable" error based on the task. If /// the task is indirect, emits notes all the way up to the outermost task, /// and emits an error at the outer task. Otherwise, emits an error directly. + template InFlightDiagnostic emitNondifferentiabilityError( SILInstruction *inst, DifferentiationInvoker invoker, - Optional> diag = None); + Diag diag, U &&...args); /// Given a value and a differentiation task associated with the parent /// function, emits a "not differentiable" error based on the task. If the /// task is indirect, emits notes all the way up to the outermost task, and /// emits an error at the outer task. Otherwise, emits an error directly. + template InFlightDiagnostic emitNondifferentiabilityError( SILValue value, DifferentiationInvoker invoker, - Optional> diag = None); + Diag diag, U &&...args); /// Emit a "not differentiable" error based on the given differentiation task /// and diagnostic. + template InFlightDiagnostic emitNondifferentiabilityError( SourceLoc loc, DifferentiationInvoker invoker, - Optional> diag = None); + Diag diag, U &&...args); }; } // end anonymous namespace @@ -1226,36 +1201,41 @@ ADContext::ADContext(SILModuleTransform &transform) : transform(transform), module(*transform.getModule()), passManager(*transform.getPassManager()) {} +template InFlightDiagnostic ADContext::emitNondifferentiabilityError(SILValue value, DifferentiationInvoker invoker, - Optional> diag) { + Diag diag, U &&...args) { LLVM_DEBUG({ getADDebugStream() << "Diagnosing non-differentiability.\n"; getADDebugStream() << "For value:\n" << value; getADDebugStream() << "With invoker:\n" << invoker << '\n'; }); auto valueLoc = value.getLoc().getSourceLoc(); - return emitNondifferentiabilityError(valueLoc, invoker, diag); + return emitNondifferentiabilityError(valueLoc, invoker, diag, + std::forward(args)...); } +template InFlightDiagnostic ADContext::emitNondifferentiabilityError(SILInstruction *inst, DifferentiationInvoker invoker, - Optional> diag) { + Diag diag, U &&...args) { LLVM_DEBUG({ getADDebugStream() << "Diagnosing non-differentiability.\n"; getADDebugStream() << "For instruction:\n" << *inst; getADDebugStream() << "With invoker:\n" << invoker << '\n'; }); auto instLoc = inst->getLoc().getSourceLoc(); - return emitNondifferentiabilityError(instLoc, invoker, diag); + return emitNondifferentiabilityError(instLoc, invoker, diag, + std::forward(args)...); } +template InFlightDiagnostic ADContext::emitNondifferentiabilityError(SourceLoc loc, DifferentiationInvoker invoker, - Optional> diag) { + Diag diag, U &&...args) { switch (invoker.getKind()) { // For `autodiff_function` instructions: if the `autodiff_function` // instruction comes from a differential operator, emit an error on the @@ -1266,12 +1246,10 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, if (auto *expr = findDifferentialOperator(inst)) { diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error) .highlight(expr->getSubExpr()->getSourceRange()); - return diagnose(loc, - diag.getValueOr(diag::autodiff_expression_not_differentiable_note)); + return diagnose(loc, diag, std::forward(args)...); } diagnose(loc, diag::autodiff_expression_not_differentiable_error); - return diagnose(loc, - diag.getValueOr(diag::autodiff_expression_not_differentiable_note)); + return diagnose(loc, diag, std::forward(args)...); } // For `[differentiable]` attributes, try to find an AST function declaration @@ -1299,8 +1277,7 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, if (!foundAttr) diagnose(original->getLocation().getSourceLoc(), diag::autodiff_function_not_differentiable_error); - return diagnose(loc, - diag.getValueOr(diag::autodiff_expression_not_differentiable_note)); + return diagnose(loc, diag, std::forward(args)...); } // For indirect differentiation, emit a "not differentiable" note on the @@ -1313,9 +1290,9 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc, std::tie(inst, attr) = invoker.getIndirectDifferentiation(); auto invokerLookup = invokers.find(attr); assert(invokerLookup != invokers.end() && "Expected parent invoker"); - emitNondifferentiabilityError(inst, invokerLookup->second, None); - return diagnose(loc, - diag.getValueOr(diag::autodiff_when_differentiating_function_call)); + emitNondifferentiabilityError(inst, invokerLookup->second, + diag::autodiff_expression_not_differentiable_note); + return diagnose(loc, diag::autodiff_when_differentiating_function_call); } } } @@ -1564,18 +1541,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, } // Handle `struct_extract` and `struct_element_addr` instructions. -// - If the field is marked `@noDerivative` and belongs to a -// `@_fieldwiseDifferentiable` struct, do not set the result as varied because -// it is not in the set of differentiable variables. +// - If the field is marked `@noDerivative`, do not set the result as varied +// because it is not in the set of differentiable variables. // - Otherwise, propagate variedness from operand to result as usual. #define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \ else if (auto *sei = dyn_cast(&inst)) { \ if (isVaried(sei->getOperand(), i)) { \ auto hasNoDeriv = sei->getField()->getAttrs() \ .hasAttribute(); \ - auto structIsFieldwiseDiffable = sei->getStructDecl()->getAttrs() \ - .hasAttribute(); \ - if (!(hasNoDeriv && structIsFieldwiseDiffable)) \ + if (!hasNoDeriv) \ setVaried(sei, i); \ } \ } @@ -2842,20 +2816,6 @@ static void collectMinimalIndicesForFunctionCall( } } -// Returns the associated function with name `assocFnName`. If the function -// cannot be found, returns a reference to an external asssociated function -// declaration. -static SILFunction *getAssociatedFunction( - ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, - AutoDiffAssociatedFunctionKind kind, StringRef assocFnName) { - auto &module = context.getModule(); - auto *assocFn = module.lookUpFunction(assocFnName); - if (!assocFn) - assocFn = context.declareExternalAssociatedFunction( - original, attr, assocFnName, kind); - return assocFn; -} - namespace { class VJPEmitter final : public TypeSubstCloner { @@ -3099,7 +3059,8 @@ class VJPEmitter final } void visitSILInstruction(SILInstruction *inst) { - context.emitNondifferentiabilityError(inst, invoker); + context.emitNondifferentiabilityError(inst, invoker, + diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } @@ -3221,164 +3182,6 @@ class VJPEmitter final ri->getLoc(), joinElements(directResults, builder, loc)); } - void visitStructExtractInst(StructExtractInst *sei) { - auto &strategies = context.getStructExtractDifferentiationStrategies(); - // Special handling logic only applies when the `struct_extract` is active. - // If not, just do standard cloning. - if (!activityInfo.isActive(sei, getIndices())) { - LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n'); - strategies.insert( - {sei, StructExtractDifferentiationStrategy::Inactive}); - SILClonerWithScopes::visitStructExtractInst(sei); - return; - } - // This instruction is active. Determine the appropriate differentiation - // strategy, and use it. - auto *structDecl = sei->getStructDecl(); - if (structDecl->getAttrs().hasAttribute()) { - strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise; - SILClonerWithScopes::visitStructExtractInst(sei); - return; - } - // The FieldwiseProductSpace strategy is not appropriate, so use the Getter - // strategy. - strategies[sei] = StructExtractDifferentiationStrategy::Getter; - // Find the corresponding getter and its VJP. - auto *getterDecl = sei->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - if (!getterFn) { - context.emitNondifferentiabilityError( - sei, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } - SILAutoDiffIndices indices(/*source*/ 0, - AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); - auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); - if (!attr) { - context.emitNondifferentiabilityError( - sei, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } - // Reference and apply the VJP. - auto loc = sei->getLoc(); - auto *getterVJP = getAssociatedFunction( - context, getterFn, attr, AutoDiffAssociatedFunctionKind::VJP, - attr->getVJPName()); - assert(getterVJP && "Expected to find getter VJP"); - auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP); - auto *getterVJPApply = getBuilder().createApply( - loc, getterVJPRef, - getOpSubstitutionMap(getterVJP->getForwardingSubstitutionMap()), - /*args*/ {getOpValue(sei->getOperand())}, /*isNonThrowing*/ false); - // Extract direct results from `getterVJPApply`. - SmallVector vjpDirectResults; - extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults); - // Map original result. - auto originalDirectResults = - ArrayRef(vjpDirectResults).drop_back(1); - auto originalDirectResult = joinElements(originalDirectResults, - getBuilder(), - getterVJPApply->getLoc()); - mapValue(sei, originalDirectResult); - // Checkpoint the pullback. - auto pullback = vjpDirectResults.back(); - // TODO: Check whether it's necessary to reabstract getter pullbacks. - pullbackInfo.addPullbackDecl(sei, getOpType(pullback->getType())); - pullbackValues[sei->getParent()].push_back(pullback); - } - - void visitStructElementAddrInst(StructElementAddrInst *seai) { - auto &strategies = context.getStructExtractDifferentiationStrategies(); - // Special handling logic only applies when the `struct_element_addr` is - // active. If not, just do standard cloning. - if (!activityInfo.isActive(seai, getIndices())) { - LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *seai << '\n'); - strategies[seai] =StructExtractDifferentiationStrategy::Inactive; - SILClonerWithScopes::visitStructElementAddrInst(seai); - return; - } - // This instruction is active. Determine the appropriate differentiation - // strategy, and use it. - auto *structDecl = seai->getStructDecl(); - if (structDecl->getAttrs().hasAttribute()) { - strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise; - SILClonerWithScopes::visitStructElementAddrInst(seai); - return; - } - // The FieldwiseProductSpace strategy is not appropriate, so use the Getter - // strategy. - strategies[seai] = StructExtractDifferentiationStrategy::Getter; - // Find the corresponding getter and its VJP. - auto *getterDecl = seai->getField()->getGetter(); - assert(getterDecl); - auto *getterFn = getModule().lookUpFunction( - SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); - if (!getterFn) { - context.emitNondifferentiabilityError( - seai, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } - SILAutoDiffIndices indices(/*source*/ 0, - AutoDiffIndexSubset::getDefault(getASTContext(), 1, true)); - auto *attr = context.lookUpDifferentiableAttr(getterFn, indices); - if (!attr) { - context.emitNondifferentiabilityError( - seai, invoker, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } - // Set generic context scope before getting VJP function type. - auto vjpGenSig = SubsMap.getGenericSignature() - ? SubsMap.getGenericSignature()->getCanonicalSignature() - : nullptr; - Lowering::GenericContextScope genericContextScope( - context.getTypeConverter(), vjpGenSig); - // Reference the getter VJP. - auto loc = seai->getLoc(); - auto *getterVJP = getModule().lookUpFunction(attr->getVJPName()); - assert(getterVJP && "Expected to find getter VJP"); - auto vjpFnTy = getterVJP->getLoweredFunctionType(); - auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP); - // Store getter VJP arguments and indirect result buffers. - SmallVector vjpArgs; - SmallVector vjpIndirectResults; - for (auto indRes : vjpFnTy->getIndirectFormalResults()) { - auto *alloc = getBuilder().createAllocStack( - loc, getOpType(indRes.getSILStorageType())); - vjpArgs.push_back(alloc); - vjpIndirectResults.push_back(alloc); - } - vjpArgs.push_back(getOpValue(seai->getOperand())); - // Apply the getter VJP. - auto *getterVJPApply = getBuilder().createApply( - loc, getterVJPRef, - getOpSubstitutionMap(getterVJP->getForwardingSubstitutionMap()), - vjpArgs, /*isNonThrowing*/ false); - // Collect all results from `getterVJPApply` in type-defined order. - SmallVector vjpDirectResults; - extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults); - SmallVector allResults; - collectAllActualResultsInTypeOrder( - getterVJPApply, vjpDirectResults, - getterVJPApply->getIndirectSILResults(), allResults); - // Deallocate VJP indirect results. - for (auto alloc : vjpIndirectResults) - getBuilder().createDeallocStack(loc, alloc); - auto originalDirectResult = allResults[indices.source]; - // Map original result. - mapValue(seai, originalDirectResult); - // Checkpoint the pullback. - SILValue pullback = vjpDirectResults.back(); - // TODO: Check whether it's necessary to reabstract getter pullbacks. - pullbackInfo.addPullbackDecl(seai, getOpType(pullback->getType())); - pullbackValues[seai->getParent()].push_back(pullback); - } - // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai) { @@ -3438,7 +3241,8 @@ class VJPEmitter final s << "}\n";); // FIXME: We don't support multiple active results yet. if (activeResultIndices.size() > 1) { - context.emitNondifferentiabilityError(ai, invoker); + context.emitNondifferentiabilityError( + ai, invoker, diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return; } @@ -4558,7 +4362,8 @@ class AdjointEmitter final : public SILInstructionVisitor { void visitSILInstruction(SILInstruction *inst) { LLVM_DEBUG(getADDebugStream() << "Unhandled instruction in adjoint emitter: " << *inst); - getContext().emitNondifferentiabilityError(inst, getInvoker()); + getContext().emitNondifferentiabilityError(inst, getInvoker(), + diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } @@ -4719,56 +4524,52 @@ class AdjointEmitter final : public SILInstructionVisitor { break; case AdjointValueKind::Concrete: { auto adjStruct = materializeAdjointDirect(std::move(av), loc); - if (structDecl->getAttrs().hasAttribute()) { - // Find the struct `TangentVector` type. - auto structTy = remapType(si->getType()).getASTType(); - auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( - LookUpConformanceInModule(getModule().getSwiftModule())) - ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering( - tangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); - auto *tangentVectorDecl = - tangentVectorTy->getStructOrBoundGenericStruct(); - assert(tangentVectorDecl); - - // Accumulate adjoints for the fields of the `struct` operand. - for (auto *field : structDecl->getStoredProperties()) { - // There does not exist a corresponding tangent field for original - // fields with `@noDerivative` attribute. Emit an error. - if (field->getAttrs().hasAttribute()) - continue; - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - if (tangentVectorDecl == structDecl) - tanField = field; - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(field->getName()); - assert(tanFieldLookup.size() == 1); - tanField = cast(tanFieldLookup.front()); + // Find the struct `TangentVector` type. + auto structTy = remapType(si->getType()).getASTType(); + auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( + LookUpConformanceInModule(getModule().getSwiftModule())) + ->getType()->getCanonicalType(); + assert(!getModule().Types.getTypeLowering( + tangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); + + // Accumulate adjoints for the fields of the `struct` operand. + for (auto *field : structDecl->getStoredProperties()) { + // There does not exist a corresponding tangent field for original + // fields with `@noDerivative` attribute. Emit an error. + if (field->getAttrs().hasAttribute()) + continue; + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + if (tangentVectorDecl == structDecl) + tanField = field; + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(field->getName()); + if (tanFieldLookup.empty()) { + getContext().emitNondifferentiabilityError( + si, getInvoker(), + diag::autodiff_stored_property_no_corresponding_tangent, + tangentVectorDecl->getNameStr(), field->getNameStr()); + errorOccurred = true; + return; } - auto *adjStructElt = - builder.createStructExtract(loc, adjStruct, tanField); - addAdjointValue( - si->getFieldValue(field), - makeConcreteAdjointValue(ValueWithCleanup( - adjStructElt, makeCleanup(adjStructElt, emitCleanup)))); + tanField = cast(tanFieldLookup.front()); } - } else { - // FIXME(TF-21): If `TangentVector` is not marked - // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. - llvm_unreachable("Unhandled. Are you trying to differentiate a " - "memberwise initializer?"); + auto *adjStructElt = + builder.createStructExtract(loc, adjStruct, tanField); + addAdjointValue( + si->getFieldValue(field), + makeConcreteAdjointValue(ValueWithCleanup( + adjStructElt, makeCleanup(adjStructElt, emitCleanup)))); } break; } case AdjointValueKind::Aggregate: { - // FIXME(TF-21): If `TangentVector` is not marked - // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. - // for (auto pair : llvm::zip(si->getElements(), av.getAggregateElements())) - // addAdjointValue(std::get<0>(pair), std::get<1>(pair)); llvm_unreachable("Unhandled. Are you trying to differentiate a " "memberwise initializer?"); } @@ -4779,96 +4580,69 @@ class AdjointEmitter final : public SILInstructionVisitor { assert(!sei->getField()->getAttrs().hasAttribute() && "`struct_extract` with `@noDerivative` field should not be " "differentiated; activity analysis should not marked as varied"); - auto loc = sei->getLoc(); - auto &differentiationStrategies = - getContext().getStructExtractDifferentiationStrategies(); - auto strategy = differentiationStrategies.lookup(sei); - switch (strategy) { - case StructExtractDifferentiationStrategy::Inactive: - assert(!getActivityInfo().isActive(sei, getIndices())); - return; - case StructExtractDifferentiationStrategy::Fieldwise: { - // Compute adjoint as follows: - // y = struct_extract x, #key - // adj[x] += struct (0, ..., #key': adj[y], ..., 0) - // where `#key'` is the field in the tangent space corresponding to - // `#key`. - auto structTy = remapType(sei->getOperand()->getType()).getASTType(); - auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( - LookUpConformanceInModule(getModule().getSwiftModule())) - ->getType()->getCanonicalType(); - assert(!getModule().Types.getTypeLowering( - tangentVectorTy, ResilienceExpansion::Minimal) - .isAddressOnly()); - auto tangentVectorSILTy = - SILType::getPrimitiveObjectType(tangentVectorTy); - auto *tangentVectorDecl = - tangentVectorTy->getStructOrBoundGenericStruct(); - assert(tangentVectorDecl); - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - // If the tangent space is the original struct, then field is the same. - if (tangentVectorDecl == sei->getStructDecl()) - tanField = sei->getField(); - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(sei->getField()->getName()); - assert(tanFieldLookup.size() == 1); - tanField = cast(tanFieldLookup.front()); - } - // Accumulate adjoint for the `struct_extract` operand. - auto av = takeAdjointValue(sei); - switch (av.getKind()) { - case AdjointValueKind::Zero: - addAdjointValue(sei->getOperand(), - makeZeroAdjointValue(tangentVectorSILTy)); - break; - case AdjointValueKind::Concrete: - case AdjointValueKind::Aggregate: { - SmallVector eltVals; - for (auto *field : tangentVectorDecl->getStoredProperties()) { - if (field == tanField) { - eltVals.push_back(av); - } else { - auto substMap = tangentVectorTy->getMemberSubstitutionMap( - field->getModuleContext(), field); - auto fieldTy = field->getType().subst(substMap); - auto fieldSILTy = - getContext().getTypeConverter().getLoweredType( - fieldTy, ResilienceExpansion::Minimal); - assert(fieldSILTy.isObject()); - eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); - } - } - addAdjointValue(sei->getOperand(), - makeAggregateAdjointValue(tangentVectorSILTy, eltVals)); - } + // Compute adjoint as follows: + // y = struct_extract x, #key + // adj[x] += struct (0, ..., #key': adj[y], ..., 0) + // where `#key'` is the field in the tangent space corresponding to + // `#key`. + auto structTy = remapType(sei->getOperand()->getType()).getASTType(); + auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( + LookUpConformanceInModule(getModule().getSwiftModule())) + ->getType()->getCanonicalType(); + assert(!getModule().Types.getTypeLowering( + tangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); + auto tangentVectorSILTy = + SILType::getPrimitiveObjectType(tangentVectorTy); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + // If the tangent space is the original struct, then field is the same. + if (tangentVectorDecl == sei->getStructDecl()) + tanField = sei->getField(); + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(sei->getField()->getName()); + if (tanFieldLookup.empty()) { + getContext().emitNondifferentiabilityError( + sei, getInvoker(), + diag::autodiff_stored_property_no_corresponding_tangent, + sei->getStructDecl()->getNameStr(), + sei->getField()->getNameStr()); + errorOccurred = true; + return; } - return; + tanField = cast(tanFieldLookup.front()); } - case StructExtractDifferentiationStrategy::Getter: { - // Get the pullback. - auto *pullbackField = getPullbackInfo().lookUpPullbackDecl(sei); - assert(pullbackField); - auto pullback = builder.createStructExtract( - loc, getAdjointBlockPullbackStructArgument(sei->getParent()), - pullbackField); - - // Construct the pullback arguments. - auto av = takeAdjointValue(sei); - auto vector = materializeAdjointDirect(std::move(av), loc); - - // Call the pullback. - auto *pullbackCall = builder.createApply( - loc, pullback, SubstitutionMap(), {vector}, /*isNonThrowing*/ false); - assert(!pullbackCall->hasIndirectResults()); - - // Accumulate adjoint for the `struct_extract` operand. + // Accumulate adjoint for the `struct_extract` operand. + auto av = takeAdjointValue(sei); + switch (av.getKind()) { + case AdjointValueKind::Zero: addAdjointValue(sei->getOperand(), - makeConcreteAdjointValue( - ValueWithCleanup(pullbackCall, vector.getCleanup()))); + makeZeroAdjointValue(tangentVectorSILTy)); break; + case AdjointValueKind::Concrete: + case AdjointValueKind::Aggregate: { + SmallVector eltVals; + for (auto *field : tangentVectorDecl->getStoredProperties()) { + if (field == tanField) { + eltVals.push_back(av); + } else { + auto substMap = tangentVectorTy->getMemberSubstitutionMap( + field->getModuleContext(), field); + auto fieldTy = field->getType().subst(substMap); + auto fieldSILTy = + getContext().getTypeConverter().getLoweredType( + fieldTy, ResilienceExpansion::Minimal); + assert(fieldSILTy.isObject()); + eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); + } + } + addAdjointValue(sei->getOperand(), + makeAggregateAdjointValue(tangentVectorSILTy, eltVals)); } } } diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 43fbdc18560bf..d3165a51d9e3e 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -690,8 +690,6 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(), /*Inherited*/ C.AllocateCopy(inherited), /*GenericParams*/ {}, parentDC); - structDecl->getAttrs().add( - new (C) FieldwiseDifferentiableAttr(/*implicit*/ true)); structDecl->setImplicit(); structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); @@ -960,12 +958,6 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, if (!getAssociatedType(member, parentDC, id)) return nullptr; - // Since associated types will be derived, we make this struct a fieldwise - // differentiable type. - if (!nominal->getAttrs().hasAttribute()) - nominal->getAttrs().add( - new (C) FieldwiseDifferentiableAttr(/*implicit*/ true)); - // Prevent re-synthesis during repeated calls. // FIXME: Investigate why this is necessary to prevent duplicate synthesis. auto lookup = nominal->lookupDirect(id); diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 60e522d32df97..d2c307a486d0e 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -135,7 +135,6 @@ class AttributeEarlyChecker : public AttributeVisitor { IGNORED_ATTR(Differentiable) IGNORED_ATTR(Differentiating) IGNORED_ATTR(CompilerEvaluable) - IGNORED_ATTR(FieldwiseDifferentiable) IGNORED_ATTR(NoDerivative) #undef IGNORED_ATTR @@ -872,7 +871,6 @@ class AttributeChecker : public AttributeVisitor { void visitDifferentiableAttr(DifferentiableAttr *attr); void visitDifferentiatingAttr(DifferentiatingAttr *attr); void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr); - void visitFieldwiseDifferentiableAttr(FieldwiseDifferentiableAttr *attr); void visitNoDerivativeAttr(NoDerivativeAttr *attr); }; } // end anonymous namespace @@ -2887,6 +2885,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AbstractFunctionDecl *original = dyn_cast(D); if (auto *asd = dyn_cast(D)) { + if (asd->getImplInfo().isSimpleStored() && + (attr->getJVP() || attr->getVJP())) { + diagnoseAndRemoveAttr(attr, + diag::differentiable_attr_stored_property_variable_unsupported); + return; + } // When used directly on a storage decl (stored/computed property or // subscript), the getter is currently inferred to be `@differentiable`. // TODO(TF-129): Infer setter to also be `@differentiable` after @@ -3570,23 +3574,6 @@ void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) { // TypeChecker::checkFunctionBodyCompilerEvaluable(). } -// SWIFT_ENABLE_TENSORFLOW -void AttributeChecker::visitFieldwiseDifferentiableAttr( - FieldwiseDifferentiableAttr *attr) { - auto *structDecl = dyn_cast(D); - if (!structDecl) { - diagnoseAndRemoveAttr(attr, - diag::fieldwise_differentiable_only_on_differentiable_structs); - return; - } - if (!conformsToDifferentiableInModule( - structDecl->getDeclaredInterfaceType(), D->getModuleContext())) { - diagnoseAndRemoveAttr(attr, - diag::fieldwise_differentiable_only_on_differentiable_structs); - return; - } -} - // SWIFT_ENABLE_TENSORFLOW void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) { auto *vd = dyn_cast(D); diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 6e340460c83f8..ff25f8c4d4b15 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -1303,7 +1303,6 @@ namespace { UNINTERESTING_ATTR(Differentiable) UNINTERESTING_ATTR(Differentiating) UNINTERESTING_ATTR(CompilerEvaluable) - UNINTERESTING_ATTR(FieldwiseDifferentiable) UNINTERESTING_ATTR(NoDerivative) // These can't appear on overridable declarations. diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index 46e1ec3aa875e..0e57d4c6dbe69 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -29,17 +29,23 @@ struct S { } extension S : Differentiable, VectorNumeric { + struct TangentVector: Differentiable, VectorNumeric { + var dp: Float + } + typealias AllDifferentiableVariables = S static var zero: S { return S(p: 0) } typealias Scalar = Float static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) } static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) } static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) } - typealias TangentVector = S + func moved(along direction: TangentVector) -> S { + return S(p: p + direction.dp) + } } // expected-error @+2 {{function is not differentiable}} -// expected-note @+1 {{property is not differentiable}} +// expected-note @+1 {{property cannot be differentiated because 'S.TangentVector' does not have a member named 'p'}} _ = gradient(at: S(p: 0)) { s in 2 * s.p } struct NoDerivativeProperty : Differentiable { diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index 35126fa38c58f..29e669fe17116 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -6,11 +6,11 @@ public struct Foo : Differentiable { public var a: Float } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable { +// CHECK-AST-LABEL: public struct Foo : Differentiable { // CHECK-AST: @differentiable // CHECK-AST: public var a: Float // CHECK-AST: internal init(a: Float) -// CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables +// CHECK-AST: public struct AllDifferentiableVariables // CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables // CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables // CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables @@ -25,7 +25,7 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in x.a + x.a } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable { +// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable { // CHECK-AST: internal var a: Float // CHECK-AST: internal init(a: Float) // CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf @@ -36,11 +36,11 @@ struct TestNoDerivative : Differentiable { @noDerivative var technicallyDifferentiable: Float } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestNoDerivative : Differentiable { +// CHECK-AST-LABEL: internal struct TestNoDerivative : Differentiable { // CHECK-AST: var w: Float // CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float // CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float) -// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric +// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric // CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables @@ -50,11 +50,11 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable { @noDerivative var technicallyDifferentiable: Float } -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct TestKeyPathIterable : Differentiable, KeyPathIterable { +// CHECK-AST-LABEL: internal struct TestKeyPathIterable : Differentiable, KeyPathIterable { // CHECK-AST: var w: Float // CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float // CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float) -// CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric +// CHECK-AST: internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric // CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables @@ -66,7 +66,7 @@ struct GenericTanMember : Differentiable, AdditiveArithmetic // TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized. // `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized. -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct GenericTanMember : Differentiable, AdditiveArithmetic where T : Differentiable +// CHECK-AST-LABEL: internal struct GenericTanMember : Differentiable, AdditiveArithmetic where T : Differentiable // CHECK-AST: internal var x: T.TangentVector // CHECK-AST: internal init(x: T.TangentVector) // CHECK-AST: internal typealias TangentVector = GenericTanMember @@ -81,7 +81,7 @@ public struct ConditionallyDifferentiable { } extension ConditionallyDifferentiable : Differentiable where T : Differentiable {} -// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct ConditionallyDifferentiable { +// CHECK-AST-LABEL: public struct ConditionallyDifferentiable { // CHECK-AST: @differentiable(wrt: self where T : Differentiable) // CHECK-AST: public let x: T // CHECK-AST: internal init(x: T) diff --git a/test/AutoDiff/differentiable_attr_silgen.swift b/test/AutoDiff/differentiable_attr_silgen.swift index a63a8f84ea114..6f3ce873f6073 100644 --- a/test/AutoDiff/differentiable_attr_silgen.swift +++ b/test/AutoDiff/differentiable_attr_silgen.swift @@ -76,43 +76,6 @@ public func dhasvjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float) // CHECK-LABEL: sil [ossa] @dhasvjp -//===----------------------------------------------------------------------===// -// Stored property -//===----------------------------------------------------------------------===// - -struct DiffStoredProp { - @differentiable(wrt: (self), jvp: storedPropJVP, vjp: storedPropVJP) - let storedProp: Float - - @_silgen_name("storedPropJVP") - func storedPropJVP() -> (Float, (DiffStoredProp) -> Float) { - fatalError("unimplemented") - } - - @_silgen_name("storedPropVJP") - func storedPropVJP() -> (Float, (Float) -> DiffStoredProp) { - fatalError("unimplemented") - } -} - -extension DiffStoredProp : VectorNumeric { - static var zero: DiffStoredProp { fatalError("unimplemented") } - static func + (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp { - fatalError("unimplemented") - } - static func - (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp { - fatalError("unimplemented") - } - typealias Scalar = Float - static func * (lhs: Float, rhs: DiffStoredProp) -> DiffStoredProp { - fatalError("unimplemented") - } -} - -extension DiffStoredProp : Differentiable { - typealias TangentVector = DiffStoredProp -} - //===----------------------------------------------------------------------===// // Computed property //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index b7f30d522bc86..d00f12e4b348f 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -1,7 +1,10 @@ // RUN: %target-swift-frontend -typecheck -verify %s @differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} -let global: Float = 1 +let globalConst: Float = 1 + +@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}} +var globalVar: Float = 1 func testLocalVariables() { // expected-error @+1 {{'_' has no parameters to differentiate with respect to}} @@ -225,25 +228,18 @@ class Foo { } struct JVPStruct { + @differentiable let p: Float - @differentiable(wrt: (self), jvp: storedPropJVP) - let storedImmutableOk: Float - - // expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} - @differentiable(wrt: (self), jvp: storedPropJVP) - let storedImmutableWrongType: Double - - @differentiable(wrt: (self), jvp: storedPropJVP) - var storedMutableOk: Float - - // expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} - @differentiable(wrt: (self), jvp: storedPropJVP) - var storedMutableWrongType: Double + // expected-error @+1 {{'funcJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}} + @differentiable(wrt: (self), jvp: funcJVP) + func funcWrongType() -> Double { + fatalError("unimplemented") + } } extension JVPStruct { - func storedPropJVP() -> (Float, (JVPStruct) -> Float) { + func funcJVP() -> (Float, (JVPStruct) -> Float) { fatalError("unimplemented") } } @@ -383,23 +379,15 @@ func vjpNonDiffResult2(x: Float) -> (Float, Int) { struct VJPStruct { let p: Float - @differentiable(vjp: storedPropVJP) - let storedImmutableOk: Float - - // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} - @differentiable(vjp: storedPropVJP) - let storedImmutableWrongType: Double - - @differentiable(vjp: storedPropVJP) - var storedMutableOk: Float - - // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} - @differentiable(vjp: storedPropVJP) - var storedMutableWrongType: Double + // expected-error @+1 {{'funcVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} + @differentiable(vjp: funcVJP) + func funcWrongType() -> Double { + fatalError("unimplemented") + } } extension VJPStruct { - func storedPropVJP() -> (Float, (Float) -> VJPStruct) { + func funcVJP() -> (Float, (Float) -> VJPStruct) { fatalError("unimplemented") } } diff --git a/test/AutoDiff/differentiating_attr_type_checking.swift b/test/AutoDiff/differentiating_attr_type_checking.swift index 6e6b7c525fd0f..2170a666d96d9 100644 --- a/test/AutoDiff/differentiating_attr_type_checking.swift +++ b/test/AutoDiff/differentiating_attr_type_checking.swift @@ -295,3 +295,23 @@ func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float) func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { return (x, { $0 }) } + +// Test usage of `@differentiable` on a stored property +struct PropertyDiff : Differentiable & AdditiveArithmetic { + // expected-error @+1 {{'jvp:' or 'vjp:' cannot be specified for stored properties}} + @differentiable(vjp: vjpPropertyA) + var a: Float = 1 + typealias TangentVector = PropertyDiff + typealias AllDifferentiableVariables = PropertyDiff + func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) { + (.zero, { _ in .zero }) + } +} + +@differentiable +func f(_ x: PropertyDiff) -> Float { + return x.a +} + +let a = gradient(at: PropertyDiff(), in: f) +print(a) diff --git a/test/AutoDiff/e2e_differentiable_property.swift b/test/AutoDiff/e2e_differentiable_property.swift index a140ba08766f9..2822697e0b8da 100644 --- a/test/AutoDiff/e2e_differentiable_property.swift +++ b/test/AutoDiff/e2e_differentiable_property.swift @@ -9,7 +9,7 @@ import StdlibUnittest var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty") struct TangentSpace : VectorNumeric { - let dx, dy: Float + let x, y: Float } extension TangentSpace : Differentiable { @@ -26,18 +26,13 @@ struct Space { } func vjpX() -> (Float, (Float) -> TangentSpace) { - return (x, { v in TangentSpace(dx: v, dy: 0) } ) + return (x, { v in TangentSpace(x: v, y: 0) } ) } private let storedX: Float - /// `y` is a stored property with a custom vjp for its getter. - @differentiable(vjp: vjpY) - let y: Float - - func vjpY() -> (Float, (Float) -> TangentSpace) { - return (y, { v in TangentSpace(dx: 0, dy: v) }) - } + @differentiable + var y: Float init(x: Float, y: Float) { self.storedX = x @@ -48,7 +43,7 @@ struct Space { extension Space : Differentiable { typealias TangentVector = TangentSpace func moved(along: TangentSpace) -> Space { - return Space(x: x + along.dx, y: y + along.dy) + return Space(x: x + along.x, y: y + along.y) } } @@ -56,7 +51,7 @@ E2EDifferentiablePropertyTests.test("computed property") { let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in return 2 * point.x } - let expectedGrad = TangentSpace(dx: 2, dy: 0) + let expectedGrad = TangentSpace(x: 2, y: 0) expectEqual(expectedGrad, actualGrad) } @@ -64,13 +59,13 @@ E2EDifferentiablePropertyTests.test("stored property") { let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in return 3 * point.y } - let expectedGrad = TangentSpace(dx: 0, dy: 3) + let expectedGrad = TangentSpace(x: 0, y: 3) expectEqual(expectedGrad, actualGrad) } struct GenericMemberWrapper : Differentiable { // Stored property. - @differentiable(vjp: vjpX) + @differentiable var x: T func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) { @@ -86,7 +81,6 @@ E2EDifferentiablePropertyTests.test("generic stored property") { expectEqual(expectedGrad, actualGrad) } -@_fieldwiseDifferentiable struct ProductSpaceSelfTangent : VectorNumeric { let x, y: Float } @@ -111,7 +105,6 @@ extension ProductSpaceOtherTangentTangentSpace : Differentiable { typealias TangentVector = ProductSpaceOtherTangentTangentSpace } -@_fieldwiseDifferentiable struct ProductSpaceOtherTangent { let x, y: Float } diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index c1bf6e9c2b31d..a1fe13a13050f 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -128,4 +128,16 @@ func TF_508_func(x: TF_508_Struct, y: TF_508_Struct) } let TF_508_bp = pullback(at: TF_508_inst, TF_508_inst, in: TF_508_func) +// TF-523 +struct TF_523_Struct : Differentiable & AdditiveArithmetic { + var a: Float = 1 + typealias TangentVector = TF_523_Struct + typealias AllDifferentiableVariables = TF_523_Struct +} + +@differentiable +func TF_523_f(_ x: TF_523_Struct) -> Float { + return x.a * 2 +} + // TODO: add more tests. diff --git a/test/AutoDiff/method.swift b/test/AutoDiff/method.swift index f3c9e15ad725f..c755df508123a 100644 --- a/test/AutoDiff/method.swift +++ b/test/AutoDiff/method.swift @@ -8,8 +8,15 @@ var MethodTests = TestSuite("Method") // ==== Tests with generated adjoint ==== struct Parameter : Equatable { + private let storedX: Float @differentiable(wrt: (self), jvp: jvpX, vjp: vjpX) - let x: Float + var x: Float { + return storedX + } + + init(x: Float) { + storedX = x + } func vjpX() -> (Float, (Float) -> Parameter) { return (x, { dx in Parameter(x: dx) } ) @@ -155,8 +162,15 @@ struct DiffWrtSelf : Differentiable { } struct CustomParameter : Equatable { + let storedX: Float @differentiable(wrt: (self), vjp: vjpX) - let x: Float + var x: Float { + return storedX + } + + init(x: Float) { + storedX = x + } func vjpX() -> (Float, (Float) -> CustomParameter) { return (x, { dx in CustomParameter(x: dx) }) diff --git a/test/AutoDiff/protocol_requirement_autodiff.swift b/test/AutoDiff/protocol_requirement_autodiff.swift index 5aecc510577bd..fedd5f4397bbf 100644 --- a/test/AutoDiff/protocol_requirement_autodiff.swift +++ b/test/AutoDiff/protocol_requirement_autodiff.swift @@ -18,23 +18,14 @@ extension DiffReq where TangentVector : AdditiveArithmetic { struct Quadratic : DiffReq, Equatable { typealias TangentVector = Quadratic - @differentiable(wrt: (self), vjp: vjpA) + @differentiable let a: Float - func vjpA() -> (Float, (Float) -> Quadratic) { - return (a, { da in Quadratic(da, 0, 0) } ) - } - @differentiable(wrt: (self), vjp: vjpB) + @differentiable let b: Float - func vjpB() -> (Float, (Float) -> Quadratic) { - return (b, { db in Quadratic(0, db, 0) } ) - } - @differentiable(wrt: (self), vjp: vjpC) + @differentiable let c: Float - func vjpC() -> (Float, (Float) -> Quadratic) { - return (c, { dc in Quadratic(0, 0, dc) } ) - } init(_ a: Float, _ b: Float, _ c: Float) { self.a = a diff --git a/test/AutoDiff/separate_cotangent_type.swift b/test/AutoDiff/separate_cotangent_type.swift index b37e869c82965..8fbe76ecc12fd 100644 --- a/test/AutoDiff/separate_cotangent_type.swift +++ b/test/AutoDiff/separate_cotangent_type.swift @@ -10,7 +10,6 @@ import Glibc var SeparateTangentTypeTests = TestSuite("SeparateTangentType") -@_fieldwiseDifferentiable struct DifferentiableSubset : Differentiable { @differentiable(wrt: self) var w: Float @@ -18,7 +17,6 @@ struct DifferentiableSubset : Differentiable { var b: Float @noDerivative var flag: Bool - @_fieldwiseDifferentiable struct TangentVector : Differentiable, VectorNumeric { typealias TangentVector = DifferentiableSubset.TangentVector var w: Float @@ -41,12 +39,10 @@ SeparateTangentTypeTests.test("Initialization") { expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) } -// FIXME(SR-9602): If `TangentVector` is not marked -// `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. -// SeparateTangentTypeTests.test("SomeArithmetics") { -// let x = DifferentiableSubset(w: 0, b: 1, flag: false) -// let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) } -// expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) -// } +SeparateTangentTypeTests.test("SomeArithmetics") { + let x = DifferentiableSubset(w: 0, b: 1, flag: false) + let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) } + expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) +} runAllTests() diff --git a/test/AutoDiff/simple_model.swift b/test/AutoDiff/simple_model.swift index f0548f461841b..623f1263cd22d 100644 --- a/test/AutoDiff/simple_model.swift +++ b/test/AutoDiff/simple_model.swift @@ -6,17 +6,11 @@ import StdlibUnittest var SimpleModelTests = TestSuite("SimpleModel") struct DenseLayer : Equatable { - @differentiable(wrt: self, vjp: vjpW) + @differentiable let w: Float - func vjpW() -> (Float, (Float) -> DenseLayer) { - return (w, { dw in DenseLayer(w: dw, b: 0) } ) - } - @differentiable(wrt: self, vjp: vjpB) + @differentiable let b: Float - func vjpB() -> (Float, (Float) -> DenseLayer) { - return (b, { db in DenseLayer(w: 0, b: db) } ) - } } extension DenseLayer : Differentiable, VectorNumeric { @@ -46,23 +40,14 @@ extension DenseLayer { } struct Model : Equatable { - @differentiable(wrt: self, vjp: vjpL1) + @differentiable let l1: DenseLayer - func vjpL1() -> (DenseLayer, (DenseLayer) -> Model) { - return (l1, { dl1 in Model(l1: dl1, l2: DenseLayer.zero, l3: DenseLayer.zero) } ) - } - @differentiable(wrt: self, vjp: vjpL2) + @differentiable let l2: DenseLayer - func vjpL2() -> (DenseLayer, (DenseLayer) -> Model) { - return (l2, { dl2 in Model(l1: DenseLayer.zero, l2: dl2, l3: DenseLayer.zero) } ) - } - @differentiable(wrt: self, vjp: vjpL3) + @differentiable let l3: DenseLayer - func vjpL3() -> (DenseLayer, (DenseLayer) -> Model) { - return (l3, { dl3 in Model(l1: DenseLayer.zero, l2: DenseLayer.zero, l3: dl3) } ) - } } extension Model : Differentiable, VectorNumeric { diff --git a/test/AutoDiff/witness_table_silgen.swift b/test/AutoDiff/witness_table_silgen.swift index 9a7073450a254..e52faf389813f 100644 --- a/test/AutoDiff/witness_table_silgen.swift +++ b/test/AutoDiff/witness_table_silgen.swift @@ -11,7 +11,6 @@ protocol Proto : Differentiable { func function3(_ x: Float, _ y: Double) -> Double } -@_fieldwiseDifferentiable struct S : Proto, VectorNumeric { static var zero: S { return S(p: 0) } typealias Scalar = Float @@ -21,11 +20,8 @@ struct S : Proto, VectorNumeric { typealias TangentVector = S - @differentiable(wrt: (self), vjp: vjpP) + @differentiable let p: Float - func vjpP() -> (Float, (Float) -> S) { - return (p, { dp in S(p: dp) }) - } @differentiable(wrt: (x, y)) func function1(_ x: Float, _ y: Double) -> Float {