diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 133efb7408268..75cddf6bdd90f 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3001,18 +3001,18 @@ ERROR(differentiable_attr_protocol_req_assoc_func,none, ERROR(differentiable_attr_stored_property_variable_unsupported,none, "'@differentiable' attribute on stored property cannot specify " "'jvp:' or 'vjp:'", ()) -ERROR(differentiable_attr_class_member_no_dynamic_self,none, - "'@differentiable' attribute cannot be declared on class methods " +ERROR(differentiable_attr_class_member_dynamic_self_result_unsupported,none, + "'@differentiable' attribute cannot be declared on class members " "returning 'Self'", ()) -// TODO(TF-654): Remove when differentiation supports class initializers. -ERROR(differentiable_attr_class_init_not_yet_supported,none, - "'@differentiable' attribute does not yet support class initializers", - ()) +ERROR(differentiable_attr_nonfinal_class_init_unsupported,none, + "'@differentiable' attribute cannot be declared on 'init' in a non-final " + "class; consider making %0 final", (Type)) ERROR(differentiable_attr_empty_where_clause,none, "empty 'where' clause in '@differentiable' attribute", ()) // SWIFT_ENABLE_TENSORFLOW ERROR(differentiable_attr_nongeneric_trailing_where,none, - "trailing 'where' clause in '@differentiable' attribute of non-generic function %0", (DeclName)) + "trailing 'where' clause in '@differentiable' attribute of non-generic " + "function %0", (DeclName)) ERROR(differentiable_attr_where_clause_for_nongeneric_original,none, "'where' clause is valid only when original function is generic %0", (DeclName)) @@ -3049,6 +3049,12 @@ ERROR(derivative_attr_not_in_same_file_as_original,none, "derivative not in the same file as the original function", ()) ERROR(derivative_attr_original_stored_property_unsupported,none, "cannot register derivative for stored property %0", (DeclNameRef)) +ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none, + "cannot register derivative for class member %0 returning 'Self'", + (DeclNameRef)) +ERROR(derivative_attr_nonfinal_class_init_unsupported,none, + "cannot register derivative for 'init' in a non-final class; consider " + "making %0 final", (Type)) ERROR(derivative_attr_original_already_has_derivative,none, "a derivative already exists for %0", (DeclName)) NOTE(derivative_attr_duplicate_note,none, diff --git a/include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h b/include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h index 074a762b32eef..f3f7df3614915 100644 --- a/include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h +++ b/include/swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h @@ -434,6 +434,16 @@ class PullbackEmitter final : public SILInstructionVisitor { void visitUnconditionalCheckedCastAddrInst( UnconditionalCheckedCastAddrInst *uccai); + /// Handle `unchecked_ref_cast` instruction. + /// Original: y = unchecked_ref_cast x + /// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type) + void visitUncheckedRefCastInst(UncheckedRefCastInst *urci); + + /// Handle `upcast` instruction. + /// Original: y = upcast x + /// Adjoint: adj[x] += adj[y] (assuming x' and y' have the same type) + void visitUpcastInst(UpcastInst *ui); + #define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst); #undef NOT_DIFFERENTIABLE diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index 455c48ac220ad..4620ebbf3a1b3 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -2705,17 +2705,15 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion, loweredInterfaceType); // SWIFT_ENABLE_TENSORFLOW - // In the case of autodiff derivative functions, the above computations - // determine `silFnType` by first computing the derivative function type at - // the AST level and then lowering that. Unfortunately, the actual - // SILFunctionType for the function is determined by first lowering the - // function's AST type, and then computing the derivative function type at the - // SIL level. "Lowering" does not commute with "getting the autodiff - // associated type", so these two computations produce different results. - // Therefore `silFnType` is not the actual type of the function that - // `constant` refers to. + // For derivative functions, the above computations determine `silFnType` + // by first computing the derivative AST function type and then lowering it to + // SIL. Unfortunately, the expected derivative SIL function type is determined + // by first lowering the original function's AST type, and then computing its + // SIL derivative function type. "Lowering" does not commute with "getting the + // derivative type", so these two computations produce different results. + // Therefore, `silFnType` is not the expected SIL derivative function type. // - // We hackily fix this problem by redoing the computation in the right order. + // We fix this problem by performing the computation in the right order. if (auto *autoDiffFuncId = constant.autoDiffDerivativeFunctionIdentifier) { auto origFnConstantInfo = getConstantInfo( TypeExpansionContext::minimal(), constant.asAutoDiffOriginalFunction()); @@ -2725,6 +2723,7 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion, loweredIndices, /*resultIndex*/ 0, autoDiffFuncId->getKind(), *this, LookUpConformanceInModule(&M)); } + // SWIFT_ENABLE_TENSORFLOW END LLVM_DEBUG(llvm::dbgs() << "lowering type for constant "; constant.print(llvm::dbgs()); diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 8417067db1c1e..e3ef10c4cce24 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -3749,8 +3749,9 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( auto *thunk = fb.getOrCreateFunction( loc, name, customDerivativeFn->getLinkage(), thunkFnTy, IsBare, IsNotTransparent, customDerivativeFn->isSerialized(), - customDerivativeFn->isDynamicallyReplaceable(), customDerivativeFn->getEntryCount(), - IsThunk, customDerivativeFn->getClassSubclassScope()); + customDerivativeFn->isDynamicallyReplaceable(), + customDerivativeFn->getEntryCount(), IsThunk, + customDerivativeFn->getClassSubclassScope()); thunk->setInlineStrategy(AlwaysInline); if (!thunk->empty()) return thunk; @@ -3762,15 +3763,39 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk( thunkSGF.collectThunkParams(loc, params, &indirectResults); auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn); - auto fnRefType = - fnRef->getType().castTo(); + auto fnRefType = fnRef->getType().castTo(); // Collect thunk arguments, converting ownership. SmallVector arguments; for (auto *indRes : indirectResults) arguments.push_back(indRes); - forwardFunctionArguments(thunkSGF, loc, fnRefType, params, - arguments); + forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments); + + // Special support for thunking class initializer derivatives. + // + // User-defined custom derivatives take a metatype as the last parameter: + // - `$(Param0, Param1, ..., @thick Class.Type) -> (...)` + // But class initializers take an allocated instance as the last parameter: + // - `$(Param0, Param1, ..., @owned Class) -> (...)` + // + // Adjust forwarded arguments: + // - Pop the last `@owned Class` argument. + // - Create a `@thick Class.Type` value and pass it as the last argument. + auto *origAFD = + cast(originalFn->getDeclContext()->getAsDecl()); + if (isa(origAFD) && + SILDeclRef(origAFD, SILDeclRef::Kind::Initializer).mangle() == + originalFn->getName()) { + auto classArgument = arguments.pop_back_val(); + auto *classDecl = classArgument->getType().getClassOrBoundGenericClass(); + assert(classDecl && "Expected last argument to have class type"); + auto classMetatype = MetatypeType::get( + classDecl->getDeclaredInterfaceType(), MetatypeRepresentation::Thick); + auto canClassMetatype = classMetatype->getCanonicalType(); + auto *metatype = thunkSGF.B.createMetatype( + loc, SILType::getPrimitiveObjectType(canClassMetatype)); + arguments.push_back(metatype); + } // Apply function argument. auto apply = thunkSGF.emitApplyWithRethrow( loc, fnRef, /*substFnType*/ fnRef->getType(), diff --git a/lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp b/lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp index 334b20204d8b4..2a01ce66c1142 100644 --- a/lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp +++ b/lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp @@ -376,7 +376,8 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB, auto *tanField = cast(tanFieldLookup.front()); // Create a local allocation for the element adjoint buffer. auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); - auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType); + auto eltTanSILType = + remapType(SILType::getPrimitiveAddressType(eltTanType)); auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); builder.emitScopedBorrowOperation( loc, adjClass, [&](SILValue borrowedAdjClass) { @@ -1090,7 +1091,7 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint, auto arrayTanType = cast(arrayAdjoint->getType().getASTType()); auto arrayType = arrayTanType->getParent()->castTo(); auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType(); - auto eltTanSILType = SILType::getPrimitiveAddressType(eltTanType); + auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType)); // Get `function_ref` and generic signature of // `Array.TangentVector.subscript.getter`. auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct(); @@ -1602,12 +1603,11 @@ void PullbackEmitter::visitLoadOperation(SingleValueInstruction *inst) { void PullbackEmitter::visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc, SILValue origDest) { auto &adjBuf = getAdjointBuffer(bb, origDest); - auto bufType = remapType(adjBuf->getType()); auto adjVal = builder.emitLoadValueOperation(loc, adjBuf, LoadOwnershipQualifier::Take); recordTemporary(adjVal); addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc); - emitZeroIndirect(bufType.getASTType(), adjBuf, loc); + emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc); } void PullbackEmitter::visitStoreInst(StoreInst *si) { @@ -1672,6 +1672,26 @@ void PullbackEmitter::visitUnconditionalCheckedCastAddrInst( emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc()); } +void PullbackEmitter::visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { + auto *bb = urci->getParent(); + assert(urci->getOperand()->getType().isObject()); + assert(getRemappedTangentType(urci->getOperand()->getType()) == + getRemappedTangentType(urci->getType()) && + "Operand/result must have the same `TangentVector` type"); + auto adj = getAdjointValue(bb, urci); + addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc()); +} + +void PullbackEmitter::visitUpcastInst(UpcastInst *ui) { + auto *bb = ui->getParent(); + assert(ui->getOperand()->getType().isObject()); + assert(getRemappedTangentType(ui->getOperand()->getType()) == + getRemappedTangentType(ui->getType()) && + "Operand/result must have the same `TangentVector` type"); + auto adj = getAdjointValue(bb, ui); + addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc()); +} + #define NOT_DIFFERENTIABLE(INST, DIAG) \ void PullbackEmitter::visit##INST##Inst(INST##Inst *inst) { \ getContext().emitNondifferentiabilityError(inst, getInvoker(), \ diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index edf19b838eb14..d73f3089c6030 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3969,32 +3969,39 @@ llvm::Expected DifferentiableAttributeTypeCheckRequest::evaluate( return nullptr; } + // Diagnose if original function is an invalid class member. bool isOriginalClassMember = original->getDeclContext() && original->getDeclContext()->getSelfClassDecl(); - - // Diagnose if original function is an invalid class member. if (isOriginalClassMember) { - // Class methods returning dynamic `Self` are not supported. - // (For class methods, dynamic `Self` is supported only as the single - // result - tuple-returning JVPs/VJPs would not type-check.) - if (auto *originalFn = dyn_cast(original)) { - if (originalFn->hasDynamicSelfResult()) { - diags.diagnose(attr->getLocation(), - diag::differentiable_attr_class_member_no_dynamic_self); + auto *classDecl = original->getDeclContext()->getSelfClassDecl(); + assert(classDecl); + // Class members returning dynamic `Self` are not supported. + // Dynamic `Self` is supported only as a single top-level result for class + // members. JVP/VJP functions returning `(Self, ...)` tuples would not + // type-check. + bool diagnoseDynamicSelfResult = original->hasDynamicSelfResult(); + if (diagnoseDynamicSelfResult) { + // Diagnose class initializers in non-final classes. + if (isa(original)) { + if (!classDecl->isFinal()) { + diags.diagnose( + attr->getLocation(), + diag::differentiable_attr_nonfinal_class_init_unsupported, + classDecl->getDeclaredInterfaceType()); + attr->setInvalid(); + return nullptr; + } + } + // Diagnose all other declarations returning dynamic `Self`. + else { + diags.diagnose( + attr->getLocation(), + diag:: + differentiable_attr_class_member_dynamic_self_result_unsupported); attr->setInvalid(); return nullptr; } } - - // TODO(TF-654): Class initializers are not yet supported. - // Extra JVP/VJP type calculation logic is necessary because classes have - // both allocators and initializers. - if (auto *initDecl = dyn_cast(original)) { - diags.diagnose(attr->getLocation(), - diag::differentiable_attr_class_init_not_yet_supported); - attr->setInvalid(); - return nullptr; - } } // Resolve the derivative generic signature. @@ -4284,6 +4291,38 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, return true; } } + // Diagnose if original function is an invalid class member. + bool isOriginalClassMember = + originalAFD->getDeclContext() && + originalAFD->getDeclContext()->getSelfClassDecl(); + if (isOriginalClassMember) { + auto *classDecl = originalAFD->getDeclContext()->getSelfClassDecl(); + assert(classDecl); + // Class members returning dynamic `Self` are not supported. + // Dynamic `Self` is supported only as a single top-level result for class + // members. JVP/VJP functions returning `(Self, ...)` tuples would not + // type-check. + bool diagnoseDynamicSelfResult = originalAFD->hasDynamicSelfResult(); + if (diagnoseDynamicSelfResult) { + // Diagnose class initializers in non-final classes. + if (isa(originalAFD)) { + if (!classDecl->isFinal()) { + diags.diagnose(attr->getLocation(), + diag::derivative_attr_nonfinal_class_init_unsupported, + classDecl->getDeclaredInterfaceType()); + return true; + } + } + // Diagnose all other declarations returning dynamic `Self`. + else { + diags.diagnose( + attr->getLocation(), + diag::derivative_attr_class_member_dynamic_self_result_unsupported, + DeclNameRef(originalAFD->getFullName())); + return true; + } + } + } attr->setOriginalFunction(originalAFD); // Get the resolved differentiability parameter indices. diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index a5f6c05debfe7..bcf9151b58583 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -1080,8 +1080,7 @@ class Super: Differentiable { var base: Float - // NOTE(TF-654): Class initializers are not yet supported. - // expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}} + // expected-error @+1 {{'@differentiable' attribute cannot be declared on 'init' in a non-final class; consider making 'Super' final}} @differentiable init(base: Float) { self.base = base @@ -1124,7 +1123,7 @@ class Super: Differentiable { func instanceMethod(_ x: Float, y: T) -> Float { x } // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'@differentiable' attribute cannot be declared on class methods returning 'Self'}} + // expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}} @differentiable(vjp: vjpDynamicSelfResult) func dynamicSelfResult() -> Self { self } @@ -1148,6 +1147,18 @@ class Sub: Super { override func testSuperclassDerivatives(_ x: Float) -> Float { x } } +final class FinalClass: Differentiable { + typealias TangentVector = DummyTangentVector + func move(along _: TangentVector) {} + + var base: Float + + @differentiable + init(base: Float) { + self.base = base + } +} + // Test unsupported accessors: `set`, `_read`, `_modify`. struct UnsupportedAccessors: Differentiable { diff --git a/test/AutoDiff/downstream/class_differentiation.swift b/test/AutoDiff/downstream/class_differentiation.swift index 83505d73e9af9..c88cdab1a20e4 100644 --- a/test/AutoDiff/downstream/class_differentiation.swift +++ b/test/AutoDiff/downstream/class_differentiation.swift @@ -9,17 +9,23 @@ import DifferentiationUnittest var ClassTests = TestSuite("ClassDifferentiation") ClassTests.test("TrivialMember") { - class C: Differentiable { + final class C: Differentiable { @differentiable var float: Float @noDerivative final var noDerivative: Float = 1 + @differentiable init(_ float: Float) { self.float = float } + @differentiable + convenience init(convenience x: Float) { + self.init(x) + } + @differentiable func method(_ x: Float) -> Float { x * float @@ -44,6 +50,7 @@ ClassTests.test("TrivialMember") { } // Test class initializer differentiation. expectEqual(10, pullback(at: 3, in: { C($0) })(.init(float: 10))) + expectEqual(10, pullback(at: 3, in: { C(convenience: $0) })(.init(float: 10))) // Test class method differentiation. expectEqual((.init(float: 3), 10), gradient(at: C(10), 3, in: { c, x in c.method(x) })) expectEqual(.init(float: 0), gradient(at: C(10), in: { c in c.testNoDerivative() })) @@ -52,10 +59,11 @@ ClassTests.test("TrivialMember") { } ClassTests.test("NontrivialMember") { - class C: Differentiable { + final class C: Differentiable { @differentiable var float: Tracked + @differentiable init(_ float: Tracked) { self.float = float } @@ -84,14 +92,35 @@ ClassTests.test("NontrivialMember") { gradient(at: C(10), C(20), in: { c1, c2 in C.controlFlow(c1, c2, true) })) } +ClassTests.test("GenericNontrivialMember") { + final class C: Differentiable where T == T.TangentVector { + @differentiable + var x: Tracked + + @differentiable + init(_ x: T) { + self.x = Tracked(x) + } + + @differentiable + convenience init(convenience x: T) { + self.init(x) + } + } + // Test class initializer differentiation. + expectEqual(10, pullback(at: 3, in: { C($0) })(.init(x: 10))) + expectEqual(10, pullback(at: 3, in: { C(convenience: $0) })(.init(x: 10))) +} + // TF-1149: Test class with loadable type but address-only `TangentVector` type. // TODO(TF-1149): Uncomment when supported. /* ClassTests.test("AddressOnlyTangentVector") { - class C: Differentiable { + final class C: Differentiable { @differentiable var stored: T + @differentiable init(_ stored: T) { self.stored = stored } diff --git a/test/AutoDiff/downstream/class_method.swift b/test/AutoDiff/downstream/class_method.swift index f3220d9cfa83c..d9bff362b1079 100644 --- a/test/AutoDiff/downstream/class_method.swift +++ b/test/AutoDiff/downstream/class_method.swift @@ -7,7 +7,7 @@ import DifferentiationUnittest var ClassMethodTests = TestSuite("ClassMethods") ClassMethodTests.test("Final") { - final class Final : Differentiable { + final class Final: Differentiable { func method(_ x: Tracked) -> Tracked { return x * x } @@ -35,14 +35,14 @@ ClassMethodTests.test("Simple") { } } - class SubOverride : Super { + class SubOverride: Super { @differentiable(wrt: x) override func f(_ x: Tracked) -> Tracked { return 3 * x } } - class SubOverrideCustomDerivatives : Super { + class SubOverrideCustomDerivatives: Super { @differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2) override func f(_ x: Tracked) -> Tracked { return 3 * x @@ -64,19 +64,14 @@ ClassMethodTests.test("Simple") { } ClassMethodTests.test("SimpleWrtSelf") { - class Super : Differentiable { + class Super: Differentiable { var base: Tracked // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial. var _nontrivial: [Tracked] = [] - // TODO(TF-654): Remove attribute when differentiation supports class initializers. - // @differentiable(vjp: vjpInit) - required init(base: Tracked) { + init(base: Tracked) { self.base = base } - static func vjpInit(base: Tracked) -> (Super, (TangentVector) -> Tracked) { - return (Super(base: base), { x in x.base }) - } @differentiable(wrt: (self, x), jvp: jvpf, vjp: vjpf) func f(_ x: Tracked) -> Tracked { @@ -97,14 +92,36 @@ ClassMethodTests.test("SimpleWrtSelf") { } } - class SubOverride : Super { + final class SubOverride: Super { + @differentiable + override init(base: Tracked) { + super.init(base: base) + } + + // Note: `TangentVector` type is unused. + // There is no way to customize `SubOverride: Differentiable` conformance. + // The conformance is always inherited from `Super`. + struct TangentVector: Differentiable & AdditiveArithmetic { + var base: Float + } + @differentiable(wrt: (self, x)) override func f(_ x: Tracked) -> Tracked { return 3 * x } } - class SubOverrideCustomDerivatives : Super { + final class SubOverrideCustomDerivatives: Super { + @differentiable(vjp: vjpInit) + override init(base: Tracked) { + super.init(base: base) + } + static func vjpInit(base: Tracked) -> ( + SubOverrideCustomDerivatives, (Super.TangentVector) -> Tracked + ) { + return (SubOverrideCustomDerivatives(base: base), { x in x.base * 2 }) + } + @differentiable(wrt: (self, x)) @differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2) override func f(_ x: Tracked) -> Tracked { @@ -118,13 +135,10 @@ ClassMethodTests.test("SimpleWrtSelf") { } } - // TODO(TF-654): Uncomment when differentiation supports class initializers. - /* let v = Super.TangentVector(base: 100, _nontrivial: []) expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v)) expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v)) - expectEqual(100, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v)) - */ + expectEqual(200, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v)) // `valueWithGradient` is not used because nested tuples cannot be compared // with `expectEqual`. @@ -140,7 +154,7 @@ ClassMethodTests.test("SimpleWrtSelf") { } ClassMethodTests.test("Generics") { - class Super where T == T.TangentVector { + class Super where T == T.TangentVector { @differentiable(wrt: x, jvp: jvpf, vjp: vjpf) func f(_ x: Tracked) -> Tracked { return Tracked(2) * x @@ -157,21 +171,21 @@ ClassMethodTests.test("Generics") { } } - class SubOverride : Super where T == T.TangentVector { + class SubOverride: Super where T == T.TangentVector { @differentiable(wrt: x) override func f(_ x: Tracked) -> Tracked { return x } } - class SubSpecializeOverride : Super { + class SubSpecializeOverride: Super { @differentiable(wrt: x) override func f(_ x: Tracked) -> Tracked { return 3 * x } } - class SubOverrideCustomDerivatives : Super + class SubOverrideCustomDerivatives: Super where T == T.TangentVector { @differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2) override func f(_ x: Tracked) -> Tracked { @@ -189,7 +203,7 @@ ClassMethodTests.test("Generics") { } } - class SubSpecializeOverrideCustomDerivatives : Super { + class SubSpecializeOverrideCustomDerivatives: Super { @differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2) override func f(_ x: Tracked) -> Tracked { return 3 * x @@ -206,7 +220,7 @@ ClassMethodTests.test("Generics") { } } - func classValueWithGradient( + func classValueWithGradient( _ c: Super ) -> (T, T) where T == T.TangentVector { let (x,y) = valueWithGradient(at: Tracked(1), in: { @@ -221,19 +235,14 @@ ClassMethodTests.test("Generics") { } ClassMethodTests.test("Methods") { - class Super : Differentiable { + class Super: Differentiable { var base: Tracked // Dummy to make `Super.AllDifferentiableVariables` be nontrivial. var _nontrivial: [Tracked] = [] - // TODO(TF-654): Remove attribute when differentiation supports class initializers. - // @differentiable(vjp: vjpInit) init(base: Tracked) { self.base = base } - static func vjpInit(base: Tracked) -> (Super, (TangentVector) -> Tracked) { - return (Super(base: base), { x in x.base }) - } @differentiable(vjp: vjpSquared) func squared() -> Tracked { base * base } @@ -246,7 +255,7 @@ ClassMethodTests.test("Methods") { } } - class Sub1 : Super { + class Sub1: Super { @differentiable(vjp: vjpSquared2) override func squared() -> Tracked { base * base } final func vjpSquared2() -> (Tracked, (Tracked) -> TangentVector) { @@ -261,15 +270,8 @@ ClassMethodTests.test("Methods") { return valueWithGradient(at: c) { c in c.squared() } } - // TODO(TF-654, TF-645): Uncomment when differentiation supports class initializers or `ref_element_addr`. - // expectEqual(4, gradient(at: 2) { x in Super(base: x).squared() }) - - // TODO(TF-647): Handle `unchecked_ref_cast` in `Sub1.init` during pullback generation. - // FIXME: `Super.init` VJP type mismatch for empty `Super.AllDifferentiableVariables`: - // SIL verification failed: VJP type does not match expected VJP type - // $@convention(method) (Tracked, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Tracked) - // $@convention(method) (Tracked, @owned Super) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Tracked) - // expectEqual(4, gradient(at: 2) { x in Sub1(base: x).squared() }) + expectEqual(4, gradient(at: 2) { x in Super(base: x).squared() }) + expectEqual(4, gradient(at: 2) { x in Sub1(base: x).squared() }) expectEqual(Super.TangentVector(base: 4, _nontrivial: []), gradient(at: Super(base: 2)) { foo in foo.squared() }) @@ -278,15 +280,11 @@ ClassMethodTests.test("Methods") { } ClassMethodTests.test("Properties") { - class Super : Differentiable { + class Super: Differentiable { + @differentiable var base: Tracked - // TODO(TF-654): Remove attribute when differentiation supports class initializers. - // @differentiable(vjp: vjpInit) init(base: Tracked) { self.base = base } - static func vjpInit(base: Tracked) -> (Super, (TangentVector) -> Tracked) { - return (Super(base: base), { x in x.base }) - } @differentiable(vjp: vjpSquared) var squared: Tracked { base * base } @@ -297,22 +295,16 @@ ClassMethodTests.test("Properties") { } } - class Sub1 : Super { - // FIXME(TF-625): Crash due to `Super.AllDifferentiableVariables` abstraction pattern mismatch. - // SIL verification failed: vtable entry for #Super.squared!getter.1.jvp.1.S must be ABI-compatible - // ABI incompatible return values - // @convention(method) (@guaranteed Super) -> (Tracked, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Tracked) - // @convention(method) (@guaranteed Sub1) -> (Tracked, @owned @callee_guaranteed (Super.AllDifferentiableVariables) -> Tracked) - // @differentiable - // override var squared: Tracked { base * base } + class Sub1: Super { + @differentiable + override var squared: Tracked { base * base } } func classValueWithGradient(_ c: Super) -> (Tracked, Super.TangentVector) { return valueWithGradient(at: c) { c in c.squared } } - // TODO(TF-654, TF-645): Uncomment when differentiation supports class initializers or `ref_element_addr`. - // expectEqual(4, gradient(at: 2) { x in Super(base: x).squared }) + expectEqual(4, gradient(at: 2) { x in Super(base: x).squared }) expectEqual(Super.TangentVector(base: 4), gradient(at: Super(base: 2)) { foo in foo.squared }) } diff --git a/test/AutoDiff/downstream/derivative_attr_type_checking.swift b/test/AutoDiff/downstream/derivative_attr_type_checking.swift index 5838e5e8c67a2..79abde03de09d 100644 --- a/test/AutoDiff/downstream/derivative_attr_type_checking.swift +++ b/test/AutoDiff/downstream/derivative_attr_type_checking.swift @@ -448,22 +448,43 @@ func two7(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Flo // Test class methods. -class Super { +class Super: Differentiable { + var float: Float + + init(_ float: Float) { + self.float = float + } + + // expected-error @+1 {{cannot register derivative for 'init' in a non-final class; consider making 'Super' final}} + @derivative(of: init) + static func vjpInit(_ float: Float) -> (value: Super, pullback: (TangentVector) -> Float) { + return (Super(float), { v in v.float }) + } + @differentiable func foo(_ x: Float) -> Float { return x } - @derivative(of: foo) + @derivative(of: foo, wrt: x) func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { return (foo(x), { v in v }) } } -class Sub : Super { +final class Sub : Super { + override init(_ float: Float) { + self.float = float + } + + @derivative(of: init) + static func vjpSubInit(_ float: Float) -> (value: Sub, pullback: (TangentVector) -> Float) { + return (Sub(float), { v in v.float }) + } + // TODO(TF-649): Enable `@derivative` to override original functions from superclass. // expected-error @+1 {{'foo' is not defined in the current type context}} - @derivative(of: foo) + @derivative(of: foo, wrt: x) override func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { return (foo(x), { v in v }) } diff --git a/test/AutoDiff/downstream/differentiable_attr_type_checking.swift b/test/AutoDiff/downstream/differentiable_attr_type_checking.swift index b338ff13588c8..058876d055df3 100644 --- a/test/AutoDiff/downstream/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/downstream/differentiable_attr_type_checking.swift @@ -1030,8 +1030,7 @@ extension ProtocolRequirementUnsupported { class Super : Differentiable { var base: Float - // NOTE(TF-654): Class initializers are not yet supported. - // expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}} + // expected-error @+1 {{'@differentiable' attribute cannot be declared on 'init' in a non-final class; consider making 'Super' final}} @differentiable init(base: Float) { self.base = base @@ -1074,10 +1073,14 @@ class Super : Differentiable { func instanceMethod(_ x: Float, y: T) -> Float { x } // expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}} - // expected-error @+1 {{'@differentiable' attribute cannot be declared on class methods returning 'Self'}} + // expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}} @differentiable(vjp: vjpDynamicSelfResult) func dynamicSelfResult() -> Self { self } + // expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}} + @differentiable + var testDynamicSelfProperty: Self { self } + // TODO(TF-632): Fix "'TangentVector' is not a member type of 'Self'" diagnostic. // The underlying error should appear instead: // "covariant 'Self' can only appear at the top level of method result type". @@ -1098,6 +1101,15 @@ class Sub : Super { override func testSuperclassDerivatives(_ x: Float) -> Float { x } } +final class FinalClass: Differentiable { + var base: Float + + @differentiable + init(base: Float) { + self.base = base + } +} + // Test unsupported accessors: `set`, `_read`, `_modify`. struct UnsupportedAccessors: Differentiable { diff --git a/test/AutoDiff/downstream/differentiation_transform_diagnostics.swift b/test/AutoDiff/downstream/differentiation_transform_diagnostics.swift index 3561184e5153a..5af8acdd40181 100644 --- a/test/AutoDiff/downstream/differentiation_transform_diagnostics.swift +++ b/test/AutoDiff/downstream/differentiation_transform_diagnostics.swift @@ -247,9 +247,6 @@ class MultipleDiffAttrsClass : Differentiable { func f(_ x: Float) -> Float { x } } func testMultipleDiffAttrsClass(_ c: C, _ x: Float) { - // TODO(TF-647): Handle differentiation of `upcast` instruction. - // expected-error @+2 {{function is not differentiable}} - // expected-note @+1 {{expression is not differentiable}} _ = gradient(at: c, x) { c, x in c.f(x) } _ = gradient(at: x) { x in c.f(x) } } diff --git a/test/AutoDiff/downstream/forward_mode_runtime.swift b/test/AutoDiff/downstream/forward_mode_runtime.swift index d346e8e7fdec9..1e94ab898d4a8 100644 --- a/test/AutoDiff/downstream/forward_mode_runtime.swift +++ b/test/AutoDiff/downstream/forward_mode_runtime.swift @@ -745,18 +745,11 @@ ForwardModeTests.test("SimpleWrtSelf") { // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial. var _nontrivial: [Float] = [] - // TODO(TF-654): Remove attribute when differentiation supports class initializers. - // @differentiable(vjp: vjpInit) + // FIXME(SR-12175): Fix forward-mode differentiation crash. + // @differentiable required init(base: Float) { self.base = base } - static func vjpInit(base: Float) -> (Super, (TangentVector) -> Float) { - return (Super(base: base), { x in x.base }) - } - - static func jvpInit(base: Float) -> (Super, (Float) -> TangentVector) { - return (Super(base: base), { x in TangentVector(base: x, _nontrivial: []) }) - } @differentiable(wrt: (self, x), jvp: jvpf, vjp: vjpf) func f(_ x: Float) -> Float { @@ -794,7 +787,7 @@ ForwardModeTests.test("SimpleWrtSelf") { } } - // TODO(TF-654): Uncomment when differentiation supports class initializers. + // FIXME(SR-12175): Fix forward-mode differentiation crash. // let v = Super.TangentVector(base: 100, _nontrivial: []) // expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v)) // expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v)) diff --git a/test/AutoDiff/downstream/vtable_sil.swift b/test/AutoDiff/downstream/vtable_sil.swift index 2639e251f3988..d2f46426f1145 100644 --- a/test/AutoDiff/downstream/vtable_sil.swift +++ b/test/AutoDiff/downstream/vtable_sil.swift @@ -10,14 +10,9 @@ class Super : Differentiable { // FIXME(TF-648): Dummy to make `Super.TangentVector` be nontrivial. var _nontrivial: [Float] = [] - // TODO(TF-654): Remove attribute when differentiation supports class initializers. - // @differentiable(vjp: vjpInit) init(base: Float) { self.base = base } - static func vjpInit(base: Float) -> (Super, (TangentVector) -> Float) { - return (Super(base: base), { x in x.base }) - } @differentiable var property: Float { base } @@ -54,14 +49,8 @@ class Super : Differentiable { } class Sub : Super { - // TODO(TF-654): Remove attribute when differentiation supports class initializers. - // @differentiable(vjp: vjpInit2) override init(base: Float) { super.init(base: base) - self.base = base - } - static func vjpInit2(base: Float) -> (Sub, (TangentVector) -> Float) { - return (Sub(base: base), { x in x.base }) } @differentiable