32
32
using namespace swift ;
33
33
using namespace irgen ;
34
34
35
- using DiffFuncIndex = DifferentiableFunctionExtractee;
36
35
36
+ // ----------------------------------------------------------------------------//
37
+ // `@differentiable` (non-linear) function type info
38
+ // ----------------------------------------------------------------------------//
37
39
namespace {
38
- class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo > {
40
+ class DifferentiableFuncFieldInfo final : public RecordField<DifferentiableFuncFieldInfo > {
39
41
public:
40
- DiffFuncFieldInfo (DiffFuncIndex index, const TypeInfo &type,
41
- IndexSubset *parameterIndices)
42
- : RecordField(type), Index(index), ParameterIndices(parameterIndices) {}
42
+ DifferentiableFuncFieldInfo (
43
+ DifferentiableFunctionExtractee component, const TypeInfo &type,
44
+ IndexSubset *parameterIndices)
45
+ : RecordField(type), component(component),
46
+ parameterIndices (parameterIndices) {}
43
47
44
48
// / The field index.
45
- const DiffFuncIndex Index ;
49
+ const DifferentiableFunctionExtractee component ;
46
50
47
51
// / The parameter indices.
48
- IndexSubset *ParameterIndices ;
52
+ IndexSubset *parameterIndices ;
49
53
50
54
std::string getFieldName () const {
51
- switch (Index ) {
55
+ switch (component ) {
52
56
case DifferentiableFunctionExtractee::Original:
53
57
return " original" ;
54
58
case DifferentiableFunctionExtractee::JVP:
@@ -61,32 +65,32 @@ class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
61
65
SILType getType (IRGenModule &IGM, SILType t) const {
62
66
auto fnTy = t.castTo <SILFunctionType>();
63
67
auto origFnTy = fnTy->getWithoutDifferentiability ();
64
- if (Index == DifferentiableFunctionExtractee::Original)
68
+ if (component == DifferentiableFunctionExtractee::Original)
65
69
return SILType::getPrimitiveObjectType (origFnTy);
66
- auto kind = *Index .getExtracteeAsDerivativeFunction ();
70
+ auto kind = *component .getExtracteeAsDerivativeFunction ();
67
71
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType (
68
- ParameterIndices , /* resultIndex*/ 0 , kind,
72
+ parameterIndices , /* resultIndex*/ 0 , kind,
69
73
IGM.getSILTypes (), LookUpConformanceInModule (IGM.getSwiftModule ()));
70
74
return SILType::getPrimitiveObjectType (assocTy);
71
75
}
72
76
};
73
77
74
- class DiffFuncTypeInfo final
75
- : public RecordTypeInfo<DiffFuncTypeInfo , LoadableTypeInfo,
76
- DiffFuncFieldInfo > {
78
+ class DifferentiableFuncTypeInfo final
79
+ : public RecordTypeInfo<DifferentiableFuncTypeInfo , LoadableTypeInfo,
80
+ DifferentiableFuncFieldInfo > {
77
81
using super =
78
- RecordTypeInfo<DiffFuncTypeInfo , LoadableTypeInfo, DiffFuncFieldInfo >;
82
+ RecordTypeInfo<DifferentiableFuncTypeInfo , LoadableTypeInfo, DifferentiableFuncFieldInfo >;
79
83
80
84
public:
81
- DiffFuncTypeInfo (ArrayRef<DiffFuncFieldInfo> fields, unsigned explosionSize,
82
- llvm::Type *ty, Size size, SpareBitVector &&spareBits ,
83
- Alignment align, IsPOD_t isPOD ,
84
- IsFixedSize_t alwaysFixedSize)
85
+ DifferentiableFuncTypeInfo (
86
+ ArrayRef<DifferentiableFuncFieldInfo> fields, unsigned explosionSize ,
87
+ llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
88
+ IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
85
89
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
86
90
isPOD, alwaysFixedSize) {}
87
91
88
92
Address projectFieldAddress (IRGenFunction &IGF, Address addr, SILType T,
89
- const DiffFuncFieldInfo &field) const {
93
+ const DifferentiableFuncFieldInfo &field) const {
90
94
return field.projectAddress (IGF, addr, getNonFixedOffsets (IGF, T));
91
95
}
92
96
@@ -110,50 +114,52 @@ class DiffFuncTypeInfo final
110
114
}
111
115
};
112
116
113
- class DiffFuncTypeBuilder
114
- : public RecordTypeBuilder<DiffFuncTypeBuilder, DiffFuncFieldInfo ,
115
- DiffFuncIndex > {
117
+ class DifferentiableFuncTypeBuilder
118
+ : public RecordTypeBuilder<DifferentiableFuncTypeBuilder, DifferentiableFuncFieldInfo ,
119
+ DifferentiableFunctionExtractee > {
116
120
117
- SILFunctionType *origFnTy ;
121
+ SILFunctionType *originalType ;
118
122
IndexSubset *parameterIndices;
119
123
120
124
public:
121
- DiffFuncTypeBuilder (IRGenModule &IGM, SILFunctionType *fnTy)
122
- : RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability ()),
125
+ DifferentiableFuncTypeBuilder (IRGenModule &IGM, SILFunctionType *fnTy)
126
+ : RecordTypeBuilder(IGM),
127
+ originalType (fnTy->getWithoutDifferentiability ()),
123
128
parameterIndices(fnTy->getDifferentiationParameterIndices ()) {
124
- assert (fnTy->isDifferentiable () );
129
+ assert (fnTy->getDifferentiabilityKind () == DifferentiabilityKind::Normal );
125
130
}
126
131
127
- TypeInfo *createFixed (ArrayRef<DiffFuncFieldInfo > fields,
132
+ TypeInfo *createFixed (ArrayRef<DifferentiableFuncFieldInfo > fields,
128
133
StructLayout &&layout) {
129
134
llvm_unreachable (" @differentiable functions are always loadable" );
130
135
}
131
136
132
- DiffFuncTypeInfo *createLoadable (ArrayRef<DiffFuncFieldInfo> fields,
133
- StructLayout &&layout,
134
- unsigned explosionSize) {
135
- return DiffFuncTypeInfo ::create (
137
+ DifferentiableFuncTypeInfo *createLoadable (
138
+ ArrayRef<DifferentiableFuncFieldInfo> fields, StructLayout &&layout,
139
+ unsigned explosionSize) {
140
+ return DifferentiableFuncTypeInfo ::create (
136
141
fields, explosionSize, layout.getType (), layout.getSize (),
137
142
std::move (layout.getSpareBits ()), layout.getAlignment (), layout.isPOD (),
138
143
layout.isAlwaysFixedSize ());
139
144
}
140
145
141
- TypeInfo *createNonFixed (ArrayRef<DiffFuncFieldInfo > fields,
146
+ TypeInfo *createNonFixed (ArrayRef<DifferentiableFuncFieldInfo > fields,
142
147
FieldsAreABIAccessible_t fieldsAccessible,
143
148
StructLayout &&layout) {
144
149
llvm_unreachable (" @differentiable functions are always loadable" );
145
150
}
146
151
147
- DiffFuncFieldInfo getFieldInfo (unsigned index, DiffFuncIndex field,
148
- const TypeInfo &fieldTI) {
149
- return DiffFuncFieldInfo (field, fieldTI, parameterIndices);
152
+ DifferentiableFuncFieldInfo getFieldInfo (
153
+ unsigned index, DifferentiableFunctionExtractee component,
154
+ const TypeInfo &fieldTI) {
155
+ return DifferentiableFuncFieldInfo (component, fieldTI, parameterIndices);
150
156
}
151
157
152
- SILType getType (DiffFuncIndex field ) {
153
- if (field == DifferentiableFunctionExtractee::Original)
154
- return SILType::getPrimitiveObjectType (origFnTy ->getCanonicalType ());
155
- auto kind = *field .getExtracteeAsDerivativeFunction ();
156
- auto assocTy = origFnTy ->getAutoDiffDerivativeFunctionType (
158
+ SILType getType (DifferentiableFunctionExtractee component ) {
159
+ if (component == DifferentiableFunctionExtractee::Original)
160
+ return SILType::getPrimitiveObjectType (originalType ->getCanonicalType ());
161
+ auto kind = *component .getExtracteeAsDerivativeFunction ();
162
+ auto assocTy = originalType ->getAutoDiffDerivativeFunctionType (
157
163
parameterIndices, /* resultIndex*/ 0 , kind, IGM.getSILTypes (),
158
164
LookUpConformanceInModule (IGM.getSwiftModule ()));
159
165
return SILType::getPrimitiveObjectType (assocTy);
@@ -166,11 +172,161 @@ class DiffFuncTypeBuilder
166
172
};
167
173
} // end anonymous namespace
168
174
175
+ // ----------------------------------------------------------------------------//
176
+ // `@differentiable(linear)` function type info
177
+ // ----------------------------------------------------------------------------//
178
+ namespace {
179
+ class LinearFuncFieldInfo final : public RecordField<LinearFuncFieldInfo> {
180
+ public:
181
+ LinearFuncFieldInfo (LinearDifferentiableFunctionTypeComponent component,
182
+ const TypeInfo &type, IndexSubset *parameterIndices)
183
+ : RecordField(type), component(component),
184
+ parameterIndices (parameterIndices) {}
185
+
186
+ // / The field index.
187
+ const LinearDifferentiableFunctionTypeComponent component;
188
+
189
+ // / The parameter indices.
190
+ IndexSubset *parameterIndices;
191
+
192
+ std::string getFieldName () const {
193
+ switch (component) {
194
+ case LinearDifferentiableFunctionTypeComponent::Original:
195
+ return " original" ;
196
+ case LinearDifferentiableFunctionTypeComponent::Transpose:
197
+ return " transpose" ;
198
+ }
199
+ }
200
+
201
+ SILType getType (IRGenModule &IGM, SILType t) const {
202
+ auto fnTy = t.castTo <SILFunctionType>();
203
+ auto origFnTy = fnTy->getWithoutDifferentiability ();
204
+ switch (component) {
205
+ case LinearDifferentiableFunctionTypeComponent::Original:
206
+ return SILType::getPrimitiveObjectType (origFnTy);
207
+ case LinearDifferentiableFunctionTypeComponent::Transpose:
208
+ auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType (
209
+ parameterIndices, IGM.getSILTypes (),
210
+ LookUpConformanceInModule (IGM.getSwiftModule ()));
211
+ return SILType::getPrimitiveObjectType (transposeTy);
212
+ }
213
+ }
214
+ };
215
+
216
+ class LinearFuncTypeInfo final
217
+ : public RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo,
218
+ LinearFuncFieldInfo> {
219
+ using super =
220
+ RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo, LinearFuncFieldInfo>;
221
+
222
+ public:
223
+ LinearFuncTypeInfo (
224
+ ArrayRef<LinearFuncFieldInfo> fields, unsigned explosionSize,
225
+ llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
226
+ IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
227
+ : super(fields, explosionSize, ty, size, std::move(spareBits), align,
228
+ isPOD, alwaysFixedSize) {}
229
+
230
+ Address projectFieldAddress (IRGenFunction &IGF, Address addr, SILType T,
231
+ const LinearFuncFieldInfo &field) const {
232
+ return field.projectAddress (IGF, addr, getNonFixedOffsets (IGF, T));
233
+ }
234
+
235
+ void initializeFromParams (IRGenFunction &IGF, Explosion ¶ms, Address src,
236
+ SILType T, bool isOutlined) const override {
237
+ llvm_unreachable (" unexploded @differentiable function as argument?" );
238
+ }
239
+
240
+ void addToAggLowering (IRGenModule &IGM, SwiftAggLowering &lowering,
241
+ Size offset) const override {
242
+ for (auto &field : getFields ()) {
243
+ auto fieldOffset = offset + field.getFixedByteOffset ();
244
+ cast<LoadableTypeInfo>(field.getTypeInfo ())
245
+ .addToAggLowering (IGM, lowering, fieldOffset);
246
+ }
247
+ }
248
+
249
+ llvm::NoneType getNonFixedOffsets (IRGenFunction &IGF) const { return None; }
250
+ llvm::NoneType getNonFixedOffsets (IRGenFunction &IGF, SILType T) const {
251
+ return None;
252
+ }
253
+ };
254
+
255
+ class LinearFuncTypeBuilder
256
+ : public RecordTypeBuilder<LinearFuncTypeBuilder, LinearFuncFieldInfo,
257
+ LinearDifferentiableFunctionTypeComponent> {
258
+
259
+ SILFunctionType *originalType;
260
+ IndexSubset *parameterIndices;
261
+
262
+ public:
263
+ LinearFuncTypeBuilder (IRGenModule &IGM, SILFunctionType *fnTy)
264
+ : RecordTypeBuilder(IGM),
265
+ originalType (fnTy->getWithoutDifferentiability ()),
266
+ parameterIndices(fnTy->getDifferentiationParameterIndices ()) {
267
+ assert (fnTy->getDifferentiabilityKind () == DifferentiabilityKind::Linear);
268
+ }
269
+
270
+ TypeInfo *createFixed (ArrayRef<LinearFuncFieldInfo> fields,
271
+ StructLayout &&layout) {
272
+ llvm_unreachable (" @differentiable functions are always loadable" );
273
+ }
274
+
275
+ LinearFuncTypeInfo *createLoadable (ArrayRef<LinearFuncFieldInfo> fields,
276
+ StructLayout &&layout,
277
+ unsigned explosionSize) {
278
+ return LinearFuncTypeInfo::create (
279
+ fields, explosionSize, layout.getType (), layout.getSize (),
280
+ std::move (layout.getSpareBits ()), layout.getAlignment (), layout.isPOD (),
281
+ layout.isAlwaysFixedSize ());
282
+ }
283
+
284
+ TypeInfo *createNonFixed (ArrayRef<LinearFuncFieldInfo> fields,
285
+ FieldsAreABIAccessible_t fieldsAccessible,
286
+ StructLayout &&layout) {
287
+ llvm_unreachable (" @differentiable functions are always loadable" );
288
+ }
289
+
290
+ LinearFuncFieldInfo getFieldInfo (
291
+ unsigned index, LinearDifferentiableFunctionTypeComponent field,
292
+ const TypeInfo &fieldTI) {
293
+ return LinearFuncFieldInfo (field, fieldTI, parameterIndices);
294
+ }
295
+
296
+ SILType getType (LinearDifferentiableFunctionTypeComponent component) {
297
+ switch (component) {
298
+ case LinearDifferentiableFunctionTypeComponent::Original:
299
+ return SILType::getPrimitiveObjectType (originalType->getCanonicalType ());
300
+ case LinearDifferentiableFunctionTypeComponent::Transpose:
301
+ auto transposeTy = originalType->getAutoDiffTransposeFunctionType (
302
+ parameterIndices, IGM.getSILTypes (),
303
+ LookUpConformanceInModule (IGM.getSwiftModule ()));
304
+ return SILType::getPrimitiveObjectType (transposeTy);
305
+ }
306
+ }
307
+
308
+ StructLayout performLayout (ArrayRef<const TypeInfo *> fieldTypes) {
309
+ return StructLayout (IGM, /* decl=*/ nullptr , LayoutKind::NonHeapObject,
310
+ LayoutStrategy::Universal, fieldTypes);
311
+ }
312
+ };
313
+ } // end anonymous namespace
314
+
315
+ // ----------------------------------------------------------------------------//
316
+ // Type converter entry points
317
+ // ----------------------------------------------------------------------------//
318
+
169
319
const TypeInfo *
170
- TypeConverter::convertDifferentiableFunctionType (SILFunctionType *type) {
171
- assert (type->isDifferentiable ());
172
- DiffFuncTypeBuilder builder (IGM, type);
320
+ TypeConverter::convertNormalDifferentiableFunctionType (SILFunctionType *type) {
321
+ DifferentiableFuncTypeBuilder builder (IGM, type);
173
322
return builder.layout ({DifferentiableFunctionExtractee::Original,
174
323
DifferentiableFunctionExtractee::JVP,
175
324
DifferentiableFunctionExtractee::VJP});
176
325
}
326
+
327
+ const TypeInfo *
328
+ TypeConverter::convertLinearDifferentiableFunctionType (SILFunctionType *type) {
329
+ LinearFuncTypeBuilder builder (IGM, type);
330
+ return builder.layout ({LinearDifferentiableFunctionTypeComponent::Original,
331
+ LinearDifferentiableFunctionTypeComponent::Transpose});
332
+ }
0 commit comments