Skip to content

Commit

Permalink
Synthesize a FloatingLiteral if the DRE's type is not integral.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Jul 18, 2014
1 parent 8827fa0 commit f915dd3
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 24 deletions.
42 changes: 31 additions & 11 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,23 +176,43 @@ namespace clad {
NodeContext DerivativeBuilder::VisitDeclRefExpr(const DeclRefExpr* DRE) {
DeclRefExpr* clonedDRE = VisitStmt(DRE).getAs<DeclRefExpr>();
SourceLocation noLoc;
QualType Ty = DRE->getType();
if (clonedDRE->getDecl()->getNameAsString() ==
m_IndependentVar->getNameAsString()) {
llvm::APInt one(m_Context.getIntWidth(m_Context.IntTy), /*value*/1);
IntegerLiteral* constant1 = IntegerLiteral::Create(m_Context, one,
m_Context.IntTy,
noLoc);
return NodeContext(constant1);
if (Ty->isIntegralType(m_Context)) {
llvm::APInt one(m_Context.getIntWidth(m_Context.IntTy), /*value*/1);
IntegerLiteral* constant1 = IntegerLiteral::Create(m_Context, one,
m_Context.IntTy,
noLoc);
return NodeContext(constant1);
}
else {
llvm::APFloat one(m_Context.getFloatTypeSemantics(Ty), 1);
FloatingLiteral* constant1 = FloatingLiteral::Create(m_Context, one,
/*isexact*/true,
Ty, noLoc);
return NodeContext(constant1);
}
}
else {
llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/0);
IntegerLiteral* constant0 = IntegerLiteral::Create(m_Context, zero,
m_Context.IntTy,
noLoc);
return NodeContext(constant0);
if (Ty->isIntegralType(m_Context)) {
llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/0);
IntegerLiteral* constant0 = IntegerLiteral::Create(m_Context, zero,
m_Context.IntTy,
noLoc);
return NodeContext(constant0);
}
else {
llvm::APFloat zero
= llvm::APFloat::getZero(m_Context.getFloatTypeSemantics(Ty));
FloatingLiteral* constant0 = FloatingLiteral::Create(m_Context, zero,
/*isexact*/true,
Ty, noLoc);
return NodeContext(constant0);
}
}
}

NodeContext DerivativeBuilder::VisitIntegerLiteral(const IntegerLiteral* IL) {
SourceLocation noLoc;
llvm::APInt zero(m_Context.getIntWidth(m_Context.IntTy), /*value*/0);
Expand Down
12 changes: 12 additions & 0 deletions test/FirstDerivative/BasicArithmeticAddSub.C
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: %cladclang %s -I%S/../../include -oBasicArithmeticAddSub.out -Xclang -verify 2>&1 | FileCheck %s
// RUN: ./BasicArithmeticAddSub.out | FileCheck -check-prefix=CHECK-EXEC %s

//CHECK-NOT: {{.*error:.*}}
#include "clad/Differentiator/Differentiator.h"

extern "C" int printf(const char* fmt, ...);
Expand Down Expand Up @@ -78,6 +79,13 @@ int as_1(int x) {
// CHECK-NEXT: return 1 + (1) - (1) + (0) - (0) + (0) - (0);
// CHECK-NEXT: }

float IntegerLiteralToFloatLiteral(float x, float y) {
return x * x - y;
}
// CHECK: float IntegerLiteralToFloatLiteral_derived_x(float x, float y) {
// CHECK-NEXT: return (1.F * x + x * 1.F) - (0.F);
// CHECK-NEXT: }

int a_1_derived_x(int x);
int a_2_derived_x(int x);
int a_3_derived_x(int x);
Expand All @@ -87,6 +95,7 @@ int s_2_derived_x(int x);
int s_3_derived_x(int x);
int s_4_derived_x(int x);
int as_1_derived_x(int x);
float IntegerLiteralToFloatLiteral_derived_x(float x, float y);

int main () { // expected-no-diagnostics
int x = 4;
Expand Down Expand Up @@ -117,5 +126,8 @@ int main () { // expected-no-diagnostics
clad::differentiate(as_1, 1);
printf("Result is = %d\n", as_1_derived_x(1)); // CHECK-EXEC: Result is = 1

clad::differentiate(IntegerLiteralToFloatLiteral, 1);
printf("Result is = %f\n", IntegerLiteralToFloatLiteral_derived_x(5., 0.)); // CHECK-EXEC: Result is = 10

return 0;
}
8 changes: 4 additions & 4 deletions test/FirstDerivative/DiffInterface.C
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ int f_1(float y) {

// CHECK: int f_1_derived_y(float y) {
// CHECK-NEXT: int x = 1, z = 3;
// CHECK-NEXT: return ((0 * y + x * 1) * z + x * y * 0);
// CHECK-NEXT: return ((0 * y + x * 1.F) * z + x * y * 0);
// CHECK-NEXT: }

int f_2(int x, float y, int z) {
return x * y * z; // y * z;
}

// CHECK: int f_2_derived_x(int x, float y, int z) {
// CHECK-NEXT: return ((1 * y + x * 0) * z + x * y * 0);
// CHECK-NEXT: return ((1 * y + x * 0.F) * z + x * y * 0);
// CHECK-NEXT: }

// x * z
// CHECK: int f_2_derived_y(int x, float y, int z) {
// CHECK-NEXT: return ((0 * y + x * 1) * z + x * y * 0);
// CHECK-NEXT: return ((0 * y + x * 1.F) * z + x * y * 0);
// CHECK-NEXT: }

// x * y
// CHECK: int f_2_derived_z(int x, float y, int z) {
// CHECK-NEXT: return ((0 * y + x * 0) * z + x * y * 1);
// CHECK-NEXT: return ((0 * y + x * 0.F) * z + x * y * 1);
// CHECK-NEXT: }

int f_3() {
Expand Down
12 changes: 6 additions & 6 deletions test/FirstDerivative/TemplateFunction.C
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ int main () {

clad::differentiate(simple_return<float>, 1);
// CHECK: float simple_return_derived_x(float x) {
// CHECK-NEXT: return 1;
// CHECK-NEXT: return 1.F;
// CHECK-NEXT: }

clad::differentiate(simple_return<double>, 1);
// CHECK: double simple_return_derived_x(double x) {
// CHECK-NEXT: return 1;
// CHECK-NEXT: return 1.;
// CHECK-NEXT: }

clad::differentiate(addition<int>, 1);
Expand All @@ -42,12 +42,12 @@ int main () {

clad::differentiate(addition<float>, 1);
// CHECK: float addition_derived_x(float x) {
// CHECK-NEXT: return 1 + (1);
// CHECK-NEXT: return 1.F + (1.F);
// CHECK-NEXT: }

clad::differentiate(addition<double>, 1);
// CHECK: double addition_derived_x(double x) {
// CHECK-NEXT: return 1 + (1);
// CHECK-NEXT: return 1. + (1.);
// CHECK-NEXT: }

clad::differentiate(multiplication<int>, 1);
Expand All @@ -57,12 +57,12 @@ int main () {

clad::differentiate(multiplication<float>, 1);
// CHECK: float multiplication_derived_x(float x) {
// CHECK-NEXT: return (1 * x + x * 1);
// CHECK-NEXT: return (1.F * x + x * 1.F);
// CHECK-NEXT: }

clad::differentiate(multiplication<double>, 1);
// CHECK: double multiplication_derived_x(double x) {
// CHECK-NEXT: return (1 * x + x * 1);
// CHECK-NEXT: return (1. * x + x * 1.);
// CHECK-NEXT: }

return 0;
Expand Down
6 changes: 3 additions & 3 deletions test/Misc/RunDemos.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
// RUN: %cladclang %S/../../demos/Gradient.cpp -I%S/../../include -oGradient.out 2>&1 | FileCheck -check-prefix CHECK_GRADIENT %s
// CHECK_GRADIENT-NOT:{{.*error|warning|note:.*}}
// CHECK_GRADIENT:float sphere_implicit_func_derived_x(float x, float y, float z, float xc, float yc, float zc, float r) {
// CHECK_GRADIENT: return ((1 - (0)) * (x - xc) + (x - xc) * (1 - (0))) + (((0 - (0)) * (y - yc) + (y - yc) * (0 - (0)))) + (((0 - (0)) * (z - zc) + (z - zc) * (0 - (0)))) - ((0 * r + r * 0));
// CHECK_GRADIENT: return ((1.F - (0.F)) * (x - xc) + (x - xc) * (1.F - (0.F))) + (((0.F - (0.F)) * (y - yc) + (y - yc) * (0.F - (0.F)))) + (((0.F - (0.F)) * (z - zc) + (z - zc) * (0.F - (0.F)))) - ((0.F * r + r * 0.F));
// CHECK_GRADIENT:}

// CHECK_GRADIENT:float sphere_implicit_func_derived_y(float x, float y, float z, float xc, float yc, float zc, float r) {
// CHECK_GRADIENT: return ((0 - (0)) * (x - xc) + (x - xc) * (0 - (0))) + (((1 - (0)) * (y - yc) + (y - yc) * (1 - (0)))) + (((0 - (0)) * (z - zc) + (z - zc) * (0 - (0)))) - ((0 * r + r * 0));
// CHECK_GRADIENT: return ((0.F - (0.F)) * (x - xc) + (x - xc) * (0.F - (0.F))) + (((1.F - (0.F)) * (y - yc) + (y - yc) * (1.F - (0.F)))) + (((0.F - (0.F)) * (z - zc) + (z - zc) * (0.F - (0.F)))) - ((0.F * r + r * 0.F));
// CHECK_GRADIENT:}

// CHECK_GRADIENT:float sphere_implicit_func_derived_z(float x, float y, float z, float xc, float yc, float zc, float r) {
// CHECK_GRADIENT: return ((0 - (0)) * (x - xc) + (x - xc) * (0 - (0))) + (((0 - (0)) * (y - yc) + (y - yc) * (0 - (0)))) + (((1 - (0)) * (z - zc) + (z - zc) * (1 - (0)))) - ((0 * r + r * 0));
// CHECK_GRADIENT: return ((0.F - (0.F)) * (x - xc) + (x - xc) * (0.F - (0.F))) + (((0.F - (0.F)) * (y - yc) + (y - yc) * (0.F - (0.F)))) + (((1.F - (0.F)) * (z - zc) + (z - zc) * (1.F - (0.F)))) - ((0.F * r + r * 0.F));
// CHECK_GRADIENT:}

// RUN: ./Gradient.out | FileCheck -check-prefix CHECK_GRADIENT_EXEC %s
Expand Down

0 comments on commit f915dd3

Please sign in to comment.