From 368d8b1463a11ed34f0de202e224e0946031495f Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 16:45:49 -0700 Subject: [PATCH 1/3] [AutoDiff] Type-check `@differentiable` attributes during validation. Type-check `@differentiable` attributes during `TypeChecker::validateDecl` for all relevant declaration kinds (initializers, subscripts, variables), not just function declarations. Resolves TF-888. TF-789 tracks proper request-based type-checking for `@differentiable` attribute. --- lib/Sema/TypeCheckDecl.cpp | 17 +++++++++++- .../differentiable_attr_other_module.swift | 26 +++++++++++++++++++ .../differentiable_attr_cross_module.swift | 25 ++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 test/AutoDiff/Inputs/differentiable_attr_other_module.swift create mode 100644 test/AutoDiff/differentiable_attr_cross_module.swift diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index d4bd8736e5070..4bc50ab8a3a58 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -3989,6 +3989,11 @@ void TypeChecker::validateDecl(ValueDecl *D) { assert(VD->hasInterfaceType()); } + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. + checkDeclDifferentiableAttributes(VD); + // SWIFT_ENABLE_TENSORFLOW END + // We're not really done with processing the signature yet, but // @objc checking requires the declaration to call itself validated // so that it can be considered as a witness. @@ -4118,8 +4123,10 @@ void TypeChecker::validateDecl(ValueDecl *D) { // FIXME: Roll all of this interface type computation into a request. FD->computeType(); - // TODO(TF-789): Figure out the proper way to typecheck these. + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. checkDeclDifferentiableAttributes(FD); + // SWIFT_ENABLE_TENSORFLOW END // Member functions need some special validation logic. if (FD->getDeclContext()->isTypeContext()) { @@ -4164,6 +4171,10 @@ void TypeChecker::validateDecl(ValueDecl *D) { typeCheckParameterList(CD->getParameters(), res, TypeResolverContext::AbstractFunctionDecl); CD->computeType(); + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. + checkDeclDifferentiableAttributes(CD); + // SWIFT_ENABLE_TENSORFLOW END break; } @@ -4196,6 +4207,10 @@ void TypeChecker::validateDecl(ValueDecl *D) { SF->markDeclWithOpaqueResultTypeAsValidated(SD); } } + // SWIFT_ENABLE_TENSORFLOW + // TODO(TF-789): Find proper way to type-check `@differentiable` attributes. + checkDeclDifferentiableAttributes(SD); + // SWIFT_ENABLE_TENSORFLOW END break; } diff --git a/test/AutoDiff/Inputs/differentiable_attr_other_module.swift b/test/AutoDiff/Inputs/differentiable_attr_other_module.swift new file mode 100644 index 0000000000000..5dd77d828cb70 --- /dev/null +++ b/test/AutoDiff/Inputs/differentiable_attr_other_module.swift @@ -0,0 +1,26 @@ +// Verify that `@differentiable` declarations can be differentiated from other +// modules. + +public struct Foo: Differentiable { + public var x: Float + + @differentiable + public init(_ x: Float) { + self.x = x + } + + @differentiable + public func method() -> Float { + x + } + + @differentiable + public var computedProperty: Float { + x + } + + @differentiable + public subscript() -> Float { + x + } +} diff --git a/test/AutoDiff/differentiable_attr_cross_module.swift b/test/AutoDiff/differentiable_attr_cross_module.swift new file mode 100644 index 0000000000000..7958c580b13fd --- /dev/null +++ b/test/AutoDiff/differentiable_attr_cross_module.swift @@ -0,0 +1,25 @@ +// Verify that `@differentiable` declarations can be differentiated from other +// modules. + +// RUN: %empty-directory(%t) +// RUN: %target-build-swift %S/../Inputs/differentiable_attr_other_module.swift %s -o /dev/null + +@differentiable(wrt: x) +func testInitializer(_ x: Float) -> Float { + return Foo(x).x +} + +@differentiable(wrt: foo) +func testMethod(_ foo: Foo) -> Float { + return foo.method() +} + +@differentiable(wrt: foo) +func testComputedProperty(_ foo: Foo) -> Float { + return foo.computedProperty +} + +@differentiable(wrt: foo) +func testSubscript(_ foo: Foo) -> Float { + return foo[] +} From ac17a00c4691bbd49bd1bdcf4193c94faac50157 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 10 Oct 2019 17:12:22 -0700 Subject: [PATCH 2/3] Fix test. The primary file must be named "main" to avoid a linkear error: ``` Undefined symbols for architecture x86_64: "_main", referenced from: implicit entry/start for main executable ld: symbol(s) not found for architecture x86_64 ``` --- .../main.swift} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/AutoDiff/{differentiable_attr_cross_module.swift => differentiable_attr_cross_module/main.swift} (100%) diff --git a/test/AutoDiff/differentiable_attr_cross_module.swift b/test/AutoDiff/differentiable_attr_cross_module/main.swift similarity index 100% rename from test/AutoDiff/differentiable_attr_cross_module.swift rename to test/AutoDiff/differentiable_attr_cross_module/main.swift From 9b9541c49738be69a1ff34dc93352c4178de600e Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 11 Oct 2019 01:28:58 +0000 Subject: [PATCH 3/3] Work around `ElementaryFunctions` linker error on Ubuntu. TF-892 tracks fixing the issue. --- test/AutoDiff/differentiable_attr_cross_module/main.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/AutoDiff/differentiable_attr_cross_module/main.swift b/test/AutoDiff/differentiable_attr_cross_module/main.swift index 7958c580b13fd..7ce6da06cca2c 100644 --- a/test/AutoDiff/differentiable_attr_cross_module/main.swift +++ b/test/AutoDiff/differentiable_attr_cross_module/main.swift @@ -2,7 +2,8 @@ // modules. // RUN: %empty-directory(%t) -// RUN: %target-build-swift %S/../Inputs/differentiable_attr_other_module.swift %s -o /dev/null +// RUN: %target-build-swift %S/../Inputs/differentiable_attr_other_module.swift %s -o /dev/null -lm +// NOTE(TF-892): `-lm` is necessary to prevent linker errors related to `ElementaryFunctions` on Ubuntu. @differentiable(wrt: x) func testInitializer(_ x: Float) -> Float {