diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 2fa1237462f4e..57a337865605f 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -787,7 +787,9 @@ DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst( : getDifferentiabilityWitnessType( module, witnessKind, witness)), witnessKind(witnessKind), witness(witness), - hasExplicitFunctionType(FunctionType) {} + hasExplicitFunctionType(FunctionType) { + assert(witness && "Witness must not be null"); +} // SWIFT_ENABLE_TENSORFLOW END FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind, diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 273141144f343..c61c8d60d32cc 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -1576,6 +1576,28 @@ class SILVerifier : public SILVerifierBase { "The function operand must be a '@differentiable(linear)' " "function"); } + + void checkDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *dwfi) { + auto witnessFnTy = dwfi->getType().castTo(); + auto *witness = dwfi->getWitness(); + // `DifferentiabilityWitnessFunctionInst` constructor asserts that + // `witness` is non-null. + auto witnessKind = dwfi->getWitnessKind(); + // Return if not witnessing a derivative function. + auto derivKind = witnessKind.getAsDerivativeFunctionKind(); + if (!derivKind) + return; + // Return if witness does not define the referenced derivative. + auto *derivativeFn = witness->getDerivative(*derivKind); + if (!derivativeFn) + return; + auto derivativeFnTy = derivativeFn->getLoweredFunctionType(); + requireSameType(SILType::getPrimitiveObjectType(witnessFnTy), + SILType::getPrimitiveObjectType(derivativeFnTy), + "Type of witness instruction does not match actual type of " + "witnessed function"); + } // SWIFT_ENABLE_TENSORFLOW END void verifyLLVMIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID) {