Skip to content

Commit

Permalink
Make GlobalStoreAndRef consistent in usage inside and outside of loop…
Browse files Browse the repository at this point in the history
…s (and consistent with StoreAndRef)
  • Loading branch information
PetroZarytskyi committed May 23, 2024
1 parent 4126fa7 commit 9565bdb
Show file tree
Hide file tree
Showing 14 changed files with 399 additions and 538 deletions.
22 changes: 11 additions & 11 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,17 @@ namespace clad {
/// it into m_Globals block (to be inserted into the beginning of fn's
/// body). Returns reference R to the created declaration. If E is not null,
/// puts an additional assignment statement (R = E) in the forward block.
/// Alternatively, if isInsideLoop is true, stores E in a stack. Returns
/// StmtDiff, where .getExpr() is intended to be used in forward pass and
/// .getExpr_dx() in the reverse pass. Two expressions can be different in
/// some cases, e.g. clad::push/pop inside loops.
StmtDiff GlobalStoreAndRef(clang::Expr* E,
clang::QualType Type,
llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff GlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t",
bool force = false);
/// Alternatively, if isInsideLoop is true, stores E in a stack S. Puts a
/// push statement (clad::push(S, E)) in the forward block and a pop
/// statement
/// ((clad::pop(S))) in the reverse block. Returns a reference to the top
/// of the stack (clad::back(S)).
clang::Expr* GlobalStoreAndRef(clang::Expr* E, clang::QualType Type,
llvm::StringRef prefix = "_t",
bool force = false);
clang::Expr* GlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t",
bool force = false);
StmtDiff BuildPushPop(clang::Expr* E, clang::QualType Type,
llvm::StringRef prefix = "_t", bool force = false);
StmtDiff StoreAndRestore(clang::Expr* E, llvm::StringRef prefix = "_t",
Expand Down
331 changes: 83 additions & 248 deletions lib/Differentiator/ReverseModeVisitor.cpp

Large diffs are not rendered by default.

38 changes: 19 additions & 19 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ float func(float* a, float* b) {
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, a[i]);
Expand Down Expand Up @@ -95,7 +95,7 @@ float func2(float* a) {
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
Expand Down Expand Up @@ -129,7 +129,7 @@ float func3(float* a, float* b) {
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
Expand Down Expand Up @@ -168,7 +168,7 @@ double func4(double x) {
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[3] = {x, 2 * x, x * x};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
Expand Down Expand Up @@ -226,14 +226,14 @@ double func5(int k) {
//CHECK-NEXT: double _d_arr[n];
//CHECK-NEXT: clad::zero_init(_d_arr, n);
//CHECK-NEXT: double arr[n];
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr[i]);
//CHECK-NEXT: arr[i] = k;
//CHECK-NEXT: }
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t2 = 0;
//CHECK-NEXT: _t2 = {{0U|0UL}};
//CHECK-NEXT: for (i0 = 0; i0 < 3; i0++) {
//CHECK-NEXT: _t2++;
//CHECK-NEXT: clad::push(_t3, sum);
Expand Down Expand Up @@ -283,7 +283,7 @@ double func6(double seed) {
//CHECK-NEXT: clad::array<double> arr({{3U|3UL}});
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr) , arr = {seed, seed * i, seed + i};
Expand Down Expand Up @@ -338,7 +338,7 @@ double func7(double *params) {
//CHECK-NEXT: clad::array<double> paramsPrime({{1U|1UL}});
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double out = 0.;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 1; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, paramsPrime) , paramsPrime = {params[0]};
Expand Down Expand Up @@ -438,7 +438,7 @@ double func9(double i, double j) {
//CHECK-NEXT: int idx = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[5] = {};
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (idx = 0; idx < 5; ++idx) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr[idx]);
Expand All @@ -456,11 +456,11 @@ double func9(double i, double j) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: --idx;
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = clad::pop(_t1);
//CHECK-NEXT: arr[idx] = _r0;
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: modify_pullback(_r0, i, &_d_arr[idx], &_r1);
//CHECK-NEXT: *_d_i += _r1;
//CHECK-NEXT: arr[idx] = clad::back(_t1);
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: modify_pullback(clad::back(_t1), i, &_d_arr[idx], &_r0);
//CHECK-NEXT: clad::pop(_t1);
//CHECK-NEXT: *_d_i += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -489,7 +489,7 @@ double func10(double *arr, int n) {
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < n; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, res);
Expand All @@ -504,9 +504,9 @@ double func10(double *arr, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: double _r0 = clad::pop(_t2);
//CHECK-NEXT: arr[i] = _r0;
//CHECK-NEXT: sq_pullback(_r0, _r_d0, &_d_arr[i]);
//CHECK-NEXT: arr[i] = clad::back(_t2);
//CHECK-NEXT: sq_pullback(clad::back(_t2), _r_d0, &_d_arr[i]);
//CHECK-NEXT: clad::pop(_t2);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -584,7 +584,7 @@ int main() {
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, ret);
Expand Down
2 changes: 1 addition & 1 deletion test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ __device__ __host__ double gauss(double* x, double* p, double sigma, int dim) {
//CHECK-NEXT: double _t5;
//CHECK-NEXT: double _t6;
//CHECK-NEXT: double t = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < dim; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, t);
Expand Down
6 changes: 3 additions & 3 deletions test/ErrorEstimation/LoopsAndArrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ float func(float* p, int n) {
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: unsigned {{int|long}} p_size = 0;
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
Expand Down Expand Up @@ -67,7 +67,7 @@ float func2(float x) {
//CHECK-NEXT: float m = 0;
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float z;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 9; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, m) , m = x * x;
Expand Down Expand Up @@ -168,7 +168,7 @@ float func4(float x[10], float y[10]) {
//CHECK-NEXT: unsigned {{int|long}} y_size = 0;
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < 10; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, x[i]);
Expand Down
6 changes: 3 additions & 3 deletions test/ErrorEstimation/LoopsAndArraysExec.C
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ double runningSum(float* f, int n) {
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: unsigned {{int|long}} f_size = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 1; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
Expand Down Expand Up @@ -72,7 +72,7 @@ double mulSum(float* a, float* b, int n) {
//CHECK-NEXT: unsigned {{int|long}} a_size = 0;
//CHECK-NEXT: unsigned {{int|long}} b_size = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, {{0U|0UL}});
Expand Down Expand Up @@ -131,7 +131,7 @@ double divSum(float* a, float* b, int n) {
//CHECK-NEXT: unsigned {{int|long}} b_size = 0;
//CHECK-NEXT: unsigned {{int|long}} a_size = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
Expand Down
12 changes: 6 additions & 6 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ double fn4(double* arr, int n) {
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: _t0 = res;
// CHECK-NEXT: res += sum(arr, n);
// CHECK-NEXT: _t1 = 0;
// CHECK-NEXT: _t1 = {{0U|0UL}};
// CHECK-NEXT: for (i = 0; i < n; ++i) {
// CHECK-NEXT: _t1++;
// CHECK-NEXT: clad::push(_t2, arr[i]);
Expand All @@ -190,9 +190,9 @@ double fn4(double* arr, int n) {
// CHECK-NEXT: _d_arr[i] += _r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r1 = clad::pop(_t2);
// CHECK-NEXT: arr[i] = _r1;
// CHECK-NEXT: twice_pullback(_r1, &_d_arr[i]);
// CHECK-NEXT: arr[i] = clad::back(_t2);
// CHECK-NEXT: twice_pullback(clad::back(_t2), &_d_arr[i]);
// CHECK-NEXT: clad::pop(_t2);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
Expand Down Expand Up @@ -494,7 +494,7 @@ double fn13(double* x, const double* w) {
// CHECK-NEXT: std::size_t i = 0;
// CHECK-NEXT: clad::tape<double> _t1 = {};
// CHECK-NEXT: double wCopy[2];
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (i = 0; i < 2; ++i) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, wCopy[i]);
Expand Down Expand Up @@ -834,7 +834,7 @@ double sq_defined_later(double x) {
// CHECK-NEXT: clad::tape<float> _t1 = {};
// CHECK-NEXT: double _t2;
// CHECK-NEXT: float res = 0;
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (i = 0; i < n; ++i) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, res);
Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ float running_sum(float* p, int n) {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: int i = 0;
// CHECK-NEXT: clad::tape<float> _t1 = {};
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (i = 1; i < n; i++) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, p[i]);
Expand Down
Loading

0 comments on commit 9565bdb

Please sign in to comment.