Skip to content

Commit

Permalink
use eigen instead of stanmath
Browse files Browse the repository at this point in the history
compile time differentiation is much faster
  • Loading branch information
gf712 committed Nov 5, 2019
1 parent 5815a3d commit aa229fa
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/shogun/kernel/GaussianKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
* Tonmoy Saikia, Sergey Lisitsyn, Matt Aasted, Sanuj Sharma
*/

#include <Eigen/Core>
#include <unsupported/Eigen/AutoDiff>
#include <shogun/lib/common.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/features/DotFeatures.h>
#include <shogun/distance/EuclideanDistance.h>
#include <shogun/mathematics/Math.h>
#include <stan/math/rev/scal.hpp>


using namespace shogun;

Expand Down Expand Up @@ -92,26 +94,25 @@ void CGaussianKernel::set_width(float64_t w)
SGMatrix<float64_t> CGaussianKernel::get_parameter_gradient(const TParameter* param, index_t index)
{
using std::exp;
using std::log;

require(lhs, "Left hand side features must be set!");
require(rhs, "Right hand side features must be set!");

if (!strcmp(param->m_name, "log_width"))
{
SGMatrix<float64_t> derivative=SGMatrix<float64_t>(num_lhs, num_rhs);
stan::math::var log_width = m_log_width;
auto constant_part = exp(log_width * 2.0) * 2.0;
using EigenScalar = Eigen::Matrix<float64_t, 1, 1>;
Eigen::AutoDiffScalar<EigenScalar> eigen_log_width = m_log_width;

for (int k=0; k<num_rhs; k++)
{
#pragma omp parallel for
for (int j=0; j<num_lhs; j++)
{
auto f = exp(-CShiftInvariantKernel::distance(j, k) / constant_part);
f.grad();
derivative(j, k) = log_width.adj();
stan::math::set_zero_all_adjoints();
eigen_log_width.derivatives() = EigenScalar::Unit(1,0);
auto el = CShiftInvariantKernel::distance(j, k);
Eigen::AutoDiffScalar<EigenScalar> kernel = exp(-el / (exp(eigen_log_width * 2.0) * 2.0));
derivative(j, k) = kernel.derivatives()(0);
}
}
return derivative;
Expand Down

0 comments on commit aa229fa

Please sign in to comment.