diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 6b766aa793324..67ca13e6ed86e 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -401,8 +401,8 @@ DECL_ATTR(differentiable, Differentiable, AllowMultipleAttributes, 83) DECL_ATTR(differentiating, Differentiating, - OnFunc | LongAttribute | AllowMultipleAttributes, - 84) + OnFunc | LongAttribute | AllowMultipleAttributes | + NotSerialized, 84) SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable, OnAccessor | OnFunc | OnConstructor | OnSubscript, /* Not serialized */ 85) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 0b96f0687fd50..143bca4fba5e4 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1497,9 +1497,9 @@ class DifferentiableAttr final void setRequirements(ASTContext &context, ArrayRef requirements); FuncDecl *getJVPFunction() const { return JVPFunction; } - void setJVPFunction(FuncDecl *decl) { JVPFunction = decl; } + void setJVPFunction(FuncDecl *decl); FuncDecl *getVJPFunction() const { return VJPFunction; } - void setVJPFunction(FuncDecl *decl) { VJPFunction = decl; } + void setVJPFunction(FuncDecl *decl); bool parametersMatch(const DifferentiableAttr &other) const { assert(ParameterIndices && other.ParameterIndices); diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index fb6ad365faf1d..e465f4da131e2 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -1279,6 +1279,18 @@ void DifferentiableAttr::setRequirements(ASTContext &context, Requirements = context.AllocateCopy(requirements); } +void DifferentiableAttr::setJVPFunction(FuncDecl *decl) { + JVPFunction = decl; + if (decl && !JVP) + JVP = {decl->getFullName(), DeclNameLoc(decl->getNameLoc())}; +} + +void DifferentiableAttr::setVJPFunction(FuncDecl *decl) { + VJPFunction = decl; + if (decl && !VJP) + VJP = {decl->getFullName(), DeclNameLoc(decl->getNameLoc())}; +} + void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D, ModuleDecl *prettyPrintInModule) const { StreamPrinter P(OS); diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 64f1cb4fd89d8..4a667378be8c0 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -2590,31 +2590,6 @@ ModuleFile::getDeclCheckedImpl(DeclID DID) { break; } - // SWIFT_ENABLE_TENSORFLOW - case decls_block::Differentiating_DECL_ATTR: { - bool isImplicit; - uint64_t origNameId; - DeclID origDeclId; - ArrayRef parameters; - - serialization::decls_block::DifferentiatingDeclAttrLayout::readRecord( - scratch, isImplicit, origNameId, origDeclId, parameters); - - DeclNameWithLoc origName = {getIdentifier(origNameId), DeclNameLoc()}; - FuncDecl *origDecl = cast(getDecl(origDeclId)); - - llvm::SmallBitVector parametersBitVector(parameters.size()); - for (unsigned i : indices(parameters)) - parametersBitVector[i] = parameters[i]; - auto *indices = AutoDiffParameterIndices::get(parametersBitVector, ctx); - - auto diffAttr = DifferentiatingAttr::create( - ctx, isImplicit, SourceLoc(), SourceRange(), origName, indices); - diffAttr->setOriginalFunction(origDecl); - Attr = diffAttr; - break; - } - case decls_block::DynamicReplacement_DECL_ATTR: { bool isImplicit; uint64_t numArgs; diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index b667c47c99e62..54b93b4ca7275 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2193,6 +2193,8 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) { case DAK_RestatedObjCConformance: case DAK_ClangImporterSynthesizedType: case DAK_PrivateImport: + // SWIFT_ENABLE_TENSORFLOW + case DAK_Differentiating: llvm_unreachable("cannot serialize attribute"); case DAK_Count: @@ -2397,26 +2399,6 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) { writeGenericRequirements(attr->getRequirements(), DeclTypeAbbrCodes); return; } - - // SWIFT_ENABLE_TENSORFLOW - case DAK_Differentiating: { - auto abbrCode = DeclTypeAbbrCodes[DifferentiatingDeclAttrLayout::Code]; - auto attr = cast(DA); - IdentifierID origName = - addDeclBaseNameRef(attr->getOriginal().Name.getBaseName()); - DeclID origRef = addDeclRef(attr->getOriginalFunction()); - - auto paramIndices = attr->getParameterIndices(); - assert(paramIndices && "Checked parameter indices must be resolved"); - SmallVector indices; - for (unsigned i : swift::indices(paramIndices->parameters)) - indices.push_back(paramIndices->parameters[i]); - - DifferentiatingDeclAttrLayout::emitRecord( - Out, ScratchRecord, abbrCode, attr->isImplicit(), origName, origRef, - indices); - return; - } } } diff --git a/test/Serialization/differentiating_attr.swift b/test/Serialization/differentiating_attr.swift index 8b51df9cf0b61..e50b2c5470792 100644 --- a/test/Serialization/differentiating_attr.swift +++ b/test/Serialization/differentiating_attr.swift @@ -7,27 +7,24 @@ // BCANALYZER-NOT: UnknownCode +// CHECK: @differentiable(wrt: x, jvp: jvpAddWrtX) +// CHECK-NEXT: @differentiable(vjp: vjpAdd) func add(x: Float, y: Float) -> Float { return x + y } -// CHECK: @differentiating(add, wrt: x) -// CHECK-NEXT: func jvpAddWrtX(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) @differentiating(add, wrt: x) -func jvpAddWrtX(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) { +func jvpAddWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) { return (x + y, { $0 }) } -// CHECK: @differentiating(add) -// CHECK-NEXT: func vjpAddWrtXY(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) @differentiating(add) -func vjpAddWrtXY(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { +func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { return (x + y, { ($0, $0) }) } +// CHECK: @differentiable(vjp: vjpGeneric where T : Differentiable) func generic(x: T) -> T { return x } -// CHECK: @differentiating(generic) -// CHECK-NEXT: func vjpGeneric(x: T) -> (value: T, pullback: (T.CotangentVector) -> T.CotangentVector) @differentiating(generic) func vjpGeneric(x: T) -> (value: T, pullback: (T.CotangentVector) -> T.CotangentVector) where T : Numeric, T : Differentiable @@ -36,21 +33,21 @@ func vjpGeneric(x: T) -> (value: T, pullback: (T.CotangentVector) -> T.Cotang } protocol InstanceMethod : Differentiable { + // CHECK: @differentiable(vjp: vjpFoo) func foo(_ x: Self) -> Self + // CHECK: @differentiable(jvp: jvpBarWrt where T == T.TangentVector) func bar(_ x: T) -> Self } extension InstanceMethod { - // CHECK: @differentiating(foo) - // CHECK-NEXT: func vjpFoo(x: Self) -> (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)) @differentiating(foo) func vjpFoo(x: Self) -> (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)) { return (x, { ($0, $0) }) } - // CHECK: @differentiating(bar) - // CHECK-NEXT: func jvpBarWrt(_ x: T) -> (value: Self, differential: (Self.TangentVector, T.TangentVector) -> Self.TangentVector) where T : Differentiable @differentiating(bar, wrt: (self, x)) - func jvpBarWrt(_ x: T) -> (value: Self, differential: (Self.TangentVector, T.TangentVector) -> Self.TangentVector) { + func jvpBarWrt(_ x: T) -> (value: Self, differential: (Self.TangentVector, T) -> Self.TangentVector) + where T == T.TangentVector + { return (self, { dself, dx in dself }) } }