Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4466,20 +4466,37 @@ void DifferentiationTask::createVJP() {
SILDebugScope(original->getLocation(), vjp));
attr->setVJPName(vjp->getName());

// Work around a bad interaction between VJPs, TFDeabstraction, and SIL
// optimizations.
//
// The bad interaction is: TFDeabstraction cannot inline the functions that
// the VJP partially applies. When run in "-Onone" mode, this is fine (but
// inefficient), because we just generate sends/recvs to handle it. But in
// "-O" mode, later SIL optimization passes:
// - fold the partial applications and inline them, or
// - specialize the partial_apply callees.
// In either case, the inlined/specialized bodies haven't been deabstracted,
// so partitioning crashes when it sees them.
//
// TODO: Fix this problem in a way that allows inlining/specialization of
// VJPs. Teaching TFDeabstraction to fold partial applications might work.
vjp->setInlineStrategy(NoInline);
vjp->addSemanticsAttr("optimize.sil.specialize.generic.never");

LLVM_DEBUG(llvm::dbgs() << " vjp type: "
<< vjp->getLoweredFunctionType() << "\n");

// We'll use these conventions frequently.
auto originalConv = original->getConventions();
auto primalConv = primal->getConventions();
auto vjpConv = vjp->getConventions();

// Keep track of some stack allocation to clean up.
SmallVector<SILValue, 2> stackAllocsToCleanUp;

// Validate signatures.
#ifndef NDEBUG
auto adjointConv = adjoint->getConventions();
auto vjpConv = vjp->getConventions();

unsigned numOriginalParameters = originalConv.getNumParameters();
unsigned numOriginalResults = originalConv.getResults().size();
Expand Down Expand Up @@ -4538,7 +4555,16 @@ void DifferentiationTask::createVJP() {

// Create the entry block with indirect results and parameters.
auto *entry = vjp->createBasicBlock();
createEntryArguments(vjp);
unsigned argumentIndex = 0;
for (auto indResultTy : vjpConv.getIndirectSILResultTypes())
entry->createFunctionArgument(
vjp->mapTypeIntoContext(indResultTy).getAddressType(),
original->getEntryBlock()->getArgument(argumentIndex++)->getDecl());
for (auto paramTy : vjpConv.getParameterSILTypes())
entry->createFunctionArgument(
vjp->mapTypeIntoContext(paramTy),
original->getEntryBlock()->getArgument(argumentIndex++)->getDecl());

SILBuilder builder(entry);
auto loc = vjp->getLocation();

Expand Down
2 changes: 2 additions & 0 deletions lib/SILOptimizer/Mandatory/TFLowerGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,8 @@ bool TFGraphLowering::serializeGraphProtoBuf(ASTContext &ctx,
/// repl)
std::string getTFCompatibleFuncName(SILFunction *fn) {
auto fnName = fn->getName();
if (fnName.startswith("AD__"))
fnName = fnName.substr(4);
if (fnName.startswith("$"))
fnName = fnName.substr(1);

Expand Down
5 changes: 4 additions & 1 deletion stdlib/public/TensorFlow/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ public extension Tensor where Scalar : Numeric {
public func matmul<Scalar : Numeric>(
_ left: Tensor<Scalar>, _ right: Tensor<Scalar>
) -> Tensor<Scalar> {
return Raw.matMul(left, right)
// Default arguments specified explicitly to avoid "external declarations of
// SILFunctions with shared visibility is not allowed" SILVerifier error in
// "tests/AutoDiff/tensor_autodiff_runtime.swift".
return Raw.matMul(left, right, transposeA: false, transposeB: false)
}

infix operator • : MultiplicationPrecedence
Expand Down
1 change: 1 addition & 0 deletions test/TensorFlow/tensor_autodiff.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: %target-swift-frontend -O -emit-sil %s | %FileCheck %s
// RUN: %target-swift-frontend -Xllvm -differentiation-use-vjp -O -emit-sil %s | %FileCheck %s

import TensorFlow

Expand Down
1 change: 1 addition & 0 deletions test/TensorFlowRuntime/tensor_autodiff_runtime.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: %target-run-simple-swift %swift-tensorflow-test-run-extra-options
// RUN: %target-run-use-vjp-swift %swift-tensorflow-test-run-extra-options
//
// TODO(SR-9110): Make this pass in dynamic compilation mode.
// %target-run-dynamic-compilation-swift
Expand Down