In [8]:
%load_ext cython
import cython

In [60]:
%%cython -f
cimport numpy as np
import numpy as np
from libc.stdlib cimport malloc, free
from scipy.optimize import fmin_l_bfgs_b
ctypedef np.float64_t c_float_t
cdef extern from "/Users/christian/CLionProjects/felsenstein/felsenstein_logspace.c":
    pass
    
A = 20
cdef extern from "/Users/christian/CLionProjects/felsenstein/felsenstein.h":
    
    ctypedef struct NodePrecomputation:
        pass
    
    ctypedef struct Node:
        Node* left
        Node* right

        int seq_id
        c_float_t phi_left
        c_float_t phi_right
        
        NodePrecomputation* data
        
    ctypedef struct NodeBuffer:
        pass
    
    ctypedef struct Buffer:
        NodeBuffer* left
        NodeBuffer* right
    
    ctypedef struct Constants:
        int L;
        c_float_t* single_aa_frequencies
        Node* phylo_tree;
        np.uint8_t* msa;
        int i;
        int j;
    
    void initialize_constants(Constants* consts);
    cdef c_float_t calculate_fx_grad(c_float_t* x, c_float_t* grad, Constants* consts, Buffer* buf)
    cdef void initialize_buffer(NodeBuffer* buffer)


cdef class ExtraArguments:
    
    cdef single_aa_frequencies
    cdef Constants consts
    cdef Buffer buffer
    cdef c_float_t lam
    cdef int n_nodes
    cdef Node* tree_nodes
    
    def __cinit__(self, msa, i, j, lam, node_info):
        N, L = msa.shape
        self.lam = lam
        
        cdef int leaf_index = N - 1
        n_nodes = len(node_info)
        cdef Node* nodes = <Node*> malloc(sizeof(Node)*n_nodes)
        for i, connectivity in enumerate(node_info[::-1]):
            node_idx = n_nodes - 1 - i
            if connectivity is None:
                # this is a leaf
                nodes[node_idx].left = NULL
                nodes[node_idx].right = NULL
                nodes[node_idx].seq_id = leaf_index
                leaf_index -= 1
            else:
                (left_node, left_time), (right_node, right_time) = connectivity
                if left_node is not None:
                    nodes[node_idx].left = &nodes[left_node]
                    nodes[node_idx].phi_left = np.exp(-left_time)
                else:
                    nodes[node_idx].left = NULL
                if right_node is not None:
                    nodes[node_idx].right = &nodes[right_node]
                    nodes[node_idx].phi_right = np.exp(-right_time)
                else:
                    nodes[node_idx].right = NULL
        self.tree_nodes = nodes

        cdef Constants consts = Constants()
        consts.L = L
        consts.phylo_tree = &nodes[0]
        cdef np.uint8_t[:] my_msa = msa.ravel()
        consts.msa = &my_msa[0]
        consts.i = i
        consts.j = j
        aa_counts = np.bincount(msa.ravel(), minlength=20) + 1e-9
        aa_freqs = np.log(aa_counts / aa_counts.sum())
        self.single_aa_frequencies = aa_freqs
        cdef c_float_t* aa_freqs_c = <c_float_t*> malloc(sizeof(c_float_t)*A)
        for a in range(A):
            aa_freqs_c[a] = aa_freqs[a]
        consts.single_aa_frequencies = aa_freqs_c
        initialize_constants(&consts)
        self.consts = consts
        
        cdef Buffer buffer = Buffer()
        cdef NodeBuffer* buffer_left = <NodeBuffer*> malloc(sizeof(NodeBuffer))
        initialize_buffer(buffer_left)
        buffer.left = buffer_left
        cdef NodeBuffer* buffer_right = <NodeBuffer*> malloc(sizeof(NodeBuffer))
        initialize_buffer(buffer_right)
        buffer.right = buffer_right
        self.buffer = buffer
        
    def __dealloc__(self):
        
        for i in range(self.n_nodes):
            free(&self.tree_nodes[i])
        free(self.tree_nodes)
        free(self.buffer.left)
        free(self.buffer.right)
        free(self.consts.single_aa_frequencies)
    
    
def optimize_felsenstein(msa, i, j):
    A = 20
    N, L = msa.shape
    np.random.seed(42)
    x0 = np.random.rand(2*A + A*A)
    
    node_info = [((1,10), (2,15)), None, None]
    extra_args = ExtraArguments(msa, i, j, 10, node_info)
    x_opt, fx_opt, info = fmin_l_bfgs_b(felsenstein_fx_grad, x0, args=(extra_args,), factr=10, pgtol=1e-9)
    info['fx_opt'] = fx_opt
    return x_opt[:2*A], x_opt[2*A:], info


def felsenstein_fx_grad(double[:] x, ExtraArguments extra_args):
    cdef Constants* consts = &extra_args.consts
    cdef Buffer* buffer = &extra_args.buffer
    cdef c_float_t lam = extra_args.lam
    grad = np.empty(2*A + A*A)
    cdef c_float_t[:] grad_c = grad
    fx = -calculate_fx_grad(&x[0], &grad_c[0], consts, buffer)
    print(-fx)
    grad = -grad
    cdef c_float_t penalty = 0
    cdef c_float_t w
    
    cdef int i
    
    for i in range(A*A):
        w = x[2*A + i]
        penalty += 0.5 * lam * (w*w)
        grad[2*A + i] += lam * w
    
    return fx + penalty, grad



In [63]:
import numpy as np
N = 2
L = 5
A = 20
np.random.seed(42)
#msa = np.random.randint(0, A, (N,L), dtype=np.uint8)
msa = np.array([1]*L + [2]*L, dtype=np.uint8).reshape(N, L)
i = 0
j = 1
v_opt, w_opt, info = optimize_felsenstein(msa, i, j)

-12.357366981829946
-12.260279148153693
-11.884334485261803
-10.96509286930284
-10.626345881420079
-9.319649881385871
-5.229800591270545
-3.0806995182616657
-3.801363344879745
-2.8863822057270907
-2.850836993929744
-2.798087370659855
-2.7512613134006703
-2.708878971578925
-2.7272993873462914
-2.703734256441349
-2.694447084376786
-2.6893698900001146
-2.6834206611549107
-2.6815637395540395
-2.680313582685601
-2.6824240157744446
-2.6804179439937252
-2.6797132281311242
-2.6795110975667225
-2.6798712465295065
-2.6797153756196437
-2.6796640295743437
-2.67962087654999
-2.6796630088436
-2.6796340666871017
-2.679596848719222
-2.679628089152268
-2.6796142643767853
-2.679614346136868
-2.679636820288663
-2.679613932244026
-2.679615517473075
-2.6796055704466726
-2.6796574738098915
-2.679628802107053
-2.6796203393315783
-2.679614931570107
-2.6796140589791166
-2.6796156437218297
-2.679615902741192
-2.6796183850785704
-2.6795958131616096
-2.679612370239817
-2.6796155674948627
-2.6796159258395047
-2.67

In [65]:
info

{'grad': array([ 7.42868491e-13,  6.88474297e-07, -6.88487070e-07,  5.27877994e-13,
         9.50602869e-13,  9.50034518e-13,  1.03472045e-12,  3.05879896e-13,
         5.24702997e-13,  4.29888609e-13,  1.06458060e-12,  2.37096865e-13,
         3.29769523e-13,  8.99057074e-13,  9.28046557e-13,  9.25357259e-13,
         8.12353556e-13,  5.98021254e-13,  6.88096763e-13,  8.24585343e-13,
         1.11479112e-14,  2.46120432e-07, -2.46120696e-07,  1.71403616e-14,
         1.48513541e-14,  7.57801704e-15,  2.15089922e-14,  1.34539214e-14,
         1.15742995e-14,  2.53918427e-14,  1.12317069e-14,  2.22625049e-14,
         2.48779387e-14,  4.95265224e-15,  4.71782052e-15,  7.17235931e-15,
         1.87956760e-14,  2.40818951e-14,  9.54608959e-15,  1.52345416e-14,
        -5.90526075e-08,  4.22828973e-07,  3.09927159e-07, -8.71986873e-08,
        -6.58917349e-08, -1.12699467e-07, -5.88853126e-08, -8.91081093e-08,
        -9.53475572e-08, -4.30127321e-08, -9.39376381e-08, -8.01737029e-08,
    

In [37]:
x0 = np.random.rand(2*A + A*A)
epsilon = 1e-9
for comp in range(2*A + A*A):

    x_fwd = x0.copy()
    x_fwd[comp] += epsilon

    x_rev = x0.copy()
    x_rev[comp] -= epsilon

    fx, grad = check_derivative(x0, msa, 0, 1)
    fx_fwd, _ = check_derivative(x_fwd, msa, 0, 1)
    fx_rev, _ = check_derivative(x_rev, msa, 0, 1)

    grad_estim = (fx_fwd - fx_rev) / (2*epsilon)

    print(grad_estim, grad[comp])

0.08992273592411948 0.08992537097098256
-0.8452158972716006 -0.8452152692785194
-0.8775518089976231 -0.8775477783817596
0.11314638115322849 0.1131495031994259
0.07027933790482166 0.07028127268213544
0.06884404157858626 0.06884778522323476
0.05957501159059574 0.05958077378072139
0.15222800797687341 0.15223075965795543
0.1035616037370346 0.10356094082144156
0.11813483524747424 0.11813904992340099
0.06735634272558855 0.06735863459770018
0.1829798534913607 0.1829833013759842
0.1213820155498979 0.1213787375842169
0.07635936327687887 0.07636151843217204
0.06717959522006822 0.06718517359314956
0.06778222427783476 0.0677882345783588
0.07754552555638838 0.077548179581711
0.11529088794759444 0.11529554291877163
0.09677858514578475 0.09678404262084059
0.07436051774334373 0.07436422611806631
0.11238698860438488 0.11238419962016981
-0.9265708200700827 -0.9265720884174728
-0.9245839649452136 -0.9245845872485687
0.08715561605754374 0.08715225926050334
0.09312728366239753 0.09312037552532748
0.1339639

KeyboardInterrupt: 

In [38]:
fmin_l_bfgs_b(check_derivative, x0, args=(msa, 0, 1))

(array([-1.40715672e-01,  6.05850825e+00,  6.32824919e+00, -5.81570330e-02,
        -3.59071518e-01, -2.99525823e-01, -4.12542023e-01,  3.91376788e-02,
         4.45047666e-02,  5.03086831e-02, -5.36673155e-01,  9.60098961e-02,
         1.28132182e-01, -3.01380955e-01, -3.51706262e-01, -2.64822913e-01,
        -2.41894177e-01, -1.69791285e-01, -2.26579569e-01, -2.27944492e-01,
        -1.99721946e-01,  6.56847882e+00,  6.46011431e+00, -1.83940314e-01,
        -1.47203225e-01, -3.52640045e-02, -3.49532193e-01, -2.89494543e-01,
        -1.56131365e-01, -5.23365340e-01, -1.46368650e-01, -3.66453637e-01,
        -4.27989428e-01, -8.09753824e-02, -1.00230649e-01, -1.05797827e-01,
        -3.64346872e-01, -3.48638798e-01, -6.57175472e-03, -9.99879101e-02,
         1.12675349e-01,  3.35919605e-01, -7.87329386e-02,  8.92254082e-01,
         2.49127394e-01,  6.44796194e-01,  3.03422455e-01,  5.07260795e-01,
         5.32394541e-01,  1.78250692e-01,  9.47988938e-01,  7.62251036e-01,
         9.2