Skip to content

Commit

Permalink
Merge pull request #3661 from abhinavrai44/inference
Browse files Browse the repository at this point in the history
Port Inference.cpp to use OpenMP
  • Loading branch information
vigsterkr committed Mar 14, 2017
2 parents 649142d + 2e8330c commit 35d9c9d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 108 deletions.
133 changes: 30 additions & 103 deletions src/shogun/machine/gp/Inference.cpp
Expand Up @@ -37,21 +37,9 @@
#include <shogun/distributions/classical/GaussianDistribution.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/mathematics/Math.h>
#include <shogun/lib/Lock.h>

using namespace shogun;

#ifndef DOXYGEN_SHOULD_SKIP_THIS
struct GRADIENT_THREAD_PARAM
{
CInference* inf;
CMap<TParameter*, SGVector<float64_t> >* grad;
CSGObject* obj;
TParameter* param;
CLock* lock;
};
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

CInference::CInference()
{
init();
Expand Down Expand Up @@ -211,108 +199,47 @@ get_negative_log_marginal_likelihood_derivatives(CMap<TParameter*, CSGObject*>*

SG_REF(result);

// create lock object
CLock lock;

#ifdef HAVE_PTHREAD
if (num_deriv<2)
#pragma omp parallel for
for (index_t i=0; i<num_deriv; i++)
{
#endif /* HAVE_PTHREAD */
for (index_t i=0; i<num_deriv; i++)
{
CMapNode<TParameter*, CSGObject*>* node=params->get_node_ptr(i);

GRADIENT_THREAD_PARAM thread_params;
CMapNode<TParameter*, CSGObject*>* node=params->get_node_ptr(i);
SGVector<float64_t> gradient;

thread_params.inf=this;
thread_params.obj=node->data;
thread_params.param=node->key;
thread_params.grad=result;
thread_params.lock=&lock;

get_derivative_helper((void*) &thread_params);
if(node->data == this)
{
// try to find dervative wrt InferenceMethod.parameter
gradient=this->get_derivative_wrt_inference_method(node->key);
}
#ifdef HAVE_PTHREAD
}
else
{
pthread_t* threads=SG_MALLOC(pthread_t, num_deriv);
GRADIENT_THREAD_PARAM* thread_params=SG_MALLOC(GRADIENT_THREAD_PARAM,
num_deriv);

for (index_t t=0; t<num_deriv; t++)
else if (node->data == this->m_model)
{
CMapNode<TParameter*, CSGObject*>* node=params->get_node_ptr(t);

thread_params[t].inf=this;
thread_params[t].obj=node->data;
thread_params[t].param=node->key;
thread_params[t].grad=result;
thread_params[t].lock=&lock;

pthread_create(&threads[t], NULL, CInference::get_derivative_helper,
(void*)&thread_params[t]);
// try to find derivative wrt LikelihoodModel.parameter
gradient=this->get_derivative_wrt_likelihood_model(node->key);
}
else if (node->data ==this->m_kernel)
{
// try to find derivative wrt Kernel.parameter
gradient=this->get_derivative_wrt_kernel(node->key);
}
else if (node->data ==this->m_mean)
{
// try to find derivative wrt MeanFunction.parameter
gradient=this->get_derivative_wrt_mean(node->key);
}
else
{
SG_SERROR("Can't compute derivative of negative log marginal "
"likelihood wrt %s.%s", node->data->get_name(), node->key->m_name);
}

for (index_t t=0; t<num_deriv; t++)
pthread_join(threads[t], NULL);

SG_FREE(thread_params);
SG_FREE(threads);
#pragma omp critical
{
result->add(node->key, gradient);
}
}
#endif /* HAVE_PTHREAD */

return result;
}

void* CInference::get_derivative_helper(void *p)
{
GRADIENT_THREAD_PARAM* thread_param=(GRADIENT_THREAD_PARAM*)p;

CInference* inf=thread_param->inf;
CSGObject* obj=thread_param->obj;
CMap<TParameter*, SGVector<float64_t> >* grad=thread_param->grad;
TParameter* param=thread_param->param;
CLock* lock=thread_param->lock;

REQUIRE(param, "Parameter should not be NULL\n");
REQUIRE(obj, "Object of the parameter should not be NULL\n");

SGVector<float64_t> gradient;

if (obj==inf)
{
// try to find dervative wrt InferenceMethod.parameter
gradient=inf->get_derivative_wrt_inference_method(param);
}
else if (obj==inf->m_model)
{
// try to find derivative wrt LikelihoodModel.parameter
gradient=inf->get_derivative_wrt_likelihood_model(param);
}
else if (obj==inf->m_kernel)
{
// try to find derivative wrt Kernel.parameter
gradient=inf->get_derivative_wrt_kernel(param);
}
else if (obj==inf->m_mean)
{
// try to find derivative wrt MeanFunction.parameter
gradient=inf->get_derivative_wrt_mean(param);
}
else
{
SG_SERROR("Can't compute derivative of negative log marginal "
"likelihood wrt %s.%s", obj->get_name(), param->m_name);
}

lock->lock();
grad->add(param, gradient);
lock->unlock();

return NULL;
}

void CInference::update()
{
check_members();
Expand Down
5 changes: 0 additions & 5 deletions src/shogun/machine/gp/Inference.h
Expand Up @@ -450,11 +450,6 @@ class CInference : public CDifferentiableFunction
virtual SGVector<float64_t> get_derivative_wrt_mean(
const TParameter* param)=0;

/** pthread helper method to compute negative log marginal likelihood
* derivatives wrt hyperparameter
*/
static void* get_derivative_helper(void* p);

/** update gradients */
virtual void compute_gradient();

Expand Down

0 comments on commit 35d9c9d

Please sign in to comment.