diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 2fa1237462f4e..7ee43554d1e42 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -781,13 +781,21 @@ SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType( DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst( SILModule &module, SILDebugLocation debugLoc, DifferentiabilityWitnessFunctionKind witnessKind, - SILDifferentiabilityWitness *witness, Optional FunctionType) - : InstructionBase(debugLoc, FunctionType - ? *FunctionType + SILDifferentiabilityWitness *witness, Optional functionType) + : InstructionBase(debugLoc, functionType + ? *functionType : getDifferentiabilityWitnessType( module, witnessKind, witness)), witnessKind(witnessKind), witness(witness), - hasExplicitFunctionType(FunctionType) {} + hasExplicitFunctionType(functionType) { + assert(witness && "Differentiability witness must not be null"); +#ifndef NDEBUG + if (functionType.hasValue()) { + assert(module.getStage() == SILStage::Lowered && + "Explicit type is valid only in lowered SIL"); + } +#endif +} // SWIFT_ENABLE_TENSORFLOW END FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind, diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 273141144f343..c61c8d60d32cc 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -1576,6 +1576,28 @@ class SILVerifier : public SILVerifierBase { "The function operand must be a '@differentiable(linear)' " "function"); } + + void checkDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *dwfi) { + auto witnessFnTy = dwfi->getType().castTo(); + auto *witness = dwfi->getWitness(); + // `DifferentiabilityWitnessFunctionInst` constructor asserts that + // `witness` is non-null. + auto witnessKind = dwfi->getWitnessKind(); + // Return if not witnessing a derivative function. + auto derivKind = witnessKind.getAsDerivativeFunctionKind(); + if (!derivKind) + return; + // Return if witness does not define the referenced derivative. + auto *derivativeFn = witness->getDerivative(*derivKind); + if (!derivativeFn) + return; + auto derivativeFnTy = derivativeFn->getLoweredFunctionType(); + requireSameType(SILType::getPrimitiveObjectType(witnessFnTy), + SILType::getPrimitiveObjectType(derivativeFnTy), + "Type of witness instruction does not match actual type of " + "witnessed function"); + } // SWIFT_ENABLE_TENSORFLOW END void verifyLLVMIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID) { diff --git a/test/AutoDiff/differentiability_witness_function_inst.sil b/test/AutoDiff/differentiability_witness_function_inst.sil index 6668f1399ea01..75814400eb829 100644 --- a/test/AutoDiff/differentiability_witness_function_inst.sil +++ b/test/AutoDiff/differentiability_witness_function_inst.sil @@ -58,9 +58,6 @@ bb0: // Test "dependent" generic requirements: `T == T.TangentVector` depends on `T: Differentiable`. %generic_vjp_wrt_0_1_dependent_req = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T - // Test explicit function types. - %explicit_fnty = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float) - return undef : $() } @@ -73,7 +70,6 @@ bb0: // CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 // CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 // CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 -// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float) // CHECK: } // IRGEN: @AD__foo_PSUURS = external global %swift.differentiability_witness, align 8 @@ -106,6 +102,3 @@ bb0: // IRGEN: [[PTR7:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__generic_PSSRSs14DifferentiableRz13TangentVectorsAAPQzRszl, i32 0, i32 1), align 8 // IRGEN: [[FNPTR7:%.*]] = bitcast i8* [[PTR7]] to { i8*, %swift.refcounted* } (%swift.opaque*, %swift.opaque*, float, %swift.type*, i8**)* - -// IRGEN: [[PTR8:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__foo_PSUURS, i32 0, i32 0), align 8 -// IRGEN: [[FNPTR8:%.*]] = bitcast i8* [[PTR8]] to { float, i8*, %swift.refcounted* } (float, float, float)* diff --git a/test/AutoDiff/differentiable_function_inst_lowered.sil b/test/AutoDiff/differentiable_function_inst_lowered.sil index 1e8126ba86e27..5af65f62b4032 100644 --- a/test/AutoDiff/differentiable_function_inst_lowered.sil +++ b/test/AutoDiff/differentiable_function_inst_lowered.sil @@ -1,6 +1,7 @@ // RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s -// Test `differentiable_function_extract` with explicit lowered type. +// Test `differentiable_function_extract` and +// `differentiability_witness_function` with explicit lowered type. // SIL generated via `%target-sil-opt -loadable-address %s`. // Note: SIL serialization/deserialization does not support lowered SIL. @@ -27,37 +28,43 @@ struct Large : Differentiable { mutating func move(along direction: Large.TangentVector) } +sil_differentiability_witness [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large // CHECK-LABEL: sil @test sil @test : $@convention(thin) () -> () { bb0: - %0 = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + %func = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %func_jvpwitness_wrt_012 = differentiability_witness_function [jvp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector) + %func_vjpwitness_wrt_012 = differentiability_witness_function [vjp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + %func_diff_wrt_012 = differentiable_function [parameters 0 1 2] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large with_derivative {%func_jvpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector), %func_vjpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))} + %func_vjp_wrt_012 = differentiable_function_extract [vjp] %func_diff_wrt_012 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) - // CHECK: %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + // CHECK: [[FUNC_REF:%.*]] = function_ref @examplefunc + // CHECK: [[DIFF_WRT_012:%.*]] = differentiable_function [parameters 0 1 2] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: [[VJP_WRT_012:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_012]] : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) - %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + %func_diff_wrt_0 = differentiable_function [parameters 0] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %func_vjp_wrt_0 = differentiable_function_extract [vjp] %func_diff_wrt_0 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) - // CHECK: %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + // CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) - %5 = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + %method = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %method_diff_wrt_0123 = differentiable_function [parameters 0 1 2] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %7 = differentiable_function_extract [vjp] %method_diff_wrt_0123 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) - // CHECK: %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + // CHECK: [[METHOD_REF:%.*]] = function_ref @examplemethod + // CHECK: [[DIFF_WRT_0123:%.*]] = differentiable_function [parameters 0 1 2] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: [[VJP_WRT_0123:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0123]] : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) - %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + %method_diff_wrt_0 = differentiable_function [parameters 0] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %method_vjp_wrt_0 = differentiable_function_extract [vjp] %method_diff_wrt_0 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) - // CHECK: %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large - // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + // CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) %10 = tuple () return %10 : $()