Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify pullback calls in the reverse mode #802

Merged
merged 10 commits into from
Mar 18, 2024
Merged
9 changes: 5 additions & 4 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,12 @@ namespace clad {
/// otherwise returns false.
bool HasAnyReferenceOrPointerArgument(const clang::FunctionDecl* FD);

/// Returns true if `T` is a reference, pointer or array type.
/// Returns true if `arg` is an argument passed by reference or is of
/// pointer/array type.
///
/// \note Please note that this function returns true for array types as
/// well.
bool IsReferenceOrPointerType(clang::QualType T);
/// \note Please note that this function returns false for temporary
/// expressions.
bool IsReferenceOrPointerArg(const clang::Expr* arg);
parth-07 marked this conversation as resolved.
Show resolved Hide resolved

/// Returns true if `T1` and `T2` have same cononical type; otherwise
/// returns false.
Expand Down
15 changes: 7 additions & 8 deletions include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] CallArgs The orignal call arguments of the function call.
/// \param[in] ArgResultDecls The differentiated call arguments.
/// \param[in] numArgs The number of call args.
void EmitNestedFunctionParamError(
clang::FunctionDecl* fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, size_t numArgs);
void
EmitNestedFunctionParamError(clang::FunctionDecl* fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::Expr*>& ArgResult,
size_t numArgs);

/// Checks if a variable should be considered in error estimation.
///
Expand Down Expand Up @@ -181,16 +182,14 @@ class ErrorEstimationHandler : public ExternalRMVSource {
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) override;
void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&,
clang::BinaryOperator::Opcode&) override;
void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override;
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls,
bool hasAssignee) override;
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) override;
void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
llvm::SmallVectorImpl<clang::Decl*>& declsDiff) override;
Expand Down
4 changes: 2 additions & 2 deletions include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
virtual void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, bool asGrad) {}
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) {}

Check warning on line 130 in include/clad/Differentiator/ExternalRMVSource.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/ExternalRMVSource.h#L130

Added line #L130 was not covered by tests

/// This is called just before finalising processing of post and pre
/// increment and decrement operations.
Expand Down Expand Up @@ -157,7 +157,7 @@

virtual void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls, bool hasAssignee) {}
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) {}

Check warning on line 160 in include/clad/Differentiator/ExternalRMVSource.h

View check run for this annotation

Codecov / codecov/patch

include/clad/Differentiator/ExternalRMVSource.h#L160

Added line #L160 was not covered by tests

virtual void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
Expand Down
6 changes: 2 additions & 4 deletions include/clad/Differentiator/MultiplexExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) override;
void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) override;
void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&,
clang::BinaryOperatorKind& opCode) override;
Expand All @@ -60,8 +59,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls,
bool hasAssignee) override;
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) override;
void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
llvm::SmallVectorImpl<clang::Decl*>& declsDiff) override;
Expand Down
10 changes: 7 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,20 @@ namespace clad {
///
/// \param[in] targetFuncCall The function to get the derivative for.
/// \param[in] retType The return type of the target call expression.
/// \param[in] dfdx The dfdx corresponding to this call expression.
/// \param[in] numArgs The total number of 'args'.
/// \param[in] NumericalDiffMultiArg The built statements to add to block
/// later.
/// \param[in] PreCallStmts The built statements to add to block
/// before the call to the derived function.
/// \param[in] PostCallStmts The built statements to add to block
/// after the call to the derived function.
/// \param[in] args All the arguments to the target function.
/// \param[in] outputArgs The output gradient arguments.
///
/// \returns The derivative function call.
clang::Expr* GetMultiArgCentralDiffCall(
PetroZarytskyi marked this conversation as resolved.
Show resolved Hide resolved
clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs,
llvm::SmallVectorImpl<clang::Stmt*>& NumericalDiffMultiArg,
clang::Expr* dfdx, llvm::SmallVectorImpl<clang::Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<clang::Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
/// Emits diagnostic messages on differentiation (or lack thereof) for
Expand Down
8 changes: 6 additions & 2 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,12 @@ namespace clad {
return false;
}

bool IsReferenceOrPointerType(QualType T) {
return T->isReferenceType() || isArrayOrPointerType(T);
bool IsReferenceOrPointerArg(const Expr* arg) {
// The argument is passed by reference if it's passed as an L-value.
// However, if arg is a MaterializeTemporaryExpr, then arg is a
// temporary variable passed as a const reference.
bool isRefType = arg->isLValue() && !isa<MaterializeTemporaryExpr>(arg);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why is lvalue a reference type?

int a = b; // a is an l-value, but not a reference.
int &a_ref = a; // a_ref is an l-value and a reference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg is supposed to be the argument expression passed to the function. If the function expects a ref-type argument, then arg is an l-value (usually a DeclRefExpr). But when it expects a non-ref type argument, it is implicitly converted to an r-value. The AST of arg will look somewhat like this:

ImplicitCastExpr <l-value to r-value>
-DeclRefExpr

So arg will be an r-value. At least this is my understanding.

return isRefType || isArrayOrPointerType(arg->getType());
}

bool SameCanonicalType(clang::QualType T1, clang::QualType T2) {
Expand Down
10 changes: 5 additions & 5 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void ErrorEstimationHandler::SaveReturnExpr(Expr* retExpr) {

void ErrorEstimationHandler::EmitNestedFunctionParamError(
FunctionDecl* fnDecl, llvm::SmallVectorImpl<Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<VarDecl*>& ArgResultDecls, size_t numArgs) {
llvm::SmallVectorImpl<Expr*>& ArgResult, size_t numArgs) {
assert(fnDecl && "Must have a value");
for (size_t i = 0; i < numArgs; i++) {
if (!fnDecl->getParamDecl(0)->getType()->isLValueReferenceType())
Expand All @@ -109,7 +109,7 @@ void ErrorEstimationHandler::EmitNestedFunctionParamError(
// if (utils::IsReferenceOrPointerType(fnDecl->getParamDecl(i)->getType()))
// continue;
Expr* errorExpr = m_EstModel->AssignError(
{derivedCallArgs[i], m_RMV->BuildDeclRef(ArgResultDecls[i])},
{derivedCallArgs[i], m_RMV->Clone(ArgResult[i])},
fnDecl->getNameInfo().getAsString() + "_param_" + std::to_string(i));
Expr* errorStmt = m_RMV->BuildOp(BO_AddAssign, m_FinalError, errorExpr);
m_ReverseErrorStmts.push_back(errorStmt);
Expand Down Expand Up @@ -372,7 +372,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) {
void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<VarDecl*>& ArgResultDecls, bool asGrad) {
llvm::SmallVectorImpl<Expr*>& ArgResult, bool asGrad) {
if (OverloadedDerivedFn && asGrad) {
// Derivative was found.
FunctionDecl* fnDecl =
Expand All @@ -382,7 +382,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr(
// in the input prameters (if of reference type) to call and save to
// emit them later.

EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResultDecls,
EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResult,
CE->getNumArgs());
}
}
Expand Down Expand Up @@ -416,7 +416,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingDifferentiateSingleExpr(

void ErrorEstimationHandler::ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<DeclStmt*>& ArgDecls, bool hasAssignee) {
llvm::SmallVectorImpl<Stmt*>& ArgDecls, bool hasAssignee) {
auto errorRef =
m_RMV->BuildVarDecl(m_RMV->m_Context.DoubleTy, "_t",
m_RMV->getZeroInit(m_RMV->m_Context.DoubleTy));
Expand Down
8 changes: 4 additions & 4 deletions lib/Differentiator/MultiplexExternalRMVSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingVisitReturnStmt(
void MultiplexExternalRMVSource::ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, bool asGrad) {
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) {
for (auto source : m_Sources) {
source->ActBeforeFinalizingVisitCallExpr(CE, OverloadedDerivedFn, derivedCallArgs,
ArgResultDecls, asGrad);
source->ActBeforeFinalizingVisitCallExpr(
CE, OverloadedDerivedFn, derivedCallArgs, ArgResult, asGrad);
}
}

Expand Down Expand Up @@ -199,7 +199,7 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingDifferentiateSingleExpr(

void MultiplexExternalRMVSource::ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls, bool hasAssignee) {
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) {
for (auto source : m_Sources)
source->ActBeforeDifferentiatingCallExpr(pullbackArgs, ArgDecls,
hasAssignee);
Expand Down
Loading
Loading