Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numerical differentiation support to clad #261

Merged
merged 1 commit into from Aug 22, 2021

Conversation

grimmmyshini
Copy link
Collaborator

@grimmmyshini grimmmyshini commented Jul 15, 2021

This functionality will allow clad to be able to use numerical differentiation as a backup (or even standalone) in the case that a differentiation request is made for a function that is not visible to clad. Full list of possible functionalities is as given below:

PR Checklist:

  • Add support for using numerical diff as a backup for unsupported function calls in forward-mode.
  • Add support for using numerical diff as a backup for unsupported function calls in reverse-mode.
  • Add support for usage of numerical diff as a standalone function (i.e. clad::numerical_diff).
  • Add support for numerical differentiation of multi-arg calls.
  • Extend support of numerical diff to differentiate unsupported operators.
  • Support printing of error information for the numerical methods.

Good to implement:

  • Add support for non-primitive types like user structs, tensors etc.

Known bugs:

  • Standalone code does not work because central_difference is not specialized correctly.

@codecov
Copy link

codecov bot commented Jul 15, 2021

Codecov Report

Merging #261 (d20aedc) into master (b0fa87a) will increase coverage by 0.25%.
The diff coverage is 95.96%.

❗ Current head d20aedc differs from pull request most recent head 53db6d1. Consider uploading reports for the commit 53db6d1 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #261      +/-   ##
==========================================
+ Coverage   88.20%   88.46%   +0.25%     
==========================================
  Files          28       28              
  Lines        3636     3735      +99     
==========================================
+ Hits         3207     3304      +97     
- Misses        429      431       +2     
Impacted Files Coverage Δ
include/clad/Differentiator/VisitorBase.h 100.00% <ø> (ø)
lib/Differentiator/ForwardModeVisitor.cpp 90.87% <87.50%> (-0.36%) ⬇️
lib/Differentiator/ReverseModeVisitor.cpp 90.97% <96.15%> (+0.31%) ⬆️
include/clad/Differentiator/DerivativeBuilder.h 100.00% <100.00%> (ø)
lib/Differentiator/DerivativeBuilder.cpp 100.00% <100.00%> (ø)
lib/Differentiator/VisitorBase.cpp 96.28% <100.00%> (+0.60%) ⬆️
tools/ClangPlugin.cpp 88.97% <100.00%> (+0.16%) ⬆️
tools/ClangPlugin.h 73.21% <100.00%> (+1.51%) ⬆️
Impacted Files Coverage Δ
include/clad/Differentiator/VisitorBase.h 100.00% <ø> (ø)
lib/Differentiator/ForwardModeVisitor.cpp 90.87% <87.50%> (-0.36%) ⬇️
lib/Differentiator/ReverseModeVisitor.cpp 90.97% <96.15%> (+0.31%) ⬆️
include/clad/Differentiator/DerivativeBuilder.h 100.00% <100.00%> (ø)
lib/Differentiator/DerivativeBuilder.cpp 100.00% <100.00%> (ø)
lib/Differentiator/VisitorBase.cpp 96.28% <100.00%> (+0.60%) ⬆️
tools/ClangPlugin.cpp 88.97% <100.00%> (+0.16%) ⬆️
tools/ClangPlugin.h 73.21% <100.00%> (+1.51%) ⬆️

@grimmmyshini grimmmyshini marked this pull request as ready for review August 13, 2021 18:35
@grimmmyshini grimmmyshini force-pushed the central-diff branch 2 times, most recently from 7a9ebec to e9bc67d Compare August 17, 2021 16:09
Copy link
Owner

@vgvassilev vgvassilev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests for the diagnostics branches? Codecov tells where they are.

demos/CustomTypeNumDiff.cpp Outdated Show resolved Hide resolved
demos/CustomTypeNumDiff.cpp Outdated Show resolved Hide resolved
demos/CustomTypeNumDiff.cpp Outdated Show resolved Hide resolved
include/clad/Differentiator/DerivativeBuilder.h Outdated Show resolved Hide resolved
include/clad/Differentiator/NumericalDiff.h Outdated Show resolved Hide resolved
include/clad/Differentiator/VisitorBase.h Outdated Show resolved Hide resolved
include/clad/Differentiator/VisitorBase.h Outdated Show resolved Hide resolved
lib/Differentiator/VisitorBase.cpp Outdated Show resolved Hide resolved
lib/Differentiator/VisitorBase.cpp Outdated Show resolved Hide resolved
test/lit.cfg Show resolved Hide resolved
@grimmmyshini
Copy link
Collaborator Author

Could you add tests for the diagnostics branches? Codecov tells where they are.

That is blocked right now. When clad fails to derive a function, it crashes. Let me put up an issue for this soon.

Copy link
Owner

@vgvassilev vgvassilev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think the more I convince myself that PR, at some stage, should turn into a callback class based implementation. It is an opt-in piece and also somewhat an AD bail out plan. Let's resolve the comments we got so far and discuss how to move forward given the limited amount of time we have left.

demos/CustomTypeNumDiff.cpp Outdated Show resolved Hide resolved
demos/CustomTypeNumDiff.cpp Outdated Show resolved Hide resolved
demos/CustomTypeNumDiff.cpp Outdated Show resolved Hide resolved
include/clad/Differentiator/NumericalDiff.h Show resolved Hide resolved
include/clad/Differentiator/NumericalDiff.h Outdated Show resolved Hide resolved
lib/Differentiator/ForwardModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/ReverseModeVisitor.cpp Outdated Show resolved Hide resolved
Comment on lines 1373 to 1410
// Try numerically deriving it.
if (NArgs == 1) {
// Build a clone call expression so that we can correctly
// scope the function to be differentiated.
Expr* call = m_Sema
.ActOnCallExpr(getCurrentScope(),
Clone(CE->getCallee()),
noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs),
noLoc)
.get();
Expr* fnCallee = cast<CallExpr>(call)->getCallee();
OverloadedDerivedFn = GetSingleArgCentralDiffCall(fnCallee,
DerivedCallArgs
[0],
0, 1,
DerivedCallArgs);
asGrad = !OverloadedDerivedFn;
} else {
// Build a clone call expression so that we can correctly
// scope the function to be differentiated.
Expr* call = m_Sema
.ActOnCallExpr(getCurrentScope(),
Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(
CallArgs),
noLoc)
.get();
Expr* fnCallee = cast<CallExpr>(call)->getCallee();
OverloadedDerivedFn = GetMultiArgCentralDiffCall(
fnCallee, CEType.getCanonicalType(), CE->getNumArgs(),
NumericalDiffMultiArg, DerivedCallArgs, DerivedCallOutputArgs);
}
if (!OverloadedDerivedFn) {
// Function was not derived => issue a warning.
diag(DiagnosticsEngine::Warning,
CE->getBeginLoc(),
"function '%0' was not differentiated because clad failed to "
"differentiate it and no suitable overload was found in "
"namespace 'custom_derivatives', and function may not be "
"eligible for numerical differentiation.",
{FD->getNameAsString()});
return StmtDiff(Clone(CE));
} else {
diag(DiagnosticsEngine::Remark, noLoc,
"Falling back to numerical differentiation for '%0' since no "
"suitable overload was found and clad could not derive it. "
"To disable this feature, compile your programs with "
"-DCLAD_NO_NUM_DIFF.",
{FD->getNameAsString()});
}
} else {
OverloadedDerivedFn = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(derivedFD),
noLoc,
llvm::MutableArrayRef<Expr*>(
DerivedCallArgs),
noLoc)
.get();
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this into a separate function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole thing? I don't know if that would be feasible...We will have to send a lot of information as parameters to that function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about putting 1406 - 1423 in a function? Putting all of it will require a function with 8-9 arguments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vgvassilev How do you think it looks now? If it is not good, I can put it all in a function.

noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs),
noLoc)
.get();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving into a separate function.

}

/// A funtion to calculate the derivative of a function using the central
/// difference formula. Note: we do not propogate errors resulting in the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate here and maybe in the comments.

include/clad/Differentiator/NumericalDiff.h Show resolved Hide resolved
include/clad/Differentiator/NumericalDiff.h Outdated Show resolved Hide resolved
lib/Differentiator/ForwardModeVisitor.cpp Outdated Show resolved Hide resolved
lib/Differentiator/ForwardModeVisitor.cpp Outdated Show resolved Hide resolved
@@ -179,6 +179,12 @@ namespace clad {
estimationPlugin->InstantiateCustomModel(*m_DerivativeBuilder));
}
}

// If enabled, set the proper fields in derivative builder.
if (m_DO.PrintNumDiffErrorInfo) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

test/Jacobian/Pointers.C Outdated Show resolved Hide resolved
Clad can now build calls to a "forward" (single-arg) and "reverse" (multi-arg) numerical diff function. Numerical diff can be used in the following contexts:
- With clad forward and reverse mode for all in-built scalar and non scalar types.
- With clad and standalone for "forward" and "reverse" derivatives.
- Standalone for in-built scalar and non-scalar types.
- Standalone with support for user defined types as input (both value and pointer forms).
- Standalone to differentiate overloaded operators.
- Standalone to differentiate functors.

Numerical diff error estimates can be printed with the help of the '-fprint-num-diff-errors' flag, and numerical diff may be disabled using the '-DCLAD_NO_NUM_DIFF' flag during compilation.
@vgvassilev vgvassilev merged commit 4f1c10e into vgvassilev:master Aug 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants