diff --git a/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp b/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp index 6253491e2023e..b1e84686c3641 100644 --- a/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp +++ b/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp @@ -4466,12 +4466,30 @@ void DifferentiationTask::createVJP() { SILDebugScope(original->getLocation(), vjp)); attr->setVJPName(vjp->getName()); + // Work around a bad interaction between VJPs, TFDeabstraction, and SIL + // optimizations. + // + // The bad interaction is: TFDeabstraction cannot inline the functions that + // the VJP partially applies. When run in "-Onone" mode, this is fine (but + // inefficient), because we just generate sends/recvs to handle it. But in + // "-O" mode, later SIL optimization passes: + // - fold the partial applications and inline them, or + // - specialize the partial_apply callees. + // In either case, the inlined/specialized bodies haven't been deabstracted, + // so partitioning crashes when it sees them. + // + // TODO: Fix this problem in a way that allows inlining/specialization of + // VJPs. Teaching TFDeabstraction to fold partial applications might work. + vjp->setInlineStrategy(NoInline); + vjp->addSemanticsAttr("optimize.sil.specialize.generic.never"); + LLVM_DEBUG(llvm::dbgs() << " vjp type: " << vjp->getLoweredFunctionType() << "\n"); // We'll use these conventions frequently. auto originalConv = original->getConventions(); auto primalConv = primal->getConventions(); + auto vjpConv = vjp->getConventions(); // Keep track of some stack allocation to clean up. SmallVector stackAllocsToCleanUp; @@ -4479,7 +4497,6 @@ void DifferentiationTask::createVJP() { // Validate signatures. #ifndef NDEBUG auto adjointConv = adjoint->getConventions(); - auto vjpConv = vjp->getConventions(); unsigned numOriginalParameters = originalConv.getNumParameters(); unsigned numOriginalResults = originalConv.getResults().size(); @@ -4538,7 +4555,16 @@ void DifferentiationTask::createVJP() { // Create the entry block with indirect results and parameters. auto *entry = vjp->createBasicBlock(); - createEntryArguments(vjp); + unsigned argumentIndex = 0; + for (auto indResultTy : vjpConv.getIndirectSILResultTypes()) + entry->createFunctionArgument( + vjp->mapTypeIntoContext(indResultTy).getAddressType(), + original->getEntryBlock()->getArgument(argumentIndex++)->getDecl()); + for (auto paramTy : vjpConv.getParameterSILTypes()) + entry->createFunctionArgument( + vjp->mapTypeIntoContext(paramTy), + original->getEntryBlock()->getArgument(argumentIndex++)->getDecl()); + SILBuilder builder(entry); auto loc = vjp->getLocation(); diff --git a/lib/SILOptimizer/Mandatory/TFLowerGraph.cpp b/lib/SILOptimizer/Mandatory/TFLowerGraph.cpp index e64abeb111b3e..6c3b6be61a21c 100644 --- a/lib/SILOptimizer/Mandatory/TFLowerGraph.cpp +++ b/lib/SILOptimizer/Mandatory/TFLowerGraph.cpp @@ -2694,6 +2694,8 @@ bool TFGraphLowering::serializeGraphProtoBuf(ASTContext &ctx, /// repl) std::string getTFCompatibleFuncName(SILFunction *fn) { auto fnName = fn->getName(); + if (fnName.startswith("AD__")) + fnName = fnName.substr(4); if (fnName.startswith("$")) fnName = fnName.substr(1); diff --git a/stdlib/public/TensorFlow/Ops.swift b/stdlib/public/TensorFlow/Ops.swift index 097c487472bab..75ffc84617428 100644 --- a/stdlib/public/TensorFlow/Ops.swift +++ b/stdlib/public/TensorFlow/Ops.swift @@ -275,7 +275,10 @@ public extension Tensor where Scalar : Numeric { public func matmul( _ left: Tensor, _ right: Tensor ) -> Tensor { - return Raw.matMul(left, right) + // Default arguments specified explicitly to avoid "external declarations of + // SILFunctions with shared visibility is not allowed" SILVerifier error in + // "tests/AutoDiff/tensor_autodiff_runtime.swift". + return Raw.matMul(left, right, transposeA: false, transposeB: false) } infix operator • : MultiplicationPrecedence diff --git a/test/TensorFlow/tensor_autodiff.swift b/test/TensorFlow/tensor_autodiff.swift index 9944117f0d42d..419f6b1b67a57 100644 --- a/test/TensorFlow/tensor_autodiff.swift +++ b/test/TensorFlow/tensor_autodiff.swift @@ -1,4 +1,5 @@ // RUN: %target-swift-frontend -O -emit-sil %s | %FileCheck %s +// RUN: %target-swift-frontend -Xllvm -differentiation-use-vjp -O -emit-sil %s | %FileCheck %s import TensorFlow diff --git a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift index 7d274987e04f2..3b53e3fee9a7c 100644 --- a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift +++ b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift @@ -1,4 +1,5 @@ // RUN: %target-run-simple-swift %swift-tensorflow-test-run-extra-options +// RUN: %target-run-use-vjp-swift %swift-tensorflow-test-run-extra-options // // TODO(SR-9110): Make this pass in dynamic compilation mode. // %target-run-dynamic-compilation-swift