Skip to content

Commit 39211c0

Browse files
committed
[AutoDiff] [IRGen] Lower @differentiable(linear) function types.
1 parent 5e52226 commit 39211c0

File tree

8 files changed

+469
-92
lines changed

8 files changed

+469
-92
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,24 @@ class SILFunctionType;
3939
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
4040
enum class SILLinkage : uint8_t;
4141

42-
enum class DifferentiabilityKind: uint8_t {
42+
enum class DifferentiabilityKind : uint8_t {
4343
NonDifferentiable = 0b00,
4444
Normal = 0b01,
4545
Linear = 0b11
4646
};
4747

48+
// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`.
49+
enum class NormalDifferentiableFunctionTypeComponent : uint8_t {
50+
Original = 0,
51+
JVP = 1,
52+
VJP = 2
53+
};
54+
55+
enum class LinearDifferentiableFunctionTypeComponent : uint8_t {
56+
Original = 0,
57+
Transpose = 1
58+
};
59+
4860
class ParsedAutoDiffParameter {
4961
public:
5062
enum class Kind { Named, Ordered, Self };

include/swift/AST/Types.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4220,14 +4220,19 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42204220

42214221
CanSILFunctionType getWithoutDifferentiability();
42224222

4223-
/// Returns the type of a differentiation function that is associated with
4224-
/// a function of this type.
4223+
/// Returns the type of the derivative function.
42254224
CanSILFunctionType getAutoDiffDerivativeFunctionType(
42264225
IndexSubset *parameterIndices, unsigned resultIndex,
42274226
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
42284227
LookupConformanceFn lookupConformance,
42294228
CanGenericSignature derivativeFunctionGenericSignature = nullptr);
42304229

4230+
/// Returns the type of the transpose function.
4231+
CanSILFunctionType getAutoDiffTransposeFunctionType(
4232+
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
4233+
LookupConformanceFn lookupConformance,
4234+
CanGenericSignature derivativeFunctionGenericSignature = nullptr);
4235+
42314236
/// Returns a bit vector that specifices which parameters you can
42324237
/// differentiate with respect to for this differentiable function type. (e.g.
42334238
/// which parameters are not `@nondiff`). The function type must be

lib/IRGen/GenDiffFunc.cpp

Lines changed: 200 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,27 @@
3232
using namespace swift;
3333
using namespace irgen;
3434

35-
using DiffFuncIndex = DifferentiableFunctionExtractee;
3635

36+
//----------------------------------------------------------------------------//
37+
// `@differentiable` (non-linear) function type info
38+
//----------------------------------------------------------------------------//
3739
namespace {
38-
class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
40+
class DifferentiableFuncFieldInfo final : public RecordField<DifferentiableFuncFieldInfo> {
3941
public:
40-
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type,
41-
IndexSubset *parameterIndices)
42-
: RecordField(type), Index(index), ParameterIndices(parameterIndices) {}
42+
DifferentiableFuncFieldInfo(
43+
DifferentiableFunctionExtractee component, const TypeInfo &type,
44+
IndexSubset *parameterIndices)
45+
: RecordField(type), component(component),
46+
parameterIndices(parameterIndices) {}
4347

4448
/// The field index.
45-
const DiffFuncIndex Index;
49+
const DifferentiableFunctionExtractee component;
4650

4751
/// The parameter indices.
48-
IndexSubset *ParameterIndices;
52+
IndexSubset *parameterIndices;
4953

5054
std::string getFieldName() const {
51-
switch (Index) {
55+
switch (component) {
5256
case DifferentiableFunctionExtractee::Original:
5357
return "original";
5458
case DifferentiableFunctionExtractee::JVP:
@@ -61,32 +65,32 @@ class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
6165
SILType getType(IRGenModule &IGM, SILType t) const {
6266
auto fnTy = t.castTo<SILFunctionType>();
6367
auto origFnTy = fnTy->getWithoutDifferentiability();
64-
if (Index == DifferentiableFunctionExtractee::Original)
68+
if (component == DifferentiableFunctionExtractee::Original)
6569
return SILType::getPrimitiveObjectType(origFnTy);
66-
auto kind = *Index.getExtracteeAsDerivativeFunction();
70+
auto kind = *component.getExtracteeAsDerivativeFunction();
6771
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
68-
ParameterIndices, /*resultIndex*/ 0, kind,
72+
parameterIndices, /*resultIndex*/ 0, kind,
6973
IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule()));
7074
return SILType::getPrimitiveObjectType(assocTy);
7175
}
7276
};
7377

74-
class DiffFuncTypeInfo final
75-
: public RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo,
76-
DiffFuncFieldInfo> {
78+
class DifferentiableFuncTypeInfo final
79+
: public RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
80+
DifferentiableFuncFieldInfo> {
7781
using super =
78-
RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo, DiffFuncFieldInfo>;
82+
RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo, DifferentiableFuncFieldInfo>;
7983

8084
public:
81-
DiffFuncTypeInfo(ArrayRef<DiffFuncFieldInfo> fields, unsigned explosionSize,
82-
llvm::Type *ty, Size size, SpareBitVector &&spareBits,
83-
Alignment align, IsPOD_t isPOD,
84-
IsFixedSize_t alwaysFixedSize)
85+
DifferentiableFuncTypeInfo(
86+
ArrayRef<DifferentiableFuncFieldInfo> fields, unsigned explosionSize,
87+
llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
88+
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
8589
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
8690
isPOD, alwaysFixedSize) {}
8791

8892
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
89-
const DiffFuncFieldInfo &field) const {
93+
const DifferentiableFuncFieldInfo &field) const {
9094
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
9195
}
9296

@@ -110,50 +114,52 @@ class DiffFuncTypeInfo final
110114
}
111115
};
112116

113-
class DiffFuncTypeBuilder
114-
: public RecordTypeBuilder<DiffFuncTypeBuilder, DiffFuncFieldInfo,
115-
DiffFuncIndex> {
117+
class DifferentiableFuncTypeBuilder
118+
: public RecordTypeBuilder<DifferentiableFuncTypeBuilder, DifferentiableFuncFieldInfo,
119+
DifferentiableFunctionExtractee> {
116120

117-
SILFunctionType *origFnTy;
121+
SILFunctionType *originalType;
118122
IndexSubset *parameterIndices;
119123

120124
public:
121-
DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
122-
: RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability()),
125+
DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
126+
: RecordTypeBuilder(IGM),
127+
originalType(fnTy->getWithoutDifferentiability()),
123128
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
124-
assert(fnTy->isDifferentiable());
129+
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal);
125130
}
126131

127-
TypeInfo *createFixed(ArrayRef<DiffFuncFieldInfo> fields,
132+
TypeInfo *createFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
128133
StructLayout &&layout) {
129134
llvm_unreachable("@differentiable functions are always loadable");
130135
}
131136

132-
DiffFuncTypeInfo *createLoadable(ArrayRef<DiffFuncFieldInfo> fields,
133-
StructLayout &&layout,
134-
unsigned explosionSize) {
135-
return DiffFuncTypeInfo::create(
137+
DifferentiableFuncTypeInfo *createLoadable(
138+
ArrayRef<DifferentiableFuncFieldInfo> fields, StructLayout &&layout,
139+
unsigned explosionSize) {
140+
return DifferentiableFuncTypeInfo::create(
136141
fields, explosionSize, layout.getType(), layout.getSize(),
137142
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
138143
layout.isAlwaysFixedSize());
139144
}
140145

141-
TypeInfo *createNonFixed(ArrayRef<DiffFuncFieldInfo> fields,
146+
TypeInfo *createNonFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
142147
FieldsAreABIAccessible_t fieldsAccessible,
143148
StructLayout &&layout) {
144149
llvm_unreachable("@differentiable functions are always loadable");
145150
}
146151

147-
DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field,
148-
const TypeInfo &fieldTI) {
149-
return DiffFuncFieldInfo(field, fieldTI, parameterIndices);
152+
DifferentiableFuncFieldInfo getFieldInfo(
153+
unsigned index, DifferentiableFunctionExtractee component,
154+
const TypeInfo &fieldTI) {
155+
return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices);
150156
}
151157

152-
SILType getType(DiffFuncIndex field) {
153-
if (field == DifferentiableFunctionExtractee::Original)
154-
return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType());
155-
auto kind = *field.getExtracteeAsDerivativeFunction();
156-
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
158+
SILType getType(DifferentiableFunctionExtractee component) {
159+
if (component == DifferentiableFunctionExtractee::Original)
160+
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
161+
auto kind = *component.getExtracteeAsDerivativeFunction();
162+
auto assocTy = originalType->getAutoDiffDerivativeFunctionType(
157163
parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(),
158164
LookUpConformanceInModule(IGM.getSwiftModule()));
159165
return SILType::getPrimitiveObjectType(assocTy);
@@ -166,11 +172,161 @@ class DiffFuncTypeBuilder
166172
};
167173
} // end anonymous namespace
168174

175+
//----------------------------------------------------------------------------//
176+
// `@differentiable(linear)` function type info
177+
//----------------------------------------------------------------------------//
178+
namespace {
179+
class LinearFuncFieldInfo final : public RecordField<LinearFuncFieldInfo> {
180+
public:
181+
LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component,
182+
const TypeInfo &type, IndexSubset *parameterIndices)
183+
: RecordField(type), component(component),
184+
parameterIndices(parameterIndices) {}
185+
186+
/// The field index.
187+
const LinearDifferentiableFunctionTypeComponent component;
188+
189+
/// The parameter indices.
190+
IndexSubset *parameterIndices;
191+
192+
std::string getFieldName() const {
193+
switch (component) {
194+
case LinearDifferentiableFunctionTypeComponent::Original:
195+
return "original";
196+
case LinearDifferentiableFunctionTypeComponent::Transpose:
197+
return "transpose";
198+
}
199+
}
200+
201+
SILType getType(IRGenModule &IGM, SILType t) const {
202+
auto fnTy = t.castTo<SILFunctionType>();
203+
auto origFnTy = fnTy->getWithoutDifferentiability();
204+
switch (component) {
205+
case LinearDifferentiableFunctionTypeComponent::Original:
206+
return SILType::getPrimitiveObjectType(origFnTy);
207+
case LinearDifferentiableFunctionTypeComponent::Transpose:
208+
auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType(
209+
parameterIndices, IGM.getSILTypes(),
210+
LookUpConformanceInModule(IGM.getSwiftModule()));
211+
return SILType::getPrimitiveObjectType(transposeTy);
212+
}
213+
}
214+
};
215+
216+
class LinearFuncTypeInfo final
217+
: public RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo,
218+
LinearFuncFieldInfo> {
219+
using super =
220+
RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo, LinearFuncFieldInfo>;
221+
222+
public:
223+
LinearFuncTypeInfo(
224+
ArrayRef<LinearFuncFieldInfo> fields, unsigned explosionSize,
225+
llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
226+
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
227+
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
228+
isPOD, alwaysFixedSize) {}
229+
230+
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
231+
const LinearFuncFieldInfo &field) const {
232+
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
233+
}
234+
235+
void initializeFromParams(IRGenFunction &IGF, Explosion &params, Address src,
236+
SILType T, bool isOutlined) const override {
237+
llvm_unreachable("unexploded @differentiable function as argument?");
238+
}
239+
240+
void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
241+
Size offset) const override {
242+
for (auto &field : getFields()) {
243+
auto fieldOffset = offset + field.getFixedByteOffset();
244+
cast<LoadableTypeInfo>(field.getTypeInfo())
245+
.addToAggLowering(IGM, lowering, fieldOffset);
246+
}
247+
}
248+
249+
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
250+
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
251+
return None;
252+
}
253+
};
254+
255+
class LinearFuncTypeBuilder
256+
: public RecordTypeBuilder<LinearFuncTypeBuilder, LinearFuncFieldInfo,
257+
LinearDifferentiableFunctionTypeComponent> {
258+
259+
SILFunctionType *originalType;
260+
IndexSubset *parameterIndices;
261+
262+
public:
263+
LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
264+
: RecordTypeBuilder(IGM),
265+
originalType(fnTy->getWithoutDifferentiability()),
266+
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
267+
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear);
268+
}
269+
270+
TypeInfo *createFixed(ArrayRef<LinearFuncFieldInfo> fields,
271+
StructLayout &&layout) {
272+
llvm_unreachable("@differentiable functions are always loadable");
273+
}
274+
275+
LinearFuncTypeInfo *createLoadable(ArrayRef<LinearFuncFieldInfo> fields,
276+
StructLayout &&layout,
277+
unsigned explosionSize) {
278+
return LinearFuncTypeInfo::create(
279+
fields, explosionSize, layout.getType(), layout.getSize(),
280+
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
281+
layout.isAlwaysFixedSize());
282+
}
283+
284+
TypeInfo *createNonFixed(ArrayRef<LinearFuncFieldInfo> fields,
285+
FieldsAreABIAccessible_t fieldsAccessible,
286+
StructLayout &&layout) {
287+
llvm_unreachable("@differentiable functions are always loadable");
288+
}
289+
290+
LinearFuncFieldInfo getFieldInfo(
291+
unsigned index, LinearDifferentiableFunctionTypeComponent field,
292+
const TypeInfo &fieldTI) {
293+
return LinearFuncFieldInfo(field, fieldTI, parameterIndices);
294+
}
295+
296+
SILType getType(LinearDifferentiableFunctionTypeComponent component) {
297+
switch (component) {
298+
case LinearDifferentiableFunctionTypeComponent::Original:
299+
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
300+
case LinearDifferentiableFunctionTypeComponent::Transpose:
301+
auto transposeTy = originalType->getAutoDiffTransposeFunctionType(
302+
parameterIndices, IGM.getSILTypes(),
303+
LookUpConformanceInModule(IGM.getSwiftModule()));
304+
return SILType::getPrimitiveObjectType(transposeTy);
305+
}
306+
}
307+
308+
StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
309+
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
310+
LayoutStrategy::Universal, fieldTypes);
311+
}
312+
};
313+
} // end anonymous namespace
314+
315+
//----------------------------------------------------------------------------//
316+
// Type converter entry points
317+
//----------------------------------------------------------------------------//
318+
169319
const TypeInfo *
170-
TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) {
171-
assert(type->isDifferentiable());
172-
DiffFuncTypeBuilder builder(IGM, type);
320+
TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) {
321+
DifferentiableFuncTypeBuilder builder(IGM, type);
173322
return builder.layout({DifferentiableFunctionExtractee::Original,
174323
DifferentiableFunctionExtractee::JVP,
175324
DifferentiableFunctionExtractee::VJP});
176325
}
326+
327+
const TypeInfo *
328+
TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) {
329+
LinearFuncTypeBuilder builder(IGM, type);
330+
return builder.layout({LinearDifferentiableFunctionTypeComponent::Original,
331+
LinearDifferentiableFunctionTypeComponent::Transpose});
332+
}

lib/IRGen/GenFunc.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,14 @@ Address irgen::projectBlockStorageCapture(IRGenFunction &IGF,
480480

481481
const TypeInfo *TypeConverter::convertFunctionType(SILFunctionType *T) {
482482
// SWIFT_ENABLE_TENSORFLOW
483-
if (T->isDifferentiable())
484-
return convertDifferentiableFunctionType(T);
483+
switch (T->getDifferentiabilityKind()) {
484+
case DifferentiabilityKind::Normal:
485+
return convertNormalDifferentiableFunctionType(T);
486+
case DifferentiabilityKind::Linear:
487+
return convertLinearDifferentiableFunctionType(T);
488+
case DifferentiabilityKind::NonDifferentiable:
489+
break;
490+
}
485491

486492
switch (T->getRepresentation()) {
487493
case SILFunctionType::Representation::Block:

lib/IRGen/GenType.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ class TypeConverter {
138138
const TypeInfo *convertStructType(TypeBase *key, CanType type, StructDecl *D);
139139
const TypeInfo *convertFunctionType(SILFunctionType *T);
140140
// SWIFT_ENABLE_TENSORFLOW
141-
const TypeInfo *convertDifferentiableFunctionType(SILFunctionType *T);
141+
const TypeInfo *convertNormalDifferentiableFunctionType(SILFunctionType *T);
142+
const TypeInfo *convertLinearDifferentiableFunctionType(SILFunctionType *T);
142143
const TypeInfo *convertBlockStorageType(SILBlockStorageType *T);
143144
const TypeInfo *convertBoxType(SILBoxType *T);
144145
const TypeInfo *convertArchetypeType(ArchetypeType *T);

0 commit comments

Comments
 (0)