diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index 4afed5c78da76..3677f48bd1231 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -424,6 +424,8 @@ WARNING(autodiff_nonvaried_result_fixit,none, "result does not depend on differentiation arguments and will always " "have a zero derivative; do you want to add '.withoutDerivative()'?", ()) +NOTE(autodiff_enums_unsupported,none, + "differentiating enum values is not yet supported", ()) NOTE(autodiff_global_let_closure_not_differentiable,none, "global constant closure is not differentiable", ()) NOTE(autodiff_cannot_differentiate_global_var_closures,none, diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 5000136f40f92..2488e7ef94455 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -1533,6 +1533,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, setVaried(cbi->getFalseBB()->getArgument(opIdx), i); } } + // Handle `switch_enum`. + else if (auto *sei = dyn_cast(&inst)) { + if (isVaried(sei->getOperand(), i)) { + for (auto *succBB : sei->getSuccessorBlocks()) + for (auto *arg : succBB->getArguments()) + setVaried(arg, i); + // Default block cannot have arguments. + } + } // Handle everything else. else { for (auto &op : inst.getAllOperands()) @@ -1767,8 +1776,9 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, // Diagnose unsupported branching terminators. for (auto &bb : *original) { auto *term = bb.getTerminator(); - // Supported terminators are: `br`, `cond_br`. - if (isa(term) || isa(term)) + // Supported terminators are: `br`, `cond_br`, `switch_enum`. + if (isa(term) || isa(term) || + isa(term)) continue; // If terminator is an unsupported branching terminator, emit an error. if (term->isBranch()) { @@ -3134,6 +3144,56 @@ class VJPEmitter final getOpBasicBlock(cbi->getFalseBB()), falseArgs); } + void visitSwitchEnumInst(SwitchEnumInst *sei) { + // Build pullback struct value for original block. + auto *origBB = sei->getParent(); + auto *pbStructVal = buildPullbackValueStructValue(sei); + + // Creates a trampoline block for given original successor block. The + // trampoline block has the same arguments as the VJP successor block but + // drops the last predecessor enum argument. The generated `switch_enum` + // instruction branches to the trampoline block, and the trampoline block + // constructs a predecessor enum value and branches to the VJP successor + // block. + auto createTrampolineBasicBlock = + [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { + auto *vjpSuccBB = getOpBasicBlock(origSuccBB); + // Create the trampoline block. + auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); + for (auto *arg : vjpSuccBB->getArguments().drop_back()) + trampolineBB->createPhiArgument(arg->getType(), + arg->getOwnershipKind()); + // Build predecessor enum value for successor block and branch to it. + SILBuilder trampolineBuilder(trampolineBB); + auto *succEnumVal = buildPredecessorEnumValue( + trampolineBuilder, origBB, origSuccBB, pbStructVal); + SmallVector forwardedArguments( + trampolineBB->getArguments().begin(), + trampolineBB->getArguments().end()); + forwardedArguments.push_back(succEnumVal); + trampolineBuilder.createBranch( + sei->getLoc(), vjpSuccBB, forwardedArguments); + return trampolineBB; + }; + + // Create trampoline successor basic blocks. + SmallVector, 4> caseBBs; + for (unsigned i : range(sei->getNumCases())) { + auto caseBB = sei->getCase(i); + auto *trampolineBB = createTrampolineBasicBlock(caseBB.second); + caseBBs.push_back({caseBB.first, trampolineBB}); + } + // Create trampoline default basic block. + SILBasicBlock *newDefaultBB = nullptr; + if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull()) + newDefaultBB = createTrampolineBasicBlock(defaultBB); + + // Create a new `switch_enum` instruction. + getBuilder().createSwitchEnum( + sei->getLoc(), getOpValue(sei->getOperand()), + newDefaultBB, caseBBs); + } + // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai) { @@ -4155,6 +4215,13 @@ class AdjointEmitter final : public SILInstructionVisitor { auto addActiveValue = [&](SILValue v) { if (visited.count(v)) return; + // Diagnose active enum values. Differentiation of enum values is not + // yet supported; requires special adjoint handling. + if (v->getType().getEnumOrBoundGenericEnum()) { + getContext().emitNondifferentiabilityError( + v, getInvoker(), diag::autodiff_enums_unsupported); + errorOccurred = true; + } // Skip address projections. // Address projections do not need their own adjoint buffers; they // become projections into their adjoint base buffer. @@ -4175,8 +4242,12 @@ class AdjointEmitter final : public SILInstructionVisitor { if (getActivityInfo().isActive(result, getIndices())) addActiveValue(result); } + if (errorOccurred) + break; domOrder.pushChildren(bb); } + if (errorOccurred) + return true; // Create adjoint blocks and arguments, visiting original blocks in // post-order. @@ -4196,7 +4267,6 @@ class AdjointEmitter final : public SILInstructionVisitor { adjointPullbackStructArguments[origBB] = lastArg; continue; } - // Get all active values in the original block. // If the original block has no active values, continue. auto &bbActiveValues = activeValues[origBB]; @@ -4421,7 +4491,7 @@ class AdjointEmitter final : public SILInstructionVisitor { getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb); adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB}); } - // Emit clenaups for all block-local adjoint values. + // Emit cleanups for all block-local adjoint values. for (auto adjVal : blockLocalAdjointValues) emitCleanupForAdjointValue(adjVal); blockLocalAdjointValues.clear(); diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift index 596b92d7ed7c7..ee26d8e18fbb2 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift @@ -171,6 +171,16 @@ extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables, } } +extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, + T == T.AllDifferentiableVariables, T == T.TangentVector { + @usableFromInline + @differentiating(*) + internal static func _vjpMultiply(lhs: Self, rhs: Self) + -> (value: Self, pullback: (Self) -> (Self, Self)) { + return (lhs * rhs, { v in (v * rhs, v * lhs) }) + } +} + // Differential operators for `Tracked`. public extension Differentiable { @inlinable diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index 34fb22cb1f6f5..2af3b6b3f14a9 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -75,7 +75,7 @@ _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { func uses_optionals(_ x: Float) -> Float { var maybe: Float? = 10 maybe = x - // expected-note @+1 {{differentiating control flow is not yet supported}} + // expected-note @+1 {{differentiating enum values is not yet supported}} return maybe! } diff --git a/test/AutoDiff/control_flow.swift b/test/AutoDiff/control_flow.swift index 70f109fceb548..d3754063f1fab 100644 --- a/test/AutoDiff/control_flow.swift +++ b/test/AutoDiff/control_flow.swift @@ -258,7 +258,7 @@ ControlFlowTests.test("Conditionals") { } expectEqual((0, 10), gradient(at: 4, 5, in: guard3)) expectCrash { - gradient(at: -3, -2, in: guard3) + _ = gradient(at: -3, -2, in: guard3) } func cond_empty(_ x: Float) -> Float { @@ -424,4 +424,99 @@ ControlFlowTests.test("Recursion") { expectEqual(1, gradient(at: 100, in: { x in product(x, count: 1) })) } +ControlFlowTests.test("Enums") { + enum Enum { + case a(Float) + case b(Float, Float) + + func enum_notactive1(_ x: Float) -> Float { + switch self { + case let .a(a): return x * a + case let .b(b1, b2): return x * b1 * b2 + } + } + } + + func enum_notactive1(_ e: Enum, _ x: Float) -> Float { + switch e { + case let .a(a): return x * a + case let .b(b1, b2): return x * b1 * b2 + } + } + expectEqual(10, gradient(at: 2, in: { x in enum_notactive1(.a(10), x) })) + expectEqual(10, gradient(at: 2, in: { x in Enum.a(10).enum_notactive1(x) })) + expectEqual(20, gradient(at: 2, in: { x in enum_notactive1(.b(4, 5), x) })) + expectEqual(20, gradient(at: 2, in: { x in Enum.b(4, 5).enum_notactive1(x) })) + + func enum_notactive2(_ e: Enum, _ x: Float) -> Float { + var y = x + if x > 0 { + var z = y + y + switch e { + case .a: z = z - y + case .b: y = y + x + } + var w = y + if case .a = e { + w = w + z + } + return w + } else if case .b = e { + return y + y + } + return x + y + } + expectEqual((8, 2), valueWithGradient(at: 4, in: { x in enum_notactive2(.a(10), x) })) + expectEqual((20, 2), valueWithGradient(at: 10, in: { x in enum_notactive2(.b(4, 5), x) })) + expectEqual((-20, 2), valueWithGradient(at: -10, in: { x in enum_notactive2(.a(10), x) })) + expectEqual((-2674, 2), valueWithGradient(at: -1337, in: { x in enum_notactive2(.b(4, 5), x) })) + + func optional_notactive1(_ optional: Float?, _ x: Float) -> Float { + if let y = optional { + return x * y + } + return x + x + } + expectEqual(2, gradient(at: 2, in: { x in optional_notactive1(nil, x) })) + expectEqual(10, gradient(at: 2, in: { x in optional_notactive1(10, x) })) + + struct Dense : Differentiable { + var w1: Float + @noDerivative var w2: Float? + + @differentiable + func callAsFunction(_ input: Float) -> Float { + if let w2 = w2 { + return input * w1 * w2 + } + return input * w1 + } + } + expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20), + Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) })) + expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4), + Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) })) + + indirect enum Indirect { + case e(Float, Enum) + case indirect(Indirect) + } + + func enum_indirect_notactive1(_ indirect: Indirect, _ x: Float) -> Float { + switch indirect { + case let .e(f, e): + switch e { + case .a: return x * f * enum_notactive1(e, x) + case .b: return x * f * enum_notactive1(e, x) + } + case let .indirect(ind): return enum_indirect_notactive1(ind, x) + } + } + do { + let ind: Indirect = .e(10, .a(3)) + expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(ind, x) })) + expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(.indirect(ind), x) })) + } +} + runAllTests() diff --git a/test/AutoDiff/control_flow_diagnostics.swift b/test/AutoDiff/control_flow_diagnostics.swift index 625ad9340422e..0e0f5e10911b7 100644 --- a/test/AutoDiff/control_flow_diagnostics.swift +++ b/test/AutoDiff/control_flow_diagnostics.swift @@ -1,6 +1,6 @@ // RUN: %target-swift-frontend -emit-sil -verify %s -// Test supported `br` and `cond_br` terminators. +// Test supported `br`, `cond_br`, and `switch_enum` terminators. @differentiable func branch(_ x: Float) -> Float { @@ -12,21 +12,82 @@ func branch(_ x: Float) -> Float { return x } -// Test currently unsupported `switch_enum` terminator. - enum Enum { case a(Float) case b(Float) } +@differentiable +func enum_nonactive1(_ e: Enum, _ x: Float) -> Float { + switch e { + case .a: return x + case .b: return x + } +} + +@differentiable +func enum_nonactive2(_ e: Enum, _ x: Float) -> Float { + switch e { + case let .a(a): return x + a + case let .b(b): return x + b + } +} + +// Test unsupported differentiation of active enum values. + // expected-error @+1 {{function is not differentiable}} @differentiable // expected-note @+1 {{when differentiating this function definition}} -func switch_enum(_ e: Enum, _ x: Float) -> Float { - // expected-note @+1 {{differentiating control flow is not yet supported}} +func enum_active(_ x: Float) -> Float { + // expected-note @+1 {{differentiating enum values is not yet supported}} + let e: Enum + if x > 0 { + e = .a(x) + } else { + e = .b(x) + } switch e { - case let .a(a): return a - case let .b(b): return b + case let .a(a): return x + a + case let .b(b): return x + b + } +} + +enum Tree : Differentiable & AdditiveArithmetic { + case leaf(Float) + case branch(Float, Float) + + typealias TangentVector = Self + typealias AllDifferentiableVariables = Self + static var zero: Self { .leaf(0) } + + // expected-error @+1 {{function is not differentiable}} + @differentiable + // expected-note @+2 {{when differentiating this function definition}} + // expected-note @+1 {{differentiating enum values is not yet supported}} + static func +(_ lhs: Self, _ rhs: Self) -> Self { + switch (lhs, rhs) { + case let (.leaf(x), .leaf(y)): + return .leaf(x + y) + case let (.branch(x1, x2), .branch(y1, y2)): + return .branch(x1 + x2, y1 + y2) + default: + fatalError() + } + } + + // expected-error @+1 {{function is not differentiable}} + @differentiable + // expected-note @+2 {{when differentiating this function definition}} + // expected-note @+1 {{differentiating enum values is not yet supported}} + static func -(_ lhs: Self, _ rhs: Self) -> Self { + switch (lhs, rhs) { + case let (.leaf(x), .leaf(y)): + return .leaf(x - y) + case let (.branch(x1, x2), .branch(y1, y2)): + return .branch(x1 - x2, y1 - y2) + default: + fatalError() + } } } diff --git a/test/AutoDiff/leakchecking.swift b/test/AutoDiff/leakchecking.swift index 75f6c9a66f8bf..6cb791bd5d305 100644 --- a/test/AutoDiff/leakchecking.swift +++ b/test/AutoDiff/leakchecking.swift @@ -137,6 +137,59 @@ LeakCheckingTests.test("ControlFlow") { expectEqual((-2674, 2), Tracked(-1337).valueWithGradient(in: cond_nestedstruct_var)) } + // FIXME: Fix control flow AD memory leaks. + // See related FIXME comments in adjoint value/buffer propagation in + // lib/SILOptimizer/Mandatory/Differentiation.cpp. + testWithLeakChecking(expectedLeakCount: 12) { + struct Dense : Differentiable { + var w1: Tracked + @noDerivative var w2: Tracked? + + func callAsFunction(_ input: Tracked) -> Tracked { + if let w2 = w2 { + return input * w1 * w2 + } + return input * w1 + } + } + expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20), + Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) })) + expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4), + Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) })) + } + + // FIXME: Fix control flow AD memory leaks. + // See related FIXME comments in adjoint value/buffer propagation in + // lib/SILOptimizer/Mandatory/Differentiation.cpp. + testWithLeakChecking(expectedLeakCount: 48) { + enum Enum { + case a(Tracked) + case b(Tracked, Tracked) + } + func enum_notactive2(_ e: Enum, _ x: Tracked) -> Tracked { + var y = x + if x > 0 { + var z = y + y + switch e { + case .a: z = z - y + case .b: y = y + x + } + var w = y + if case .a = e { + w = w + z + } + return w + } else if case .b = e { + return y + y + } + return x + y + } + expectEqual((8, 2), Tracked(4).valueWithGradient(in: { x in enum_notactive2(.a(10), x) })) + expectEqual((20, 2), Tracked(10).valueWithGradient(in: { x in enum_notactive2(.b(4, 5), x) })) + expectEqual((-20, 2), Tracked(-10).valueWithGradient(in: { x in enum_notactive2(.a(10), x) })) + expectEqual((-2674, 2), Tracked(-1337).valueWithGradient(in: { x in enum_notactive2(.b(4, 5), x) })) + } + // FIXME: Fix control flow AD memory leaks. // See related FIXME comments in adjoint value/buffer propagation in // lib/SILOptimizer/Mandatory/Differentiation.cpp.