diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index 05ac8cb913d52..77a7e868ef3a9 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -445,6 +445,9 @@ ERROR(constexpr_imported_func_not_onone, none, "imported constant evaluable " // Automatic differentiation diagnostics ERROR(autodiff_internal_swift_not_imported,none, "AD internal error: the Swift module is not imported", ()) +ERROR(autodiff_conversion_to_linear_function_not_supported,none, + "conversion to '@differentiable(linear)' function type is not yet " + "supported", ()) ERROR(autodiff_function_not_differentiable_error,none, "function is not differentiable", ()) ERROR(autodiff_expression_not_differentiable_error,none, diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 6cf576f188769..e0560fc63312c 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -1177,9 +1177,14 @@ ERROR(c_function_pointer_from_generic_function,none, "a C function pointer cannot be formed from a reference to a generic " "function", ()) // SWIFT_ENABLE_TENSORFLOW +// TODO(TF-908): Remove this diagnostic once linear-to-differentiable conversion +// is supported. +ERROR(unsupported_linear_to_differentiable_conversion,none, + "conversion from '@differentiable(linear)' to '@differentiable' is not " + "yet supported", ()) ERROR(invalid_differentiable_function_conversion_expr,none, - "a '@differentiable' function can only be formed from a reference to a " - "'func' or a literal closure", ()) + "a '@differentiable%select{|(linear)}0' function can only be formed from " + "a reference to a 'func' or a literal closure", (bool)) NOTE(invalid_differentiable_function_conversion_parameter,none, "did you mean to take a '%0' closure?", (StringRef)) ERROR(invalid_autoclosure_forwarding,none, diff --git a/include/swift/AST/Expr.h b/include/swift/AST/Expr.h index b75f6dedb2d73..bb8aaff53ad3e 100644 --- a/include/swift/AST/Expr.h +++ b/include/swift/AST/Expr.h @@ -3034,6 +3034,16 @@ class DifferentiableFunctionExpr : public ImplicitConversionExpr { } }; +class LinearFunctionExpr : public ImplicitConversionExpr { +public: + LinearFunctionExpr(Expr *subExpr, Type ty) + : ImplicitConversionExpr(ExprKind::LinearFunction, subExpr, ty) {} + + static bool classof(const Expr *E) { + return E->getKind() == ExprKind::LinearFunction; + } +}; + class DifferentiableFunctionExtractOriginalExpr : public ImplicitConversionExpr { public: @@ -3045,6 +3055,28 @@ class DifferentiableFunctionExtractOriginalExpr return E->getKind() == ExprKind::DifferentiableFunctionExtractOriginal; } }; + +class LinearFunctionExtractOriginalExpr : public ImplicitConversionExpr { +public: + LinearFunctionExtractOriginalExpr(Expr *subExpr, Type ty) + : ImplicitConversionExpr(ExprKind::LinearFunctionExtractOriginal, + subExpr, ty) {} + + static bool classof(const Expr *E) { + return E->getKind() == ExprKind::LinearFunctionExtractOriginal; + } +}; + +class LinearToDifferentiableFunctionExpr : public ImplicitConversionExpr { +public: + LinearToDifferentiableFunctionExpr(Expr *subExpr, Type ty) + : ImplicitConversionExpr( + ExprKind::LinearToDifferentiableFunction, subExpr, ty) {} + + static bool classof(const Expr *E) { + return E->getKind() == ExprKind::LinearToDifferentiableFunction; + } +}; // SWIFT_ENABLE_TENSORFLOW END /// Use an opaque type to abstract a value of the underlying concrete type. diff --git a/include/swift/AST/ExprNodes.def b/include/swift/AST/ExprNodes.def index 7aa566f0962c1..5492c78a82504 100644 --- a/include/swift/AST/ExprNodes.def +++ b/include/swift/AST/ExprNodes.def @@ -174,8 +174,11 @@ ABSTRACT_EXPR(ImplicitConversion, Expr) EXPR(UnderlyingToOpaque, ImplicitConversionExpr) // SWIFT_ENABLE_TENSORFLOW EXPR(DifferentiableFunction, ImplicitConversionExpr) + EXPR(LinearFunction, ImplicitConversionExpr) EXPR(DifferentiableFunctionExtractOriginal, ImplicitConversionExpr) - EXPR_RANGE(ImplicitConversion, Load, DifferentiableFunctionExtractOriginal) + EXPR(LinearFunctionExtractOriginal, ImplicitConversionExpr) + EXPR(LinearToDifferentiableFunction, ImplicitConversionExpr) + EXPR_RANGE(ImplicitConversion, Load, LinearToDifferentiableFunction) // SWIFT_ENABLE_TENSORFLOW END ABSTRACT_EXPR(ExplicitCast, Expr) ABSTRACT_EXPR(CheckedCast, ExplicitCastExpr) diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 225f0803d5fc8..9bcee7cea58a1 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -521,7 +521,7 @@ class SILBuilder { LinearFunctionInst *createLinearFunction( SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction, - Optional TransposeFunction) { + Optional TransposeFunction = None) { return insert(LinearFunctionInst::create( getModule(), getSILDebugLocation(Loc), ParameterIndices, OriginalFunction, TransposeFunction, hasOwnership())); diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index b710e6585e06a..192f41da8150f 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -2401,12 +2401,29 @@ class PrintExpr : public ExprVisitor { printRec(E->getSubExpr()); PrintWithColorRAII(OS, ParenthesisColor) << ')'; } + void visitLinearFunctionExpr(LinearFunctionExpr *E) { + printCommon(E, "linear_function") << '\n'; + printRec(E->getSubExpr()); + PrintWithColorRAII(OS, ParenthesisColor) << ')'; + } void visitDifferentiableFunctionExtractOriginalExpr( DifferentiableFunctionExtractOriginalExpr *E) { printCommon(E, "differentiable_function_extract_original") << '\n'; printRec(E->getSubExpr()); PrintWithColorRAII(OS, ParenthesisColor) << ')'; } + void visitLinearFunctionExtractOriginalExpr( + LinearFunctionExtractOriginalExpr *E) { + printCommon(E, "linear_function_extract_original") << '\n'; + printRec(E->getSubExpr()); + PrintWithColorRAII(OS, ParenthesisColor) << ')'; + } + void visitLinearToDifferentiableFunctionExpr( + LinearToDifferentiableFunctionExpr *E) { + printCommon(E, "linear_to_differentiable_function") << '\n'; + printRec(E->getSubExpr()); + PrintWithColorRAII(OS, ParenthesisColor) << ')'; + } // SWIFT_ENABLE_TENSORFLOW END void visitInOutExpr(InOutExpr *E) { diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp index 44267d27ec8ff..aadf403e120d7 100644 --- a/lib/AST/Expr.cpp +++ b/lib/AST/Expr.cpp @@ -352,7 +352,10 @@ ConcreteDeclRef Expr::getReferencedDecl() const { PASS_THROUGH_REFERENCE(UnevaluatedInstance, getSubExpr); // SWIFT_ENABLE_TENSORFLOW PASS_THROUGH_REFERENCE(DifferentiableFunction, getSubExpr); + PASS_THROUGH_REFERENCE(LinearFunction, getSubExpr); PASS_THROUGH_REFERENCE(DifferentiableFunctionExtractOriginal, getSubExpr); + PASS_THROUGH_REFERENCE(LinearFunctionExtractOriginal, getSubExpr); + PASS_THROUGH_REFERENCE(LinearToDifferentiableFunction, getSubExpr); // SWIFT_ENABLE_TENSORFLOW END PASS_THROUGH_REFERENCE(BridgeToObjC, getSubExpr); PASS_THROUGH_REFERENCE(BridgeFromObjC, getSubExpr); @@ -678,7 +681,10 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const { case ExprKind::UnevaluatedInstance: // SWIFT_ENABLE_TENSORFLOW case ExprKind::DifferentiableFunction: + case ExprKind::LinearFunction: case ExprKind::DifferentiableFunctionExtractOriginal: + case ExprKind::LinearFunctionExtractOriginal: + case ExprKind::LinearToDifferentiableFunction: // SWIFT_ENABLE_TENSORFLOW END case ExprKind::EnumIsCase: case ExprKind::ConditionalBridgeFromObjC: diff --git a/lib/SIL/OwnershipUtils.cpp b/lib/SIL/OwnershipUtils.cpp index b1e23a36b3f90..373e9342a4057 100644 --- a/lib/SIL/OwnershipUtils.cpp +++ b/lib/SIL/OwnershipUtils.cpp @@ -46,8 +46,8 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) { case SILNodeKind::DestructureTupleInst: // SWIFT_ENABLE_TENSORFLOW case SILNodeKind::DifferentiableFunctionInst: - case SILNodeKind::DifferentiableFunctionExtractInst: - // SWIFT_ENABLE_TENSORFLOW + case SILNodeKind::LinearFunctionInst: + // SWIFT_ENABLE_TENSORFLOW END return true; default: return false; @@ -62,6 +62,10 @@ bool swift::isGuaranteedForwardingValueKind(SILNodeKind kind) { case SILNodeKind::StructExtractInst: case SILNodeKind::OpenExistentialValueInst: case SILNodeKind::OpenExistentialBoxValueInst: + // SWIFT_ENABLE_TENSORFLOW + case SILNodeKind::DifferentiableFunctionExtractInst: + case SILNodeKind::LinearFunctionExtractInst: + // SWIFT_ENABLE_TENSORFLOW END return true; default: return isOwnershipForwardingValueKind(kind); diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp index 5fe0b23be7029..daf33f57f9e87 100644 --- a/lib/SILGen/SILGenExpr.cpp +++ b/lib/SILGen/SILGenExpr.cpp @@ -506,8 +506,14 @@ namespace { // SWIFT_ENABLE_TENSORFLOW RValue visitDifferentiableFunctionExpr(DifferentiableFunctionExpr *E, SGFContext C); + RValue visitLinearFunctionExpr(LinearFunctionExpr *E, SGFContext C); RValue visitDifferentiableFunctionExtractOriginalExpr( DifferentiableFunctionExtractOriginalExpr *E, SGFContext C); + RValue visitLinearFunctionExtractOriginalExpr( + LinearFunctionExtractOriginalExpr *E, SGFContext C); + RValue visitLinearToDifferentiableFunctionExpr( + LinearToDifferentiableFunctionExpr *E, SGFContext C); + // SWIFT_ENABLE_TENSORFLOW END }; } // end anonymous namespace @@ -5436,6 +5442,15 @@ RValue RValueEmitter::visitDifferentiableFunctionExpr( return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(diffFunc)); } +RValue RValueEmitter::visitLinearFunctionExpr( + LinearFunctionExpr *E, SGFContext C) { + auto origFunc = SGF.emitRValueAsSingleValue(E->getSubExpr()); + auto destTy = SGF.getLoweredType(E->getType()).castTo(); + auto *diffFunc = SGF.B.createLinearFunction( + E, destTy->getDifferentiationParameterIndices(), origFunc.forward(SGF)); + return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(diffFunc)); +} + RValue RValueEmitter::visitDifferentiableFunctionExtractOriginalExpr( DifferentiableFunctionExtractOriginalExpr *E, SGFContext C) { auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr()); @@ -5445,6 +5460,22 @@ RValue RValueEmitter::visitDifferentiableFunctionExtractOriginalExpr( auto ownedOrigFunc = SGF.B.emitCopyValueOperation(E, borrowedOrigFunc); return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(ownedOrigFunc)); } + +RValue RValueEmitter::visitLinearFunctionExtractOriginalExpr( + LinearFunctionExtractOriginalExpr *E, SGFContext C) { + auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr()); + auto borrowedDiffFunc = diffFunc.borrow(SGF, E); + auto *borrowedOrigFunc = SGF.B.createLinearFunctionExtract( + E, LinearFunctionExtractee::Original, borrowedDiffFunc.getValue()); + auto ownedOrigFunc = SGF.B.emitCopyValueOperation(E, borrowedOrigFunc); + return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(ownedOrigFunc)); +} + +RValue RValueEmitter::visitLinearToDifferentiableFunctionExpr( + LinearToDifferentiableFunctionExpr *E, SGFContext C) { + // TODO: Implement this. + llvm_unreachable("Unsupported!"); +} // SWIFT_ENABLE_TENSORFLOW END RValue RValueEmitter::visitTapExpr(TapExpr *E, SGFContext C) { diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 7c5646392535c..54e10068964c9 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -8773,6 +8773,8 @@ void Differentiation::run() { // A global differentiation context. ADContext context(*this); + bool errorOccurred = false; + // Register all `@differentiable` attributes and `differentiable_function` // instructions in the module that trigger differentiation. for (SILFunction &f : module) { @@ -8783,10 +8785,18 @@ void Differentiation::run() { context.getInvokers().insert({diffAttr, invoker}); continue; } - for (SILBasicBlock &bb : f) - for (SILInstruction &i : bb) + for (SILBasicBlock &bb : f) { + for (SILInstruction &i : bb) { if (auto *dfi = dyn_cast(&i)) context.getDifferentiableFunctionInsts().push_back(dfi); + else if (auto *lfi = dyn_cast(&i)) { + astCtx.Diags.diagnose( + lfi->getLoc().getSourceLoc(), + diag::autodiff_conversion_to_linear_function_not_supported); + errorOccurred = true; + } + } + } } // If nothing has triggered differentiation, there's nothing to do. @@ -8802,8 +8812,6 @@ void Differentiation::run() { return; } - bool errorOccurred = false; - // Process all `[differentiable]` attributes. for (auto invokerPair : context.getInvokers()) { auto *attr = invokerPair.first; diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 6e98c9ca25143..fddaa57523a8e 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -5883,9 +5883,21 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs, auto &tc = cs.getTypeChecker(); Type fromType = cs.getType(expr); auto fromFnType = fromType->getAs(); + auto isToTypeLinear = + toType->getDifferentiabilityKind() == DifferentiabilityKind::Linear; + // Conversion from a `@differentiable` function to a `@differentiable(linear)` + // function is not allowed, because the from-expression will never be a + // closure expression or a declaration/member reference. + if (fromFnType->getDifferentiabilityKind() == DifferentiabilityKind::Normal && + toType->getDifferentiabilityKind() == DifferentiabilityKind::Linear) { + tc.diagnose(expr->getLoc(), + diag::invalid_differentiable_function_conversion_expr, + isToTypeLinear); + return; + } // Conversion from a non-`@differentiable` function to a `@differentiable` is // only allowed from a closure expression or a declaration/member reference. - if (toType->isDifferentiable() && !fromFnType->isDifferentiable()) { + if (!fromFnType->isDifferentiable() && toType->isDifferentiable()) { auto maybeDiagnoseFunctionRef = [&](Expr *semanticExpr) { if (auto *capture = dyn_cast(semanticExpr)) semanticExpr = capture->getClosureBody(); @@ -5897,20 +5909,15 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs, // note with a fix-it. if (auto *paramDecl = dyn_cast(declRef->getDecl())) { tc.diagnose(expr->getLoc(), - diag::invalid_differentiable_function_conversion_expr); + diag::invalid_differentiable_function_conversion_expr, + isToTypeLinear); if (paramDecl->getType()->is()) { auto *typeRepr = paramDecl->getTypeLoc().getTypeRepr(); while (auto *attributed = dyn_cast(typeRepr)) typeRepr = attributed->getTypeRepr(); std::string attributeString = "@differentiable"; - switch (toType->getDifferentiabilityKind()) { - case DifferentiabilityKind::Linear: + if (isToTypeLinear) attributeString += "(linear)"; - break; - case DifferentiabilityKind::Normal: - case DifferentiabilityKind::NonDifferentiable: - break; - } auto *funcTypeRepr = cast(typeRepr); auto paramListLoc = funcTypeRepr->getArgsTypeRepr()->getStartLoc(); tc.diagnose(paramDecl->getLoc(), @@ -5930,7 +5937,8 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs, return; } tc.diagnose(expr->getLoc(), - diag::invalid_differentiable_function_conversion_expr); + diag::invalid_differentiable_function_conversion_expr, + isToTypeLinear); }; maybeDiagnoseFunctionRef(getSemanticExprForDeclOrMemberRef(expr)); } @@ -6583,23 +6591,55 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType, // SWIFT_ENABLE_TENSORFLOW auto fromEI = fromFunc->getExtInfo(); + auto isFromDifferentiable = fromEI.isDifferentiable(); + auto isToDifferentiable = toEI.isDifferentiable(); // Handle implicit conversion from @differentiable. - if (fromEI.isDifferentiable() && !toEI.isDifferentiable()) { + if (isFromDifferentiable && !isToDifferentiable) { fromFunc = fromFunc->getWithoutDifferentiability() ->castTo(); - expr = cs.cacheType(new (tc.Context) - DifferentiableFunctionExtractOriginalExpr(expr, fromFunc)); + switch (fromEI.getDifferentiabilityKind()) { + case DifferentiabilityKind::Normal: + expr = cs.cacheType(new (tc.Context) + DifferentiableFunctionExtractOriginalExpr(expr, fromFunc)); + break; + case DifferentiabilityKind::Linear: + expr = cs.cacheType(new (tc.Context) + LinearFunctionExtractOriginalExpr(expr, fromFunc)); + break; + case DifferentiabilityKind::NonDifferentiable: + llvm_unreachable("Cannot be NonDifferentiable"); + } } - // Handle implicit conversion to @differentiable. + // Handle implicit conversion from @differentiable(linear) to + // @differentiable. + else if (fromEI.getDifferentiabilityKind() == + DifferentiabilityKind::Linear && + toEI.getDifferentiabilityKind() == DifferentiabilityKind::Normal) { + // TODO(TF-908): Create a `LinearToDifferentiableFunctionExpr` and SILGen + // it as thunk application. Remove the diagnostic. + tc.diagnose(expr->getLoc(), + diag::unsupported_linear_to_differentiable_conversion); + } + // Handle implicit conversion from non-@differentiable to @differentiable. maybeDiagnoseUnsupportedDifferentiableConversion(cs, expr, toFunc); - if (!fromEI.isDifferentiable() && toEI.isDifferentiable()) { + if (!isFromDifferentiable && isToDifferentiable) { auto newEI = fromEI.withDifferentiabilityKind(toEI.getDifferentiabilityKind()); fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult()) ->withExtInfo(newEI) ->castTo(); - expr = cs.cacheType(new (tc.Context) - DifferentiableFunctionExpr(expr, fromFunc)); + switch (toEI.getDifferentiabilityKind()) { + case DifferentiabilityKind::Normal: + expr = cs.cacheType(new (tc.Context) + DifferentiableFunctionExpr(expr, fromFunc)); + break; + case DifferentiabilityKind::Linear: + expr = cs.cacheType(new (tc.Context) + LinearFunctionExpr(expr, fromFunc)); + break; + case DifferentiabilityKind::NonDifferentiable: + llvm_unreachable("Cannot be NonDifferentiable"); + } } // If we have a ClosureExpr, then we can safely propagate the 'no escape' diff --git a/test/AutoDiff/differentiable_func_type_type_checking.swift b/test/AutoDiff/differentiable_func_type_type_checking.swift index 6f129e4b8003a..0177876a972a0 100644 --- a/test/AutoDiff/differentiable_func_type_type_checking.swift +++ b/test/AutoDiff/differentiable_func_type_type_checking.swift @@ -50,6 +50,17 @@ extension Float { _ = gradient(of: Float.addOne) // okay _ = gradient(of: Float(1.0).addOne) // okay +// TODO(TF-908): Remove this test once linear-to-differentiable conversion is supported. +func linearToDifferentiable(_ f: @escaping @differentiable(linear) (Float) -> Float) { + // expected-error @+1 {{conversion from '@differentiable(linear)' to '@differentiable' is not yet supported}} + _ = f as @differentiable (Float) -> Float +} + +func differentiableToLinear(_ f: @escaping @differentiable (Float) -> Float) { + // expected-error @+1 {{a '@differentiable(linear)' function can only be formed from a reference to a 'func' or a literal closure}} + _ = f as @differentiable(linear) (Float) -> Float +} + //===----------------------------------------------------------------------===// // Parameter selection (@nondiff) //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/differentiable_function_silgen.swift b/test/AutoDiff/differentiable_function_silgen.swift index 79685da1d9718..8728c7e85760b 100644 --- a/test/AutoDiff/differentiable_function_silgen.swift +++ b/test/AutoDiff/differentiable_function_silgen.swift @@ -1,6 +1,5 @@ // RUN: %target-swift-frontend -dump-ast %s | %FileCheck %s -check-prefix=CHECK-AST // RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s -check-prefix=CHECK-SILGEN -// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -check-prefix=CHECK-SIL //===----------------------------------------------------------------------===// // Closure conversion @@ -14,16 +13,14 @@ func myfunction(_ f: @escaping @differentiable (Float) -> (Float)) -> (Float) -> return f } -func myfunction2(_ f: @escaping @differentiable(linear) (Float) -> (Float)) /*-> (Float) -> Float*/ { +func myfunction2(_ f: @escaping @differentiable(linear) (Float) -> (Float)) -> (Float) -> Float { // @differentiable(linear) functions should be callable. _ = f(.zero) - // TODO(TF-900): Uncomment the following line to test conversion to non-differentiable function type. - // return f + return f } var global_f: @differentiable (Float) -> Float = {$0} -// TODO(TF-902): Uncomment the following line to test linear function storage. -// var global_f_linear: @differentiable(linear) (Float) -> Float = {$0} +var global_f_linear: @differentiable(linear) (Float) -> Float = {$0} func calls_global_f() { _ = global_f(10) @@ -33,8 +30,7 @@ func calls_global_f() { func apply() { _ = myfunction(thin) - // TODO(TF-900): Uncomment the following line to test direct calls to a linear function. - // _ = myfunction2(thin) + _ = myfunction2(thin) } // CHECK-AST-LABEL: (func_decl {{.*}} "myfunction(_:)" @@ -62,24 +58,27 @@ func apply() { // CHECK-SILGEN: [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float // CHECK-SILGEN: return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float -// CHEK-SILGEN-LABEL: @{{.*}}myfunction2{{.*}} : $@convention(thin) (@guaranteed @differentiable(linear) @callee_guaranteed (Float) -> Float) -> () { -// CHEK-SILGEN: bb0([[LIN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float): -// CHEK-SILGEN: [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float -// CHEK-SILGEN: apply [[COPIED_LIN]]({{%.*}}, {{%.*}}) : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0 -// CHEK-SILGEN: [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float -// CHEK-SILGEN: apply [[BORROWED_LIN]]({{%.*}}) : $@differentiable(linear) @callee_guaranteed (Float) -> Float -// CHEK-SILGEN: end_borrow [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float -// CHEK-SILGEN: destroy_value [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float // id: %13 -// TODO(TF-900): Change this to returning the extracted original function. -// CHEK-SILGEN: %14 = tuple () -// CHEK-SILGEN: return %14 : $() +// CHECK-SILGEN-LABEL: @{{.*}}myfunction2{{.*}} +// CHECK-SILGEN: bb0([[LIN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float): +// CHECK-SILGEN: [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: apply [[BORROWED_LIN]]({{%.*}}) : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: end_borrow [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[BORROWED_ORIG:%.*]] = linear_function_extract [original] [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float +// CHECK-SILGEN: end_borrow [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: destroy_value [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float +// CHECK-SILGEN: return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float // CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}} // CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float // CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float // CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [wrt 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float - -// CHECK-SIL: [[DIFFED:%.*]] = differentiable_function [wrt 0] {{%.*}} : $@callee_guaranteed (Float) -> Float +// CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float +// CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float +// CHECK-SILGEN-NEXT: [[LIN:%.*]] = linear_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float //===----------------------------------------------------------------------===// // Reabstraction @@ -113,4 +112,4 @@ func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) { // CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [wrt 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} // CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector -// HECK-SILGEN: apply [[DIFF_API]]({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector +// CHECK-SILGEN: apply [[DIFF_API]]({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/differentiation_transform_diagnostics.swift similarity index 96% rename from test/AutoDiff/autodiff_diagnostics.swift rename to test/AutoDiff/differentiation_transform_diagnostics.swift index 8a36f7fb68be4..976f3a69ec54c 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/differentiation_transform_diagnostics.swift @@ -1,5 +1,7 @@ // RUN: %target-swift-frontend -emit-sil -verify %s +// This file tests SIL diagnostics during the differentiation transform. + //===----------------------------------------------------------------------===// // Basic function //===----------------------------------------------------------------------===// @@ -329,3 +331,10 @@ struct TF_675 : Differentiable { } // expected-error @+1 {{function is not differentiable}} let _: @differentiable (Float) -> Float = TF_675().method + +//===----------------------------------------------------------------------===// +// Conversion to `@differentiable(linear)` (not yet supported) +//===----------------------------------------------------------------------===// + +// expected-error @+1 {{conversion to '@differentiable(linear)' function type is not yet supported}} +let _: @differentiable(linear) (Float) -> Float = { x in x } diff --git a/test/AutoDiff/differentiable_func_sil_diagnostics.swift b/test/AutoDiff/sil_diagnostics_after_differentiation.swift similarity index 97% rename from test/AutoDiff/differentiable_func_sil_diagnostics.swift rename to test/AutoDiff/sil_diagnostics_after_differentiation.swift index ed91b27c736fc..c0f07a5c3e898 100644 --- a/test/AutoDiff/differentiable_func_sil_diagnostics.swift +++ b/test/AutoDiff/sil_diagnostics_after_differentiation.swift @@ -2,7 +2,7 @@ // This test file contains SIL diagnostics tests for differentiable functions // such as escaping capture errors. -// NOTE: Only add tests for errors that would occur after the differentiation +// NOTE: Only add tests for errors that would occur after the differentiation // transform. func nonescapingArgument(f: @differentiable (Float, Float) -> Float) -> Float {