From 1e1f39cb80ff5a1e7f1d2aed020d99d31b91b36f Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 10 Aug 2020 09:13:59 -0700 Subject: [PATCH] [AutoDiff] Fix `Optional` differentiation crash. Fix `Optional` differentiation crash for non-resilient `Wrapped` reference type. Add `NonresilientTracked` type to `DifferentiationUnittest` for testing. Resolves SR-13377. --- .../Differentiation/PullbackCloner.cpp | 13 +- .../DifferentiationUnittest/CMakeLists.txt | 2 +- ...wift => DifferentiationUnittest.swift.gyb} | 206 +++++++----- test/AutoDiff/validation-test/optional.swift | 295 +++++++++++++++--- 4 files changed, 387 insertions(+), 129 deletions(-) rename stdlib/private/DifferentiationUnittest/{DifferentiationUnittest.swift => DifferentiationUnittest.swift.gyb} (50%) diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 79611c38eaaf4..aa40f16f53d91 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -2056,8 +2056,8 @@ void PullbackCloner::Implementation::accumulateAdjointForOptional( // Find `Optional.some` EnumElementDecl. auto someEltDecl = builder.getASTContext().getOptionalSomeDecl(); - // Initialize a `Optional` buffer from `wrappedAdjoint`as the - // input for `Optional.TangentVector.init`. + // Initialize an `Optional` buffer from `wrappedAdjoint` as + // the input for `Optional.TangentVector.init`. auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType); if (optionalOfWrappedTanType.isLoadableOrOpaque(builder.getFunction())) { // %enum = enum $Optional, #Optional.some!enumelt, @@ -2066,7 +2066,7 @@ void PullbackCloner::Implementation::accumulateAdjointForOptional( optionalOfWrappedTanType); // store %enum to %optArgBuf builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf, - StoreOwnershipQualifier::Trivial); + StoreOwnershipQualifier::Init); } else { // %enumAddr = init_enum_data_addr %optArgBuf $Optional, // #Optional.some!enumelt @@ -2279,14 +2279,15 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { for (auto pair : incomingValues) { auto *predBB = std::get<0>(pair); auto incomingValue = std::get<1>(pair); - blockTemporaries[getPullbackBlock(predBB)].insert(concreteBBArgAdjCopy); // Handle `switch_enum` on `Optional`. auto termInst = bbArg->getSingleTerminator(); - if (isSwitchEnumInstOnOptional(termInst)) + if (isSwitchEnumInstOnOptional(termInst)) { accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy); - else + } else { + blockTemporaries[getPullbackBlock(predBB)].insert(concreteBBArgAdjCopy); setAdjointValue(predBB, incomingValue, makeConcreteAdjointValue(concreteBBArgAdjCopy)); + } } break; } diff --git a/stdlib/private/DifferentiationUnittest/CMakeLists.txt b/stdlib/private/DifferentiationUnittest/CMakeLists.txt index bb6284b191310..33da12b9b766a 100644 --- a/stdlib/private/DifferentiationUnittest/CMakeLists.txt +++ b/stdlib/private/DifferentiationUnittest/CMakeLists.txt @@ -1,6 +1,6 @@ add_swift_target_library(swiftDifferentiationUnittest ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB # This file should be listed first. Module name is inferred from the filename. - DifferentiationUnittest.swift + GYB_SOURCES DifferentiationUnittest.swift.gyb SWIFT_MODULE_DEPENDS _Differentiation StdlibUnittest INSTALL_IN_COMPONENT stdlib-experimental diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb similarity index 50% rename from stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift rename to stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb index f24dd18a7925e..282f89f93324f 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift.gyb @@ -1,4 +1,4 @@ -//===--- DifferentiationUnittest.swift ------------------------------------===// +//===--- DifferentiationUnittest.swift.gyb --------------------------------===// // // This source file is part of the Swift.org open source project // @@ -43,19 +43,21 @@ public extension TestSuite { _ testFunction: @escaping () -> Void ) { test(name, file: file, line: line) { - withLeakChecking(expectedLeakCount: expectedLeakCount, file: file, - line: line, testFunction) + withLeakChecking( + expectedLeakCount: expectedLeakCount, file: file, + line: line, testFunction) } } } -/// A type that tracks the number of live instances of a wrapped value type. +/// A resilient type that tracks the number of live instances of a wrapped +/// value type. /// /// `Tracked` is used to check for memory leaks in functions created via /// automatic differentiation. public struct Tracked { - fileprivate class Box { - fileprivate var value : T + internal class Box { + fileprivate var value: T init(_ value: T) { self.value = value _GlobalLeakCount.count += 1 @@ -64,71 +66,109 @@ public struct Tracked { _GlobalLeakCount.count -= 1 } } - private var handle: Box + internal var handle: Box - @differentiable(where T : Differentiable, T == T.TangentVector) + @differentiable(where T: Differentiable, T == T.TangentVector) public init(_ value: T) { self.handle = Box(value) } - @differentiable(where T : Differentiable, T == T.TangentVector) + @differentiable(where T: Differentiable, T == T.TangentVector) public var value: T { get { handle.value } set { handle.value = newValue } } } -extension Tracked : ExpressibleByFloatLiteral where T : ExpressibleByFloatLiteral { +/// A non-resilient type that tracks the number of live instances of a wrapped +/// value type. +/// +/// `NonresilientTracked` is used to check for memory leaks in functions +/// created via automatic differentiation. +@frozen +public struct NonresilientTracked { + @usableFromInline + internal class Box { + fileprivate var value: T + init(_ value: T) { + self.value = value + _GlobalLeakCount.count += 1 + } + deinit { + _GlobalLeakCount.count -= 1 + } + } + @usableFromInline + internal var handle: Box + + @differentiable(where T: Differentiable, T == T.TangentVector) + public init(_ value: T) { + self.handle = Box(value) + } + + @differentiable(where T: Differentiable, T == T.TangentVector) + public var value: T { + get { handle.value } + set { handle.value = newValue } + } +} + +% for Self in ['Tracked', 'NonresilientTracked']: + +extension ${Self}: ExpressibleByFloatLiteral +where T: ExpressibleByFloatLiteral { public init(floatLiteral value: T.FloatLiteralType) { self.handle = Box(T(floatLiteral: value)) } } -extension Tracked : CustomStringConvertible { - public var description: String { return "Tracked(\(value))" } +extension ${Self}: CustomStringConvertible { + public var description: String { return "${Self}(\(value))" } } -extension Tracked : ExpressibleByIntegerLiteral where T : ExpressibleByIntegerLiteral { +extension ${Self}: ExpressibleByIntegerLiteral +where T: ExpressibleByIntegerLiteral { public init(integerLiteral value: T.IntegerLiteralType) { self.handle = Box(T(integerLiteral: value)) } } -extension Tracked : Comparable where T : Comparable { - public static func < (lhs: Tracked, rhs: Tracked) -> Bool { +extension ${Self}: Comparable where T: Comparable { + public static func < (lhs: ${Self}, rhs: ${Self}) -> Bool { return lhs.value < rhs.value } - public static func <= (lhs: Tracked, rhs: Tracked) -> Bool { + public static func <= (lhs: ${Self}, rhs: ${Self}) -> Bool { return lhs.value <= rhs.value } - public static func > (lhs: Tracked, rhs: Tracked) -> Bool { + public static func > (lhs: ${Self}, rhs: ${Self}) -> Bool { return lhs.value > rhs.value } - public static func >= (lhs: Tracked, rhs: Tracked) -> Bool { + public static func >= (lhs: ${Self}, rhs: ${Self}) -> Bool { return lhs.value >= rhs.value } } -extension Tracked : AdditiveArithmetic where T : AdditiveArithmetic { - public static var zero: Tracked { return Tracked(T.zero) } - public static func + (lhs: Tracked, rhs: Tracked) -> Tracked { - return Tracked(lhs.value + rhs.value) +extension ${Self}: AdditiveArithmetic where T: AdditiveArithmetic { + public static var zero: ${Self} { return ${Self}(T.zero) } + public static func + (lhs: ${Self}, rhs: ${Self}) -> ${Self} { + return ${Self}(lhs.value + rhs.value) } - public static func - (lhs: Tracked, rhs: Tracked) -> Tracked { - return Tracked(lhs.value - rhs.value) + public static func - (lhs: ${Self}, rhs: ${Self}) -> ${Self} { + return ${Self}(lhs.value - rhs.value) } } -extension Tracked : Equatable where T : Equatable { - public static func == (lhs: Tracked, rhs: Tracked) -> Bool { +extension ${Self}: Equatable where T: Equatable { + public static func == (lhs: ${Self}, rhs: ${Self}) -> Bool { return lhs.value == rhs.value } } -extension Tracked : SignedNumeric & Numeric where T : SignedNumeric, T == T.Magnitude { - public typealias Magnitude = Tracked +extension ${Self}: SignedNumeric & Numeric +where T: SignedNumeric, T == T.Magnitude { + public typealias Magnitude = ${Self} - public init?(exactly source: U) where U : BinaryInteger { + public init?(exactly source: U) where U: BinaryInteger { if let t = T(exactly: source) { self.init(t) } @@ -136,169 +176,191 @@ extension Tracked : SignedNumeric & Numeric where T : SignedNumeric, T == T.Magn } public var magnitude: Magnitude { return Magnitude(value.magnitude) } - public static func * (lhs: Tracked, rhs: Tracked) -> Tracked { - return Tracked(lhs.value * rhs.value) + public static func * (lhs: ${Self}, rhs: ${Self}) -> ${Self} { + return ${Self}(lhs.value * rhs.value) } - public static func *= (lhs: inout Tracked, rhs: Tracked) { + public static func *= (lhs: inout ${Self}, rhs: ${Self}) { lhs = lhs * rhs } } -extension Tracked where T : FloatingPoint { - public static func / (lhs: Tracked, rhs: Tracked) -> Tracked { - return Tracked(lhs.value / rhs.value) +extension ${Self} where T: FloatingPoint { + public static func / (lhs: ${Self}, rhs: ${Self}) -> ${Self} { + return ${Self}(lhs.value / rhs.value) } - public static func /= (lhs: inout Tracked, rhs: Tracked) { + public static func /= (lhs: inout ${Self}, rhs: ${Self}) { lhs = lhs / rhs } } -extension Tracked : Strideable where T : Strideable, T.Stride == T.Stride.Magnitude { - public typealias Stride = Tracked +extension ${Self}: Strideable +where T: Strideable, T.Stride == T.Stride.Magnitude { + public typealias Stride = ${Self} - public func distance(to other: Tracked) -> Stride { + public func distance(to other: ${Self}) -> Stride { return Stride(value.distance(to: other.value)) } - public func advanced(by n: Stride) -> Tracked { - return Tracked(value.advanced(by: n.value)) + public func advanced(by n: Stride) -> ${Self} { + return ${Self}(value.advanced(by: n.value)) } } // For now, `T` must be restricted to trivial types (like `Float` or `Tensor`). -extension Tracked : Differentiable where T : Differentiable, T == T.TangentVector { - public typealias TangentVector = Tracked +extension ${Self}: Differentiable +where T: Differentiable, T == T.TangentVector { + public typealias TangentVector = ${Self} } -extension Tracked where T : Differentiable, T == T.TangentVector { +extension ${Self} where T: Differentiable, T == T.TangentVector { @usableFromInline @derivative(of: init) internal static func _vjpInit(_ value: T) - -> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) { - return (Tracked(value), { v in v.value }) + -> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) + { + return (${Self}(value), { v in v.value }) } @usableFromInline @derivative(of: init) internal static func _jvpInit(_ value: T) - -> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector)) { - return (Tracked(value), { v in Tracked(v) }) + -> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector)) + { + return (${Self}(value), { v in ${Self}(v) }) } @usableFromInline @derivative(of: value) - internal func _vjpValue() -> (value: T, pullback: (T.TangentVector) -> Self.TangentVector) { - return (value, { v in Tracked(v) }) + internal func _vjpValue() -> ( + value: T, pullback: (T.TangentVector) -> Self.TangentVector + ) { + return (value, { v in ${Self}(v) }) } @usableFromInline @derivative(of: value) - internal func _jvpValue() -> (value: T, differential: (Self.TangentVector) -> T.TangentVector) { + internal func _jvpValue() -> ( + value: T, differential: (Self.TangentVector) -> T.TangentVector + ) { return (value, { v in v.value }) } } -extension Tracked where T : Differentiable, T == T.TangentVector { +extension ${Self} where T: Differentiable, T == T.TangentVector { @usableFromInline @derivative(of: +) internal static func _vjpAdd(lhs: Self, rhs: Self) - -> (value: Self, pullback: (Self) -> (Self, Self)) { + -> (value: Self, pullback: (Self) -> (Self, Self)) + { return (lhs + rhs, { v in (v, v) }) } @usableFromInline @derivative(of: +) internal static func _jvpAdd(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> Self) { + -> (value: Self, differential: (Self, Self) -> Self) + { return (lhs + rhs, { $0 + $1 }) } @usableFromInline @derivative(of: -) internal static func _vjpSubtract(lhs: Self, rhs: Self) - -> (value: Self, pullback: (Self) -> (Self, Self)) { + -> (value: Self, pullback: (Self) -> (Self, Self)) + { return (lhs - rhs, { v in (v, .zero - v) }) } @usableFromInline @derivative(of: -) internal static func _jvpSubtract(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> Self) { + -> (value: Self, differential: (Self, Self) -> Self) + { return (lhs - rhs, { $0 - $1 }) } } -extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, - T == T.TangentVector { +extension ${Self} +where + T: Differentiable & SignedNumeric, T == T.Magnitude, + T == T.TangentVector +{ @usableFromInline @derivative(of: *) internal static func _vjpMultiply(lhs: Self, rhs: Self) - -> (value: Self, pullback: (Self) -> (Self, Self)) { + -> (value: Self, pullback: (Self) -> (Self, Self)) + { return (lhs * rhs, { v in (v * rhs, v * lhs) }) } @usableFromInline @derivative(of: *) internal static func _jvpMultiply(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> (Self)) { + -> (value: Self, differential: (Self, Self) -> (Self)) + { return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs }) } } -extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector { +extension ${Self} +where T: Differentiable & FloatingPoint, T == T.TangentVector { @usableFromInline @derivative(of: /) internal static func _vjpDivide(lhs: Self, rhs: Self) - -> (value: Self, pullback: (Self) -> (Self, Self)) { + -> (value: Self, pullback: (Self) -> (Self, Self)) + { return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } @usableFromInline @derivative(of: /) internal static func _jvpDivide(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> (Self)) { + -> (value: Self, differential: (Self, Self) -> (Self)) + { return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy }) } } -// Differential operators for `Tracked`. +// Differential operators for `${Self}`. public func gradient( - at x: T, in f: @differentiable (T) -> Tracked + at x: T, in f: @differentiable (T) -> ${Self} ) -> T.TangentVector where R.TangentVector == R { return pullback(at: x, in: f)(1) } public func gradient( - at x: T, _ y: U, in f: @differentiable (T, U) -> Tracked + at x: T, _ y: U, in f: @differentiable (T, U) -> ${Self} ) -> (T.TangentVector, U.TangentVector) where R.TangentVector == R { return pullback(at: x, y, in: f)(1) } public func derivative( - at x: Tracked, in f: @differentiable (Tracked) -> R + at x: ${Self}, in f: @differentiable (${Self}) -> R ) -> R.TangentVector where T.TangentVector == T { return differential(at: x, in: f)(1) } public func derivative( - at x: Tracked, _ y: Tracked, - in f: @differentiable (Tracked, Tracked) -> R + at x: ${Self}, _ y: ${Self}, + in f: @differentiable (${Self}, ${Self}) -> R ) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U { return differential(at: x, y, in: f)(1, 1) } public func valueWithGradient( - at x: T, in f: @differentiable (T) -> Tracked -) -> (value: Tracked, gradient: T.TangentVector) { + at x: T, in f: @differentiable (T) -> ${Self} +) -> (value: ${Self}, gradient: T.TangentVector) { let (y, pullback) = valueWithPullback(at: x, in: f) return (y, pullback(1)) } public func valueWithDerivative( - at x: Tracked, in f: @differentiable (Tracked) -> R + at x: ${Self}, in f: @differentiable (${Self}) -> R ) -> (value: R, derivative: R.TangentVector) { let (y, differential) = valueWithDifferential(at: x, in: f) return (y, differential(1)) } + +% end diff --git a/test/AutoDiff/validation-test/optional.swift b/test/AutoDiff/validation-test/optional.swift index c2fc0fac4a695..4aeb60ba42762 100644 --- a/test/AutoDiff/validation-test/optional.swift +++ b/test/AutoDiff/validation-test/optional.swift @@ -1,8 +1,8 @@ // RUN: %target-run-simple-swift // REQUIRES: executable_test -import StdlibUnittest import DifferentiationUnittest +import StdlibUnittest var OptionalTests = TestSuite("OptionalDifferentiation") @@ -23,7 +23,7 @@ OptionalTests.test("Let") { @differentiable func optional_let(_ maybeX: Float?) -> Float { if let x = maybeX { - return x * x + return x * x } return 10 } @@ -33,13 +33,27 @@ OptionalTests.test("Let") { @differentiable func optional_let_tracked(_ maybeX: Tracked?) -> Tracked { if let x = maybeX { - return x * x + return x * x } return 10 } expectEqual(gradient(at: 10, in: optional_let_tracked), .init(20.0)) expectEqual(gradient(at: nil, in: optional_let_tracked), .init(0.0)) + @differentiable + func optional_let_nonresilient_tracked(_ maybeX: NonresilientTracked?) + -> NonresilientTracked + { + if let x = maybeX { + return x * x + } + return 10 + } + expectEqual( + gradient(at: 10, in: optional_let_nonresilient_tracked), .init(20.0)) + expectEqual( + gradient(at: nil, in: optional_let_nonresilient_tracked), .init(0.0)) + @differentiable func optional_let_nested(_ nestedMaybeX: Float??) -> Float { if let maybeX = nestedMaybeX { @@ -54,7 +68,26 @@ OptionalTests.test("Let") { expectEqual(gradient(at: nil, in: optional_let_nested), .init(.init(0.0))) @differentiable - func optional_let_nested_tracked(_ nestedMaybeX: Tracked??) -> Tracked { + func optional_let_nested_tracked(_ nestedMaybeX: Tracked??) -> Tracked< + Float + > { + if let maybeX = nestedMaybeX { + if let x = maybeX { + return x * x + } + return 10 + } + return 10 + } + expectEqual( + gradient(at: 10, in: optional_let_nested_tracked), .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_let_nested_tracked), .init(.init(0.0))) + + @differentiable + func optional_let_nested_nonresilient_tracked( + _ nestedMaybeX: NonresilientTracked?? + ) -> NonresilientTracked { if let maybeX = nestedMaybeX { if let x = maybeX { return x * x @@ -63,24 +96,38 @@ OptionalTests.test("Let") { } return 10 } - expectEqual(gradient(at: 10, in: optional_let_nested_tracked), .init(.init(20.0))) - expectEqual(gradient(at: nil, in: optional_let_nested_tracked), .init(.init(0.0))) + expectEqual( + gradient(at: 10, in: optional_let_nested_nonresilient_tracked), + .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_let_nested_nonresilient_tracked), + .init(.init(0.0))) @differentiable - func optional_let_generic(_ maybeX: T?, _ defaultValue: T) -> T { + func optional_let_generic(_ maybeX: T?, _ defaultValue: T) + -> T + { if let x = maybeX { - return x + return x } return defaultValue } expectEqual(gradient(at: 10, 20, in: optional_let_generic), (.init(1.0), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_let_generic), (.init(0.0), 1.0)) + expectEqual( + gradient(at: nil, 20, in: optional_let_generic), (.init(0.0), 1.0)) - expectEqual(gradient(at: Tracked.init(10), Tracked.init(20), in: optional_let_generic), (.init(1.0), 0.0)) - expectEqual(gradient(at: nil, Tracked.init(20), in: optional_let_generic), (.init(0.0), 1.0)) + expectEqual( + gradient( + at: Tracked.init(10), Tracked.init(20), + in: optional_let_generic), (.init(1.0), 0.0)) + expectEqual( + gradient(at: nil, Tracked.init(20), in: optional_let_generic), + (.init(0.0), 1.0)) @differentiable - func optional_let_nested_generic(_ nestedMaybeX: T??, _ defaultValue: T) -> T { + func optional_let_nested_generic( + _ nestedMaybeX: T??, _ defaultValue: T + ) -> T { if let maybeX = nestedMaybeX { if let x = maybeX { return x @@ -90,8 +137,12 @@ OptionalTests.test("Let") { return defaultValue } - expectEqual(gradient(at: 10.0, 20.0, in: optional_let_nested_generic), (.init(.init(1.0)), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_let_nested_generic), (.init(.init(0.0)), 1.0)) + expectEqual( + gradient(at: 10.0, 20.0, in: optional_let_nested_generic), + (.init(.init(1.0)), 0.0)) + expectEqual( + gradient(at: nil, 20, in: optional_let_nested_generic), + (.init(.init(0.0)), 1.0)) } OptionalTests.test("Switch") { @@ -115,14 +166,28 @@ OptionalTests.test("Switch") { expectEqual(gradient(at: 10, in: optional_switch_tracked), .init(20.0)) expectEqual(gradient(at: nil, in: optional_switch_tracked), .init(0.0)) + @differentiable + func optional_switch_nonresilient_tracked( + _ maybeX: NonresilientTracked? + ) -> NonresilientTracked { + switch maybeX { + case nil: return 10 + case let .some(x): return x * x + } + } + expectEqual( + gradient(at: 10, in: optional_switch_nonresilient_tracked), .init(20.0)) + expectEqual( + gradient(at: nil, in: optional_switch_nonresilient_tracked), .init(0.0)) + @differentiable func optional_switch_nested(_ nestedMaybeX: Float??) -> Float { switch nestedMaybeX { case nil: return 10 case let .some(maybeX): switch maybeX { - case nil: return 10 - case let .some(x): return x * x + case nil: return 10 + case let .some(x): return x * x } } } @@ -130,42 +195,76 @@ OptionalTests.test("Switch") { expectEqual(gradient(at: nil, in: optional_switch_nested), .init(.init(0.0))) @differentiable - func optional_switch_nested_tracked(_ nestedMaybeX: Tracked??) -> Tracked { + func optional_switch_nested_tracked(_ nestedMaybeX: Tracked??) + -> Tracked + { switch nestedMaybeX { case nil: return 10 case let .some(maybeX): switch maybeX { - case nil: return 10 - case let .some(x): return x * x + case nil: return 10 + case let .some(x): return x * x } } } - expectEqual(gradient(at: 10, in: optional_switch_nested_tracked), .init(.init(20.0))) - expectEqual(gradient(at: nil, in: optional_switch_nested_tracked), .init(.init(0.0))) + expectEqual( + gradient(at: 10, in: optional_switch_nested_tracked), .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_switch_nested_tracked), .init(.init(0.0))) @differentiable - func optional_switch_generic(_ maybeX: T?, _ defaultValue: T) -> T { + func optional_switch_nested_nonresilient_tracked( + _ nestedMaybeX: NonresilientTracked?? + ) -> NonresilientTracked { + switch nestedMaybeX { + case nil: return 10 + case let .some(maybeX): + switch maybeX { + case nil: return 10 + case let .some(x): return x * x + } + } + } + expectEqual( + gradient(at: 10, in: optional_switch_nested_nonresilient_tracked), + .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_switch_nested_nonresilient_tracked), + .init(.init(0.0))) + + @differentiable + func optional_switch_generic( + _ maybeX: T?, _ defaultValue: T + ) -> T { switch maybeX { case nil: return defaultValue case let .some(x): return x } } - expectEqual(gradient(at: 10, 20, in: optional_switch_generic), (.init(1.0), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_switch_generic), (.init(0.0), 1.0)) + expectEqual( + gradient(at: 10, 20, in: optional_switch_generic), (.init(1.0), 0.0)) + expectEqual( + gradient(at: nil, 20, in: optional_switch_generic), (.init(0.0), 1.0)) @differentiable - func optional_switch_nested_generic(_ nestedMaybeX: T??, _ defaultValue: T) -> T { + func optional_switch_nested_generic( + _ nestedMaybeX: T??, _ defaultValue: T + ) -> T { switch nestedMaybeX { case nil: return defaultValue case let .some(maybeX): switch maybeX { - case nil: return defaultValue - case let .some(x): return x + case nil: return defaultValue + case let .some(x): return x } } } - expectEqual(gradient(at: 10, 20, in: optional_switch_nested_generic), (.init(.init(1.0)), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_switch_nested_generic), (.init(.init(0.0)), 1.0)) + expectEqual( + gradient(at: 10, 20, in: optional_switch_nested_generic), + (.init(.init(1.0)), 0.0)) + expectEqual( + gradient(at: nil, 20, in: optional_switch_nested_generic), + (.init(.init(0.0)), 1.0)) } OptionalTests.test("Var1") { @@ -173,7 +272,7 @@ OptionalTests.test("Var1") { func optional_var1(_ maybeX: Float?) -> Float { var maybeX = maybeX if let x = maybeX { - return x * x + return x * x } return 10 } @@ -184,13 +283,28 @@ OptionalTests.test("Var1") { func optional_var1_tracked(_ maybeX: Tracked?) -> Tracked { var maybeX = maybeX if let x = maybeX { - return x * x + return x * x } return 10 } expectEqual(gradient(at: 10, in: optional_var1_tracked), .init(20.0)) expectEqual(gradient(at: nil, in: optional_var1_tracked), .init(0.0)) + @differentiable + func optional_var1_nonresilient_tracked(_ maybeX: NonresilientTracked?) + -> NonresilientTracked + { + var maybeX = maybeX + if let x = maybeX { + return x * x + } + return 10 + } + expectEqual( + gradient(at: 10, in: optional_var1_nonresilient_tracked), .init(20.0)) + expectEqual( + gradient(at: nil, in: optional_var1_nonresilient_tracked), .init(0.0)) + @differentiable func optional_var1_nested(_ nestedMaybeX: Float??) -> Float { var nestedMaybeX = nestedMaybeX @@ -206,7 +320,9 @@ OptionalTests.test("Var1") { expectEqual(gradient(at: nil, in: optional_var1_nested), .init(.init(0.0))) @differentiable - func optional_var1_nested_tracked(_ nestedMaybeX: Tracked??) -> Tracked { + func optional_var1_nested_tracked(_ nestedMaybeX: Tracked??) + -> Tracked + { var nestedMaybeX = nestedMaybeX if let maybeX = nestedMaybeX { if var x = maybeX { @@ -216,22 +332,50 @@ OptionalTests.test("Var1") { } return 10 } - expectEqual(gradient(at: 10, in: optional_var1_nested_tracked), .init(.init(20.0))) - expectEqual(gradient(at: nil, in: optional_var1_nested_tracked), .init(.init(0.0))) + expectEqual( + gradient(at: 10, in: optional_var1_nested_tracked), .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_var1_nested_tracked), .init(.init(0.0))) @differentiable - func optional_var1_generic(_ maybeX: T?, _ defaultValue: T) -> T { + func optional_var1_nested_nonresilient_tracked( + _ nestedMaybeX: NonresilientTracked?? + ) -> NonresilientTracked { + var nestedMaybeX = nestedMaybeX + if let maybeX = nestedMaybeX { + if var x = maybeX { + return x * x + } + return 10 + } + return 10 + } + expectEqual( + gradient(at: 10, in: optional_var1_nested_nonresilient_tracked), + .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_var1_nested_nonresilient_tracked), + .init(.init(0.0))) + + @differentiable + func optional_var1_generic(_ maybeX: T?, _ defaultValue: T) + -> T + { var maybeX = maybeX if let x = maybeX { - return x + return x } return defaultValue } - expectEqual(gradient(at: 10, 20, in: optional_var1_generic), (.init(1.0), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_var1_generic), (.init(0.0), 1.0)) + expectEqual( + gradient(at: 10, 20, in: optional_var1_generic), (.init(1.0), 0.0)) + expectEqual( + gradient(at: nil, 20, in: optional_var1_generic), (.init(0.0), 1.0)) @differentiable - func optional_var1_nested_generic(_ nestedMaybeX: T??, _ defaultValue: T) -> T { + func optional_var1_nested_generic( + _ nestedMaybeX: T??, _ defaultValue: T + ) -> T { var nestedMaybeX = nestedMaybeX if let maybeX = nestedMaybeX { if var x = maybeX { @@ -241,8 +385,12 @@ OptionalTests.test("Var1") { } return defaultValue } - expectEqual(gradient(at: 10, 20, in: optional_var1_nested_generic), (.init(.init(1.0)), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_var1_nested_generic), (.init(.init(0.0)), 1.0)) + expectEqual( + gradient(at: 10, 20, in: optional_var1_nested_generic), + (.init(.init(1.0)), 0.0)) + expectEqual( + gradient(at: nil, 20, in: optional_var1_nested_generic), + (.init(.init(0.0)), 1.0)) } OptionalTests.test("Var2") { @@ -266,6 +414,20 @@ OptionalTests.test("Var2") { expectEqual(gradient(at: 10, in: optional_var2_tracked), .init(20.0)) expectEqual(gradient(at: nil, in: optional_var2_tracked), .init(0.0)) + @differentiable + func optional_var2_nonresilient_tracked(_ maybeX: NonresilientTracked?) + -> NonresilientTracked + { + if var x = maybeX { + return x * x + } + return 10 + } + expectEqual( + gradient(at: 10, in: optional_var2_nonresilient_tracked), .init(20.0)) + expectEqual( + gradient(at: nil, in: optional_var2_nonresilient_tracked), .init(0.0)) + @differentiable func optional_var2_nested(_ nestedMaybeX: Float??) -> Float { if var maybeX = nestedMaybeX { @@ -280,7 +442,26 @@ OptionalTests.test("Var2") { expectEqual(gradient(at: nil, in: optional_var2_nested), .init(.init(0.0))) @differentiable - func optional_var2_nested_tracked(_ nestedMaybeX: Tracked??) -> Tracked { + func optional_var2_nested_tracked(_ nestedMaybeX: Tracked??) + -> Tracked + { + if var maybeX = nestedMaybeX { + if var x = maybeX { + return x * x + } + return 10 + } + return 10 + } + expectEqual( + gradient(at: 10, in: optional_var2_nested_tracked), .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_var2_nested_tracked), .init(.init(0.0))) + + @differentiable + func optional_var2_nested_nonresilient_tracked( + _ nestedMaybeX: NonresilientTracked?? + ) -> NonresilientTracked { if var maybeX = nestedMaybeX { if var x = maybeX { return x * x @@ -289,21 +470,31 @@ OptionalTests.test("Var2") { } return 10 } - expectEqual(gradient(at: 10, in: optional_var2_nested_tracked), .init(.init(20.0))) - expectEqual(gradient(at: nil, in: optional_var2_nested_tracked), .init(.init(0.0))) + expectEqual( + gradient(at: 10, in: optional_var2_nested_nonresilient_tracked), + .init(.init(20.0))) + expectEqual( + gradient(at: nil, in: optional_var2_nested_nonresilient_tracked), + .init(.init(0.0))) @differentiable - func optional_var2_generic(_ maybeX: T?, _ defaultValue: T) -> T { + func optional_var2_generic(_ maybeX: T?, _ defaultValue: T) + -> T + { if var x = maybeX { return x } return defaultValue } - expectEqual(gradient(at: 10, 20, in: optional_var2_generic), (.init(1.0), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_var2_generic), (.init(0.0), 1.0)) + expectEqual( + gradient(at: 10, 20, in: optional_var2_generic), (.init(1.0), 0.0)) + expectEqual( + gradient(at: nil, 20, in: optional_var2_generic), (.init(0.0), 1.0)) @differentiable - func optional_var2_nested_generic(_ nestedMaybeX: T??, _ defaultValue: T) -> T { + func optional_var2_nested_generic( + _ nestedMaybeX: T??, _ defaultValue: T + ) -> T { if var maybeX = nestedMaybeX { if var x = maybeX { return x @@ -312,8 +503,12 @@ OptionalTests.test("Var2") { } return defaultValue } - expectEqual(gradient(at: 10, 20, in: optional_var2_nested_generic), (.init(.init(1.0)), 0.0)) - expectEqual(gradient(at: nil, 20, in: optional_var2_nested_generic), (.init(.init(0.0)), 1.0)) + expectEqual( + gradient(at: 10, 20, in: optional_var2_nested_generic), + (.init(.init(1.0)), 0.0)) + expectEqual( + gradient(at: nil, 20, in: optional_var2_nested_generic), + (.init(.init(0.0)), 1.0)) } runAllTests()