From dfd0ce981635a3286d0adaa7e4ea17e754c63c1c Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 14 Jun 2019 13:49:16 -0700 Subject: [PATCH 1/3] [AutoDiff] Support differentiation of `switch_enum`. Handle `switch_enum` terminator during VJP and adjoint generation. Necessary step for differentiating `for-in` loops, which contain optional iterator `next()` values. Diagnose differentiation of active enum values, which requires further adjoint generation support. --- include/swift/AST/DiagnosticsSIL.def | 2 + .../Mandatory/Differentiation.cpp | 90 ++++++++++++++++--- test/AutoDiff/autodiff_diagnostics.swift | 2 +- test/AutoDiff/control_flow.swift | 88 +++++++++++++++++- test/AutoDiff/control_flow_diagnostics.swift | 75 ++++++++++++++-- test/AutoDiff/control_flow_sil.swift | 28 +++--- 6 files changed, 252 insertions(+), 33 deletions(-) diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index 4afed5c78da76..3677f48bd1231 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -424,6 +424,8 @@ WARNING(autodiff_nonvaried_result_fixit,none, "result does not depend on differentiation arguments and will always " "have a zero derivative; do you want to add '.withoutDerivative()'?", ()) +NOTE(autodiff_enums_unsupported,none, + "differentiating enum values is not yet supported", ()) NOTE(autodiff_global_let_closure_not_differentiable,none, "global constant closure is not differentiable", ()) NOTE(autodiff_cannot_differentiate_global_var_closures,none, diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 5000136f40f92..9171cdd76a6ab 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1533,6 +1533,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, setVaried(cbi->getFalseBB()->getArgument(opIdx), i); } } + // Handle `switch_enum`. + else if (auto *sei = dyn_cast(&inst)) { + if (isVaried(sei->getOperand(), i)) { + for (auto *succBB : sei->getSuccessorBlocks()) + for (auto *arg : succBB->getArguments()) + setVaried(arg, i); + // Default block cannot have arguments. + } + } // Handle everything else. else { for (auto &op : inst.getAllOperands()) @@ -1767,8 +1776,9 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, // Diagnose unsupported branching terminators. for (auto &bb : *original) { auto *term = bb.getTerminator(); - // Supported terminators are: `br`, `cond_br`. - if (isa(term) || isa(term)) + // Supported terminators are: `br`, `cond_br`, `switch_enum`. + if (isa(term) || isa(term) || + isa(term)) continue; // If terminator is an unsupported branching terminator, emit an error. if (term->isBranch()) { @@ -3134,6 +3144,56 @@ class VJPEmitter final getOpBasicBlock(cbi->getFalseBB()), falseArgs); } + void visitSwitchEnumInst(SwitchEnumInst *sei) { + // Build pullback struct value for original block. + auto *origBB = sei->getParent(); + auto *pbStructVal = buildPullbackValueStructValue(sei); + + // Creates a trampoline block for given original successor block. The + // trampoline block has the same arguments as the VJP successor block but + // drops the last predecessor enum argument. The generated `switch_enum` + // instruction branches to the trampoline block, and the trampoline block + // constructs a predecessor enum value and branches to the VJP successor + // block. + auto createTrampolineBasicBlock = + [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { + auto *vjpSuccBB = getOpBasicBlock(origSuccBB); + // Create the trampoline block. + auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); + for (auto *arg : vjpSuccBB->getArguments().drop_back()) + trampolineBB->createPhiArgument(arg->getType(), + arg->getOwnershipKind()); + // Build predecessor enum value for successor block and branch to it. + SILBuilder trampolineBuilder(trampolineBB); + auto *succEnumVal = buildPredecessorEnumValue( + trampolineBuilder, origBB, origSuccBB, pbStructVal); + SmallVector forwardedArguments( + trampolineBB->getArguments().begin(), + trampolineBB->getArguments().end()); + forwardedArguments.push_back(succEnumVal); + trampolineBuilder.createBranch( + sei->getLoc(), vjpSuccBB, forwardedArguments); + return trampolineBB; + }; + + // Create trampoline successor basic blocks. + SmallVector, 4> caseBBs; + for (unsigned i : range(sei->getNumCases())) { + auto caseBB = sei->getCase(i); + auto *trampolineBB = createTrampolineBasicBlock(caseBB.second); + caseBBs.push_back({caseBB.first, trampolineBB}); + } + // Create trampoline default basic block. + SILBasicBlock *newDefaultBB = nullptr; + if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull()) + newDefaultBB = createTrampolineBasicBlock(defaultBB); + + // Create a new `switch_enum` instruction. + getBuilder().createSwitchEnum( + sei->getLoc(), getOpValue(sei->getOperand()), + newDefaultBB, caseBBs); + } + // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai) { @@ -4155,6 +4215,13 @@ class AdjointEmitter final : public SILInstructionVisitor { auto addActiveValue = [&](SILValue v) { if (visited.count(v)) return; + // Diagnose active enum values. Differentiation of enum values is not + // yet supported; requires special adjoint handling. + if (v->getType().getEnumOrBoundGenericEnum()) { + getContext().emitNondifferentiabilityError( + v, getInvoker(), diag::autodiff_enums_unsupported); + errorOccurred = true; + } // Skip address projections. // Address projections do not need their own adjoint buffers; they // become projections into their adjoint base buffer. @@ -4175,8 +4242,12 @@ class AdjointEmitter final : public SILInstructionVisitor { if (getActivityInfo().isActive(result, getIndices())) addActiveValue(result); } + if (errorOccurred) + break; domOrder.pushChildren(bb); } + if (errorOccurred) + return true; // Create adjoint blocks and arguments, visiting original blocks in // post-order. @@ -4196,7 +4267,10 @@ class AdjointEmitter final : public SILInstructionVisitor { adjointPullbackStructArguments[origBB] = lastArg; continue; } - + // Add a pullback struct argument. + auto *pbStructArg = adjointBB->createPhiArgument( + pbStructLoweredType, ValueOwnershipKind::Guaranteed); + adjointPullbackStructArguments[origBB] = pbStructArg; // Get all active values in the original block. // If the original block has no active values, continue. auto &bbActiveValues = activeValues[origBB]; @@ -4222,10 +4296,6 @@ class AdjointEmitter final : public SILInstructionVisitor { activeValueAdjointBBArgumentMap[{origBB, activeValue}] = adjointArg; } } - // Add a pullback struct argument. - auto *pbStructArg = adjointBB->createPhiArgument( - pbStructLoweredType, ValueOwnershipKind::Guaranteed); - adjointPullbackStructArguments[origBB] = pbStructArg; // - Create adjoint trampoline blocks for each successor block of the // original block. Adjoint trampoline blocks only have a pullback // struct argument, and branch from the adjoint successor block to the @@ -4373,6 +4443,8 @@ class AdjointEmitter final : public SILInstructionVisitor { assert(adjointSuccBB && adjointSuccBB->getNumArguments() == 1); SILBuilder adjointTrampolineBBBuilder(adjointSuccBB); SmallVector trampolineArguments; + // Propagate pullback struct argument. + trampolineArguments.push_back(adjointSuccBB->getArguments().front()); // Propagate adjoint values/buffers of active values/buffers to // predecessor blocks. auto &predBBActiveValues = activeValues[predBB]; @@ -4411,8 +4483,6 @@ class AdjointEmitter final : public SILInstructionVisitor { adjLoc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization); } } - // Propagate pullback struct argument. - trampolineArguments.push_back(adjointSuccBB->getArguments().front()); // Branch from adjoint trampoline block to adjoint block. adjointTrampolineBBBuilder.createBranch( adjLoc, adjointBB, trampolineArguments); @@ -4421,7 +4491,7 @@ class AdjointEmitter final : public SILInstructionVisitor { getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb); adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB}); } - // Emit clenaups for all block-local adjoint values. + // Emit cleanups for all block-local adjoint values. for (auto adjVal : blockLocalAdjointValues) emitCleanupForAdjointValue(adjVal); blockLocalAdjointValues.clear(); diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index 34fb22cb1f6f5..2af3b6b3f14a9 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -75,7 +75,7 @@ _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { func uses_optionals(_ x: Float) -> Float { var maybe: Float? = 10 maybe = x - // expected-note @+1 {{differentiating control flow is not yet supported}} + // expected-note @+1 {{differentiating enum values is not yet supported}} return maybe! } diff --git a/test/AutoDiff/control_flow.swift b/test/AutoDiff/control_flow.swift index 70f109fceb548..728f39e5898e4 100644 --- a/test/AutoDiff/control_flow.swift +++ b/test/AutoDiff/control_flow.swift @@ -258,7 +258,7 @@ ControlFlowTests.test("Conditionals") { } expectEqual((0, 10), gradient(at: 4, 5, in: guard3)) expectCrash { - gradient(at: -3, -2, in: guard3) + _ = gradient(at: -3, -2, in: guard3) } func cond_empty(_ x: Float) -> Float { @@ -424,4 +424,90 @@ ControlFlowTests.test("Recursion") { expectEqual(1, gradient(at: 100, in: { x in product(x, count: 1) })) } +ControlFlowTests.test("Enums") { + enum Enum { + case a(Float) + case b(Float, Float) + + func enum_notactive1(_ x: Float) -> Float { + switch self { + case let .a(a): return x * a + case let .b(b1, b2): return x * b1 * b2 + } + } + } + + func enum_notactive1(_ e: Enum, _ x: Float) -> Float { + switch e { + case let .a(a): return x * a + case let .b(b1, b2): return x * b1 * b2 + } + } + expectEqual(10, gradient(at: 2, in: { x in enum_notactive1(.a(10), x) })) + expectEqual(10, gradient(at: 2, in: { x in Enum.a(10).enum_notactive1(x) })) + expectEqual(20, gradient(at: 2, in: { x in enum_notactive1(.b(4, 5), x) })) + expectEqual(20, gradient(at: 2, in: { x in Enum.b(4, 5).enum_notactive1(x) })) + + func enum_notactive2(_ e: Enum, _ x: Float) -> Float { + if x > 0 { + switch e { + case .a: return x * x * x + case .b: return -x + } + } else if case .b = e { + return -x + } + return x * x + } + expectEqual(12, gradient(at: 2, in: { x in enum_notactive2(.a(10), x) })) + expectEqual(-1, gradient(at: 2, in: { x in enum_notactive2(.b(4, 5), x) })) + + func optional_notactive1(_ optional: Float?, _ x: Float) -> Float { + if let y = optional { + return x * y + } + return x + x + } + expectEqual(2, gradient(at: 2, in: { x in optional_notactive1(nil, x) })) + expectEqual(10, gradient(at: 2, in: { x in optional_notactive1(10, x) })) + + struct Dense : Differentiable { + var w1: Float + @noDerivative var w2: Float? + + @differentiable + func callAsFunction(_ input: Float) -> Float { + if let w2 = w2 { + return input * w1 * w2 + } + return input * w1 + } + } + expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20), + Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) })) + expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4), + Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) })) + + indirect enum Indirect { + case e(Float, Enum) + case indirect(Indirect) + } + + func enum_indirect_notactive1(_ indirect: Indirect, _ x: Float) -> Float { + switch indirect { + case let .e(f, e): + switch e { + case .a: return x * f * enum_notactive1(e, x) + case .b: return x * f * enum_notactive1(e, x) + } + case let .indirect(ind): return enum_indirect_notactive1(ind, x) + } + } + do { + let ind: Indirect = .e(10, .a(3)) + expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(ind, x) })) + expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(.indirect(ind), x) })) + } +} + runAllTests() diff --git a/test/AutoDiff/control_flow_diagnostics.swift b/test/AutoDiff/control_flow_diagnostics.swift index 625ad9340422e..0e0f5e10911b7 100644 --- a/test/AutoDiff/control_flow_diagnostics.swift +++ b/test/AutoDiff/control_flow_diagnostics.swift @@ -1,6 +1,6 @@ // RUN: %target-swift-frontend -emit-sil -verify %s -// Test supported `br` and `cond_br` terminators. +// Test supported `br`, `cond_br`, and `switch_enum` terminators. @differentiable func branch(_ x: Float) -> Float { @@ -12,21 +12,82 @@ func branch(_ x: Float) -> Float { return x } -// Test currently unsupported `switch_enum` terminator. - enum Enum { case a(Float) case b(Float) } +@differentiable +func enum_nonactive1(_ e: Enum, _ x: Float) -> Float { + switch e { + case .a: return x + case .b: return x + } +} + +@differentiable +func enum_nonactive2(_ e: Enum, _ x: Float) -> Float { + switch e { + case let .a(a): return x + a + case let .b(b): return x + b + } +} + +// Test unsupported differentiation of active enum values. + // expected-error @+1 {{function is not differentiable}} @differentiable // expected-note @+1 {{when differentiating this function definition}} -func switch_enum(_ e: Enum, _ x: Float) -> Float { - // expected-note @+1 {{differentiating control flow is not yet supported}} +func enum_active(_ x: Float) -> Float { + // expected-note @+1 {{differentiating enum values is not yet supported}} + let e: Enum + if x > 0 { + e = .a(x) + } else { + e = .b(x) + } switch e { - case let .a(a): return a - case let .b(b): return b + case let .a(a): return x + a + case let .b(b): return x + b + } +} + +enum Tree : Differentiable & AdditiveArithmetic { + case leaf(Float) + case branch(Float, Float) + + typealias TangentVector = Self + typealias AllDifferentiableVariables = Self + static var zero: Self { .leaf(0) } + + // expected-error @+1 {{function is not differentiable}} + @differentiable + // expected-note @+2 {{when differentiating this function definition}} + // expected-note @+1 {{differentiating enum values is not yet supported}} + static func +(_ lhs: Self, _ rhs: Self) -> Self { + switch (lhs, rhs) { + case let (.leaf(x), .leaf(y)): + return .leaf(x + y) + case let (.branch(x1, x2), .branch(y1, y2)): + return .branch(x1 + x2, y1 + y2) + default: + fatalError() + } + } + + // expected-error @+1 {{function is not differentiable}} + @differentiable + // expected-note @+2 {{when differentiating this function definition}} + // expected-note @+1 {{differentiating enum values is not yet supported}} + static func -(_ lhs: Self, _ rhs: Self) -> Self { + switch (lhs, rhs) { + case let (.leaf(x), .leaf(y)): + return .leaf(x - y) + case let (.branch(x1, x2), .branch(y1, y2)): + return .branch(x1 - x2, y1 - y2) + default: + fatalError() + } } } diff --git a/test/AutoDiff/control_flow_sil.swift b/test/AutoDiff/control_flow_sil.swift index 65b17c3ec1eae..44f415500281b 100644 --- a/test/AutoDiff/control_flow_sil.swift +++ b/test/AutoDiff/control_flow_sil.swift @@ -72,9 +72,9 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1 // CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0): -// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0) +// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float) -// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0): +// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): // CHECK-SIL: [[BB1_PB:%.*]] = struct_extract [[BB1_PB_STRUCT]] // CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float) // CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]] @@ -87,9 +87,9 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0): -// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0) +// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float) -// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0): +// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): // CHECK-SIL: [[BB2_PB:%.*]] = struct_extract [[BB2_PB_STRUCT]] // CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float) // CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]] @@ -101,12 +101,12 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0) +// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) // CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0) +// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) -// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0): +// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float): // CHECK-SIL: release_value {{%.*}} : $Float // CHECK-SIL: release_value {{%.*}} : $Float // CHECK-SIL: return {{%.*}} : $Float @@ -158,9 +158,9 @@ func cond_tuple_var(_ x: Float) -> Float { // CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1 // CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0): -// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0) +// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float) -// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0): +// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): // CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]] // CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float) // CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float @@ -168,9 +168,9 @@ func cond_tuple_var(_ x: Float) -> Float { // CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0): -// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0) +// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float) -// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0): +// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): // CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]] // CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float) // CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float @@ -178,10 +178,10 @@ func cond_tuple_var(_ x: Float) -> Float { // CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) +// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) // CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) +// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) -// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0): +// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float): // CHECK-SIL: return {{%.*}} : $Float From 7459d4ec1c0a2af3c03ca7cfac78c666e14adce9 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 16 Jun 2019 15:38:33 -0700 Subject: [PATCH 2/3] Revert adjoint pullback struct argument reordering. --- .../Mandatory/Differentiation.cpp | 12 ++++---- test/AutoDiff/control_flow_sil.swift | 28 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 9171cdd76a6ab..2488e7ef94455 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -4267,10 +4267,6 @@ class AdjointEmitter final : public SILInstructionVisitor { adjointPullbackStructArguments[origBB] = lastArg; continue; } - // Add a pullback struct argument. - auto *pbStructArg = adjointBB->createPhiArgument( - pbStructLoweredType, ValueOwnershipKind::Guaranteed); - adjointPullbackStructArguments[origBB] = pbStructArg; // Get all active values in the original block. // If the original block has no active values, continue. auto &bbActiveValues = activeValues[origBB]; @@ -4296,6 +4292,10 @@ class AdjointEmitter final : public SILInstructionVisitor { activeValueAdjointBBArgumentMap[{origBB, activeValue}] = adjointArg; } } + // Add a pullback struct argument. + auto *pbStructArg = adjointBB->createPhiArgument( + pbStructLoweredType, ValueOwnershipKind::Guaranteed); + adjointPullbackStructArguments[origBB] = pbStructArg; // - Create adjoint trampoline blocks for each successor block of the // original block. Adjoint trampoline blocks only have a pullback // struct argument, and branch from the adjoint successor block to the @@ -4443,8 +4443,6 @@ class AdjointEmitter final : public SILInstructionVisitor { assert(adjointSuccBB && adjointSuccBB->getNumArguments() == 1); SILBuilder adjointTrampolineBBBuilder(adjointSuccBB); SmallVector trampolineArguments; - // Propagate pullback struct argument. - trampolineArguments.push_back(adjointSuccBB->getArguments().front()); // Propagate adjoint values/buffers of active values/buffers to // predecessor blocks. auto &predBBActiveValues = activeValues[predBB]; @@ -4483,6 +4481,8 @@ class AdjointEmitter final : public SILInstructionVisitor { adjLoc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization); } } + // Propagate pullback struct argument. + trampolineArguments.push_back(adjointSuccBB->getArguments().front()); // Branch from adjoint trampoline block to adjoint block. adjointTrampolineBBBuilder.createBranch( adjLoc, adjointBB, trampolineArguments); diff --git a/test/AutoDiff/control_flow_sil.swift b/test/AutoDiff/control_flow_sil.swift index 44f415500281b..65b17c3ec1eae 100644 --- a/test/AutoDiff/control_flow_sil.swift +++ b/test/AutoDiff/control_flow_sil.swift @@ -72,9 +72,9 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1 // CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0): -// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float) +// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_bb1__PB__src_0_wrt_0) -// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): +// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_bb1__PB__src_0_wrt_0): // CHECK-SIL: [[BB1_PB:%.*]] = struct_extract [[BB1_PB_STRUCT]] // CHECK-SIL: [[BB1_ADJVALS:%.*]] = apply [[BB1_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float) // CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]] @@ -87,9 +87,9 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0): -// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}}: $Float) +// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}}: $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_bb2__PB__src_0_wrt_0) -// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): +// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_bb2__PB__src_0_wrt_0): // CHECK-SIL: [[BB2_PB:%.*]] = struct_extract [[BB2_PB_STRUCT]] // CHECK-SIL: [[BB2_ADJVALS:%.*]] = apply [[BB2_PB]]([[SEED]]) : $@callee_guaranteed (Float) -> (Float, Float) // CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]] @@ -101,12 +101,12 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) +// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) +// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_bb0__PB__src_0_wrt_0) -// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0, {{%.*}} : $Float): +// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_bb0__PB__src_0_wrt_0): // CHECK-SIL: release_value {{%.*}} : $Float // CHECK-SIL: release_value {{%.*}} : $Float // CHECK-SIL: return {{%.*}} : $Float @@ -158,9 +158,9 @@ func cond_tuple_var(_ x: Float) -> Float { // CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb2!enumelt.1: bb3, case #_AD__cond_tuple_var_bb3__Pred__src_0_wrt_0.bb1!enumelt.1: bb1 // CHECK-SIL: bb1([[BB3_PRED1_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0): -// CHECK-SIL: br bb2([[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float) +// CHECK-SIL: br bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED1_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0) -// CHECK-SIL: bb2([[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): +// CHECK-SIL: bb2({{%.*}} : $Float, {{%.*}} : $Float, [[BB1_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb1__PB__src_0_wrt_0): // CHECK-SIL: [[BB1_PRED:%.*]] = struct_extract [[BB1_PB_STRUCT]] // CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float) // CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float @@ -168,9 +168,9 @@ func cond_tuple_var(_ x: Float) -> Float { // CHECK-SIL: br bb5([[BB1_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb3([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0): -// CHECK-SIL: br bb4([[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float) +// CHECK-SIL: br bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB3_PRED2_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0) -// CHECK-SIL: bb4([[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0, {{%.*}} : $Float, {{%.*}} : $Float): +// CHECK-SIL: bb4({{%.*}} : $Float, {{%.*}} : $Float, [[BB2_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb2__PB__src_0_wrt_0): // CHECK-SIL: [[BB2_PRED:%.*]] = struct_extract [[BB2_PB_STRUCT]] // CHECK-SIL: copy_addr {{%.*}} to {{%.*}} : $*(Float, Float) // CHECK-SIL-NOT: copy_addr {{%.*}} to {{%.*}} : $*Float @@ -178,10 +178,10 @@ func cond_tuple_var(_ x: Float) -> Float { // CHECK-SIL: br bb6([[BB2_PB_STRUCT_DATA]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb5([[BB1_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7([[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) +// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB1_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) // CHECK-SIL: bb6([[BB2_PRED0_TRAMP_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0): -// CHECK-SIL: br bb7([[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float) +// CHECK-SIL: br bb7({{%.*}} : $Float, [[BB2_PRED0_TRAMP_PB_STRUCT]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0) -// CHECK-SIL: bb7([[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0, {{%.*}} : $Float): +// CHECK-SIL: bb7({{%.*}} : $Float, [[BB0_PB_STRUCT:%.*]] : $_AD__cond_tuple_var_bb0__PB__src_0_wrt_0): // CHECK-SIL: return {{%.*}} : $Float From 3930b0d02553a9f16354243e6c5275f52e25d0e5 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 16 Jun 2019 15:58:01 -0700 Subject: [PATCH 3/3] Add `switch_enum` tests. --- .../DifferentiationUnittest.swift | 10 ++++ test/AutoDiff/control_flow.swift | 21 +++++--- test/AutoDiff/leakchecking.swift | 53 +++++++++++++++++++ 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift index 596b92d7ed7c7..ee26d8e18fbb2 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift @@ -171,6 +171,16 @@ extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables, } } +extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, + T == T.AllDifferentiableVariables, T == T.TangentVector { + @usableFromInline + @differentiating(*) + internal static func _vjpMultiply(lhs: Self, rhs: Self) + -> (value: Self, pullback: (Self) -> (Self, Self)) { + return (lhs * rhs, { v in (v * rhs, v * lhs) }) + } +} + // Differential operators for `Tracked`. public extension Differentiable { @inlinable diff --git a/test/AutoDiff/control_flow.swift b/test/AutoDiff/control_flow.swift index 728f39e5898e4..d3754063f1fab 100644 --- a/test/AutoDiff/control_flow.swift +++ b/test/AutoDiff/control_flow.swift @@ -449,18 +449,27 @@ ControlFlowTests.test("Enums") { expectEqual(20, gradient(at: 2, in: { x in Enum.b(4, 5).enum_notactive1(x) })) func enum_notactive2(_ e: Enum, _ x: Float) -> Float { + var y = x if x > 0 { + var z = y + y switch e { - case .a: return x * x * x - case .b: return -x + case .a: z = z - y + case .b: y = y + x } + var w = y + if case .a = e { + w = w + z + } + return w } else if case .b = e { - return -x + return y + y } - return x * x + return x + y } - expectEqual(12, gradient(at: 2, in: { x in enum_notactive2(.a(10), x) })) - expectEqual(-1, gradient(at: 2, in: { x in enum_notactive2(.b(4, 5), x) })) + expectEqual((8, 2), valueWithGradient(at: 4, in: { x in enum_notactive2(.a(10), x) })) + expectEqual((20, 2), valueWithGradient(at: 10, in: { x in enum_notactive2(.b(4, 5), x) })) + expectEqual((-20, 2), valueWithGradient(at: -10, in: { x in enum_notactive2(.a(10), x) })) + expectEqual((-2674, 2), valueWithGradient(at: -1337, in: { x in enum_notactive2(.b(4, 5), x) })) func optional_notactive1(_ optional: Float?, _ x: Float) -> Float { if let y = optional { diff --git a/test/AutoDiff/leakchecking.swift b/test/AutoDiff/leakchecking.swift index 75f6c9a66f8bf..6cb791bd5d305 100644 --- a/test/AutoDiff/leakchecking.swift +++ b/test/AutoDiff/leakchecking.swift @@ -137,6 +137,59 @@ LeakCheckingTests.test("ControlFlow") { expectEqual((-2674, 2), Tracked(-1337).valueWithGradient(in: cond_nestedstruct_var)) } + // FIXME: Fix control flow AD memory leaks. + // See related FIXME comments in adjoint value/buffer propagation in + // lib/SILOptimizer/Mandatory/Differentiation.cpp. + testWithLeakChecking(expectedLeakCount: 12) { + struct Dense : Differentiable { + var w1: Tracked + @noDerivative var w2: Tracked? + + func callAsFunction(_ input: Tracked) -> Tracked { + if let w2 = w2 { + return input * w1 * w2 + } + return input * w1 + } + } + expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20), + Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) })) + expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4), + Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) })) + } + + // FIXME: Fix control flow AD memory leaks. + // See related FIXME comments in adjoint value/buffer propagation in + // lib/SILOptimizer/Mandatory/Differentiation.cpp. + testWithLeakChecking(expectedLeakCount: 48) { + enum Enum { + case a(Tracked) + case b(Tracked, Tracked) + } + func enum_notactive2(_ e: Enum, _ x: Tracked) -> Tracked { + var y = x + if x > 0 { + var z = y + y + switch e { + case .a: z = z - y + case .b: y = y + x + } + var w = y + if case .a = e { + w = w + z + } + return w + } else if case .b = e { + return y + y + } + return x + y + } + expectEqual((8, 2), Tracked(4).valueWithGradient(in: { x in enum_notactive2(.a(10), x) })) + expectEqual((20, 2), Tracked(10).valueWithGradient(in: { x in enum_notactive2(.b(4, 5), x) })) + expectEqual((-20, 2), Tracked(-10).valueWithGradient(in: { x in enum_notactive2(.a(10), x) })) + expectEqual((-2674, 2), Tracked(-1337).valueWithGradient(in: { x in enum_notactive2(.b(4, 5), x) })) + } + // FIXME: Fix control flow AD memory leaks. // See related FIXME comments in adjoint value/buffer propagation in // lib/SILOptimizer/Mandatory/Differentiation.cpp.