Skip to content

Commit

Permalink
vectorize derivative computation (sumtable_ti and core_site_likelihoo…
Browse files Browse the repository at this point in the history
…d_derivatives)
  • Loading branch information
amkozlov committed Jul 25, 2016
1 parent c0f1cb0 commit 3035dd1
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 16 deletions.
92 changes: 76 additions & 16 deletions src/core_likelihood.c
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,11 @@ PLL_EXPORT int pll_core_update_sumtable_ti(unsigned int states,

unsigned int states_padded = states;

if (states == 4)
#ifdef HAVE_AVX
if (attrib & PLL_ATTRIB_ARCH_AVX)
{
return pll_core_update_sumtable_ti_4x4(sites,
return pll_core_update_sumtable_ti_avx(states,
sites,
rate_cats,
parent_clv,
left_tipchars,
Expand All @@ -740,16 +742,26 @@ PLL_EXPORT int pll_core_update_sumtable_ti(unsigned int states,
attrib);
}

#ifdef HAVE_AVX
if (attrib & PLL_ATTRIB_ARCH_AVX)
{
states_padded = (states+3) & 0xFFFFFFFC;
}
#endif
#ifdef HAVE_SSE
#endif

/* build sumtable */
/* non-vectorized version, special case for 4 states */
if (states == 4)
{
return pll_core_update_sumtable_ti_4x4(sites,
rate_cats,
parent_clv,
left_tipchars,
eigenvecs,
inv_eigenvecs,
freqs,
tipmap,
sumtable,
attrib);
}

/* build sumtable: non-vectorized version, general case */
for (n = 0; n < sites; n++)
{
for (i = 0; i < rate_cats; ++i)
Expand Down Expand Up @@ -810,7 +822,7 @@ static void core_site_likelihood_derivatives(unsigned int states,
cat_sitelk[0] += sum[j] * diagp[0];
cat_sitelk[1] += sum[j] * diagp[1];
cat_sitelk[2] += sum[j] * diagp[2];
diagp += 3;
diagp += 4;
}

/* account for invariant sites */
Expand All @@ -819,6 +831,7 @@ static void core_site_likelihood_derivatives(unsigned int states,
{
inv_site_lk =
(invariant[0] == -1) ? 0 : t_freqs[invariant[0]] * t_prop_invar;

cat_sitelk[0] = cat_sitelk[0] * (1. - t_prop_invar) + inv_site_lk;
cat_sitelk[1] = cat_sitelk[1] * (1. - t_prop_invar);
cat_sitelk[2] = cat_sitelk[2] * (1. - t_prop_invar);
Expand Down Expand Up @@ -852,7 +865,6 @@ PLL_EXPORT int pll_core_likelihood_derivatives(unsigned int states,
{
unsigned int n, i, j;
unsigned int ef_sites;
double site_lk[3];

const double * sum;
double deriv1, deriv2;
Expand All @@ -869,12 +881,25 @@ PLL_EXPORT int pll_core_likelihood_derivatives(unsigned int states,
unsigned int pattern_weight_sum = 0;

#ifdef HAVE_AVX
double site_lk[4] __attribute__( ( aligned ( PLL_ALIGNMENT_AVX ) ) ) ;
double * invar_lk = (double *) pll_aligned_alloc (rate_cats * states * sizeof(double), PLL_ALIGNMENT_AVX);
if (attrib & PLL_ATTRIB_ARCH_AVX)
{
states_padded = (states+3) & 0xFFFFFFFC;

/* pre-compute invariant site likelihoods*/
for(i = 0; i < states; ++i)
{
for(j = 0; j < rate_cats; ++j)
{
invar_lk[i * rate_cats + j] = freqs[j][i] * prop_invar[j];
}
}
}
#endif
#ifdef HAVE_SSE
#elif defined(HAVE_SSE)
double site_lk[4] __attribute__( ( aligned ( PLL_ALIGNMENT_SSE3 ) ) ) ;
#else
double site_lk[3];
#endif

/* For Stamatakis correction, the likelihood derivatives are computed in
Expand All @@ -891,7 +916,7 @@ PLL_EXPORT int pll_core_likelihood_derivatives(unsigned int states,
*d_f = 0.0;
*dd_f = 0.0;

diagptable = (double *) calloc (rate_cats * states * 3, sizeof(double));
diagptable = (double *) pll_aligned_alloc (rate_cats * states * 4 * sizeof(double), PLL_ALIGNMENT_AVX);
if (!diagptable)
{
pll_errno = PLL_ERROR_MEM_ALLOC;
Expand All @@ -911,15 +936,42 @@ PLL_EXPORT int pll_core_likelihood_derivatives(unsigned int states,
diagp[0] = exp(t_eigenvals[j] * ki * t_branch_length);
diagp[1] = t_eigenvals[j] * ki * diagp[0];
diagp[2] = t_eigenvals[j] * ki * t_eigenvals[j] * ki * diagp[0];
diagp += 3;
diagp[3] = 0;
diagp += 4;
}
}

sum = sumtable;
invariant_ptr = invariant;
for (n = 0; n < ef_sites; ++n)
{
core_site_likelihood_derivatives(states,
#ifdef HAVE_AVX
if (attrib & PLL_ATTRIB_ARCH_AVX)
{
const double * site_invar_lk = (!invariant || *invariant_ptr == -1) ? NULL : &invar_lk[(*invariant_ptr) * rate_cats];

if (states == 4)
core_site_likelihood_derivatives_4x4_avx(rate_cats,
rate_weights,
prop_invar,
site_invar_lk,
sum,
diagptable,
site_lk);
else
core_site_likelihood_derivatives_avx(states,
states_padded,
rate_cats,
rate_weights,
prop_invar,
site_invar_lk,
sum,
diagptable,
site_lk);
}
else
#endif
core_site_likelihood_derivatives(states,
states_padded,
rate_cats,
rate_weights,
Expand All @@ -929,6 +981,9 @@ PLL_EXPORT int pll_core_likelihood_derivatives(unsigned int states,
sum,
diagptable,
site_lk);

// printf("SITELK: %f %f %f\n", site_lk[0], site_lk[1], site_lk[2]);

invariant_ptr++;
sum += rate_cats * states_padded;

Expand Down Expand Up @@ -1008,6 +1063,11 @@ PLL_EXPORT int pll_core_likelihood_derivatives(unsigned int states,
}
}

free (diagptable);
pll_aligned_free (diagptable);

#ifdef HAVE_AVX
pll_aligned_free (invar_lk);
#endif

return PLL_SUCCESS;
}
Loading

0 comments on commit 3035dd1

Please sign in to comment.