diff --git a/docs/SIL.rst b/docs/SIL.rst index 224aca31b9752..47e56f0fbf6f4 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5611,28 +5611,29 @@ differentiable_function sil-differentiable-function-derivative-functions-clause? sil-differentiable-function-parameter-indices ::= - '[' 'wrt' [0-9]+ (' ' [0-9]+)* ']' + '[' 'parameters' [0-9]+ (' ' [0-9]+)* ']' sil-differentiable-derivative-functions-clause ::= - 'with' '{' sil-value ':' sil-type ',' sil-value ':' sil-type '}' + 'with_derivative' + '{' sil-value ':' sil-type ',' sil-value ':' sil-type '}' - differentiable_function [wrt 0] %0 : $(T) -> T \ - with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)} + differentiable_function [parameters 0] %0 : $(T) -> T \ + with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)} Bundles a function with its derivative functions into a ``@differentiable`` function. There are two derivative functions: a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP) function. -``[wrt ...]`` specifies parameter indices that the original function is +``[parameters ...]`` specifies parameter indices that the original function is differentiable with respect to. When not specified, it defaults to all parameters. -A ``with`` clause specifies the differentiation functions associated -with the original function. When a ``with`` clause is not specified, the first -operand will be differentiated to produce derivative functions, and a ``with`` -clause will be added to the instruction. +A ``with_derivative`` clause specifies the differentiation functions associated +with the original function. When a ``with_derivative`` clause is not specified, +the first operand will be differentiated to produce derivative functions, and a +``with_derivative`` clause will be added to the instruction. -In raw SIL, it is optional to provide a derivative function ``with`` clause. -In canonical SIL, a ``with`` clause is mandatory. +In raw SIL, it is optional to provide a derivative function ``with_derivative`` +clause. In canonical SIL, a ``with_derivative`` clause is mandatory. linear_function diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 1caa6ed4eb9ff..fc034f853122c 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -2922,17 +2922,17 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { // SWIFT_ENABLE_TENSORFLOW case SILInstructionKind::DifferentiableFunctionInst: { - // e.g. differentiable_function [wrt 0 1 2] %0 : $T + // e.g. differentiable_function [parameters 0 1 2] %0 : $T // - // e.g. differentiable_function [wrt 0 1 2] %0 : $T with + // e.g. differentiable_function [parameters 0 1 2] %0 : $T with_derivative // {%1 : $T, %2 : $T} // ^ jvp ^ vjp SourceLoc lastLoc; SmallVector parameterIndices; - // Parse optional `[wrt ...]` + // Parse optional `[parameters ...]` if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::identifier) && - P.peekToken().getText() == "wrt") { + P.peekToken().getText() == "parameters") { P.consumeToken(tok::l_square); P.consumeToken(tok::identifier); // Parse indices. @@ -2960,8 +2960,9 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { return true; } Optional> derivativeFunctions = None; - // Parse an optional operand list `with { , }`. - if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with") { + // Parse an optional operand list + // `with_derivative { , }`. + if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with_derivative") { P.consumeToken(tok::identifier); // Parse derivative function values as an operand list. // FIXME(rxwei): Change this to *not* require a type signature once diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 4452543e3dc07..f854bae04bd9d 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1163,14 +1163,14 @@ class SILPrinter : public SILInstructionVisitor { // SWIFT_ENABLE_TENSORFLOW void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { if (!dfi->getParameterIndices()->isEmpty()) { - *this << "[wrt"; + *this << "[parameters"; for (auto i : dfi->getParameterIndices()->getIndices()) *this << ' ' << i; *this << "] "; } *this << getIDAndType(dfi->getOriginalFunction()); if (dfi->hasDerivativeFunctions()) { - *this << " with "; + *this << " with_derivative "; *this << '{' << getIDAndType(dfi->getJVPFunction()) << ", " << getIDAndType(dfi->getVJPFunction()) << '}'; } diff --git a/test/AutoDiff/differentiable_function_inst.sil b/test/AutoDiff/differentiable_function_inst.sil index 8e9df1c6adc74..879ad6317e99d 100644 --- a/test/AutoDiff/differentiable_function_inst.sil +++ b/test/AutoDiff/differentiable_function_inst.sil @@ -17,20 +17,20 @@ sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float sil @test : $@convention(thin) () -> () { bb0: %0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float - %1 = differentiable_function [wrt 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float + %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float - %3 = differentiable_function [wrt 0] %0 : $@convention(thin) (Float, Float, Float) -> Float + %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (Float, Float, Float) -> Float // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float %5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float - %6 = differentiable_function [wrt 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float + %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float - %8 = differentiable_function [wrt 0] %5 : $@convention(method) (Float, Float, Float) -> Float + %8 = differentiable_function [parameters 0] %5 : $@convention(method) (Float, Float, Float) -> Float // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float @@ -68,9 +68,9 @@ bb0(%0 : $Float): sil @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float { bb0: %orig = function_ref @foo : $@convention(thin) (Float) -> Float - %undiffedFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float + %undiffedFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float %vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - %diffFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} + %diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} %extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float %extractedOriginal = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float return %undiffedFunc : $@differentiable @convention(thin) (Float) -> Float @@ -78,9 +78,9 @@ bb0: // CHECK-LABEL: @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float // CHECK: [[FOO:%.*]] = function_ref @foo : $@convention(thin) (Float) -> Float -// CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [wrt 0] [[FOO]] : $@convention(thin) (Float) -> Float +// CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float // CHECK: [[FOO_VJP:%.*]] = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [wrt 0] [[FOO]] : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} +// CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} // CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float // CHECK: [[EXTRACTED_ORIG:%.*]] = differentiable_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float // CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float diff --git a/test/AutoDiff/differentiable_function_inst_irgen.sil b/test/AutoDiff/differentiable_function_inst_irgen.sil index 79cd8a6a9b702..601f1907ed265 100644 --- a/test/AutoDiff/differentiable_function_inst_irgen.sil +++ b/test/AutoDiff/differentiable_function_inst_irgen.sil @@ -36,7 +36,7 @@ sil @make_diff_func : $@convention(thin) () -> (@convention(thin) (Float) -> Flo bb0: %orig = function_ref @foo : $@convention(thin) (Float) -> Float %vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - %diffFunc = differentiable_function [wrt 0] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} + %diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} %extractedOrig = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float %extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float %tuple = tuple (%extractedOrig : $@convention(thin) (Float) -> Float, %extractedVJP : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) diff --git a/test/AutoDiff/differentiable_function_silgen.swift b/test/AutoDiff/differentiable_function_silgen.swift index 8728c7e85760b..03bc23ec7a7e8 100644 --- a/test/AutoDiff/differentiable_function_silgen.swift +++ b/test/AutoDiff/differentiable_function_silgen.swift @@ -75,7 +75,7 @@ func apply() { // CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}} // CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float // CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float -// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [wrt 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float +// CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float // CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float // CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float // CHECK-SILGEN-NEXT: [[LIN:%.*]] = linear_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float @@ -110,6 +110,6 @@ func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) { // CHECK-SILGEN: [[VJP_COPY:%.*]] = copy_value [[VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) -// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [wrt 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} +// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [parameters 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with_derivative {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} // CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector // CHECK-SILGEN: apply [[DIFF_API]]({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector diff --git a/test/AutoDiff/forward_mode_sil.swift b/test/AutoDiff/forward_mode_sil.swift index 81dace664f739..ec40f38b5877d 100644 --- a/test/AutoDiff/forward_mode_sil.swift +++ b/test/AutoDiff/forward_mode_sil.swift @@ -23,14 +23,14 @@ func unary(_ x: Float) -> Float { // CHECK-SIL: [[MULT_FUNC_1:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -// CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = differentiable_function [wrt 0 1] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = differentiable_function [parameters 0 1] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_1:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[X_ARG]], [[X_ARG]], %3) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: ([[ORIG_RESULT_1:%.*]], [[MULT_DIFF_1:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_2:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -// CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = differentiable_function [wrt 0 1] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = differentiable_function [parameters 0 1] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_2:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[ORIG_RESULT_1]], [[X_ARG]], %2) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: ([[ORIG_RESULT_2:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_2]] : $(Float, @callee_guaranteed (Float, Float) -> Float) @@ -68,7 +68,7 @@ func binary(x: Float, y: Float) -> Float { // CHECK-SIL: [[MULT_FUNC:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float // CHECK-SIL: [[MULT_FUNC_JVP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: [[MULT_FUNC_VJP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -// CHECK-SIL: [[AUTODIFF_INST:%.*]] = differentiable_function [wrt 0 1] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_INST:%.*]] = differentiable_function [parameters 0 1] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} // CHECK-SIL: [[AUTODIFF_EXTRACT_INST:%.*]] = differentiable_function_extract [jvp] [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float // CHECK-SIL: [[MULT_JVP_APPLY_TUPLE:%.*]] = apply [[AUTODIFF_EXTRACT_INST]]([[X_ARG]], [[Y_ARG]], %4) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) // CHECK-SIL: ([[ORIG_RESULT:%.*]], [[MULT_DIFF:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE]] : $(Float, @callee_guaranteed (Float, Float) -> Float) diff --git a/test/AutoDiff/refcounting.swift b/test/AutoDiff/refcounting.swift index 3927212225f47..377efec9f0459 100644 --- a/test/AutoDiff/refcounting.swift +++ b/test/AutoDiff/refcounting.swift @@ -86,7 +86,7 @@ _ = pullback(at: Vector.zero, in: testOwnedVector) // CHECK: [[ADD:%.*]] = function_ref @Vector_plus // CHECK: [[ADD_JVP:%.*]] = function_ref @{{.*}}Vector_plus__jvp_src_0_wrt_0_1{{.*}} // CHECK: [[ADD_VJP:%.*]] = function_ref @{{.*}}Vector_plus__vjp_src_0_wrt_0_1{{.*}} -// CHECK: [[ADD_AD_FUNC:%.*]] = differentiable_function [wrt 0 1] [[ADD]] {{.*}} with {[[ADD_JVP]] {{.*}}, [[ADD_VJP]] {{.*}}} +// CHECK: [[ADD_AD_FUNC:%.*]] = differentiable_function [parameters 0 1] [[ADD]] {{.*}} with_derivative {[[ADD_JVP]] {{.*}}, [[ADD_VJP]] {{.*}}} // CHECK: [[ADD_AD_FUNC_EXTRACT:%.*]] = differentiable_function_extract [vjp] [[ADD_AD_FUNC]] // CHECK: [[ADD_VJP_RESULT:%.*]] = apply [[ADD_AD_FUNC_EXTRACT]]({{.*}}, {{.*}}, {{.*}}) : $@convention(method) (@guaranteed Vector, @guaranteed Vector, @thin Vector.Type) -> (@owned Vector, @owned @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)) // CHECK: [[ADD_PULLBACK:%.*]] = tuple_extract [[ADD_VJP_RESULT]] : $(Vector, @callee_guaranteed (@guaranteed Vector) -> (@owned Vector, @owned Vector)), 1 diff --git a/test/AutoDiff/simple_real_vector.swift b/test/AutoDiff/simple_real_vector.swift index 6dc92fdf90f5d..97eb50f79d301 100644 --- a/test/AutoDiff/simple_real_vector.swift +++ b/test/AutoDiff/simple_real_vector.swift @@ -46,7 +46,7 @@ public func test1() -> Vector { // CHECK-LABEL: @{{.*}}test1{{.*}} // CHECK: [[CLOSURE:%.*]] = function_ref @{{.*}}test1{{.*}}foo{{.*}} : $@convention(thin) (Vector) -> Float // CHECK: [[CLOSURE_THICK:%.*]] = thin_to_thick_function [[CLOSURE]] : $@convention(thin) (Vector) -> Float to $@callee_guaranteed (Vector) -> Float -// CHECK: [[CLOSURE_DIFF:%.*]] = differentiable_function [wrt 0] [[CLOSURE_THICK]] : $@callee_guaranteed (Vector) -> Float +// CHECK: [[CLOSURE_DIFF:%.*]] = differentiable_function [parameters 0] [[CLOSURE_THICK]] : $@callee_guaranteed (Vector) -> Float // CHECK: [[CLOSURE_DIFF_NOESC:%.*]] = convert_escape_to_noescape [not_guaranteed] [[CLOSURE_DIFF]] : $@differentiable @callee_guaranteed (Vector) -> Float to $@differentiable @noescape @callee_guaranteed (Vector) -> Float // TF-189: `TF189` is a non-trivial type but `TF189.AllDifferentiableVariables` is trivial. diff --git a/test/AutoDiff/subset_parameters_thunk.swift b/test/AutoDiff/subset_parameters_thunk.swift index 2dc0308285faf..102a439fcceb3 100644 --- a/test/AutoDiff/subset_parameters_thunk.swift +++ b/test/AutoDiff/subset_parameters_thunk.swift @@ -26,5 +26,5 @@ func differentiate_foo_wrt_0(_ x: Float) -> Float { // CHECK: [[FOO_VJP_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_VJP]]() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric, τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector)) // CHECK: [[FOO_VJP_SUBSET_THUNK_THIN:%.*]] = function_ref @AD__orig_{{.*}}foo{{.*}}_src_0_wrt_0_vjp_subset_parameters_thunk : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) // CHECK: [[FOO_VJP_SUBSET_THUNK:%.*]] = thin_to_thick_function [[FOO_VJP_SUBSET_THUNK_THIN]] : $@convention(thin) (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) to $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float) -// CHECK: [[FOO_DIFF:%.*]] = differentiable_function [wrt 0] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} +// CHECK: [[FOO_DIFF:%.*]] = differentiable_function [parameters 0] [[FOO_FLOAT]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> @out Float with_derivative {[[FOO_JVP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[FOO_VJP_SUBSET_THUNK]] : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)} // CHECK: } diff --git a/test/AutoDiff/witness_method_autodiff.sil b/test/AutoDiff/witness_method_autodiff.sil index 1c3603bb0c837..baa7c7518324f 100644 --- a/test/AutoDiff/witness_method_autodiff.sil +++ b/test/AutoDiff/witness_method_autodiff.sil @@ -14,7 +14,7 @@ protocol DiffReq { sil @differentiateWitnessMethod : $@convention(thin) (@in_guaranteed T) -> () { bb0(%0 : $*T): %1 = witness_method $T, #DiffReq.f!1 : (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float - %2 = differentiable_function [wrt 0] %1 : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float + %2 = differentiable_function [parameters 0] %1 : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float %ret = tuple () return %ret : $() @@ -24,14 +24,14 @@ bb0(%0 : $*T): // CHECK: [[ORIG_REF:%.*]] = witness_method $T, #DiffReq.f!1 // CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.SU // CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.SU -// CHECK: differentiable_function [wrt 0] [[ORIG_REF]] : {{.*}} with {[[JVP_REF]] : {{.*}}, [[VJP_REF]] : {{.*}}} +// CHECK: differentiable_function [parameters 0] [[ORIG_REF]] : {{.*}} with_derivative {[[JVP_REF]] : {{.*}}, [[VJP_REF]] : {{.*}}} // CHECK: } // end sil function 'differentiateWitnessMethod' sil @differentiatePartiallyAppliedWitnessMethod : $@convention(thin) (@in_guaranteed T) -> () { bb0(%0 : $*T): %1 = witness_method $T, #DiffReq.f!1 : (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float %2 = partial_apply [callee_guaranteed] %1(%0) : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float - %3 = differentiable_function [wrt 0] %2 : $@callee_guaranteed (Float) -> Float + %3 = differentiable_function [parameters 0] %2 : $@callee_guaranteed (Float) -> Float %ret = tuple () return %ret : $() @@ -51,5 +51,5 @@ bb0(%0 : $*T): // CHECK: [[VJP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_REF]]([[ARGCOPY2]]) // CHECK: dealloc_stack [[ARGCOPY2]] // CHECK: dealloc_stack [[ARGCOPY1]] -// CHECK: differentiable_function [wrt 0] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}} +// CHECK: differentiable_function [parameters 0] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with_derivative {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}} // CHECK: } // end sil function 'differentiatePartiallyAppliedWitnessMethod' diff --git a/test/AutoDiff/witness_table_sil.swift b/test/AutoDiff/witness_table_sil.swift index 8cd907ef86d95..a4d7fd80e8742 100644 --- a/test/AutoDiff/witness_table_sil.swift +++ b/test/AutoDiff/witness_table_sil.swift @@ -25,7 +25,7 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function1{{.*}}_jvp_SSU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float, Double) -> Float) { // CHECK: [[JVP1_ORIG_FNREF:%.*]] = function_ref {{.*}}function1{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[JVP1_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function1{{.*}}__vjp_src_0_wrt_0_1 - // CHECK: [[JVP1_ADFUNC:%.*]] = differentiable_function [wrt 0 1] [[JVP1_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP1_VJP_FNREF]] : {{.*}}} + // CHECK: [[JVP1_ADFUNC:%.*]] = differentiable_function [parameters 0 1] [[JVP1_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[JVP1_VJP_FNREF]] : {{.*}}} // CHECK: [[JVP1:%.*]] = differentiable_function_extract [jvp] [[JVP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float // CHECK: apply [[JVP1]] // CHECK: } // end sil function 'AD__{{.*}}function1{{.*}}_jvp_SSU' @@ -33,7 +33,7 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function1{{.*}}_vjp_SSU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Double)) { // CHECK: [[VJP1_ORIG_FNREF:%.*]] = function_ref {{.*}}function1{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[VJP1_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function1{{.*}}__vjp_src_0_wrt_0_1 - // CHECK: [[VJP1_ADFUNC:%.*]] = differentiable_function [wrt 0 1] [[VJP1_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP1_VJP_FNREF]] : {{.*}}} + // CHECK: [[VJP1_ADFUNC:%.*]] = differentiable_function [parameters 0 1] [[VJP1_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[VJP1_VJP_FNREF]] : {{.*}}} // CHECK: [[VJP1:%.*]] = differentiable_function_extract [vjp] [[VJP1_ADFUNC]] : $@differentiable @convention(method) (Float, Double, @nondiff S) -> Float // CHECK: apply [[VJP1]] // CHECK: } // end sil function 'AD__{{.*}}function1{{.*}}_vjp_SSU' @@ -46,7 +46,7 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function2{{.*}}_jvp_SSS : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float, Double, @in_guaranteed S) -> Float) { // CHECK: [[JVP2_ORIG_FNREF:%.*]] = function_ref {{.*}}function2{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[JVP2_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function2{{.*}}__vjp_src_0_wrt_0_1_2 - // CHECK: [[JVP2_ADFUNC:%.*]] = differentiable_function [wrt 0 1 2] [[JVP2_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP2_VJP_FNREF]] : {{.*}}} + // CHECK: [[JVP2_ADFUNC:%.*]] = differentiable_function [parameters 0 1 2] [[JVP2_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[JVP2_VJP_FNREF]] : {{.*}}} // CHECK: [[JVP2:%.*]] = differentiable_function_extract [jvp] [[JVP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float // CHECK: apply [[JVP2]] // CHECK: } // end sil function 'AD__{{.*}}function2{{.*}}_jvp_SSS' @@ -54,7 +54,7 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function2{{.*}}_vjp_SSS : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Double, @out S)) { // CHECK: [[VJP2_ORIG_FNREF:%.*]] = function_ref {{.*}}function2{{.*}} : $@convention(method) (Float, Double, S) -> Float // CHECK: [[VJP2_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function2{{.*}}__vjp_src_0_wrt_0_1_2 - // CHECK: [[VJP2_ADFUNC:%.*]] = differentiable_function [wrt 0 1 2] [[VJP2_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP2_VJP_FNREF]] : {{.*}}} + // CHECK: [[VJP2_ADFUNC:%.*]] = differentiable_function [parameters 0 1 2] [[VJP2_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[VJP2_VJP_FNREF]] : {{.*}}} // CHECK: [[VJP2:%.*]] = differentiable_function_extract [vjp] [[VJP2_ADFUNC]] : $@differentiable @convention(method) (Float, Double, S) -> Float // CHECK: apply [[VJP2]] // CHECK: } // end sil function 'AD__{{.*}}function2{{.*}}_vjp_SSS' @@ -67,7 +67,7 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function3{{.*}}_jvp_USU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Double, @owned @callee_guaranteed (Double) -> Double) { // CHECK: [[JVP3_ORIG_FNREF:%.*]] = function_ref {{.*}}function3{{.*}} : $@convention(method) (Float, Double, S) -> Double // CHECK: [[JVP3_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function3{{.*}}__vjp_src_0_wrt_1 - // CHECK: [[JVP3_ADFUNC:%.*]] = differentiable_function [wrt 1] [[JVP3_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[JVP3_VJP_FNREF]] : {{.*}}} + // CHECK: [[JVP3_ADFUNC:%.*]] = differentiable_function [parameters 1] [[JVP3_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[JVP3_VJP_FNREF]] : {{.*}}} // CHECK: [[JVP3:%.*]] = differentiable_function_extract [jvp] [[JVP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double // CHECK: apply [[JVP3]] // CHECK: } // end sil function 'AD__{{.*}}function3{{.*}}_jvp_USU' @@ -75,7 +75,7 @@ struct S : Proto, VectorProtocol { // CHECK-LABEL: sil {{.*}} @AD__{{.*}}function3{{.*}}_vjp_USU : $@convention(witness_method: Proto) (Float, Double, @in_guaranteed S) -> (Double, @owned @callee_guaranteed (Double) -> Double) { // CHECK: [[VJP3_ORIG_FNREF:%.*]] = function_ref {{.*}}function3{{.*}} : $@convention(method) (Float, Double, S) -> Double // CHECK: [[VJP3_VJP_FNREF:%.*]] = function_ref @AD__{{.*}}function3{{.*}}__vjp_src_0_wrt_1 - // CHECK: [[VJP3_ADFUNC:%.*]] = differentiable_function [wrt 1] [[VJP3_ORIG_FNREF]] : {{.*}} with {{{%.*}} : {{.*}}, [[VJP3_VJP_FNREF]] : {{.*}}} + // CHECK: [[VJP3_ADFUNC:%.*]] = differentiable_function [parameters 1] [[VJP3_ORIG_FNREF]] : {{.*}} with_derivative {{{%.*}} : {{.*}}, [[VJP3_VJP_FNREF]] : {{.*}}} // CHECK: [[VJP3:%.*]] = differentiable_function_extract [vjp] [[VJP3_ADFUNC]] : $@differentiable @convention(method) (@nondiff Float, Double, @nondiff S) -> Double // CHECK: apply [[VJP3]] // CHECK: } // end sil function 'AD__{{.*}}function3{{.*}}_vjp_USU'