Skip to content

Commit 6335b99

Browse files
authored
[AutoDiff] Update TypeBase::getAutoDiffAssociatedVectorSpace to handle function types. (#22166)
Formally, when a type `T where T : Differentiable` gets abstracted as a function `(X...) -> T` for any `X...`, the differentiability of the abstracted type depends entirely on the differentiability of `T`. Since structural types cannot conform to protocols yet in Swift, we need to handle in AD-associated type calculation the same way we handle tuples. The type calculation rules are better described as code, in imaginary syntax where parameterized extensions, variadic generic parameters, and protocol conformances for structural types are supported. ```swift extension<T..., U> ((T...) -> U) : Differentiable where U : Differentiable { public typealias TangentVector = (T...) -> U.TangentVector public typealias CotangentVector = (T...) -> U.CotangentVector public func moved(along direction: TangentVector) -> (T...) -> U { return { (x...) in self(x...).moved(along: direction(x...)) } } public func tangentVector(from cotangent: CotangentVector) -> TangentVector { return { (x...) in self(x...).tangentVector(from: cotangent(x...)) } } } ``` This is a crucial step towards the correct typing of curried differentiable functions, which helps us differentiate through curry thunks for methods. ```swift func curry<T : Differentiable, U : Differentiable>( f: @autodiff (T, U) -> V ) -> @autodiff (T) -> @autodiff (U) -> V { return { x in { y in f(x, y) } } } ``` Partially resolves [SR-9448](https://bugs.swift.org/browse/SR-9448), which needs this patch to be able to calculate the associate vector space of a curried function.
1 parent 0ae53cc commit 6335b99

File tree

4 files changed

+51
-17
lines changed

4 files changed

+51
-17
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,12 @@ class VectorSpace {
359359
public:
360360
/// A tangent space kind.
361361
enum class Kind {
362-
/// A type that conforms to `VectorNumeric`.
362+
/// A type that conforms to `AdditiveArithmetic`.
363363
Vector,
364364
/// A product of vector spaces as a tuple.
365365
Tuple,
366+
/// A function type whose innermost result conforms to `AdditiveArithmetic`.
367+
Function
366368
};
367369

368370
private:
@@ -372,9 +374,12 @@ class VectorSpace {
372374
Type vectorType;
373375
// Tuple
374376
TupleType *tupleType;
377+
// Function
378+
AnyFunctionType *functionType;
375379

376380
Value(Type vectorType) : vectorType(vectorType) {}
377381
Value(TupleType *tupleType) : tupleType(tupleType) {}
382+
Value(AnyFunctionType *functionType) : functionType(functionType) {}
378383
} value;
379384

380385
VectorSpace(Kind kind, Value value)
@@ -389,6 +394,9 @@ class VectorSpace {
389394
static VectorSpace getTuple(TupleType *tupleTy) {
390395
return {Kind::Tuple, tupleTy};
391396
}
397+
static VectorSpace getFunction(AnyFunctionType *fnTy) {
398+
return {Kind::Function, fnTy};
399+
}
392400

393401
bool isVector() const { return kind == Kind::Vector; }
394402
bool isTuple() const { return kind == Kind::Tuple; }
@@ -402,6 +410,10 @@ class VectorSpace {
402410
assert(kind == Kind::Tuple);
403411
return value.tupleType;
404412
}
413+
AnyFunctionType *getFunction() const {
414+
assert(kind == Kind::Function);
415+
return value.functionType;
416+
}
405417

406418
Type getType() const;
407419
CanType getCanonicalType() const;

lib/AST/AutoDiff.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ Type VectorSpace::getType() const {
289289
return value.vectorType;
290290
case Kind::Tuple:
291291
return value.tupleType;
292+
case Kind::Function:
293+
return value.functionType;
292294
}
293295
}
294296

lib/AST/Type.cpp

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4089,6 +4089,21 @@ Type TypeBase::openAnyExistentialType(ArchetypeType *&opened) {
40894089
}
40904090

40914091
// SWIFT_ENABLE_TENSORFLOW
4092+
// Makes a function with the same generic signature and extinfo as `copy`, but
4093+
// with `params` parameters and `retTy` return type.
4094+
static AnyFunctionType *
4095+
makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
4096+
Type retTy, GenericSignature *whereClauseGenSig) {
4097+
auto genericSignature = whereClauseGenSig;
4098+
if (!genericSignature)
4099+
if (auto *genericFunctionType = copy->getAs<GenericFunctionType>())
4100+
genericSignature = genericFunctionType->getGenericSignature();
4101+
if (genericSignature)
4102+
return GenericFunctionType::get(genericSignature, params, retTy,
4103+
copy->getExtInfo());
4104+
return FunctionType::get(params, retTy, copy->getExtInfo());
4105+
}
4106+
40924107
Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
40934108
AutoDiffAssociatedVectorSpaceKind kind,
40944109
LookupConformanceFn lookupConformance) {
@@ -4104,7 +4119,19 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
41044119
return vs;
41054120
};
41064121

4107-
// Tuples' Tangent/Cotangent is a tuple of each element's Tangent/Cotangent.
4122+
// Functions' tangent/cotangent is the same function except the innermost
4123+
// return type being replaced by its tangent/cotangent.
4124+
if (auto *fnTy = getAs<AnyFunctionType>()) {
4125+
auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedVectorSpace(
4126+
kind, lookupConformance);
4127+
if (!resultSpace)
4128+
return cache(None);
4129+
return VectorSpace::getFunction(
4130+
makeFunctionType(fnTy, fnTy->getParams(), resultSpace->getType(),
4131+
fnTy->getOptGenericSignature()));
4132+
}
4133+
4134+
// Tuples' tangent/cotangent is a tuple of each element's Tangent/Cotangent.
41084135
if (auto *tupleTy = getAs<TupleType>()) {
41094136
SmallVector<TupleTypeElt, 8> newElts;
41104137
for (auto elt : tupleTy->getElements()) {
@@ -4148,21 +4175,6 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
41484175
return cache(None);
41494176
}
41504177

4151-
// Makes a function with the same generic signature and extinfo as `copy`, but
4152-
// with `params` parameters and `retTy` return type.
4153-
static AnyFunctionType *
4154-
makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
4155-
Type retTy, GenericSignature *whereClauseGenSig) {
4156-
auto genericSignature = whereClauseGenSig;
4157-
if (!genericSignature)
4158-
if (auto *genericFunctionType = copy->getAs<GenericFunctionType>())
4159-
genericSignature = genericFunctionType->getGenericSignature();
4160-
if (genericSignature)
4161-
return GenericFunctionType::get(genericSignature, params, retTy,
4162-
copy->getExtInfo());
4163-
return FunctionType::get(params, retTy, copy->getExtInfo());
4164-
}
4165-
41664178
AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
41674179
AutoDiffParameterIndices *indices, unsigned resultIndex,
41684180
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3799,6 +3799,10 @@ SILValue AdjointEmitter::accumulateMaterializedAdjointsDirect(SILValue lhs,
37993799
}
38003800
return builder.createTuple(loc, adjointTy, adjElements);
38013801
}
3802+
case VectorSpace::Kind::Function: {
3803+
llvm_unreachable(
3804+
"Unimplemented: Emit thunks for abstracting adjoint accumulation");
3805+
}
38023806
}
38033807
}
38043808

@@ -3854,6 +3858,10 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
38543858
}
38553859
return;
38563860
}
3861+
case VectorSpace::Kind::Function: {
3862+
llvm_unreachable(
3863+
"Unimplemented: Emit thunks for abstracting adjoint accumulation");
3864+
}
38573865
}
38583866
}
38593867

0 commit comments

Comments
 (0)