@@ -153,18 +153,28 @@ final class ComplexTests: XCTestCase {
153153 }
154154
155155 func testVjpInit( ) {
156- let pb = pullback ( at: 4 , - 3 ) { r, i in
156+ var pb = pullback ( at: 4 , - 3 ) { r, i in
157157 return Complex < Float > ( real: r, imaginary: i)
158158 }
159- XCTAssertEqual ( ( - 1 , 2 ) , pb ( Complex < Float > ( real: - 1 , imaginary: 2 ) ) )
159+ var tanTuple = pb ( Complex < Float > ( real: - 1 , imaginary: 2 ) )
160+ XCTAssertEqual ( - 1 , tanTuple. 0 )
161+ XCTAssertEqual ( 2 , tanTuple. 1 )
162+
163+ pb = pullback ( at: 4 , - 3 ) { r, i in
164+ return Complex < Float > ( real: r * r, imaginary: i + i)
165+ }
166+ tanTuple = pb ( Complex < Float > ( real: - 1 , imaginary: 1 ) )
167+ XCTAssertEqual ( - 8 , tanTuple. 0 )
168+ XCTAssertEqual ( 2 , tanTuple. 1 )
160169 }
161170
162171 func testVjpAdd( ) {
163172 let pb : ( Complex < Float > ) -> Complex < Float > =
164173 pullback ( at: Complex < Float > ( real: 2 , imaginary: 3 ) ) { x in
165174 return x + Complex < Float > ( real: 5 , imaginary: 6 )
166175 }
167- XCTAssertEqual ( pb ( Complex ( real: 1 , imaginary: 1 ) ) , Complex < Float > ( real: 1 , imaginary: 1 ) )
176+ XCTAssertEqual ( pb ( Complex ( real: 1 , imaginary: 1 ) ) ,
177+ Complex < Float > ( real: 1 , imaginary: 1 ) )
168178 }
169179
170180 func testVjpSubtract( ) {
@@ -316,7 +326,7 @@ final class ComplexTests: XCTestCase {
316326 }
317327
318328 XCTAssertEqual ( - 2 , result)
319- XCTAssertEqual ( Complex ( real: 1 , imaginary: 1 ) , pbComplex ( 1 ) )
329+ XCTAssertEqual ( Complex ( real: 1 , imaginary: 0 ) , pbComplex ( 1 ) )
320330 }
321331
322332 static var allTests = [
0 commit comments