diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 3b616360ee896..0e27e01331698 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -801,7 +801,9 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn, // SIL_VTABLE or SIL_GLOBALVAR or SIL_WITNESS_TABLE record also means the end // of this SILFunction. while (kind != SIL_FUNCTION && kind != SIL_VTABLE && kind != SIL_GLOBALVAR && - kind != SIL_WITNESS_TABLE) { + // SWIFT_ENABLE_TENSORFLOW + kind != SIL_WITNESS_TABLE && kind != SIL_DIFFERENTIABILITY_WITNESS) { + // SWIFT_ENABLE_TENSORFLOW END if (kind == SIL_BASIC_BLOCK) // Handle a SILBasicBlock record. CurrentBB = readSILBasicBlock(fn, CurrentBB, scratch); diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 41b8fc3d9c679..4d055cc7d6034 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -201,6 +201,12 @@ namespace { /// Global variables that we've emitted a reference to. llvm::DenseSet GlobalsToEmit; + // SWIFT_ENABLE_TENSORFLOW + /// Referenced differentiability witnesses that need to be emitted. + llvm::DenseSet + DifferentiabilityWitnessesToEmit; + // SWIFT_ENABLE_TENSORFLOW END + /// Additional functions we might need to serialize. llvm::SmallVector Worklist; @@ -1061,6 +1067,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { auto *dwfi = cast(&SI); auto *witness = dwfi->getWitness(); + DifferentiabilityWitnessesToEmit.insert(witness); Mangle::ASTMangler mangler; auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey( witness->getKey()); @@ -2718,17 +2725,6 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { writeSILDefaultWitnessTable(wt); } - // SWIFT_ENABLE_TENSORFLOW - // Write out differentiability witnesses. - for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) { - // TODO(TF-893): Consider checking - // `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP - // functions. - if ((ShouldSerializeAll || diffWitness.isSerialized())) - writeSILDifferentiabilityWitness(diffWitness); - } - // SWIFT_ENABLE_TENSORFLOW END - // Emit only declarations if it is a module with pre-specializations. // And only do it in optimized builds. bool emitDeclarationsForOnoneSupport = @@ -2751,6 +2747,26 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) { processSILFunctionWorklist(); } + // SWIFT_ENABLE_TENSORFLOW + // Write out differentiability witnesses. + // Note: this must be done after visiting SIL functions above so that + // differentiability witness references (`differentiability_witness_function` + // instructions) have been tracked. + for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) { + // TODO(TF-893): Consider checking + // `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP + // functions. + if ((ShouldSerializeAll || diffWitness.isSerialized())) + DifferentiabilityWitnessesToEmit.insert(&diffWitness); + } + for (auto *diffWitness : DifferentiabilityWitnessesToEmit) + writeSILDifferentiabilityWitness(*diffWitness); + // Process SIL functions referenced by differentiability witnesses. + // Note: this is necessary despite processing `FuncsToEmit` below because + // `Worklist` is processed separately. + processSILFunctionWorklist(); + // SWIFT_ENABLE_TENSORFLOW END + // Now write function declarations for every function we've // emitted a reference to without emitting a function body for. for (const SILFunction &F : *SILMod) { diff --git a/test/AutoDiff/sil_differentiability_witness.sil b/test/AutoDiff/sil_differentiability_witness.sil index 7a6ba5829d8a3..d47a5ff03f2f4 100644 --- a/test/AutoDiff/sil_differentiability_witness.sil +++ b/test/AutoDiff/sil_differentiability_witness.sil @@ -1,13 +1,13 @@ // Round-trip parsing/printing test. -// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck --check-prefix=ROUNDTRIP %s +// RUN: %target-sil-opt %s | %target-sil-opt -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s // Round-trip serialization-deserialization test. // RUN: %empty-directory(%t) // RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main // RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name main -// RUN: %target-sil-opt %t/tmp.2.sib -module-name main | %FileCheck --check-prefix=ROUNDTRIP %s +// RUN: %target-sil-opt %t/tmp.2.sib -module-name main -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s // IRGen test. @@ -19,6 +19,71 @@ import Builtin import Swift import SwiftShims +// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp. + +sil @externalFn1 : $@convention(thin) (Float) -> Float + +sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + return undef : $(Float, @callee_guaranteed (Float) -> Float) +} + +sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +bb0(%0 : $Float): + return undef : $(Float, @callee_guaranteed (Float) -> Float) +} + +sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float { + jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +} + +// ROUNDTRIP-LABEL: // differentiability witness for externalFn1 +// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float { +// ROUNDTRIP: jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// ROUNDTRIP: } + +// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } { +// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0 +// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0 +// IRGEN-SAME: } + +// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp. + +sil @externalFn2 : $@convention(thin) (Float) -> Float + +sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + +sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + +sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float { + jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +} + +// ROUNDTRIP-LABEL: // differentiability witness for externalFn2 +// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float { +// ROUNDTRIP: jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +// ROUNDTRIP: } + +// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } { +// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0 +// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0 +// IRGEN-SAME: } + +// Test SIL differentiability witness declaration. + +sil @externalFn3 : $@convention(thin) (Float) -> Float + +sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float + +// ROUNDTRIP-LABEL: // differentiability witness for externalFn3 +// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}} + +// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* } + // Test public non-generic function. // SIL differentiability witness: // - Has public linkage (implicit). @@ -92,68 +157,3 @@ sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where // IRGEN-SAME: @AD__generic__jvp_src_0_wrt_0_1 // IRGEN-SAME: @AD__generic__vjp_src_0_wrt_0_1 // IRGEN-SAME: } - -// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp. - -sil @externalFn1 : $@convention(thin) (Float) -> Float - -sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -bb0(%0 : $Float): - return undef : $(Float, @callee_guaranteed (Float) -> Float) -} - -sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { -bb0(%0 : $Float): - return undef : $(Float, @callee_guaranteed (Float) -> Float) -} - -sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float { - jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -} - -// ROUNDTRIP-LABEL: // differentiability witness for externalFn1 -// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float { -// ROUNDTRIP: jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// ROUNDTRIP: } - -// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } { -// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0 -// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0 -// IRGEN-SAME: } - -// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp. - -sil @externalFn2 : $@convention(thin) (Float) -> Float - -sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - -sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - -sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float { - jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) - vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -} - -// ROUNDTRIP-LABEL: // differentiability witness for externalFn2 -// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float { -// ROUNDTRIP: jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) -// ROUNDTRIP: } - -// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } { -// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0 -// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0 -// IRGEN-SAME: } - -// Test SIL differentiability witness declaration. - -sil @externalFn3 : $@convention(thin) (Float) -> Float - -sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float - -// ROUNDTRIP-LABEL: // differentiability witness for externalFn3 -// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}} - -// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* } diff --git a/test/AutoDiff/sil_differentiability_witness_reference_serialization.sil b/test/AutoDiff/sil_differentiability_witness_reference_serialization.sil new file mode 100644 index 0000000000000..e51d485b77ce5 --- /dev/null +++ b/test/AutoDiff/sil_differentiability_witness_reference_serialization.sil @@ -0,0 +1,18 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend -emit-module -emit-module-path %t/test.swiftmodule -module-name test %s +// RUN: %target-sil-opt %t/test.swiftmodule + +sil_stage raw + +import Swift +import Builtin + +sil_differentiability_witness [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float + +sil @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float + +sil [serialized] @test_serialized : $@convention(thin) () -> () { +bb0: + %referenced_from_serialized_jvp_wrt_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float + return undef : $() +}