Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 135 additions & 142 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,141 @@ class OriginallyDefinedInAttr: public DeclAttribute {
}
};

/// A declaration name with location.
struct DeclNameWithLoc {
DeclName Name;
DeclNameLoc Loc;
};

/// Attribute that marks a function as differentiable and optionally specifies
/// custom associated derivative functions: 'jvp' and 'vjp'.
///
/// Examples:
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
class DifferentiableAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DifferentiableAttr,
ParsedAutoDiffParameter> {
friend TrailingObjects;

/// Whether this function is linear (optional).
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Optional<DeclNameWithLoc> JVP;
/// The VJP function.
Optional<DeclNameWithLoc> VJP;
/// The JVP function (optional), resolved by the type checker if JVP name is
/// specified.
FuncDecl *JVPFunction = nullptr;
/// The VJP function (optional), resolved by the type checker if VJP name is
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
/// The generic signature for autodiff associated functions. Resolved by the
/// type checker based on the original function's generic signature and the
/// attribute's where clause requirements. This is set only if the attribute
/// has a where clause.
GenericSignature DerivativeGenericSignature;

explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenericSignature);

public:
static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

static DifferentiableAttr *create(AbstractFunctionDecl *original,
bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig);

/// Get the optional 'jvp:' function name and location.
/// Use this instead of `getJVPFunction` to check whether the attribute has a
/// registered JVP.
Optional<DeclNameWithLoc> getJVP() const { return JVP; }

/// Get the optional 'vjp:' function name and location.
/// Use this instead of `getVJPFunction` to check whether the attribute has a
/// registered VJP.
Optional<DeclNameWithLoc> getVJP() const { return VJP; }

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *parameterIndices) {
ParameterIndices = parameterIndices;
}

/// The parsed differentiation parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

bool isLinear() const { return Linear; }

TrailingWhereClause *getWhereClause() const { return WhereClause; }

GenericSignature getDerivativeGenericSignature() const {
return DerivativeGenericSignature;
}
void setDerivativeGenericSignature(GenericSignature derivativeGenSig) {
DerivativeGenericSignature = derivativeGenSig;
}

FuncDecl *getJVPFunction() const { return JVPFunction; }
void setJVPFunction(FuncDecl *decl);
FuncDecl *getVJPFunction() const { return VJPFunction; }
void setVJPFunction(FuncDecl *decl);

/// Get the derivative generic environment for the given `@differentiable`
/// attribute and original function.
GenericEnvironment *
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;

// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitAssociatedFunctions` is true, omit printing associated functions.
void print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause = false,
bool omitAssociatedFunctions = false) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
}
};

/// Attributes that may be applied to declarations.
class DeclAttributes {
/// Linked list of declaration attributes.
Expand Down Expand Up @@ -1764,148 +1899,6 @@ class DeclAttributes {
SourceLoc getStartLoc(bool forModifiers = false) const;
};

/// A declaration name with location.
struct DeclNameWithLoc {
DeclName Name;
DeclNameLoc Loc;
};

/// Attribute that marks a function as differentiable and optionally specifies
/// custom associated derivative functions: 'jvp' and 'vjp'.
///
/// Examples:
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
class DifferentiableAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DifferentiableAttr,
ParsedAutoDiffParameter> {
friend TrailingObjects;

/// Whether this function is linear (optional).
bool linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Optional<DeclNameWithLoc> JVP;
/// The VJP function.
Optional<DeclNameWithLoc> VJP;
/// The JVP function (optional), resolved by the type checker if JVP name is
/// specified.
FuncDecl *JVPFunction = nullptr;
/// The VJP function (optional), resolved by the type checker if VJP name is
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
/// The generic signature for autodiff associated functions. Resolved by the
/// type checker based on the original function's generic signature and the
/// attribute's where clause requirements. This is set only if the attribute
/// has a where clause.
GenericSignature DerivativeGenericSignature;

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenericSignature);

public:
static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause);

static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig);

/// Get the optional 'jvp:' function name and location.
/// Use this instead of `getJVPFunction` to check whether the attribute has a
/// registered JVP.
Optional<DeclNameWithLoc> getJVP() const { return JVP; }

/// Get the optional 'vjp:' function name and location.
/// Use this instead of `getVJPFunction` to check whether the attribute has a
/// registered VJP.
Optional<DeclNameWithLoc> getVJP() const { return VJP; }

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *pi) {
ParameterIndices = pi;
}

/// The parsed differentiation parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

bool isLinear() const { return linear; }

TrailingWhereClause *getWhereClause() const { return WhereClause; }

GenericSignature getDerivativeGenericSignature() const {
return DerivativeGenericSignature;
}
void setDerivativeGenericSignature(ASTContext &context,
GenericSignature derivativeGenSig) {
DerivativeGenericSignature = derivativeGenSig;
}

FuncDecl *getJVPFunction() const { return JVPFunction; }
void setJVPFunction(FuncDecl *decl);
FuncDecl *getVJPFunction() const { return VJPFunction; }
void setVJPFunction(FuncDecl *decl);

bool parametersMatch(const DifferentiableAttr &other) const {
assert(ParameterIndices && other.ParameterIndices);
return ParameterIndices == other.ParameterIndices;
}

/// Get the derivative generic environment for the given `@differentiable`
/// attribute and original function.
GenericEnvironment *
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;

// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitAssociatedFunctions` is true, omit printing associated functions.
void print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause = false,
bool omitAssociatedFunctions = false) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
}
};


void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);

inline SourceLoc extractNearestSourceLoc(const DeclAttribute *attr) {
Expand Down
36 changes: 18 additions & 18 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,31 +1352,30 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
specializedSignature);
}

DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
VJP(std::move(vjp)), WhereClause(clause) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
IndexSubset *indices,
IndexSubset *parameterIndices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
ParameterIndices(indices) {
setDerivativeGenericSignature(context, derivativeGenSig);
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
setParameterIndices(parameterIndices);
setDerivativeGenericSignature(derivativeGenSig);
}

DifferentiableAttr *
Expand All @@ -1389,22 +1388,23 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
TrailingWhereClause *clause) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
linear, parameters, std::move(jvp),
return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear,
parameters, std::move(jvp),
std::move(vjp), clause);
}

DifferentiableAttr *
DifferentiableAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, IndexSubset *indices,
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
SourceLoc atLoc, SourceRange baseRange, bool linear,
IndexSubset *parameterIndices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig) {
void *mem = context.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
linear, indices, std::move(jvp),
auto &ctx = original->getASTContext();
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
linear, parameterIndices, std::move(jvp),
std::move(vjp), derivativeGenSig);
}

Expand Down
Loading