Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,12 @@ class VectorSpace {
public:
/// A tangent space kind.
enum class Kind {
/// A type that conforms to `VectorNumeric`.
/// A type that conforms to `AdditiveArithmetic`.
Vector,
/// A product of vector spaces as a tuple.
Tuple,
/// A function type whose innermost result conforms to `AdditiveArithmetic`.
Function
};

private:
Expand All @@ -372,9 +374,12 @@ class VectorSpace {
Type vectorType;
// Tuple
TupleType *tupleType;
// Function
AnyFunctionType *functionType;

Value(Type vectorType) : vectorType(vectorType) {}
Value(TupleType *tupleType) : tupleType(tupleType) {}
Value(AnyFunctionType *functionType) : functionType(functionType) {}
} value;

VectorSpace(Kind kind, Value value)
Expand All @@ -389,6 +394,9 @@ class VectorSpace {
static VectorSpace getTuple(TupleType *tupleTy) {
return {Kind::Tuple, tupleTy};
}
static VectorSpace getFunction(AnyFunctionType *fnTy) {
return {Kind::Function, fnTy};
}

bool isVector() const { return kind == Kind::Vector; }
bool isTuple() const { return kind == Kind::Tuple; }
Expand All @@ -402,6 +410,10 @@ class VectorSpace {
assert(kind == Kind::Tuple);
return value.tupleType;
}
AnyFunctionType *getFunction() const {
assert(kind == Kind::Function);
return value.functionType;
}

Type getType() const;
CanType getCanonicalType() const;
Expand Down
2 changes: 2 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ Type VectorSpace::getType() const {
return value.vectorType;
case Kind::Tuple:
return value.tupleType;
case Kind::Function:
return value.functionType;
}
}

Expand Down
44 changes: 28 additions & 16 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4089,6 +4089,21 @@ Type TypeBase::openAnyExistentialType(ArchetypeType *&opened) {
}

// SWIFT_ENABLE_TENSORFLOW
// Makes a function with the same generic signature and extinfo as `copy`, but
// with `params` parameters and `retTy` return type.
static AnyFunctionType *
makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
Type retTy, GenericSignature *whereClauseGenSig) {
auto genericSignature = whereClauseGenSig;
if (!genericSignature)
if (auto *genericFunctionType = copy->getAs<GenericFunctionType>())
genericSignature = genericFunctionType->getGenericSignature();
if (genericSignature)
return GenericFunctionType::get(genericSignature, params, retTy,
copy->getExtInfo());
return FunctionType::get(params, retTy, copy->getExtInfo());
}

Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind kind,
LookupConformanceFn lookupConformance) {
Expand All @@ -4104,7 +4119,19 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
return vs;
};

// Tuples' Tangent/Cotangent is a tuple of each element's Tangent/Cotangent.
// Functions' tangent/cotangent is the same function except the innermost
// return type being replaced by its tangent/cotangent.
if (auto *fnTy = getAs<AnyFunctionType>()) {
auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedVectorSpace(
kind, lookupConformance);
if (!resultSpace)
return cache(None);
return VectorSpace::getFunction(
makeFunctionType(fnTy, fnTy->getParams(), resultSpace->getType(),
fnTy->getOptGenericSignature()));
}

// Tuples' tangent/cotangent is a tuple of each element's Tangent/Cotangent.
if (auto *tupleTy = getAs<TupleType>()) {
SmallVector<TupleTypeElt, 8> newElts;
for (auto elt : tupleTy->getElements()) {
Expand Down Expand Up @@ -4148,21 +4175,6 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
return cache(None);
}

// Makes a function with the same generic signature and extinfo as `copy`, but
// with `params` parameters and `retTy` return type.
static AnyFunctionType *
makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
Type retTy, GenericSignature *whereClauseGenSig) {
auto genericSignature = whereClauseGenSig;
if (!genericSignature)
if (auto *genericFunctionType = copy->getAs<GenericFunctionType>())
genericSignature = genericFunctionType->getGenericSignature();
if (genericSignature)
return GenericFunctionType::get(genericSignature, params, retTy,
copy->getExtInfo());
return FunctionType::get(params, retTy, copy->getExtInfo());
}

AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
AutoDiffParameterIndices *indices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
Expand Down
8 changes: 8 additions & 0 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3819,6 +3819,10 @@ SILValue AdjointEmitter::accumulateMaterializedAdjointsDirect(SILValue lhs,
}
return builder.createTuple(loc, adjointTy, adjElements);
}
case VectorSpace::Kind::Function: {
llvm_unreachable(
"Unimplemented: Emit thunks for abstracting adjoint accumulation");
}
}
}

Expand Down Expand Up @@ -3874,6 +3878,10 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
}
return;
}
case VectorSpace::Kind::Function: {
llvm_unreachable(
"Unimplemented: Emit thunks for abstracting adjoint accumulation");
}
}
}

Expand Down