In [57]:
%load_ext cython
import cython

The cython extension is already loaded. To reload it, use:
  %reload_ext cython


In [71]:
%%cython -f
cimport numpy as np
import numpy as np
from libc.stdlib cimport malloc, free
from scipy.optimize import fmin_l_bfgs_b
ctypedef np.float64_t c_float_t
cdef extern from "/Users/christian/CLionProjects/felsenstein/felsenstein.c":
    pass
    
A = 20
cdef extern from "/Users/christian/CLionProjects/felsenstein/felsenstein.h":
    
    ctypedef struct NodePrecomputation:
        pass
    
    ctypedef struct Node:
        Node* left
        Node* right

        int seq_id
        c_float_t phi_left
        c_float_t phi_right
        
        NodePrecomputation* data
        
    ctypedef struct NodeBuffer:
        pass
    
    ctypedef struct Buffer:
        NodeBuffer* left
        NodeBuffer* right
    
    ctypedef struct Constants:
        int L;
        c_float_t* single_aa_frequencies
        Node* phylo_tree;
        np.uint8_t* msa;
        int i;
        int j;
    
    void initialize_constants(Constants* consts);
    cdef c_float_t calculate_fx_grad(c_float_t* x, c_float_t* grad, Constants* consts, Buffer* buf)
    cdef void initialize_buffer(NodeBuffer* buffer)
    
    # debug stuff
    int check_tree(Node* tree)
    int check_msa(Constants* consts, int n, int i, int N, int L);
    c_float_t check_freqs(Constants* consts, int a)


cdef class ExtraArguments:
    
    cdef single_aa_frequencies
    cdef Constants consts
    cdef Buffer buffer
    cdef c_float_t lam
    cdef int n_nodes
    cdef Node* tree_nodes
    
    def __cinit__(self, msa, i, j, lam, node_info):
        N, L = msa.shape
        self.lam = lam
        
        cdef int leaf_index = N - 1
        n_nodes = len(node_info)
        cdef Node* nodes = <Node*> malloc(sizeof(Node)*n_nodes)
        for i, connectivity in enumerate(node_info[::-1]):
            node_idx = n_nodes - 1 - i
            if connectivity is None:
                # this is a leaf
                nodes[node_idx].left = NULL
                nodes[node_idx].right = NULL
                nodes[node_idx].seq_id = leaf_index
                leaf_index -= 1
            else:
                (left_node, left_time), (right_node, right_time) = connectivity
                if left_node is not None:
                    nodes[node_idx].left = &nodes[left_node]
                    nodes[node_idx].phi_left = np.exp(-left_time)
                else:
                    nodes[node_idx].left = NULL
                if right_node is not None:
                    nodes[node_idx].right = &nodes[right_node]
                    nodes[node_idx].phi_right = np.exp(-right_time)
                else:
                    nodes[node_idx].right = NULL
        self.tree_nodes = nodes

        cdef Constants consts = Constants()
        consts.L = L
        consts.phylo_tree = &nodes[0]
        cdef np.uint8_t[:] my_msa = msa.ravel()
        consts.msa = &my_msa[0]
        consts.i = i
        consts.j = j
        aa_counts = np.bincount(msa.ravel(), minlength=20)
        aa_freqs = aa_counts / aa_counts.sum()
        self.single_aa_frequencies = aa_freqs
        cdef c_float_t* aa_freqs_c = <c_float_t*> malloc(sizeof(c_float_t)*A)
        for a in range(A):
            aa_freqs_c[a] = aa_freqs[a]
        consts.single_aa_frequencies = aa_freqs_c
        initialize_constants(&consts)
        self.consts = consts
        
        cdef Buffer buffer = Buffer()
        cdef NodeBuffer* buffer_left = <NodeBuffer*> malloc(sizeof(NodeBuffer))
        initialize_buffer(buffer_left)
        buffer.left = buffer_left
        cdef NodeBuffer* buffer_right = <NodeBuffer*> malloc(sizeof(NodeBuffer))
        initialize_buffer(buffer_right)
        buffer.right = buffer_right
        self.buffer = buffer
        
    def __dealloc__(self):
        
        for i in range(self.n_nodes):
            free(&self.tree_nodes[i])
        free(self.tree_nodes)
        free(self.buffer.left)
        free(self.buffer.right)
        free(self.consts.single_aa_frequencies)
    
    
def optimize_felsenstein(msa, i, j):
    A = 20
    N, L = msa.shape
    np.random.seed(42)
    x0 = np.random.rand(2*A + A*A)
    
    node_info = [((1,10), (2,15)), None, None]
    extra_args = ExtraArguments(msa, i, j, 10, node_info)
    x_opt, fx_opt, info = fmin_l_bfgs_b(felsenstein_fx_grad, x0, args=(extra_args,), factr=10, pgtol=1e-9)
    info['fx_opt'] = fx_opt
    return x_opt[:2*A], x_opt[2*A:], info


def felsenstein_fx_grad(double[:] x, ExtraArguments extra_args):
    cdef Constants* consts = &extra_args.consts
    cdef Buffer* buffer = &extra_args.buffer
    cdef c_float_t lam = extra_args.lam
    grad = np.empty(2*A + A*A)
    cdef c_float_t[:] grad_c = grad
    fx = calculate_fx_grad(&x[0], &grad_c[0], consts, buffer)
    print(np.log(fx))
    
    cdef c_float_t penalty = 0
    cdef c_float_t w
    
    cdef int i
    
    for i in range(2*A + A*A):
        grad[i] /= -fx    
    
    for i in range(A*A):
        w = x[2*A + i]
        penalty += 0.5 * lam * (w*w)
        grad[2*A + i] += lam * w
    
    return -np.log(fx) + penalty, grad
    
  



In [72]:
import numpy as np
N = 2
L = 5
A = 20
np.random.seed(42)
#msa = np.random.randint(0, A, (N,L), dtype=np.uint8)
msa = np.array([1]*L + [2]*L, dtype=np.uint8).reshape(N, L)
i = 0
j = 1
v_opt, w_opt, info = optimize_felsenstein(msa, i, j)

-12.35736698182829
-12.260278867091179
-11.8843330814185
-10.965090456780134
-10.626343721940888
-9.319648728958004
-5.229803006898489
-3.0806398006547617
-3.8012680985986225
-2.8863810241284393
-2.8508362971489487
-2.798090602849246
-2.7512619437755075
-2.7089223363226327
-2.7272930975903344
-2.7037270989049236
-2.6944520775292466
-2.689358681551149
-2.68341663797362
-2.6815650147010874
-2.680313115612099
-2.682434039138206
-2.6804195399255457
-2.679711240027099
-2.679509584531197
-2.679876303714885
-2.679715568512531
-2.679663666343296
-2.679619853299775
-2.6796626027386035
-2.679633215061479
-2.6795966206939306
-2.6796276084461956
-2.6796145468556447
-2.6796144993694773
-2.6796369228194683
-2.6796137767088504
-2.6796155160556414
-2.6796048094841476
-2.6796561435345
-2.679628939507259
-2.679620384330889
-2.6796149440540913
-2.679613876219724
-2.67961551078156
-2.679616014538998
-2.679617555634602
-2.679596628277287
-2.6796104409207095
-2.6796153703668817
-2.679615880787527
-2.6796181

In [65]:
info['fx_opt']

2.724968817196762

In [73]:
info

{'grad': array([ 1.65316017e-13, -1.97276457e-07,  1.97273600e-07,  1.13928150e-13,
         2.16961595e-13,  2.16822719e-13,  2.38519163e-13,  6.31864877e-14,
         1.13184877e-13,  9.12128151e-14,  2.46299130e-13,  4.80478457e-14,
         6.85127905e-14,  2.03950542e-13,  2.11234995e-13,  2.10571161e-13,
         1.82349464e-13,  1.30468142e-13,  1.52021013e-13,  1.85374485e-13,
         1.86895825e-15, -2.33680938e-07,  2.33680893e-07,  2.98118204e-15,
         2.55043372e-15,  1.23252090e-15,  3.82236596e-15,  2.29092113e-15,
         1.94637832e-15,  4.59176819e-15,  1.88418206e-15,  3.97000754e-15,
         4.48941639e-15,  7.80342440e-16,  7.40687262e-16,  1.16164377e-15,
         3.29680370e-15,  4.33008190e-15,  1.58081721e-15,  2.62208787e-15,
        -1.45989217e-08,  1.27233585e-08,  7.31689568e-10, -1.14534167e-10,
        -2.92969152e-09,  9.33291527e-09,  1.03315909e-09,  8.47000526e-09,
         9.05240705e-09, -4.54616231e-09, -1.04279522e-08,  7.71452443e-09,
    