Skip to content

Commit

Permalink
Make pow pushforward fn numerically stable
Browse files Browse the repository at this point in the history
This commit makes clad::custom_derivatives::std::pow_pushforward fn
numerically stable by only adding directional derivative w.r.t exponent
if and only if the directional seed of exponent is non-zero. This will
make the function give correct directional derivative for the cases
where log(base) is undefined but directional derivative w.r.t exponent
is not requested -- thus log(base) should ideally never be used.

Closes #507
  • Loading branch information
parth-07 authored and vgvassilev committed Nov 20, 2022
1 parent e658d27 commit 6250d7d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 34 deletions.
13 changes: 10 additions & 3 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,16 @@ template <typename T1, typename T2>
CUDA_HOST_DEVICE ValueAndPushforward<decltype(::std::pow(T1(), T2())),
decltype(::std::pow(T1(), T2()))>
pow_pushforward(T1 x, T2 exponent, T1 d_x, T2 d_exponent) {
return {::std::pow(x, exponent),
(exponent * ::std::pow(x, exponent - 1)) * d_x +
(::std::pow(x, exponent) * ::std::log(x)) * d_exponent};
auto val = ::std::pow(x, exponent);
auto derivative = (exponent * ::std::pow(x, exponent - 1)) * d_x;
// Only add directional derivative of base^exp w.r.t exp if the directional
// seed d_exponent is non-zero. This is required because if base is less than or
// equal to 0, then log(base) is undefined, and therefore if user only requested
// directional derivative of base^exp w.r.t base -- which is valid --, the result would
// be undefined because as per C++ valid number + NaN * 0 = NaN.
if (d_exponent)
derivative += (::std::pow(x, exponent) * ::std::log(x)) * d_exponent;
return {val, derivative};
}

template <typename T>
Expand Down
48 changes: 46 additions & 2 deletions test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: %cladclang %s -I%S/../../include -Xclang -verify -oBuiltinDerivatives.out -lm 2>&1 | FileCheck %s
// RUN: %cladclang %s -I%S/../../include -Xclang -verify -oBuiltinDerivatives.out -lm -lstdc++ 2>&1 | FileCheck %s
// RUN: ./BuiltinDerivatives.out | FileCheck -check-prefix=CHECK-EXEC %s

//CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"

#include "../TestUtils.h"
extern "C" int printf(const char* fmt, ...);

float f1(float x) {
Expand Down Expand Up @@ -207,6 +207,46 @@ void f10_grad(float x, int y, clad::array_ref<float> _d_x, clad::array_ref<int>
// CHECK-NEXT: }
// CHECK-NEXT: }

double f11(double x, double y) {
return std::pow((1.-x),2) + 100. * std::pow(y-std::pow(x,2),2);
}

// CHECK: void f11_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: typename {{.*}} _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: _t0 = (1. - x);
// CHECK-NEXT: _t2 = x;
// CHECK-NEXT: _t3 = y - std::pow(_t2, 2);
// CHECK-NEXT: _t1 = std::pow(_t3, 2);
// CHECK-NEXT: double f11_return = std::pow(_t0, 2) + 100. * _t1;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: int _grad1 = 0;
// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(_t0, 2, 1, &_grad0, &_grad1);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_x += -_r0;
// CHECK-NEXT: int _r1 = _grad1;
// CHECK-NEXT: double _r2 = 1 * _t1;
// CHECK-NEXT: double _r3 = 100. * 1;
// CHECK-NEXT: double _grad4 = 0.;
// CHECK-NEXT: int _grad5 = 0;
// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(_t3, 2, _r3, &_grad4, &_grad5);
// CHECK-NEXT: double _r4 = _grad4;
// CHECK-NEXT: * _d_y += _r4;
// CHECK-NEXT: double _grad2 = 0.;
// CHECK-NEXT: int _grad3 = 0;
// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(_t2, 2, -_r4, &_grad2, &_grad3);
// CHECK-NEXT: double _r5 = _grad2;
// CHECK-NEXT: * _d_x += _r5;
// CHECK-NEXT: int _r6 = _grad3;
// CHECK-NEXT: int _r7 = _grad5;
// CHECK-NEXT: }
// CHECK-NEXT: }

int main () { //expected-no-diagnostics
float f_result[2];
double d_result[2];
Expand Down Expand Up @@ -269,5 +309,9 @@ int main () { //expected-no-diagnostics
f10_grad(3, 4, &f_result[0], &i_result[0]);
printf("Result is = {%f, %d}\n", f_result[0], i_result[0]); //CHECK-EXEC: Result is = {108.000000, 88}

INIT_GRADIENT(f11);

TEST_GRADIENT(f11, /*numOfDerivativeArgs=*/2, -1, 1, &d_result[0], &d_result[1]); // CHECK-EXEC: {-4.00, 0.00}

return 0;
}
78 changes: 49 additions & 29 deletions test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,15 @@ float f4(float x) {
// CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d_y, clad::array_ref<float> _d_x, clad::array_ref<float> _d_exponent, clad::array_ref<float> _d_d_x, clad::array_ref<float> _d_d_exponent) {
// CHECK-NEXT: float _t0;
// CHECK-NEXT: float _t1;
// CHECK-NEXT: float _d_val = 0;
// CHECK-NEXT: float _t2;
// CHECK-NEXT: float _t3;
// CHECK-NEXT: float _t4;
// CHECK-NEXT: float _t5;
// CHECK-NEXT: float _t6;
// CHECK-NEXT: float _t7;
// CHECK-NEXT: float _d_derivative = 0;
// CHECK-NEXT: float _cond0;
// CHECK-NEXT: float _t8;
// CHECK-NEXT: float _t9;
// CHECK-NEXT: float _t10;
Expand All @@ -262,43 +265,35 @@ float f4(float x) {
// CHECK-NEXT: float _t14;
// CHECK-NEXT: _t0 = x;
// CHECK-NEXT: _t1 = exponent;
// CHECK-NEXT: float val = ::std::pow(_t0, _t1);
// CHECK-NEXT: _t4 = exponent;
// CHECK-NEXT: _t5 = x;
// CHECK-NEXT: _t6 = exponent - 1;
// CHECK-NEXT: _t3 = ::std::pow(_t5, _t6);
// CHECK-NEXT: _t7 = (_t4 * _t3);
// CHECK-NEXT: _t2 = d_x;
// CHECK-NEXT: _t10 = x;
// CHECK-NEXT: _t11 = exponent;
// CHECK-NEXT: _t12 = ::std::pow(_t10, _t11);
// CHECK-NEXT: _t13 = x;
// CHECK-NEXT: _t9 = ::std::log(_t13);
// CHECK-NEXT: _t14 = (_t12 * _t9);
// CHECK-NEXT: _t8 = d_exponent;
// CHECK-NEXT: float derivative = _t7 * _t2;
// CHECK-NEXT: _cond0 = d_exponent;
// CHECK-NEXT: if (_cond0) {
// CHECK-NEXT: _t10 = x;
// CHECK-NEXT: _t11 = exponent;
// CHECK-NEXT: _t12 = ::std::pow(_t10, _t11);
// CHECK-NEXT: _t13 = x;
// CHECK-NEXT: _t9 = ::std::log(_t13);
// CHECK-NEXT: _t14 = (_t12 * _t9);
// CHECK-NEXT: _t8 = d_exponent;
// CHECK-NEXT: derivative += _t14 * _t8;
// CHECK-NEXT: }
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _grad0 = 0.F;
// CHECK-NEXT: float _grad1 = 0.F;
// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(_t0, _t1, _d_y.value, &_grad0, &_grad1);
// CHECK-NEXT: float _r0 = _grad0;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: float _r1 = _grad1;
// CHECK-NEXT: * _d_exponent += _r1;
// CHECK-NEXT: float _r2 = _d_y.pushforward * _t2;
// CHECK-NEXT: float _r3 = _r2 * _t3;
// CHECK-NEXT: * _d_exponent += _r3;
// CHECK-NEXT: float _r4 = _t4 * _r2;
// CHECK-NEXT: float _grad2 = 0.F;
// CHECK-NEXT: float _grad3 = 0.F;
// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(_t5, _t6, _r4, &_grad2, &_grad3);
// CHECK-NEXT: float _r5 = _grad2;
// CHECK-NEXT: * _d_x += _r5;
// CHECK-NEXT: float _r6 = _grad3;
// CHECK-NEXT: * _d_exponent += _r6;
// CHECK-NEXT: float _r7 = _t7 * _d_y.pushforward;
// CHECK-NEXT: * _d_d_x += _r7;
// CHECK-NEXT: float _r8 = _d_y.pushforward * _t8;
// CHECK-NEXT: _d_val += _d_y.value;
// CHECK-NEXT: _d_derivative += _d_y.pushforward;
// CHECK-NEXT: }
// CHECK-NEXT: if (_cond0) {
// CHECK-NEXT: float _r_d0 = _d_derivative;
// CHECK-NEXT: _d_derivative += _r_d0;
// CHECK-NEXT: float _r8 = _r_d0 * _t8;
// CHECK-NEXT: float _r9 = _r8 * _t9;
// CHECK-NEXT: float _grad4 = 0.F;
// CHECK-NEXT: float _grad5 = 0.F;
Expand All @@ -310,8 +305,33 @@ float f4(float x) {
// CHECK-NEXT: float _r12 = _t12 * _r8;
// CHECK-NEXT: float _r13 = _r12 * clad::custom_derivatives{{(::std)?}}::log_pushforward(_t13, 1.F).pushforward;
// CHECK-NEXT: * _d_x += _r13;
// CHECK-NEXT: float _r14 = _t14 * _d_y.pushforward;
// CHECK-NEXT: float _r14 = _t14 * _r_d0;
// CHECK-NEXT: * _d_d_exponent += _r14;
// CHECK-NEXT: _d_derivative -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: float _r2 = _d_derivative * _t2;
// CHECK-NEXT: float _r3 = _r2 * _t3;
// CHECK-NEXT: * _d_exponent += _r3;
// CHECK-NEXT: float _r4 = _t4 * _r2;
// CHECK-NEXT: float _grad2 = 0.F;
// CHECK-NEXT: float _grad3 = 0.F;
// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(_t5, _t6, _r4, &_grad2, &_grad3);
// CHECK-NEXT: float _r5 = _grad2;
// CHECK-NEXT: * _d_x += _r5;
// CHECK-NEXT: float _r6 = _grad3;
// CHECK-NEXT: * _d_exponent += _r6;
// CHECK-NEXT: float _r7 = _t7 * _d_derivative;
// CHECK-NEXT: * _d_d_x += _r7;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: float _grad0 = 0.F;
// CHECK-NEXT: float _grad1 = 0.F;
// CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(_t0, _t1, _d_val, &_grad0, &_grad1);
// CHECK-NEXT: float _r0 = _grad0;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: float _r1 = _grad1;
// CHECK-NEXT: * _d_exponent += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down

0 comments on commit 6250d7d

Please sign in to comment.