Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/Serialization/DeserializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,9 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
// SIL_VTABLE or SIL_GLOBALVAR or SIL_WITNESS_TABLE record also means the end
// of this SILFunction.
while (kind != SIL_FUNCTION && kind != SIL_VTABLE && kind != SIL_GLOBALVAR &&
kind != SIL_WITNESS_TABLE) {
// SWIFT_ENABLE_TENSORFLOW
kind != SIL_WITNESS_TABLE && kind != SIL_DIFFERENTIABILITY_WITNESS) {
// SWIFT_ENABLE_TENSORFLOW END
if (kind == SIL_BASIC_BLOCK)
// Handle a SILBasicBlock record.
CurrentBB = readSILBasicBlock(fn, CurrentBB, scratch);
Expand Down
38 changes: 27 additions & 11 deletions lib/Serialization/SerializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ namespace {
/// Global variables that we've emitted a reference to.
llvm::DenseSet<const SILGlobalVariable *> GlobalsToEmit;

// SWIFT_ENABLE_TENSORFLOW
/// Referenced differentiability witnesses that need to be emitted.
llvm::DenseSet<const SILDifferentiabilityWitness *>
DifferentiabilityWitnessesToEmit;
// SWIFT_ENABLE_TENSORFLOW END

/// Additional functions we might need to serialize.
llvm::SmallVector<const SILFunction *, 16> Worklist;

Expand Down Expand Up @@ -1061,6 +1067,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
auto *dwfi = cast<DifferentiabilityWitnessFunctionInst>(&SI);
auto *witness = dwfi->getWitness();
DifferentiabilityWitnessesToEmit.insert(witness);
Mangle::ASTMangler mangler;
auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(
witness->getKey());
Expand Down Expand Up @@ -2718,17 +2725,6 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
writeSILDefaultWitnessTable(wt);
}

// SWIFT_ENABLE_TENSORFLOW
// Write out differentiability witnesses.
for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) {
// TODO(TF-893): Consider checking
// `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP
// functions.
if ((ShouldSerializeAll || diffWitness.isSerialized()))
writeSILDifferentiabilityWitness(diffWitness);
}
// SWIFT_ENABLE_TENSORFLOW END

// Emit only declarations if it is a module with pre-specializations.
// And only do it in optimized builds.
bool emitDeclarationsForOnoneSupport =
Expand All @@ -2751,6 +2747,26 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
processSILFunctionWorklist();
}

// SWIFT_ENABLE_TENSORFLOW
// Write out differentiability witnesses.
// Note: this must be done after visiting SIL functions above so that
// differentiability witness references (`differentiability_witness_function`
// instructions) have been tracked.
for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) {
// TODO(TF-893): Consider checking
// `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP
// functions.
if ((ShouldSerializeAll || diffWitness.isSerialized()))
DifferentiabilityWitnessesToEmit.insert(&diffWitness);
}
for (auto *diffWitness : DifferentiabilityWitnessesToEmit)
writeSILDifferentiabilityWitness(*diffWitness);
// Process SIL functions referenced by differentiability witnesses.
// Note: this is necessary despite processing `FuncsToEmit` below because
// `Worklist` is processed separately.
processSILFunctionWorklist();
// SWIFT_ENABLE_TENSORFLOW END

// Now write function declarations for every function we've
// emitted a reference to without emitting a function body for.
for (const SILFunction &F : *SILMod) {
Expand Down
134 changes: 67 additions & 67 deletions test/AutoDiff/sil_differentiability_witness.sil
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Round-trip parsing/printing test.

// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck --check-prefix=ROUNDTRIP %s
// RUN: %target-sil-opt %s | %target-sil-opt -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s

// Round-trip serialization-deserialization test.

// RUN: %empty-directory(%t)
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name main
// RUN: %target-sil-opt %t/tmp.2.sib -module-name main | %FileCheck --check-prefix=ROUNDTRIP %s
// RUN: %target-sil-opt %t/tmp.2.sib -module-name main -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s

// IRGen test.

Expand All @@ -19,6 +19,71 @@ import Builtin
import Swift
import SwiftShims

// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp.

sil @externalFn1 : $@convention(thin) (Float) -> Float

sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
return undef : $(Float, @callee_guaranteed (Float) -> Float)
}

sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
return undef : $(Float, @callee_guaranteed (Float) -> Float)
}

sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}

// ROUNDTRIP-LABEL: // differentiability witness for externalFn1
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
// ROUNDTRIP: jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } {
// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0
// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0
// IRGEN-SAME: }

// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.

sil @externalFn2 : $@convention(thin) (Float) -> Float

sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)

sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)

sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}

// ROUNDTRIP-LABEL: // differentiability witness for externalFn2
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
// ROUNDTRIP: jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } {
// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0
// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0
// IRGEN-SAME: }

// Test SIL differentiability witness declaration.

sil @externalFn3 : $@convention(thin) (Float) -> Float

sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float

// ROUNDTRIP-LABEL: // differentiability witness for externalFn3
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}}

// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* }

// Test public non-generic function.
// SIL differentiability witness:
// - Has public linkage (implicit).
Expand Down Expand Up @@ -92,68 +157,3 @@ sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where
// IRGEN-SAME: @AD__generic__jvp_src_0_wrt_0_1
// IRGEN-SAME: @AD__generic__vjp_src_0_wrt_0_1
// IRGEN-SAME: }

// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp.

sil @externalFn1 : $@convention(thin) (Float) -> Float

sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
return undef : $(Float, @callee_guaranteed (Float) -> Float)
}

sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
return undef : $(Float, @callee_guaranteed (Float) -> Float)
}

sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}

// ROUNDTRIP-LABEL: // differentiability witness for externalFn1
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
// ROUNDTRIP: jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } {
// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0
// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0
// IRGEN-SAME: }

// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.

sil @externalFn2 : $@convention(thin) (Float) -> Float

sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)

sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)

sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}

// ROUNDTRIP-LABEL: // differentiability witness for externalFn2
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
// ROUNDTRIP: jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// ROUNDTRIP: }

// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } {
// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0
// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0
// IRGEN-SAME: }

// Test SIL differentiability witness declaration.

sil @externalFn3 : $@convention(thin) (Float) -> Float

sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float

// ROUNDTRIP-LABEL: // differentiability witness for externalFn3
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}}

// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* }
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend -emit-module -emit-module-path %t/test.swiftmodule -module-name test %s
// RUN: %target-sil-opt %t/test.swiftmodule

sil_stage raw

import Swift
import Builtin

sil_differentiability_witness [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float

sil @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float

sil [serialized] @test_serialized : $@convention(thin) () -> () {
bb0:
%referenced_from_serialized_jvp_wrt_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float
return undef : $()
}