Skip to content

Conversation

@rxwei
Copy link
Contributor

@rxwei rxwei commented Jan 28, 2019

Formally, when a type T where T : Differentiable gets abstracted as a function (X...) -> T for any X..., the differentiability of the abstracted type depends entirely on the differentiability of T. Since structural types cannot conform to protocols yet in Swift, we need to handle in AD-associated type calculation the same way we handle tuples.

The type calculation rules are better described as code, in imaginary syntax where parameterized extensions, variadic generic parameters, and protocol conformances for structural types are supported.

extension<T..., U> ((T...) -> U) : Differentiable where U : Differentiable {
  public typealias TangentVector = (T...) -> U.TangentVector
  public typealias CotangentVector = (T...) -> U.CotangentVector
  public func moved(along direction: TangentVector) -> (T...) -> U {
    return { (x...) in self(x...).moved(along: direction(x...)) }
  }
  public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
    return { (x...) in self(x...).tangentVector(from: cotangent(x...)) }
  }
}

This is a crucial step towards the correct typing of curried differentiable functions, which helps us differentiate through curry thunks for methods.

func curry<T : Differentiable, U : Differentiable>(
  f: @autodiff (T, U) -> V
) -> @autodiff (T) -> @autodiff (U) -> V {
  return { x in { y in f(x, y) } }
}

Partially resolves SR-9448, which needs this patch to be able to calculate the associate vector space of a curried function.

…dle function types.

Formally, a type `T where T : Differentiable` gets abstracted as a function `(X...) -> T` for any `X...`, the differentiability of the abstracted type depends entirely on the differentiability of `T`. Since structural types cannot conform to functions yet in Swift, we need to handle in AD-associated type calculation the same way we handle tuples.

The type calculation rules are better described as code, in imaginary syntax where parameterized extensions, variadic generic parameters, and protocol conformances for structural types are supported.

```swift
extension<T..., U> ((T...) -> U) : Differentiable where U : Differentiable {
  public typealias TangentVector = (T...) -> U.TangentVector
  public typealias CotangentVector = (T...) -> U.CotangentVector
  public func moved(along direction: TangentVector) -> (T...) -> U {
    return { (x...) in self(x...).moved(along: direction(x...)) }
  }
  public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
    return { (x...) in self(x...).tangentVector(from: cotangent(x...)) }
  }
}
```

This is a crucial step towards the correct typing of curried differentiable functions, which helps us differentiate through curry thunks for methods.

```swift
func curry<T : Differentiable, U : Differentiable>(
  f: @autodiff (T, U) -> V
) -> @autodiff (T) -> @autodiff (U) -> V {
  return { x in { y in f(x, y) } }
}
```
@rxwei rxwei added the tensorflow This is for "tensorflow" branch PRs. label Jan 28, 2019
@rxwei rxwei requested review from marcrasi and pschuh January 28, 2019 06:46
@rxwei
Copy link
Contributor Author

rxwei commented Jan 28, 2019

@swift-ci please test tensorflow

@rxwei
Copy link
Contributor Author

rxwei commented Jan 28, 2019

CC @stephentyrone: I'm not sure the demonstrated body of tangentVector(from:) is mathematically correct because cotangent may not be isomorphic to tangent in function spaces.

Copy link
Contributor

@pschuh pschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

tensorflow This is for "tensorflow" branch PRs.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants