Skip to content

Commit

Permalink
Added the Stan test and removed unnecciary comments
Browse files Browse the repository at this point in the history
  • Loading branch information
FaroukY committed May 4, 2018
1 parent 74027a9 commit f60500b
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions tests/unit/StanMath/StanPerceptronTest_unittest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//tests/unit/stanmath/StanPerceptronTest_unittest.cc
#include <gtest/gtest.h>
#include <stan/math.hpp>
#include <iostream>
#include <cmath>
#include <Eigen/Dense>
#include <random>
using namespace std;
using namespace Eigen;
using namespace stan::math;


/*
This test will implement a perceptron using stan math
The perceptron will have one input, the column vector [1 1 1]^T
Then, it will have 2x3 Matrix of weights that it will learn
Then, the output is a 2x1 column vector
In this example, we want to learn the weights W such that the square
Error loss from the output of the perceptron to [1 1]^T is minimized.
Since we can find weights from [1 1 1]^T to [1 1]^T in a perceptron,
this error should be very close to zero after 100 epochs.
*/
TEST(StanPerceptronTest, sample_perceptron)
{
//Initialize the Input Vector
Matrix<var, 3, 1> inp;
inp(0,0)=1;
inp(1,0)=1;
inp(2,0)=1;

//Randomly Initialize the weights on the perceptron
std::random_device rd{};
std::mt19937 gen{rd()};
normal_distribution<> d{0,1};
Matrix<var, 2, 3> W1;
for(int i=0; i<2; ++i)
{
for(int j=0; j<3; ++j)
{
W1(i,j)=0.01 * d(gen);
}
}

//Define the outputs of the neural network
Matrix<var, 2, 1> outputs;

double learning_rate = 0.1;
double last_error = 0;
for(int epoch=0; epoch<100; ++epoch)
{
var error=0;
outputs = W1*inp;
for(int i=0; i<2; ++i)
{
error += (outputs(i,0)-1)*(outputs(i,0)-1);
}
error.grad();

//Now use gradient descent to change the weights
for(int i=0; i<2; ++i)
{
for(int j=0; j<3; ++j)
{
W1(i,j)=W1(i,j) - learning_rate* W1(i,j).adj();
}
}

//Store the value of current error in last_error
last_error = value_of(error);
}

//Error should be very close to 0.0
EXPECT_NEAR(last_error, 0.0, 1e-6);
}

0 comments on commit f60500b

Please sign in to comment.