diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 23a2356271f10..7d4e0a368847b 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -47,6 +47,7 @@ #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" #include "swift/SILOptimizer/PassManager/Passes.h" #include "swift/SILOptimizer/PassManager/Transforms.h" +#include "swift/SILOptimizer/Utils/CFGOptUtils.h" #include "swift/SILOptimizer/Utils/LoopUtils.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "llvm/ADT/APSInt.h" @@ -3380,19 +3381,6 @@ class VJPEmitter final /// predecessor enum argument). SmallPtrSet remappedBasicBlocks; - /// A pair of a trampoline block phi argument and its corresponding - /// destination block phi argument. - struct TrampolinedArgumentPair { - SILPhiArgument *trampolineArgument; - SILPhiArgument *destinationArgument; - }; - /// An array that keeps track of all `@guaranteed` phi arguments in any - /// trampoline blocks we've added. Each of these arguments needs to have a - /// lifetime-ending use past its destination argument's lifetime-ending use, - /// so we keep track of these pairs of arguments and emit `end_borrow`s when - /// function cloning is finished. - SmallVector trampolinedGuaranteedPhiArguments; - bool errorOccurred = false; /// Mapping from original blocks to pullback values. Used to build pullback @@ -3771,18 +3759,9 @@ class VJPEmitter final auto *vjpSuccBB = getOpBasicBlock(origSuccBB); // Create the trampoline block. auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); - for (auto *destArg : vjpSuccBB->getArguments().drop_back()) { - auto *trampolineArg = trampolineBB->createPhiArgument( + for (auto *destArg : vjpSuccBB->getArguments().drop_back()) + trampolineBB->createPhiArgument( destArg->getType(), destArg->getOwnershipKind()); - // Each `@guaranteed` trampoline argument needs to have a - // lifetime-ending use past its destination argument's lifetime-ending - // uses, so we keep track of these pairs of arguments in - // `trampolinedGuaranteedPhiArguments` and emit `end_borrow`s when - // function cloning is finished. - if (trampolineArg->getOwnershipKind() == ValueOwnershipKind::Guaranteed) - trampolinedGuaranteedPhiArguments.push_back( - {trampolineArg, cast(destArg)}); - } // Build predecessor enum value for successor block and branch to it. SILBuilder trampolineBuilder(trampolineBB); auto *succEnumVal = buildPredecessorEnumValue( @@ -7956,21 +7935,12 @@ bool VJPEmitter::run() { if (errorOccurred) return true; - // Each `@guaranteed` trampoline argument needs to have a lifetime-ending use - // past its destination argument's lifetime-ending uses (aka. `end_borrow`). - // `trampolinedGuaranteedPhiArguments` tracks all `@guaranteed` trampoline - // arguments. We emit an `end_borrow` immediately past each destination - // argument's lifetime-ending uses. - for (auto &trampolinedArgPair : trampolinedGuaranteedPhiArguments) { - for (auto *destArgUse : trampolinedArgPair.destinationArgument->getUses()) { - if (auto *lifetimeEnd = dyn_cast(destArgUse->getUser())) { - getBuilder().setInsertionPoint(lifetimeEnd->getParentBlock(), - std::next(lifetimeEnd->getIterator())); - getBuilder().emitEndBorrowOperation( - lifetimeEnd->getLoc(), trampolinedArgPair.trampolineArgument); - } - } - } + // Merge VJP basic blocks. This is significant for control flow + // differentiation: trampoline destination bbs are merged into trampoline bbs. + // NOTE(TF-990): Merging basic blocks ensures that `@guaranteed` trampoline + // bb arguments have a lifetime-ending `end_borrow` use, and is robust when + // `-enable-strip-ownership-after-serialization` is true. + mergeBasicBlocks(vjp); // Generate pullback code. PullbackEmitter PullbackEmitter(*this); diff --git a/test/AutoDiff/control_flow_sil.swift b/test/AutoDiff/control_flow_sil.swift index 3f18db2ccb379..e3edc17211a39 100644 --- a/test/AutoDiff/control_flow_sil.swift +++ b/test/AutoDiff/control_flow_sil.swift @@ -47,27 +47,21 @@ func cond(_ x: Float) -> Float { // CHECK-SIL-LABEL: sil hidden [ossa] @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float): // CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 () -// CHECK-SIL: cond_br {{%.*}}, bb1, bb3 +// CHECK-SIL: cond_br {{%.*}}, bb1, bb2 // CHECK-SIL: bb1: // CHECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]] -// CHECK-SIL: br bb2([[BB1_PRED]] : $_AD__cond_bb1__Pred__src_0_wrt_0) - -// CHECK-SIL: bb2([[BB1_PRED_ARG:%.*]] : $_AD__cond_bb1__Pred__src_0_wrt_0) // CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0 // CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]] -// CHECK-SIL: br bb5({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0) +// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0) -// CHECK-SIL: bb3: +// CHECK-SIL: bb2: // CHECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]] -// CHECK-SIL: br bb4([[BB2_PRED]] : $_AD__cond_bb2__Pred__src_0_wrt_0) - -// CHECK-SIL: bb4([[BB2_PRED_ARG:%.*]] : $_AD__cond_bb2__Pred__src_0_wrt_0) // CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0 // CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]] -// CHECK-SIL: br bb5({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0) +// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0) -// CHECK-SIL: bb5([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0) +// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0) // CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0 // CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @AD__cond__pullback_src_0_wrt_0 // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]]) @@ -145,6 +139,70 @@ func loop_generic(_ x: T) -> T { return result } +// Test `switch_enum`. + +enum Enum { + case a(Float) + case b(Float, Float) +} +@differentiable +@_silgen_name("enum_notactive") +func enum_notactive(_ e: Enum, _ x: Float) -> Float { + switch e { + case let .a(a): return x * a + case let .b(b1, b2): return x * b1 * b2 + } +} + +// ECK-SIL-LABEL: sil hidden [ossa] @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// ECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float): +// ECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__cond_bb0__PB__src_0_wrt_0 () +// ECK-SIL: cond_br {{%.*}}, bb1, bb2 + +// ECK-SIL: bb1: +// ECK-SIL: [[BB1_PRED:%.*]] = enum $_AD__cond_bb1__Pred__src_0_wrt_0, #_AD__cond_bb1__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]] +// ECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__cond_bb1__PB__src_0_wrt_0 +// ECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1, [[BB1_PB_STRUCT]] +// ECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__cond_bb3__Pred__src_0_wrt_0) + +// ECK-SIL: bb2: +// ECK-SIL: [[BB2_PRED:%.*]] = enum $_AD__cond_bb2__Pred__src_0_wrt_0, #_AD__cond_bb2__Pred__src_0_wrt_0.bb0!enumelt.1, [[BB0_PB_STRUCT]] +// ECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__cond_bb2__PB__src_0_wrt_0 +// ECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1, [[BB2_PB_STRUCT]] +// ECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0) + +// ECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0) +// ECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__cond_bb3__PB__src_0_wrt_0 +// ECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @AD__cond__pullback_src_0_wrt_0 +// ECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]]) +// ECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) +// ECK-SIL: return [[VJP_RESULT]] + +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__enum_notactive__vjp_src_0_wrt_1 : $@convention(thin) (Enum, Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL: bb0([[ENUM_ARG:%.*]] : $Enum, [[X_ARG:%.*]] : $Float): +// CHECK-SIL: [[BB0_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb0__PB__src_0_wrt_1 () +// CHECK-SIL: switch_enum [[ENUM_ARG]] : $Enum, case #Enum.a!enumelt.1: bb1, case #Enum.b!enumelt.1: bb2 + +// CHECK-SIL: bb1([[ENUM_A:%.*]] : $Float): +// CHECK-SIL: [[BB1_PRED_PRED0:%.*]] = enum $_AD__enum_notactive_bb1__Pred__src_0_wrt_1, #_AD__enum_notactive_bb1__Pred__src_0_wrt_1.bb0!enumelt.1, %4 : $_AD__enum_notactive_bb0__PB__src_0_wrt_1 +// CHECK-SIL: [[BB1_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb1__PB__src_0_wrt_1 ({{.*}}) +// CHECK-SIL: [[BB3_PRED_PRED1:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb1!enumelt.1, [[BB1_PB_STRUCT]] : $_AD__enum_notactive_bb1__PB__src_0_wrt_1 +// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED1]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) + +// CHECK-SIL: bb2([[ENUM_B:%.*]] : $(Float, Float)): +// CHECK-SIL: [[BB2_PRED_PRED0:%.*]] = enum $_AD__enum_notactive_bb2__Pred__src_0_wrt_1, #_AD__enum_notactive_bb2__Pred__src_0_wrt_1.bb0!enumelt.1, %4 : $_AD__enum_notactive_bb0__PB__src_0_wrt_1 +// CHECK-SIL: [[BB2_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb2__PB__src_0_wrt_1 ({{.*}}) +// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt.1, [[BB2_PB_STRUCT]] : $_AD__enum_notactive_bb2__PB__src_0_wrt_1 +// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) + +// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) +// CHECK-SIL: [[BB3_PB_STRUCT:%.*]] = struct $_AD__enum_notactive_bb3__PB__src_0_wrt_1 +// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @AD__enum_notactive__pullback_src_0_wrt_1 +// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PB_STRUCT]]) +// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) +// CHECK-SIL: return [[VJP_RESULT]] +// CHECK-SIL: } + // Test control flow + tuple buffer. // Verify that pullback buffers are not allocated for address projections.