diff --git a/Package.swift b/Package.swift index 8e853a36d..c6e87533c 100644 --- a/Package.swift +++ b/Package.swift @@ -30,6 +30,13 @@ let package = Package( .target( name: "TensorFlow", dependencies: []), + .target( + name: "Experimental", + dependencies: [], + path: "Sources/third_party/Experimental"), + .testTarget( + name: "ExperimentalTests", + dependencies: ["Experimental"]), .testTarget( name: "TensorFlowTests", dependencies: ["TensorFlow"]), diff --git a/Sources/third_party/Experimental/Complex.swift b/Sources/third_party/Experimental/Complex.swift new file mode 100644 index 000000000..c09f3dcf3 --- /dev/null +++ b/Sources/third_party/Experimental/Complex.swift @@ -0,0 +1,349 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +/// Note +/// ---- +/// +/// This implementation uses a modified implementation from the +/// xwu/NumericAnnex Swift numeric library repo vy Xiaodi Wu. To view the +/// original code, see the implementation here +/// +/// https://github.com/xwu/NumericAnnex/blob/master/Sources/Complex.swift +/// +/// Create new instances of `Complex` using integer or floating-point +/// literals and the imaginary unit `Complex.i`. For example: +/// +/// ```swift +/// let x: Complex = 2 + 4 * .i +/// ``` +/// +/// Additional Considerations +/// ------------------------- +/// +/// Our implementation of complex number differentiation follows the same +/// convention as Autograd. In short, we can get the derivative of a +/// holomorphic function, functions whose codomain are the Reals, and +/// functions whose codomain and domain are the Reals. You can read more about +/// Autograd at +/// +/// https://github.com/HIPS/autograd/blob/master/docs/tutorial.md#complex-numbers +/// +/// Floating-point types have special values that represent infinity or NaN +/// ("not a number"). Complex functions in different languages may return +/// different results when working with special values. + +struct Complex { + var real: T + var imaginary: T + + @differentiable(vjp: _vjpInit where T: Differentiable, T.TangentVector == T) + init(real: T = 0, imaginary: T = 0) { + self.real = real + self.imaginary = imaginary + } +} + +extension Complex: Differentiable where T: Differentiable { + typealias TangentVector = Complex + typealias AllDifferentiableVariables = Complex +} + +extension Complex { + static var i: Complex { + return Complex(real: 0, imaginary: 1) + } + + var isFinite: Bool { + return real.isFinite && imaginary.isFinite + } + + var isInfinite: Bool { + return real.isInfinite || imaginary.isInfinite + } + + var isNaN: Bool { + return (real.isNaN && !imaginary.isInfinite) || (imaginary.isNaN && !real.isInfinite) + } + + var isZero: Bool { + return real.isZero && imaginary.isZero + } +} + +extension Complex: ExpressibleByIntegerLiteral { + init(integerLiteral value: Int) { + self.real = T(value) + self.imaginary = 0 + } +} + +extension Complex: CustomStringConvertible { + var description: String { + return real.isNaN && real.sign == .minus + ? imaginary.sign == .minus + ? "-\(-real) - \(-imaginary)i" + : "-\(-real) + \(imaginary)i" + : imaginary.sign == .minus + ? "\(real) - \(-imaginary)i" + : "\(real) + \(imaginary)i" + } +} + +extension Complex: Equatable { + static func == (lhs: Complex, rhs: Complex) -> Bool { + return lhs.real == rhs.real && lhs.imaginary == rhs.imaginary + } +} + +extension Complex: AdditiveArithmetic { + @differentiable(vjp: _vjpAdd(lhs:rhs:) where T: Differentiable) + static func + (lhs: Complex, rhs: Complex) -> Complex { + var temp = lhs + temp += rhs + return temp + } + + static func += (lhs: inout Complex, rhs: Complex) { + lhs.real += rhs.real + lhs.imaginary += rhs.imaginary + } + + @differentiable(vjp: _vjpSubtract(lhs:rhs:) where T: Differentiable) + static func - (lhs: Complex, rhs: Complex) -> Complex { + var temp = lhs + temp -= rhs + return temp + } + + static func -= (lhs: inout Complex, rhs: Complex) { + lhs.real -= rhs.real + lhs.imaginary -= rhs.imaginary + } +} + +extension Complex: Numeric { + init?(exactly source: U) where U: BinaryInteger { + guard let t = T(exactly: source) else { return nil } + self.real = t + self.imaginary = 0 + } + + static private func handleMultiplyNaN(infiniteA: T, infiniteB: T, nanA: T, nanB: T) -> Complex { + var a = infiniteA + var b = infiniteB + var c = nanA + var d = nanB + + a = T(signOf: infiniteA, magnitudeOf: infiniteA.isInfinite ? 1 : 0) + b = T(signOf: infiniteB, magnitudeOf: infiniteB.isInfinite ? 1 : 0) + + if nanA.isNaN { c = T(signOf: nanA, magnitudeOf: 0) } + if nanB.isNaN { d = T(signOf: nanB, magnitudeOf: 0) } + + return Complex( + real: .infinity * (a * c - b * d), + imaginary: .infinity * (a * d + b * c) + ) + } + + @differentiable(vjp: _vjpMultiply(lhs:rhs:) where T: Differentiable) + static func * (lhs: Complex, rhs: Complex) -> Complex { + var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary + let ac = a * c, bd = b * d, ad = a * d, bc = b * c + let x = ac - bd + let y = ad + bc + + if x.isNaN && y.isNaN { + if a.isInfinite || b.isInfinite { + return handleMultiplyNaN(infiniteA: a, infiniteB: b, nanA: c, nanB: d) + } else if c.isInfinite || d.isInfinite { + return handleMultiplyNaN(infiniteA: c, infiniteB: d, nanA: a, nanB: b) + } else if ac.isInfinite || bd.isInfinite || ad.isInfinite || bc.isInfinite { + if a.isNaN { a = T(signOf: a, magnitudeOf: 0) } + if b.isNaN { b = T(signOf: b, magnitudeOf: 0) } + if c.isNaN { c = T(signOf: c, magnitudeOf: 0) } + if d.isNaN { d = T(signOf: d, magnitudeOf: 0) } + return Complex( + real: .infinity * (a * c - b * d), + imaginary: .infinity * (a * d + b * c) + ) + } + } + return Complex(real: x, imaginary: y) + } + + static func *= (lhs: inout Complex, rhs: Complex) { + lhs = lhs * rhs + } + + var magnitude: T { + var x = abs(real) + var y = abs(imaginary) + if x.isInfinite { return x } + if y.isInfinite { return y } + if x == 0 { return y } + if x < y { swap(&x, &y) } + let ratio = y / x + return x * (1 + ratio * ratio).squareRoot() + } +} + +extension Complex: SignedNumeric { + @differentiable(vjp: _vjpNegate where T: Differentiable) + static prefix func - (operand: Complex) -> Complex { + return Complex(real: -operand.real, imaginary: -operand.imaginary) + } + + mutating func negate() { + real.negate() + imaginary.negate() + } +} + +extension Complex { + @differentiable(vjp: _vjpDivide(lhs:rhs:) where T: Differentiable) + static func / (lhs: Complex, rhs: Complex) -> Complex { + var a = lhs.real, b = lhs.imaginary, c = rhs.real, d = rhs.imaginary + var x: T + var y: T + if c.magnitude >= d.magnitude { + let ratio = d / c + let denominator = c + d * ratio + x = (a + b * ratio) / denominator + y = (b - a * ratio) / denominator + } else { + let ratio = c / d + let denominator = c * ratio + d + x = (a * ratio + b) / denominator + y = (b * ratio - a) / denominator + } + if x.isNaN && y.isNaN { + if c == 0 && d == 0 && (!a.isNaN || !b.isNaN) { + x = T(signOf: c, magnitudeOf: .infinity) * a + y = T(signOf: c, magnitudeOf: .infinity) * b + } else if (a.isInfinite || b.isInfinite) && c.isFinite && d.isFinite { + a = T(signOf: a, magnitudeOf: a.isInfinite ? 1 : 0) + b = T(signOf: b, magnitudeOf: b.isInfinite ? 1 : 0) + x = .infinity * (a * c + b * d) + y = .infinity * (b * c - a * d) + } else if (c.isInfinite || d.isInfinite) && a.isFinite && b.isFinite { + c = T(signOf: c, magnitudeOf: c.isInfinite ? 1 : 0) + d = T(signOf: d, magnitudeOf: d.isInfinite ? 1 : 0) + x = 0 * (a * c + b * d) + y = 0 * (b * c - a * d) + } + } + return Complex(real: x, imaginary: y) + } + + static func /= (lhs: inout Complex, rhs: Complex) { + lhs = lhs / rhs + } +} + +extension Complex { + @differentiable(vjp: _vjpComplexConjugate where T: Differentiable) + func complexConjugate() -> Complex { + return Complex(real: real, imaginary: -imaginary) + } +} + +func abs(_ z: Complex) -> Complex { + return Complex(real: z.magnitude) +} + +extension Complex { + @differentiable(vjp: _vjpAdding(real:) where T: Differentiable, T.TangentVector == T) + func adding(real: T) -> Complex { + var c = self + c.real += real + return c + } + + @differentiable(vjp: _vjpSubtracting(real:) where T: Differentiable, T.TangentVector == T) + func subtracting(real: T) -> Complex { + var c = self + c.real -= real + return c + } + + @differentiable(vjp: _vjpAdding(imaginary:) where T: Differentiable, T.TangentVector == T) + func adding(imaginary: T) -> Complex { + var c = self + c.imaginary += imaginary + return c + } + + @differentiable(vjp: _vjpSubtracting(imaginary:) where T: Differentiable, T.TangentVector == T) + func subtracting(imaginary: T) -> Complex { + var c = self + c.imaginary -= imaginary + return c + } +} + +extension Complex where T: Differentiable, T.TangentVector == T { + static func _vjpInit(real: T, imaginary: T) -> (Complex, (Complex) -> (T, T)) { + return (Complex(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) }) + } +} + +extension Complex where T: Differentiable { + static func _vjpAdd(lhs: Complex, rhs: Complex) + -> (Complex, (Complex) -> (Complex, Complex)) { + return (lhs + rhs, { v in (v, v) }) + } + + static func _vjpSubtract(lhs: Complex, rhs: Complex) + -> (Complex, (Complex) -> (Complex, Complex)) { + return (lhs - rhs, { v in (v, -v) }) + } + + static func _vjpMultiply(lhs: Complex, rhs: Complex) + -> (Complex, (Complex) -> (Complex, Complex)) { + return (lhs * rhs, { v in (rhs * v, lhs * v) }) + } + + static func _vjpDivide(lhs: Complex, rhs: Complex) + -> (Complex, (Complex) -> (Complex, Complex)) { + return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) + } + + static func _vjpNegate(operand: Complex) + -> (Complex, (Complex) -> Complex) { + return (-operand, { -$0 }) + } + + func _vjpComplexConjugate() -> (Complex, (Complex) -> Complex) { + return (complexConjugate(), { v in v.complexConjugate() }) + } +} + +extension Complex where T: Differentiable, T.TangentVector == T { + func _vjpAdding(real: T) -> (Complex, (Complex) -> (Complex, T)) { + return (self.adding(real: real), { ($0, $0.real) }) + } + + func _vjpSubtracting(real: T) -> (Complex, (Complex) -> (Complex, T)) { + return (self.subtracting(real: real), { ($0, -$0.real) }) + } + + func _vjpAdding(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) { + return (self.adding(real: real), { ($0, $0.imaginary) }) + } + + func _vjpSubtracting(imaginary: T) -> (Complex, (Complex) -> (Complex, T)) { + return (self.subtracting(real: real), { ($0, -$0.imaginary) }) + } +} diff --git a/Sources/third_party/Experimental/LICENSE b/Sources/third_party/Experimental/LICENSE new file mode 100644 index 000000000..8362daf4d --- /dev/null +++ b/Sources/third_party/Experimental/LICENSE @@ -0,0 +1,13 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/Tests/ExperimentalTests/ComplexTests.swift b/Tests/ExperimentalTests/ComplexTests.swift new file mode 100644 index 000000000..1f85894c5 --- /dev/null +++ b/Tests/ExperimentalTests/ComplexTests.swift @@ -0,0 +1,362 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import Experimental + +final class ComplexTests: XCTestCase { + func testInitializer() { + let complex = Complex(real: 2, imaginary: 3) + XCTAssertEqual(complex.real, 2) + XCTAssertEqual(complex.imaginary, 3) + } + + func testStaticImaginary() { + let imaginary = Complex(real: 0, imaginary: 1) + XCTAssertEqual(imaginary, Complex.i) + } + + func testIsFinite() { + var complex = Complex(real: 999, imaginary: 0) + XCTAssertTrue(complex.isFinite) + + complex = Complex(real: 1.0 / 0.0, imaginary: 1) + XCTAssertFalse(complex.isFinite) + + complex = Complex(real: 1.0 / 0.0, imaginary: 1.0 / 0.0) + XCTAssertFalse(complex.isFinite) + } + + func testIsInfinite() { + var complex = Complex(real: 999, imaginary: 0) + XCTAssertFalse(complex.isInfinite) + + complex = Complex(real: 1.0 / 0.0, imaginary: 1) + XCTAssertTrue(complex.isInfinite) + + complex = Complex(real: 1.0 / 0.0, imaginary: 1.0 / 0.0) + XCTAssertTrue(complex.isInfinite) + } + + func testIsNaN() { + var complex = Complex(real: 999, imaginary: 0) + XCTAssertFalse(complex.isNaN) + + complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 1) + XCTAssertTrue(complex.isNaN) + + complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 0.0 * 1.0 / 0.0) + XCTAssertTrue(complex.isNaN) + } + + func testIsZero() { + var complex = Complex(real: 999, imaginary: 0) + XCTAssertFalse(complex.isZero) + + complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 0) + XCTAssertFalse(complex.isZero) + + complex = Complex(real: 0.0 * 1.0 / 0.0, imaginary: 0.0 * 1.0 / 0.0) + XCTAssertFalse(complex.isZero) + + complex = Complex(real: 0, imaginary: 0) + XCTAssertTrue(complex.isZero) + } + + func testEquals() { + var complexA = Complex(real: 999, imaginary: 0) + let complexB = Complex(real: 999, imaginary: 0) + XCTAssertEqual(complexA, complexB) + + complexA = Complex(real: 5, imaginary: 0) + XCTAssertNotEqual(complexA, complexB) + } + + func testPlus() { + let input = Complex(real: 5, imaginary: 1) + let expected = Complex(real: 10, imaginary: 2) + XCTAssertEqual(expected, input + input) + } + + func testMinus() { + let inputA = Complex(real: 6, imaginary: 2) + let inputB = Complex(real: 5, imaginary: 1) + let expected = Complex(real: 1, imaginary: 1) + XCTAssertEqual(expected, inputA - inputB) + } + + func testTimes() { + let inputA = Complex(real: 6, imaginary: 2) + let inputB = Complex(real: 5, imaginary: 1) + let expected = Complex(real: 28, imaginary: 16) + XCTAssertEqual(expected, inputA * inputB) + } + + func testNegate() { + var input = Complex(real: 6, imaginary: 2) + let negated = Complex(real: -6, imaginary: -2) + XCTAssertEqual(-input, negated) + input.negate() + XCTAssertEqual(input, negated) + } + + func testDivide() { + let inputA = Complex(real: 20, imaginary: -4) + let inputB = Complex(real: 3, imaginary: 2) + let expected = Complex(real: 4, imaginary: -4) + XCTAssertEqual(expected, inputA / inputB) + } + + func testComplexConjugate() { + var input = Complex(real: 2, imaginary: -4) + var expected = Complex(real: 2, imaginary: 4) + XCTAssertEqual(expected, input.complexConjugate()) + + input = Complex(real: -2, imaginary: -4) + expected = Complex(real: -2, imaginary: 4) + XCTAssertEqual(expected, input.complexConjugate()) + + input = Complex(real: 2, imaginary: 4) + expected = Complex(real: 2, imaginary: -4) + XCTAssertEqual(expected, input.complexConjugate()) + } + + func testAdding() { + var input = Complex(real: 2, imaginary: -4) + var expected = Complex(real: 3, imaginary: -4) + XCTAssertEqual(expected, input.adding(real: 1)) + + input = Complex(real: 2, imaginary: -4) + expected = Complex(real: 2, imaginary: -3) + XCTAssertEqual(expected, input.adding(imaginary: 1)) + } + + func testSubtracting() { + var input = Complex(real: 2, imaginary: -4) + var expected = Complex(real: 1, imaginary: -4) + XCTAssertEqual(expected, input.subtracting(real: 1)) + + input = Complex(real: 2, imaginary: -4) + expected = Complex(real: 2, imaginary: -5) + XCTAssertEqual(expected, input.subtracting(imaginary: 1)) + } + + func testVjpInit() { + var pb = pullback(at: 4, -3) { r, i in + return Complex(real: r, imaginary: i) + } + var tanTuple = pb(Complex(real: -1, imaginary: 2)) + XCTAssertEqual(-1, tanTuple.0) + XCTAssertEqual(2, tanTuple.1) + + pb = pullback(at: 4, -3) { r, i in + return Complex(real: r * r, imaginary: i + i) + } + tanTuple = pb(Complex(real: -1, imaginary: 1)) + XCTAssertEqual(-8, tanTuple.0) + XCTAssertEqual(2, tanTuple.1) + } + + func testVjpAdd() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 2, imaginary: 3)) { x in + return x + Complex(real: 5, imaginary: 6) + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), + Complex(real: 1, imaginary: 1)) + } + + func testVjpSubtract() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 2, imaginary: 3)) { x in + return Complex(real: 5, imaginary: 6) - x + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), Complex(real: -1, imaginary: -1)) + } + + func testVjpMultiply() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 2, imaginary: 3)) { x in + return x * x + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 0)), Complex(real: 4, imaginary: 6)) + XCTAssertEqual(pb(Complex(real: 0, imaginary: 1)), Complex(real: -6, imaginary: 4)) + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), Complex(real: -2, imaginary: 10)) + } + + func testVjpDivide() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 20, imaginary: -4)) { x in + return x / Complex(real: 2, imaginary: 2) + } + XCTAssertEqual( + pb(Complex(real: 1, imaginary: 0)), + Complex(real: 0.25, imaginary: -0.25)) + XCTAssertEqual( + pb(Complex(real: 0, imaginary: 1)), + Complex(real: 0.25, imaginary: 0.25)) + } + + func testVjpNegate() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 20, imaginary: -4)) { x in + return -x + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 0)), Complex(real: -1, imaginary: 0)) + XCTAssertEqual(pb(Complex(real: 0, imaginary: 1)), Complex(real: 0, imaginary: -1)) + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), Complex(real: -1, imaginary: -1)) + } + + func testVjpComplexConjugate() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 20, imaginary: -4)) { x in + return x.complexConjugate() + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 0)), Complex(real: 1, imaginary: 0)) + XCTAssertEqual(pb(Complex(real: 0, imaginary: 1)), Complex(real: 0, imaginary: -1)) + XCTAssertEqual(pb(Complex(real: -1, imaginary: 1)), Complex(real: -1, imaginary: -1)) + } + + func testVjpAddingReal() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 20, imaginary: -4)) { x in + return x.adding(real: 5) + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 0)), Complex(real: 1, imaginary: 0)) + XCTAssertEqual(pb(Complex(real: 0, imaginary: 1)), Complex(real: 0, imaginary: 1)) + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), Complex(real: 1, imaginary: 1)) + } + + func testVjpAddingImaginary() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 20, imaginary: -4)) { x in + return x.adding(imaginary: 5) + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 0)), Complex(real: 1, imaginary: 0)) + XCTAssertEqual(pb(Complex(real: 0, imaginary: 1)), Complex(real: 0, imaginary: 1)) + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), Complex(real: 1, imaginary: 1)) + } + + func testVjpSubtractingReal() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 20, imaginary: -4)) { x in + return x.subtracting(real: 5) + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 0)), Complex(real: 1, imaginary: 0)) + XCTAssertEqual(pb(Complex(real: 0, imaginary: 1)), Complex(real: 0, imaginary: 1)) + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), Complex(real: 1, imaginary: 1)) + } + + func testVjpSubtractingImaginary() { + let pb: (Complex) -> Complex = + pullback(at: Complex(real: 20, imaginary: -4)) { x in + return x.subtracting(imaginary: 5) + } + XCTAssertEqual(pb(Complex(real: 1, imaginary: 0)), Complex(real: 1, imaginary: 0)) + XCTAssertEqual(pb(Complex(real: 0, imaginary: 1)), Complex(real: 0, imaginary: 1)) + XCTAssertEqual(pb(Complex(real: 1, imaginary: 1)), Complex(real: 1, imaginary: 1)) + } + + func testJvpDotProduct() { + struct ComplexVector : Differentiable & AdditiveArithmetic { + var w: Complex + var x: Complex + var y: Complex + var z: Complex + + init(w: Complex, x: Complex, y: Complex, z: Complex) { + self.w = w + self.x = x + self.y = y + self.z = z + } + } + + func dot(lhs: ComplexVector, rhs: ComplexVector) -> Complex { + var result: Complex = Complex(real: 0, imaginary: 0) + result = result + lhs.w.complexConjugate() * rhs.w + result = result + lhs.x.complexConjugate() * rhs.x + result = result + lhs.y.complexConjugate() * rhs.y + result = result + lhs.z.complexConjugate() * rhs.z + return result + } + + let atVector = ComplexVector( + w: Complex(real: 1, imaginary: 1), + x: Complex(real: 1, imaginary: -1), + y: Complex(real: -1, imaginary: 1), + z: Complex(real: -1, imaginary: -1)) + let rhsVector = ComplexVector( + w: Complex(real: 3, imaginary: -4), + x: Complex(real: 6, imaginary: -2), + y: Complex(real: 1, imaginary: 2), + z: Complex(real: 4, imaginary: 3)) + let expectedVector = ComplexVector( + w: Complex(real: 7, imaginary: 1), + x: Complex(real: 8, imaginary: -4), + y: Complex(real: -1, imaginary: -3), + z: Complex(real: 1, imaginary: -7)) + + let (result, pbComplex) = valueWithPullback(at: atVector) { x in + return dot(lhs: x, rhs: rhsVector) + } + + XCTAssertEqual(Complex(real: 1, imaginary: -5), result) + XCTAssertEqual(expectedVector, pbComplex(Complex(real: 1, imaginary: 1))) + } + + func testImplicitDifferentiation() { + func addRealComponents(lhs: Complex, rhs: Complex) -> Float { + return lhs.real + rhs.real + } + + let (result, pbComplex) = valueWithPullback(at: Complex(real: 2, imaginary: -3)) { x in + return addRealComponents(lhs: x, rhs: Complex(real: -4, imaginary: 1)) + } + + XCTAssertEqual(-2, result) + XCTAssertEqual(Complex(real: 1, imaginary: 0), pbComplex(1)) + } + + static var allTests = [ + ("testInitializer", testInitializer), + ("testStaticImaginary", testStaticImaginary), + ("testIsFinite", testIsFinite), + ("testIsInfinite", testIsInfinite), + ("testIsNaN", testIsNaN), + ("testIsZero", testIsZero), + ("testEquals", testEquals), + ("testPlus", testPlus), + ("testMinus", testMinus), + ("testTimes", testTimes), + ("testNegate", testNegate), + ("testDivide", testDivide), + ("testComplexConjugate", testComplexConjugate), + ("testAdding", testAdding), + ("testSubtracting", testSubtracting), + ("testVjpInit", testVjpInit), + ("testVjpAdd", testVjpAdd), + ("testVjpSubtract", testVjpSubtract), + ("testVjpMultiply", testVjpMultiply), + ("testVjpDivide", testVjpDivide), + ("testVjpNegate", testVjpNegate), + ("testVjpComplexConjugate", testVjpComplexConjugate), + ("testVjpAddingReal", testVjpAddingReal), + ("testVjpAddingImaginary", testVjpAddingImaginary), + ("testVjpSubtractingReal", testVjpSubtractingReal), + ("testVjpSubtractingImaginary", testVjpSubtractingImaginary), + ("testJvpDotProduct", testJvpDotProduct), + ("testImplicitDifferentiation", testImplicitDifferentiation) + ] +}