Skip to content

Commit

Permalink
[core] refactor and code cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
terhorst committed Mar 31, 2018
1 parent f553b46 commit 39f8164
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 33 deletions.
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -38,6 +38,7 @@
"smcpp._estimation_tools",
sources=["smcpp/_estimation_tools.pyx"],
include_dirs=[np.get_include()],
libraries=libraries
)
]

Expand Down
21 changes: 21 additions & 0 deletions smcpp/_estimation_tools.pyx
@@ -1,6 +1,9 @@
from cython.parallel import prange
from libc.math cimport exp, log
import numpy as np

cdef extern from "<gsl/gsl_sf_gamma.h>":
double gsl_sf_lnbeta(double, double) nogil

def thin_data(data, int thinning, int offset=0):
'''
Expand Down Expand Up @@ -250,3 +253,21 @@ def windowed_mutation_counts(contig, int w):
last[k] = cd[i, k]
ret[j] = [nmiss, mut]
return ret.T


def beta_de_avg_pdf(double[:] X, double[:] y, double h):
ret = np.zeros(y.shape[0])
cdef double[:] vret = ret
cdef int i, j
cdef double a, b, ln_B
for j in prange(y.shape[0], nogil=True):
a = 1. + y[j] / h
b = 1. + (1. - y[j]) / h
ln_B = gsl_sf_lnbeta(a, b)
for i in range(X.shape[0]):
if (a == 1 and X[i] == 0.) or (b == 1 and X[i] == 1.):
vret[j] += exp(-ln_B)
if 0. < X[i] < 1.:
vret[j] += exp((a - 1.) * log(X[i]) + (b - 1.) * log(1. - X[i]) - ln_B)
ret /= len(X)
return ret
3 changes: 0 additions & 3 deletions smcpp/_smcpp.pxd
Expand Up @@ -5,9 +5,6 @@ from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string

cdef extern from "<gsl/gsl_sf_gamma.h>":
double gsl_sf_lnbeta(double, double) nogil

cdef extern from "common.h":
ctypedef vector[vector[adouble]] ParameterVector
cdef cppclass Vector[T]:
Expand Down
17 changes: 0 additions & 17 deletions smcpp/_smcpp.pyx
Expand Up @@ -2,7 +2,6 @@ cimport openmp
cimport numpy as np
from libc.math cimport exp, log
from cython.operator cimport dereference as deref, preincrement as inc
from cython.parallel import prange

import random
import sys
Expand Down Expand Up @@ -481,19 +480,3 @@ def realign(contig, int w):
contig.data = ret


def beta_de_avg_pdf(double[:] X, double[:] y, double h):
ret = np.zeros(y.shape[0])
cdef double[:] vret = ret
cdef int i, j
cdef double a, b, ln_B
for j in prange(y.shape[0], nogil=True):
a = 1. + y[j] / h
b = 1. + (1. - y[j]) / h
ln_B = gsl_sf_lnbeta(a, b)
for i in range(X.shape[0]):
if (a == 1 and X[i] == 0.) or (b == 1 and X[i] == 1.):
vret[j] += exp(-ln_B)
if 0. < X[i] < 1.:
vret[j] += exp((a - 1.) * log(X[i]) + (b - 1.) * log(1. - X[i]) - ln_B)
ret /= len(X)
return ret
9 changes: 5 additions & 4 deletions smcpp/beta_de.py
Expand Up @@ -20,9 +20,8 @@ def harmonic_number(x):
def quantile(X, h, q):
# def g(y):
# return scipy.stats.beta.pdf(X[None, :], 1. + y / h, 1. + (1. - y) / h).mean(axis=1)
import smcpp._smcpp
x = np.linspace(0, 1., 10000)[1:]
y = smcpp._smcpp.beta_de_avg_pdf(X, x, h)
y = estimation_tools.beta_de_avg_pdf(X, x, h) # this is implemented in cython
x = np.r_[0, x]
y = np.cumsum(np.r_[0, y])
y /= y[-1]
Expand Down Expand Up @@ -54,10 +53,11 @@ def sample_beta_kernel(X, mu, h):

def g(y):
return scipy.stats.beta.logpdf(X, 1. + y / h, 1. + (1. - y) / h)

def dg(y):
'(Sign of) dg/dy'
return (harmonic_number((1 - y) / h) -
harmonic_number(y / h) -
return (harmonic_number((1 - y) / h) -
harmonic_number(y / h) -
np.log(1 / X - 1))

# Slice sample from unimodal density g.
Expand All @@ -71,6 +71,7 @@ def sl(z):
else:
assert dg(0) > 0 and dg(1) < 0
mid = scipy.optimize.brentq(dg, 0, 1)

def sl(z):
l, _ = positive_part(lambda y: g(y) - z, 0, mid)
l1, _ = positive_part(lambda y: g(1 - y) - z, 0, 1 - mid)
Expand Down
22 changes: 13 additions & 9 deletions smcpp/estimation_tools.py
Expand Up @@ -11,7 +11,7 @@

from . import util, logging, model, defaults
from .contig import Contig
from ._estimation_tools import realign, thin_data, bin_observations, windowed_mutation_counts
from ._estimation_tools import realign, thin_data, bin_observations, windowed_mutation_counts, beta_de_avg_pdf


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,9 +56,9 @@ def compress_repeated_obs(dataset):


def decompress_polymorphic_spans(dataset):
miss = (np.all(dataset[:, 1::3] == -1, axis=1) &
miss = (np.all(dataset[:, 1::3] == -1, axis=1) &
np.all(dataset[:, 3::3] == 0, axis=1))
nonseg = (np.all(dataset[:, 1::3] == 0, axis=1) &
nonseg = (np.all(dataset[:, 1::3] == 0, axis=1) &
(np.all(dataset[:, 2::3] == dataset[:, 3::3], axis=1) |
np.all(dataset[:, 2::3] == 0, axis=1)))
psp = np.where((dataset[:, 0] > 1) & (~nonseg) & (~miss))[0]
Expand All @@ -69,10 +69,12 @@ def decompress_polymorphic_spans(dataset):
for i in psp:
row = dataset[i]
if first:
ret = np.r_[dataset[last:i], np.tile(np.r_[1, row[1:]], (row[0], 1))]
ret = np.r_[dataset[last:i], np.tile(
np.r_[1, row[1:]], (row[0], 1))]
first = False
else:
ret = np.r_[ret, dataset[last:i], np.tile(np.r_[1, row[1:]], (row[0], 1))]
ret = np.r_[ret, dataset[last:i], np.tile(
np.r_[1, row[1:]], (row[0], 1))]
last = i + 1
ret = np.r_[ret, dataset[last:]]
return ret
Expand All @@ -98,7 +100,8 @@ def recode_nonseg(contig, cutoff):
txt = " (converted to missing)"
d[runs, 1::3] = -1
d[runs, 3::3] = 0
f("Long runs of homozygosity%s in contig %s: \n%s (base pairs)", txt, contig.fn, d[runs, 0])
f("Long runs of homozygosity%s in contig %s: \n%s (base pairs)",
txt, contig.fn, d[runs, 0])
return contig


Expand All @@ -116,7 +119,8 @@ def break_long_spans(contig, span_cutoff):
np.all(obs[:, 3::3] == 0, axis=1))[0]
cob = 0
if obs[long_spans].size:
logger.debug("Long missing spans:\n%s (base pairs)", (obs[long_spans, 0]))
logger.debug("Long missing spans:\n%s (base pairs)",
(obs[long_spans, 0]))
positions = np.insert(np.cumsum(obs[:, 0]), 0, 0)
for x in long_spans.tolist() + [None]:
s = obs[cob:x, 0].sum()
Expand Down Expand Up @@ -195,12 +199,12 @@ def calculate_t1(model, n, q):
import smcpp._smcpp
eta = smcpp._smcpp.PyRateFunction(model, [0., np.inf])
c = n * (n - 1) / 2

def f(t):
return np.expm1(-c * eta.R(t)) + q
return scipy.optimize.brentq(f, 0., model.knots[-1])



def _load_data_helper(fn):
try:
# This parser is way faster than np.loadtxt
Expand All @@ -219,7 +223,7 @@ def _load_data_helper(fn):
attrs = json.loads(first_line[7:])
a = [len(a) for a in attrs['dist']]
n = [len(u) for u in attrs['undist']]
if "pids" not in attrs:
if "pids" not in attrs:
# FIXME this code really only exists to analyze old data sets (before the fmt changed)
# it should probably be removed
attrs["pids"] = ["pop%d" % i for i, _ in enumerate(a, 1)]
Expand Down

0 comments on commit 39f8164

Please sign in to comment.