diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index d4bd8736e5070..4bc50ab8a3a58 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -3989,6 +3989,11 @@ void TypeChecker::validateDecl(ValueDecl *D) { assert(VD->hasInterfaceType()); } + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. + checkDeclDifferentiableAttributes(VD); + // SWIFT_ENABLE_TENSORFLOW END + // We're not really done with processing the signature yet, but // @objc checking requires the declaration to call itself validated // so that it can be considered as a witness. @@ -4118,8 +4123,10 @@ void TypeChecker::validateDecl(ValueDecl *D) { // FIXME: Roll all of this interface type computation into a request. FD->computeType(); - // TODO(TF-789): Figure out the proper way to typecheck these. + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. checkDeclDifferentiableAttributes(FD); + // SWIFT_ENABLE_TENSORFLOW END // Member functions need some special validation logic. if (FD->getDeclContext()->isTypeContext()) { @@ -4164,6 +4171,10 @@ void TypeChecker::validateDecl(ValueDecl *D) { typeCheckParameterList(CD->getParameters(), res, TypeResolverContext::AbstractFunctionDecl); CD->computeType(); + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. + checkDeclDifferentiableAttributes(CD); + // SWIFT_ENABLE_TENSORFLOW END break; } @@ -4196,6 +4207,10 @@ void TypeChecker::validateDecl(ValueDecl *D) { SF->markDeclWithOpaqueResultTypeAsValidated(SD); } } + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. + checkDeclDifferentiableAttributes(SD); + // SWIFT_ENABLE_TENSORFLOW END break; } diff --git a/test/AutoDiff/Inputs/differentiable_attr_other_module.swift b/test/AutoDiff/Inputs/differentiable_attr_other_module.swift new file mode 100644 index 0000000000000..5dd77d828cb70 --- /dev/null +++ b/test/AutoDiff/Inputs/differentiable_attr_other_module.swift @@ -0,0 +1,26 @@ +// Verify that `@differentiable` declarations can be differentiated from other +// modules. + +public struct Foo: Differentiable { + public var x: Float + + @differentiable + public init(_ x: Float) { + self.x = x + } + + @differentiable + public func method() -> Float { + x + } + + @differentiable + public var computedProperty: Float { + x + } + + @differentiable + public subscript() -> Float { + x + } +} diff --git a/test/AutoDiff/differentiable_attr_cross_module/main.swift b/test/AutoDiff/differentiable_attr_cross_module/main.swift new file mode 100644 index 0000000000000..7ce6da06cca2c --- /dev/null +++ b/test/AutoDiff/differentiable_attr_cross_module/main.swift @@ -0,0 +1,26 @@ +// Verify that `@differentiable` declarations can be differentiated from other +// modules. + +// RUN: %empty-directory(%t) +// RUN: %target-build-swift %S/../Inputs/differentiable_attr_other_module.swift %s -o /dev/null -lm +// NOTE(TF-892): `-lm` is necessary to prevent linker errors related to `ElementaryFunctions` on Ubuntu. + +@differentiable(wrt: x) +func testInitializer(_ x: Float) -> Float { + return Foo(x).x +} + +@differentiable(wrt: foo) +func testMethod(_ foo: Foo) -> Float { + return foo.method() +} + +@differentiable(wrt: foo) +func testComputedProperty(_ foo: Foo) -> Float { + return foo.computedProperty +} + +@differentiable(wrt: foo) +func testSubscript(_ foo: Foo) -> Float { + return foo[] +}