@@ -43,78 +43,89 @@ class SILDifferentiabilityWitness
4343{
4444private:
4545 // / The module which contains the differentiability witness.
46- SILModule &module ;
46+ SILModule &Module ;
4747 // / The linkage of the differentiability witness.
48- SILLinkage linkage ;
48+ SILLinkage Linkage ;
4949 // / The original function.
50- SILFunction *originalFunction ;
50+ SILFunction *OriginalFunction ;
5151 // / The autodiff configuration: parameter indices, result indices, derivative
5252 // / generic signature (optional).
53- AutoDiffConfig config ;
53+ AutoDiffConfig Config ;
5454 // / The JVP (Jacobian-vector products) derivative function.
55- SILFunction *jvp ;
55+ SILFunction *JVP ;
5656 // / The VJP (vector-Jacobian products) derivative function.
57- SILFunction *vjp;
57+ SILFunction *VJP;
58+ // / Whether or not this differentiability witness is a declaration.
59+ bool IsDeclaration;
5860 // / Whether or not this differentiability witness is serialized, which allows
5961 // / devirtualization from another module.
60- bool serialized ;
62+ bool IsSerialized ;
6163 // / The AST `@differentiable` or `@differentiating` attribute from which the
6264 // / differentiability witness is generated. Used for diagnostics.
6365 // / Null if the differentiability witness is parsed from SIL or if it is
6466 // / deserialized.
65- DeclAttribute *attribute = nullptr ;
67+ DeclAttribute *Attribute = nullptr ;
6668
6769 SILDifferentiabilityWitness (SILModule &module , SILLinkage linkage,
6870 SILFunction *originalFunction,
6971 IndexSubset *parameterIndices,
7072 IndexSubset *resultIndices,
7173 GenericSignature derivativeGenSig,
7274 SILFunction *jvp, SILFunction *vjp,
73- bool isSerialized, DeclAttribute *attribute)
74- : module (module ), linkage(linkage), originalFunction(originalFunction),
75- config (parameterIndices, resultIndices, derivativeGenSig.getPointer()),
76- jvp(jvp), vjp(vjp), serialized(isSerialized), attribute(attribute) {}
75+ bool isDeclaration, bool isSerialized,
76+ DeclAttribute *attribute)
77+ : Module(module ), Linkage(linkage), OriginalFunction(originalFunction),
78+ Config (parameterIndices, resultIndices, derivativeGenSig.getPointer()),
79+ JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
80+ IsSerialized(isSerialized), Attribute(attribute) {}
7781
7882public:
79- static SILDifferentiabilityWitness *create (
83+ static SILDifferentiabilityWitness *createDeclaration (
84+ SILModule &module , SILLinkage linkage, SILFunction *originalFunction,
85+ IndexSubset *parameterIndices, IndexSubset *resultIndices,
86+ GenericSignature derivativeGenSig, DeclAttribute *attribute = nullptr );
87+
88+ static SILDifferentiabilityWitness *createDefinition (
8089 SILModule &module , SILLinkage linkage, SILFunction *originalFunction,
8190 IndexSubset *parameterIndices, IndexSubset *resultIndices,
8291 GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
8392 bool isSerialized, DeclAttribute *attribute = nullptr );
8493
8594 SILDifferentiabilityWitnessKey getKey () const ;
86- SILModule &getModule () const { return module ; }
87- SILLinkage getLinkage () const { return linkage ; }
88- SILFunction *getOriginalFunction () const { return originalFunction ; }
89- const AutoDiffConfig &getConfig () const { return config ; }
95+ SILModule &getModule () const { return Module ; }
96+ SILLinkage getLinkage () const { return Linkage ; }
97+ SILFunction *getOriginalFunction () const { return OriginalFunction ; }
98+ const AutoDiffConfig &getConfig () const { return Config ; }
9099 IndexSubset *getParameterIndices () const {
91- return config .parameterIndices ;
100+ return Config .parameterIndices ;
92101 }
93102 IndexSubset *getResultIndices () const {
94- return config .resultIndices ;
103+ return Config .resultIndices ;
95104 }
96105 GenericSignature getDerivativeGenericSignature () const {
97- return config .derivativeGenericSignature ;
106+ return Config .derivativeGenericSignature ;
98107 }
99- SILFunction *getJVP () const { return jvp ; }
100- SILFunction *getVJP () const { return vjp ; }
108+ SILFunction *getJVP () const { return JVP ; }
109+ SILFunction *getVJP () const { return VJP ; }
101110 SILFunction *getDerivative (AutoDiffDerivativeFunctionKind kind) const {
102111 switch (kind) {
103- case AutoDiffDerivativeFunctionKind::JVP: return jvp ;
104- case AutoDiffDerivativeFunctionKind::VJP: return vjp ;
112+ case AutoDiffDerivativeFunctionKind::JVP: return JVP ;
113+ case AutoDiffDerivativeFunctionKind::VJP: return VJP ;
105114 }
106115 }
107- void setJVP (SILFunction *jvp) { this -> jvp = jvp; }
108- void setVJP (SILFunction *vjp) { this -> vjp = vjp; }
116+ void setJVP (SILFunction *jvp) { JVP = jvp; }
117+ void setVJP (SILFunction *vjp) { VJP = vjp; }
109118 void setDerivative (AutoDiffDerivativeFunctionKind kind,
110119 SILFunction *derivative) {
111120 switch (kind) {
112- case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break ;
113- case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break ;
121+ case AutoDiffDerivativeFunctionKind::JVP: JVP = derivative; break ;
122+ case AutoDiffDerivativeFunctionKind::VJP: VJP = derivative; break ;
114123 }
115124 }
116- bool isSerialized () const { return serialized; }
117- DeclAttribute *getAttribute () const { return attribute; }
125+ bool isDeclaration () const { return IsDeclaration; }
126+ bool isDefinition () const { return !IsDeclaration; }
127+ bool isSerialized () const { return IsSerialized; }
128+ DeclAttribute *getAttribute () const { return Attribute; }
118129
119130 // / Verify that the differentiability witness is well-formed.
120131 void verify (const SILModule &module ) const ;
0 commit comments