Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 9 additions & 39 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "swift/SILOptimizer/Analysis/LoopAnalysis.h"
#include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/CFGOptUtils.h"
#include "swift/SILOptimizer/Utils/LoopUtils.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "llvm/ADT/APSInt.h"
Expand Down Expand Up @@ -3380,19 +3381,6 @@ class VJPEmitter final
/// predecessor enum argument).
SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks;

/// A pair of a trampoline block phi argument and its corresponding
/// destination block phi argument.
struct TrampolinedArgumentPair {
SILPhiArgument *trampolineArgument;
SILPhiArgument *destinationArgument;
};
/// An array that keeps track of all `@guaranteed` phi arguments in any
/// trampoline blocks we've added. Each of these arguments needs to have a
/// lifetime-ending use past its destination argument's lifetime-ending use,
/// so we keep track of these pairs of arguments and emit `end_borrow`s when
/// function cloning is finished.
SmallVector<TrampolinedArgumentPair, 8> trampolinedGuaranteedPhiArguments;

bool errorOccurred = false;

/// Mapping from original blocks to pullback values. Used to build pullback
Expand Down Expand Up @@ -3771,18 +3759,9 @@ class VJPEmitter final
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
// Create the trampoline block.
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
for (auto *destArg : vjpSuccBB->getArguments().drop_back()) {
auto *trampolineArg = trampolineBB->createPhiArgument(
for (auto *destArg : vjpSuccBB->getArguments().drop_back())
trampolineBB->createPhiArgument(
destArg->getType(), destArg->getOwnershipKind());
// Each `@guaranteed` trampoline argument needs to have a
// lifetime-ending use past its destination argument's lifetime-ending
// uses, so we keep track of these pairs of arguments in
// `trampolinedGuaranteedPhiArguments` and emit `end_borrow`s when
// function cloning is finished.
if (trampolineArg->getOwnershipKind() == ValueOwnershipKind::Guaranteed)
trampolinedGuaranteedPhiArguments.push_back(
{trampolineArg, cast<SILPhiArgument>(destArg)});
}
// Build predecessor enum value for successor block and branch to it.
SILBuilder trampolineBuilder(trampolineBB);
auto *succEnumVal = buildPredecessorEnumValue(
Expand Down Expand Up @@ -7956,21 +7935,12 @@ bool VJPEmitter::run() {
if (errorOccurred)
return true;

// Each `@guaranteed` trampoline argument needs to have a lifetime-ending use
// past its destination argument's lifetime-ending uses (aka. `end_borrow`).
// `trampolinedGuaranteedPhiArguments` tracks all `@guaranteed` trampoline
// arguments. We emit an `end_borrow` immediately past each destination
// argument's lifetime-ending uses.
for (auto &trampolinedArgPair : trampolinedGuaranteedPhiArguments) {
for (auto *destArgUse : trampolinedArgPair.destinationArgument->getUses()) {
if (auto *lifetimeEnd = dyn_cast<EndBorrowInst>(destArgUse->getUser())) {
getBuilder().setInsertionPoint(lifetimeEnd->getParentBlock(),
std::next(lifetimeEnd->getIterator()));
getBuilder().emitEndBorrowOperation(
lifetimeEnd->getLoc(), trampolinedArgPair.trampolineArgument);
}
}
}
// Merge VJP basic blocks. This is significant for control flow
// differentiation: trampoline destination bbs are merged into trampoline bbs.
// NOTE(TF-990): Merging basic blocks ensures that `@guaranteed` trampoline
// bb arguments have a lifetime-ending `end_borrow` use, and is robust when
// `-enable-strip-ownership-after-serialization` is true.
mergeBasicBlocks(vjp);

// Generate pullback code.
PullbackEmitter PullbackEmitter(*this);
Expand Down
80 changes: 69 additions & 11 deletions test/AutoDiff/control_flow_sil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,21 @@ func cond(_ x: Float) -> Float {
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
// CHECK-SIL: cond_br {{%.*}}, bb1, bb3
// CHECK-SIL: cond_br {{%.*}}, bb1, bb2

// CHECK-SIL: bb1:
// CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
// CHECK-SIL: br bb2([[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0)

// CHECK-SIL: bb2([[BB1_PRED_ARG:%.*]] : $_AD__cond_bb1__Pred__src_0_wrt_0)
// CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0
// CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]]
// CHECK-SIL: br bb5({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// CHECK-SIL: bb3:
// CHECK-SIL: bb2:
// CHECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
// CHECK-SIL: br bb4([[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_0)

// CHECK-SIL: bb4([[BB2_PRED_ARG:%.*]] : $_AD__cond_bb2__Pred__src_0_wrt_0)
// CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]]
// CHECK-SIL: br bb5({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// CHECK-SIL: bb5([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @AD__cond__pullback_src_0_wrt_0
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]])
Expand Down Expand Up @@ -145,6 +139,70 @@ func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
return result
}

// Test `switch_enum`.

enum Enum {
case a(Float)
case b(Float, Float)
}
@differentiable
@_silgen_name("enum_notactive")
func enum_notactive(_ e: Enum, _ x: Float) -> Float {
switch e {
case let .a(a): return x * a
case let .b(b1, b2): return x * b1 * b2
}
}

// ECK-SIL-LABEL: sil hidden [ossa] @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// ECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
// ECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 ()
// ECK-SIL: cond_br {{%.*}}, bb1, bb2

// ECK-SIL: bb1:
// ECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
// ECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0
// ECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]]
// ECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// ECK-SIL: bb2:
// ECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]]
// ECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0
// ECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]]
// ECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// ECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
// ECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0
// ECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @AD__cond__pullback_src_0_wrt_0
// ECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]])
// ECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
// ECK-SIL: return [[VJP_RESULT]]

// CHECK-SIL-LABEL: sil hidden [ossa] @AD__enum_notactive__vjp_src_0_wrt_1 : $@convention(thin) (Enum, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[ENUM_ARG:%.*]] : $Enum, [[X_ARG:%.*]] : $Float):
// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb0__PB__src_0_wrt_1 ()
// CHECK-SIL: switch_enum [[ENUM_ARG]] : $Enum, case #Enum.a!enumelt.1: bb1, case #Enum.b!enumelt.1: bb2

// CHECK-SIL: bb1([[ENUM_A:%.*]] : $Float):
// CHECK-SIL: [[BB1_PRED_PRED0:%.*]] = enum $_AD__enum_notactive_bb1__Pred__src_0_wrt_1, #_AD__enum_notactive_bb1__Pred__src_0_wrt_1.bb0!enumelt.1, %4 : $_AD__enum_notactive_bb0__PB__src_0_wrt_1
// CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb1__PB__src_0_wrt_1 ({{.*}})
// CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb1!enumelt.1, [[BB1_PB_STRUCT]] : $_AD__enum_notactive_bb1__PB__src_0_wrt_1
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)

// CHECK-SIL: bb2([[ENUM_B:%.*]] : $(Float, Float)):
// CHECK-SIL: [[BB2_PRED_PRED0:%.*]] = enum $_AD__enum_notactive_bb2__Pred__src_0_wrt_1, #_AD__enum_notactive_bb2__Pred__src_0_wrt_1.bb0!enumelt.1, %4 : $_AD__enum_notactive_bb0__PB__src_0_wrt_1
// CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb2__PB__src_0_wrt_1 ({{.*}})
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt.1, [[BB2_PB_STRUCT]] : $_AD__enum_notactive_bb2__PB__src_0_wrt_1
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)

// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
// CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb3__PB__src_0_wrt_1
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @AD__enum_notactive__pullback_src_0_wrt_1
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL: return [[VJP_RESULT]]
// CHECK-SIL: }

// Test control flow + tuple buffer.
// Verify that pullback buffers are not allocated for address projections.

Expand Down