diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h index f37459f6d8b7b..f7f2b43095323 100644 --- a/include/swift/SIL/SILDifferentiabilityWitness.h +++ b/include/swift/SIL/SILDifferentiabilityWitness.h @@ -91,6 +91,9 @@ class SILDifferentiabilityWitness GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized, DeclAttribute *attribute = nullptr); + void convertToDefinition(SILFunction *jvp, SILFunction *vjp, + bool isSerialized); + SILDifferentiabilityWitnessKey getKey() const; SILModule &getModule() const { return Module; } SILLinkage getLinkage() const { return Linkage; } diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h index 2d5ceaffd26a0..4a39941680711 100644 --- a/include/swift/SIL/SILModule.h +++ b/include/swift/SIL/SILModule.h @@ -613,6 +613,10 @@ class SILModule { /// Look up the differentiability witness corresponding to the given key. SILDifferentiabilityWitness * lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key); + + /// Attempt to deserialize the SILDifferentiabilityWitness. Returns true if + /// deserialization succeeded, false otherwise. + bool loadDifferentiabilityWitness(SILDifferentiabilityWitness *W); // SWIFT_ENABLE_TENSORFLOW_END // Given a protocol, attempt to create a default witness table declaration diff --git a/include/swift/SILOptimizer/PassManager/Passes.def b/include/swift/SILOptimizer/PassManager/Passes.def index f108e4604a133..4796abde148e8 100644 --- a/include/swift/SILOptimizer/PassManager/Passes.def +++ b/include/swift/SILOptimizer/PassManager/Passes.def @@ -148,6 +148,9 @@ PASS(DiagnoseUnreachable, "diagnose-unreachable", "Diagnose Unreachable Code") PASS(DiagnosticConstantPropagation, "diagnostic-constant-propagation", "Constants Propagation for Diagnostics") +PASS(DifferentiabilityWitnessDevirtualizer, + "differentiability-witness-devirtualizer", + "Inlines Differentiability Witnesses") PASS(EagerSpecializer, "eager-specializer", "Eager Specialization via @_specialize") PASS(EarlyCodeMotion, "early-codemotion", diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp index 407902ad41509..8ec15f4129119 100644 --- a/lib/SIL/SILDifferentiabilityWitness.cpp +++ b/lib/SIL/SILDifferentiabilityWitness.cpp @@ -60,6 +60,16 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition( return diffWitness; } +void SILDifferentiabilityWitness::convertToDefinition(SILFunction *jvp, + SILFunction *vjp, + bool isSerialized) { + assert(IsDeclaration); + IsDeclaration = false; + JVP = jvp; + VJP = vjp; + IsSerialized = isSerialized; +} + SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const { return std::make_pair(getOriginalFunction()->getName(), getConfig()); } diff --git a/lib/SIL/SILModule.cpp b/lib/SIL/SILModule.cpp index af2603d0275a9..520c3e3a07af3 100644 --- a/lib/SIL/SILModule.cpp +++ b/lib/SIL/SILModule.cpp @@ -599,6 +599,15 @@ SILModule::lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key) { mangler.mangleSILDifferentiabilityWitnessKey(key)); } +bool SILModule::loadDifferentiabilityWitness(SILDifferentiabilityWitness *W) { + auto *NewW = getSILLoader()->lookupDifferentiabilityWitness(W->getKey()); + if (!NewW) + return false; + + assert(W == NewW); + return true; +} + void SILModule::registerDeserializationNotificationHandler( std::unique_ptr &&handler) { deserializationNotificationHandlers.add(std::move(handler)); diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index b7ac4f9beb195..77e8289d1121b 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -824,13 +824,10 @@ void SILGenModule::emitDifferentiabilityWitness( derivativeCanGenSig = derivativeGenSig->getCanonicalSignature(); // Create new SIL differentiability witness. // Witness JVP and VJP are set below. - // TODO(TF-919): Explore creating serialized differentiability witnesses. - // Currently, differentiability witnesses are never serialized to avoid - // deserialization issues where JVP/VJP functions cannot be found. auto *diffWitness = SILDifferentiabilityWitness::createDefinition( - M, originalFunction->getLinkage(), originalFunction, - loweredParamIndices, config.resultIndices, derivativeCanGenSig, - /*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false); + M, originalFunction->getLinkage(), originalFunction, loweredParamIndices, + config.resultIndices, derivativeCanGenSig, + /*jvp*/ nullptr, /*vjp*/ nullptr, originalFunction->isSerialized()); // Set derivative function in differentiability witness. auto setDerivativeInDifferentiabilityWitness = diff --git a/lib/SILOptimizer/PassManager/PassPipeline.cpp b/lib/SILOptimizer/PassManager/PassPipeline.cpp index d81c316068713..903f175158751 100644 --- a/lib/SILOptimizer/PassManager/PassPipeline.cpp +++ b/lib/SILOptimizer/PassManager/PassPipeline.cpp @@ -404,6 +404,11 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) { // we do not spend time optimizing them. P.addDeadFunctionElimination(); + // SWIFT_ENABLE_TENSORFLOW + // This unblocks many other passes' optimizations (e.g. inlining) and this is + // not blocked by any other passes' optimizations, so do it early. + P.addDifferentiabilityWitnessDevirtualizer(); + // Strip ownership from non-transparent functions. if (P.getOptions().StripOwnershipAfterSerialization) P.addNonTransparentFunctionOwnershipModelEliminator(); diff --git a/lib/SILOptimizer/Transforms/CMakeLists.txt b/lib/SILOptimizer/Transforms/CMakeLists.txt index fbba6b0fcac59..f48201ab43c7a 100644 --- a/lib/SILOptimizer/Transforms/CMakeLists.txt +++ b/lib/SILOptimizer/Transforms/CMakeLists.txt @@ -17,6 +17,9 @@ silopt_register_sources( DeadStoreElimination.cpp DestroyHoisting.cpp Devirtualizer.cpp + # SWIFT_ENABLE_TENSORFLOW + DifferentiabilityWitnessDevirtualizer.cpp + # SWIFT_ENABLE_TENSORFLOW_END GenericSpecializer.cpp MergeCondFail.cpp Outliner.cpp diff --git a/lib/SILOptimizer/Transforms/DifferentiabilityWitnessDevirtualizer.cpp b/lib/SILOptimizer/Transforms/DifferentiabilityWitnessDevirtualizer.cpp new file mode 100644 index 0000000000000..8b7066b3e331b --- /dev/null +++ b/lib/SILOptimizer/Transforms/DifferentiabilityWitnessDevirtualizer.cpp @@ -0,0 +1,72 @@ +//===--- DifferentiabilityWitnessDevirtualizer.cpp ------------------------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2019 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// Devirtualized differentiability witnesses whose bodies are availabe, by +// turning "differentiability_witness_function" instructions into "function_ref" +// instructions referencing the appropriate function. +// +//===----------------------------------------------------------------------===// + +#include "swift/SIL/SILBuilder.h" +#include "swift/SIL/SILFunction.h" +#include "swift/SIL/SILInstruction.h" +#include "swift/SILOptimizer/PassManager/Transforms.h" + +using namespace swift; + +namespace { +class DifferentiabilityWitnessDevirtualizer : public SILFunctionTransform { + + /// Returns true if and changes were made. + bool devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f); + + /// The entry point to the transformation. + void run() override { + if (devirtualizeDifferentiabilityWitnessesInFunction(*getFunction())) + invalidateAnalysis(SILAnalysis::InvalidationKind::CallsAndInstructions); + } +}; +} // end anonymous namespace + +bool DifferentiabilityWitnessDevirtualizer:: + devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f) { + bool changed = false; + llvm::SmallVector insts; + for (auto &bb : f) { + for (auto &inst : bb) { + auto *dfwi = dyn_cast(&inst); + if (!dfwi) + continue; + insts.push_back(dfwi); + } + } + for (auto *inst : insts) { + auto *wit = inst->getWitness(); + if (wit->isDeclaration()) + f.getModule().loadDifferentiabilityWitness(wit); + if (wit->isDeclaration()) + continue; + changed = true; + SILBuilderWithScope builder(inst); + auto kind = inst->getWitnessKind().getAsDerivativeFunctionKind(); + assert(kind.hasValue()); + auto *newInst = builder.createFunctionRefFor(inst->getLoc(), + wit->getDerivative(*kind)); + inst->replaceAllUsesWith(newInst); + inst->getParent()->erase(inst); + } + return changed; +} + +SILTransform *swift::createDifferentiabilityWitnessDevirtualizer() { + return new DifferentiabilityWitnessDevirtualizer(); +} diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 0e27e01331698..2bc964e9bbe45 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -3459,18 +3459,29 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) { ArrayRef(parameterAndResultIndices) .take_back(numResultIndices)); - if (isDeclaration) { - auto *diffWitness = SILDifferentiabilityWitness::createDeclaration( + AutoDiffConfig config(parameterIndices, resultIndices, derivativeGenSig); + auto *diffWitness = + SILMod.lookUpDifferentiabilityWitness({originalName, config}); + + // If there is no existing differentiability witness, create one. + if (!diffWitness) + diffWitness = SILDifferentiabilityWitness::createDeclaration( SILMod, *linkage, original, parameterIndices, resultIndices, derivativeGenSig); - diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true); - return diffWitness; - } - auto *diffWitness = SILDifferentiabilityWitness::createDefinition( - SILMod, *linkage, original, parameterIndices, resultIndices, - derivativeGenSig, jvp, vjp, isSerialized); - diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true); + // If the current differentiability witness is merely a declaration, and the + // deserialized witness is a definition, upgrade the current differentiability + // witness to a definition. This can happen in the following situations: + // 1. The witness was just created above. + // 2. The witness started out as a declaration (e.g. the differentiation + // pass emitted a witness for an external function) and now we're loading + // the definition (e.g. an optimization pass asked for the definition and + // we found the definition serialized in this module). + if (diffWitness->isDeclaration() && !isDeclaration) + diffWitness->convertToDefinition(jvp, vjp, isSerialized); + + diffWitnessOrOffset.set(diffWitness, + /*isFullyDeserialized*/ diffWitness->isDefinition()); return diffWitness; } diff --git a/lib/Serialization/SerializedSILLoader.cpp b/lib/Serialization/SerializedSILLoader.cpp index f9223c4fabf4c..204cb75fdd2fb 100644 --- a/lib/Serialization/SerializedSILLoader.cpp +++ b/lib/Serialization/SerializedSILLoader.cpp @@ -133,10 +133,15 @@ SerializedSILLoader::lookupDifferentiabilityWitness( SILDifferentiabilityWitnessKey key) { Mangle::ASTMangler mangler; std::string mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(key); - for (auto &Des : LoadedSILSections) - if (auto *diffWitness = Des->lookupDifferentiabilityWitness(mangledKey)) - return diffWitness; - return nullptr; + // It is possible that one module has a declaration of a + // SILDifferentiabilityWitness, while another has the full definition. + SILDifferentiabilityWitness *wit = nullptr; + for (auto &Des : LoadedSILSections) { + wit = Des->lookupDifferentiabilityWitness(mangledKey); + if (wit && wit->isDefinition()) + return wit; + } + return wit; } // SWIFT_ENABLE_TENSORFLOW END diff --git a/test/AutoDiff/differentiability_witness_inlining.sil b/test/AutoDiff/differentiability_witness_inlining.sil new file mode 100644 index 0000000000000..b13a2b4a4eaa4 --- /dev/null +++ b/test/AutoDiff/differentiability_witness_inlining.sil @@ -0,0 +1,42 @@ +// RUN: %target-sil-opt -differentiability-witness-devirtualizer %s -enable-sil-verify-all | %FileCheck %s + +sil_stage raw + +import Swift +import Builtin + +sil_differentiability_witness [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float { + jvp: @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + vjp: @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) +} + +sil_differentiability_witness [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float + +// This is an example of a witness that is available (via deserialization) +// even though it is not defined in the current module. +// witness for static Swift.Float.+ infix(Swift.Float, Swift.Float) -> Swift.Float +sil_differentiability_witness [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float + +sil @witness_defined_in_module : $@convention(thin) (Float) -> Float + +sil @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + +sil @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + +sil @witness_definition_not_available : $@convention(thin) (Float) -> Float + +sil public_external [transparent] [serialized] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float + +sil @test : $@convention(thin) (Float) -> () { +bb0(%0 : $Float): + %1 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float + // CHECK: %1 = function_ref @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) + + %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float + // CHECK: %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float + + %3 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float + // CHECK: %3 = function_ref @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) + + return undef : $() +} diff --git a/test/AutoDiff/sil_differentiability_witness_silgen.swift b/test/AutoDiff/sil_differentiability_witness_silgen.swift index 22d763f6bdee9..6eff1bcd8471e 100644 --- a/test/AutoDiff/sil_differentiability_witness_silgen.swift +++ b/test/AutoDiff/sil_differentiability_witness_silgen.swift @@ -75,7 +75,7 @@ public struct Foo: Differentiable { public var x: Float // CHECK-LABEL: // differentiability witness for Foo.x.getter -// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float { +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float { // CHECK-NEXT: } @differentiable