You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[AutoDiff] Update TypeBase::getAutoDiffAssociatedVectorSpace to handle function types. (#22166)
Formally, when a type `T where T : Differentiable` gets abstracted as a function `(X...) -> T` for any `X...`, the differentiability of the abstracted type depends entirely on the differentiability of `T`. Since structural types cannot conform to protocols yet in Swift, we need to handle in AD-associated type calculation the same way we handle tuples.
The type calculation rules are better described as code, in imaginary syntax where parameterized extensions, variadic generic parameters, and protocol conformances for structural types are supported.
```swift
extension<T..., U> ((T...) -> U) : Differentiable where U : Differentiable {
public typealias TangentVector = (T...) -> U.TangentVector
public typealias CotangentVector = (T...) -> U.CotangentVector
public func moved(along direction: TangentVector) -> (T...) -> U {
return { (x...) in self(x...).moved(along: direction(x...)) }
}
public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
return { (x...) in self(x...).tangentVector(from: cotangent(x...)) }
}
}
```
This is a crucial step towards the correct typing of curried differentiable functions, which helps us differentiate through curry thunks for methods.
```swift
func curry<T : Differentiable, U : Differentiable>(
f: @autodiff (T, U) -> V
) -> @autodiff (T) -> @autodiff (U) -> V {
return { x in { y in f(x, y) } }
}
```
Partially resolves [SR-9448](https://bugs.swift.org/browse/SR-9448), which needs this patch to be able to calculate the associate vector space of a curried function.
0 commit comments