Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,24 @@ class SILFunctionType;
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
enum class SILLinkage : uint8_t;

enum class DifferentiabilityKind: uint8_t {
enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0b00,
Normal = 0b01,
Linear = 0b11
};

// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`.
enum class NormalDifferentiableFunctionTypeComponent : uint8_t {
Original = 0,
JVP = 1,
VJP = 2
};

enum class LinearDifferentiableFunctionTypeComponent : uint8_t {
Original = 0,
Transpose = 1
};

class ParsedAutoDiffParameter {
public:
enum class Kind { Named, Ordered, Self };
Expand Down
9 changes: 7 additions & 2 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4220,14 +4220,19 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

CanSILFunctionType getWithoutDifferentiability();

/// Returns the type of a differentiation function that is associated with
/// a function of this type.
/// Returns the type of the derivative function.
CanSILFunctionType getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFunctionGenericSignature = nullptr);

/// Returns the type of the transpose function.
CanSILFunctionType getAutoDiffTransposeFunctionType(
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFunctionGenericSignature = nullptr);

/// Returns a bit vector that specifices which parameters you can
/// differentiate with respect to for this differentiable function type. (e.g.
/// which parameters are not `@nondiff`). The function type must be
Expand Down
245 changes: 201 additions & 44 deletions lib/IRGen/GenDiffFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,28 @@
using namespace swift;
using namespace irgen;

using DiffFuncIndex = DifferentiableFunctionExtractee;

//----------------------------------------------------------------------------//
// `@differentiable` (non-linear) function type info
//----------------------------------------------------------------------------//
namespace {
class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
class DifferentiableFuncFieldInfo final
: public RecordField<DifferentiableFuncFieldInfo> {
public:
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type,
IndexSubset *parameterIndices)
: RecordField(type), Index(index), ParameterIndices(parameterIndices) {}
DifferentiableFuncFieldInfo(
DifferentiableFunctionExtractee component, const TypeInfo &type,
IndexSubset *parameterIndices)
: RecordField(type), component(component),
parameterIndices(parameterIndices) {}

/// The field index.
const DiffFuncIndex Index;
const DifferentiableFunctionExtractee component;

/// The parameter indices.
IndexSubset *ParameterIndices;
IndexSubset *parameterIndices;

std::string getFieldName() const {
switch (Index) {
switch (component) {
case DifferentiableFunctionExtractee::Original:
return "original";
case DifferentiableFunctionExtractee::JVP:
Expand All @@ -61,32 +66,32 @@ class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
SILType getType(IRGenModule &IGM, SILType t) const {
auto fnTy = t.castTo<SILFunctionType>();
auto origFnTy = fnTy->getWithoutDifferentiability();
if (Index == DifferentiableFunctionExtractee::Original)
if (component == DifferentiableFunctionExtractee::Original)
return SILType::getPrimitiveObjectType(origFnTy);
auto kind = *Index.getExtracteeAsDerivativeFunction();
auto kind = *component.getExtracteeAsDerivativeFunction();
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
ParameterIndices, /*resultIndex*/ 0, kind,
parameterIndices, /*resultIndex*/ 0, kind,
IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(assocTy);
}
};

class DiffFuncTypeInfo final
: public RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo,
DiffFuncFieldInfo> {
class DifferentiableFuncTypeInfo final
: public RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
DifferentiableFuncFieldInfo> {
using super =
RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo, DiffFuncFieldInfo>;
RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo, DifferentiableFuncFieldInfo>;

public:
DiffFuncTypeInfo(ArrayRef<DiffFuncFieldInfo> fields, unsigned explosionSize,
llvm::Type *ty, Size size, SpareBitVector &&spareBits,
Alignment align, IsPOD_t isPOD,
IsFixedSize_t alwaysFixedSize)
DifferentiableFuncTypeInfo(
ArrayRef<DifferentiableFuncFieldInfo> fields, unsigned explosionSize,
llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
isPOD, alwaysFixedSize) {}

Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
const DiffFuncFieldInfo &field) const {
const DifferentiableFuncFieldInfo &field) const {
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
}

Expand All @@ -110,50 +115,52 @@ class DiffFuncTypeInfo final
}
};

class DiffFuncTypeBuilder
: public RecordTypeBuilder<DiffFuncTypeBuilder, DiffFuncFieldInfo,
DiffFuncIndex> {
class DifferentiableFuncTypeBuilder
: public RecordTypeBuilder<DifferentiableFuncTypeBuilder, DifferentiableFuncFieldInfo,
DifferentiableFunctionExtractee> {

SILFunctionType *origFnTy;
SILFunctionType *originalType;
IndexSubset *parameterIndices;

public:
DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
: RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability()),
DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
: RecordTypeBuilder(IGM),
originalType(fnTy->getWithoutDifferentiability()),
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
assert(fnTy->isDifferentiable());
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal);
}

TypeInfo *createFixed(ArrayRef<DiffFuncFieldInfo> fields,
TypeInfo *createFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}

DiffFuncTypeInfo *createLoadable(ArrayRef<DiffFuncFieldInfo> fields,
StructLayout &&layout,
unsigned explosionSize) {
return DiffFuncTypeInfo::create(
DifferentiableFuncTypeInfo *createLoadable(
ArrayRef<DifferentiableFuncFieldInfo> fields, StructLayout &&layout,
unsigned explosionSize) {
return DifferentiableFuncTypeInfo::create(
fields, explosionSize, layout.getType(), layout.getSize(),
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
layout.isAlwaysFixedSize());
}

TypeInfo *createNonFixed(ArrayRef<DiffFuncFieldInfo> fields,
TypeInfo *createNonFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
FieldsAreABIAccessible_t fieldsAccessible,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}

DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field,
const TypeInfo &fieldTI) {
return DiffFuncFieldInfo(field, fieldTI, parameterIndices);
DifferentiableFuncFieldInfo getFieldInfo(
unsigned index, DifferentiableFunctionExtractee component,
const TypeInfo &fieldTI) {
return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices);
}

SILType getType(DiffFuncIndex field) {
if (field == DifferentiableFunctionExtractee::Original)
return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType());
auto kind = *field.getExtracteeAsDerivativeFunction();
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
SILType getType(DifferentiableFunctionExtractee component) {
if (component == DifferentiableFunctionExtractee::Original)
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
auto kind = *component.getExtracteeAsDerivativeFunction();
auto assocTy = originalType->getAutoDiffDerivativeFunctionType(
parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(),
LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(assocTy);
Expand All @@ -166,11 +173,161 @@ class DiffFuncTypeBuilder
};
} // end anonymous namespace

//----------------------------------------------------------------------------//
// `@differentiable(linear)` function type info
//----------------------------------------------------------------------------//
namespace {
class LinearFuncFieldInfo final : public RecordField<LinearFuncFieldInfo> {
public:
LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component,
const TypeInfo &type, IndexSubset *parameterIndices)
: RecordField(type), component(component),
parameterIndices(parameterIndices) {}

/// The field index.
const LinearDifferentiableFunctionTypeComponent component;

/// The parameter indices.
IndexSubset *parameterIndices;

std::string getFieldName() const {
switch (component) {
case LinearDifferentiableFunctionTypeComponent::Original:
return "original";
case LinearDifferentiableFunctionTypeComponent::Transpose:
return "transpose";
}
}

SILType getType(IRGenModule &IGM, SILType t) const {
auto fnTy = t.castTo<SILFunctionType>();
auto origFnTy = fnTy->getWithoutDifferentiability();
switch (component) {
case LinearDifferentiableFunctionTypeComponent::Original:
return SILType::getPrimitiveObjectType(origFnTy);
case LinearDifferentiableFunctionTypeComponent::Transpose:
auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType(
parameterIndices, IGM.getSILTypes(),
LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(transposeTy);
}
}
};

class LinearFuncTypeInfo final
: public RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo,
LinearFuncFieldInfo> {
using super =
RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo, LinearFuncFieldInfo>;

public:
LinearFuncTypeInfo(
ArrayRef<LinearFuncFieldInfo> fields, unsigned explosionSize,
llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
isPOD, alwaysFixedSize) {}

Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
const LinearFuncFieldInfo &field) const {
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
}

void initializeFromParams(IRGenFunction &IGF, Explosion &params, Address src,
SILType T, bool isOutlined) const override {
llvm_unreachable("unexploded @differentiable function as argument?");
}

void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
Size offset) const override {
for (auto &field : getFields()) {
auto fieldOffset = offset + field.getFixedByteOffset();
cast<LoadableTypeInfo>(field.getTypeInfo())
.addToAggLowering(IGM, lowering, fieldOffset);
}
}

llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
return None;
}
};

class LinearFuncTypeBuilder
: public RecordTypeBuilder<LinearFuncTypeBuilder, LinearFuncFieldInfo,
LinearDifferentiableFunctionTypeComponent> {

SILFunctionType *originalType;
IndexSubset *parameterIndices;

public:
LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
: RecordTypeBuilder(IGM),
originalType(fnTy->getWithoutDifferentiability()),
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear);
}

TypeInfo *createFixed(ArrayRef<LinearFuncFieldInfo> fields,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}

LinearFuncTypeInfo *createLoadable(ArrayRef<LinearFuncFieldInfo> fields,
StructLayout &&layout,
unsigned explosionSize) {
return LinearFuncTypeInfo::create(
fields, explosionSize, layout.getType(), layout.getSize(),
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
layout.isAlwaysFixedSize());
}

TypeInfo *createNonFixed(ArrayRef<LinearFuncFieldInfo> fields,
FieldsAreABIAccessible_t fieldsAccessible,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}

LinearFuncFieldInfo getFieldInfo(
unsigned index, LinearDifferentiableFunctionTypeComponent field,
const TypeInfo &fieldTI) {
return LinearFuncFieldInfo(field, fieldTI, parameterIndices);
}

SILType getType(LinearDifferentiableFunctionTypeComponent component) {
switch (component) {
case LinearDifferentiableFunctionTypeComponent::Original:
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
case LinearDifferentiableFunctionTypeComponent::Transpose:
auto transposeTy = originalType->getAutoDiffTransposeFunctionType(
parameterIndices, IGM.getSILTypes(),
LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(transposeTy);
}
}

StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
LayoutStrategy::Universal, fieldTypes);
}
};
} // end anonymous namespace

//----------------------------------------------------------------------------//
// Type converter entry points
//----------------------------------------------------------------------------//

const TypeInfo *
TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) {
assert(type->isDifferentiable());
DiffFuncTypeBuilder builder(IGM, type);
TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) {
DifferentiableFuncTypeBuilder builder(IGM, type);
return builder.layout({DifferentiableFunctionExtractee::Original,
DifferentiableFunctionExtractee::JVP,
DifferentiableFunctionExtractee::VJP});
}

const TypeInfo *
TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) {
LinearFuncTypeBuilder builder(IGM, type);
return builder.layout({LinearDifferentiableFunctionTypeComponent::Original,
LinearDifferentiableFunctionTypeComponent::Transpose});
}
10 changes: 8 additions & 2 deletions lib/IRGen/GenFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,14 @@ Address irgen::projectBlockStorageCapture(IRGenFunction &IGF,

const TypeInfo *TypeConverter::convertFunctionType(SILFunctionType *T) {
// SWIFT_ENABLE_TENSORFLOW
if (T->isDifferentiable())
return convertDifferentiableFunctionType(T);
switch (T->getDifferentiabilityKind()) {
case DifferentiabilityKind::Normal:
return convertNormalDifferentiableFunctionType(T);
case DifferentiabilityKind::Linear:
return convertLinearDifferentiableFunctionType(T);
case DifferentiabilityKind::NonDifferentiable:
break;
}

switch (T->getRepresentation()) {
case SILFunctionType::Representation::Block:
Expand Down
3 changes: 2 additions & 1 deletion lib/IRGen/GenType.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class TypeConverter {
const TypeInfo *convertStructType(TypeBase *key, CanType type, StructDecl *D);
const TypeInfo *convertFunctionType(SILFunctionType *T);
// SWIFT_ENABLE_TENSORFLOW
const TypeInfo *convertDifferentiableFunctionType(SILFunctionType *T);
const TypeInfo *convertNormalDifferentiableFunctionType(SILFunctionType *T);
const TypeInfo *convertLinearDifferentiableFunctionType(SILFunctionType *T);
const TypeInfo *convertBlockStorageType(SILBlockStorageType *T);
const TypeInfo *convertBoxType(SILBoxType *T);
const TypeInfo *convertArchetypeType(ArchetypeType *T);
Expand Down
Loading