# A Wright-Fisher simulation implemented in C via Cython.

OMG!

We would use GSL via CythonGSL, but that would require a GPL license for this notebook, but we're doing CCBY.

In [1]:
%load_ext Cython

In [2]:
import msprime
import numpy as np

  from ._conv import register_converters as _register_converters


In [119]:
%%cython -3 -lgsl -lgslcblas -lm

import msprime
import numpy as np
from collections import namedtuple
import pickle
cimport numpy as np
from cython.view cimport array as cvarray
from libc.stdlib cimport malloc, realloc, free
from libc.stdint cimport int32_t, uint32_t

from cython_gsl.gsl_rng cimport gsl_rng,gsl_rng_mt19937,gsl_rng_alloc,gsl_rng_free,gsl_rng_set
from cython_gsl.gsl_rng cimport gsl_rng_uniform
from cython_gsl.gsl_random cimport gsl_ran_flat
from cython_gsl.gsl_random cimport gsl_ran_poisson

cdef int32_t * malloc_int32_t(size_t n):
    return <int32_t*>malloc(n*sizeof(int32_t))

cdef int32_t * realloc_int32_t(void * x, size_t n):
    return <int32_t*>realloc(x,n*sizeof(int32_t))

cdef double * malloc_double(size_t n):
    return <double*>malloc(n*sizeof(double))

cdef double * realloc_double(double * x, size_t n):
    return <double*>realloc(<double *>x,n*sizeof(double))

cdef struct Mutations:
    double * pos
    int32_t * time
    int32_t * node
    size_t next_mutation, capacity
    
cdef int init_Mutations(Mutations * m):
    m.next_mutation = 0
    m.capacity = 10000
    m.pos = malloc_double(m.capacity)
    if m.pos == NULL:
        return -1
    m.time = malloc_int32_t(m.capacity)
    if m.time == NULL:
        return -1
    m.node = malloc_int32_t(m.capacity)
    if m.node == NULL:
        return -1
    return 0

cdef int realloc_Mutations(Mutations * m):
    m.capacity *= 2
    m.pos = realloc_double(m.pos,
                          m.capacity)
    if m.pos == NULL:
        return -1
    m.time = realloc_int32_t(m.time,
                            m.capacity)
    if m.time == NULL:
        return -1
    m.node = realloc_int32_t(m.node,
                            m.capacity)
    if m.node == NULL:
        return -1
    return 0

cdef void free_Mutations(Mutations * m):
    free(m.pos)
    free(m.time)
    free(m.node)
    m.next_mutation = 0
    m.capacity = 10000
    
cdef int add_mutation(double pos,
                     int32_t generation,
                     int32_t node,
                     Mutations * m):
    cdef int rv = 0
    if m.next_mutation+1 >= m.capacity:
        rv = realloc_Mutations(m)
        if rv != 0:
            return rv
    m.pos[m.next_mutation] = pos
    m.time[m.next_mutation] = generation
    m.node[m.next_mutation] = node
    m.next_mutation+=1
    return rv
    
cdef struct Nodes:
    double * time
    size_t next_node, capacity
    
cdef int init_Nodes(Nodes * n):
    n.next_node = 0
    n.capacity = 10000
    n.time = malloc_double(n.capacity)
    if n.time == NULL:
        return -1
    return 0

cdef int realloc_Nodes(Nodes * n):
    n.capacity *= 2
    n.time = realloc_double(n.time,
                            n.capacity)
    if n.time == NULL:
        return -1
    return 0
    
cdef void free_Nodes(Nodes * n):
    if n.time != NULL:
        free(n.time)
    n.next_node = 0
    n.capacity = 10000

cdef int add_node(double t, Nodes *n):
    cdef int rv = 0
    if n.next_node >= n.capacity:
        rv = realloc_Nodes(n)
        if rv != 0:
            return rv
    n.time[n.next_node] = t
    n.next_node+=1
    return rv
    
cdef struct Edges:
    double *left
    double *right
    int32_t *parent
    int32_t *child
    size_t next_edge, capacity
    
cdef int init_Edges(Edges * e):
    e.next_edge = 0
    e.capacity = 10000
    e.left = malloc_double(e.capacity)
    if e.left == NULL:
        return -1
    e.right = malloc_double(e.capacity)
    if e.right == NULL:
        return -1
    e.parent = malloc_int32_t(e.capacity)
    if e.parent == NULL:
        return -1
    e.child = malloc_int32_t(e.capacity)
    if e.child == NULL:
        return -1
    return 0
   
cdef int realloc_Edges(Edges * e):
    e.capacity *= 2
    e.left = realloc_double(e.left,e.capacity)
    if e.left == NULL:
        return -1
    e.right = realloc_double(e.right,e.capacity)
    if e.right == NULL:
        return -1
    e.parent = realloc_int32_t(e.parent,e.capacity)
    if e.parent == NULL:
        return -1
    e.child = realloc_int32_t(e.child,e.capacity)
    if e.child == NULL:
        return -1
    return 0

cdef void free_Edges(Edges * e):
    free(e.left)
    free(e.right)
    free(e.parent)
    free(e.child)
    e.next_edge = 0
    e.capacity = 10000
    
cdef int add_edge(double left, double right,
             int32_t parent, int32_t child,
             Edges * edges):
    cdef int rv=0
    if edges.next_edge+1 >= edges.capacity:
        rv = realloc_Edges(edges)
        if rv != 0:
            return rv
        
    edges.left[edges.next_edge] = left
    edges.right[edges.next_edge] = right
    edges.parent[edges.next_edge] = parent
    edges.child[edges.next_edge] = child
    edges.next_edge += 1
    return rv

cdef struct Tables:
    Nodes nodes
    Edges edges
    Mutations mutations
    gsl_rng * rng
    
cdef int init_Tables(Tables * t, int seed):
    cdef int rv = 0
    rv = init_Nodes(&t.nodes)
    if rv != 0:
        return rv
    rv = init_Edges(&t.edges)
    if rv != 0:
        return rv
    rv = init_Mutations(&t.mutations)
    if rv != 0:
        return rv
    t.rng = gsl_rng_alloc(gsl_rng_mt19937)
    if t.rng == NULL:
        return -1
    gsl_rng_set(t.rng, seed)
    return rv

cdef void free_Tables(Tables * t):
    free_Nodes(&t.nodes)
    free_Edges(&t.edges)
    free_Mutations(&t.mutations)
    gsl_rng_free(t.rng)
    
cdef int infsites(double mu, int32_t generation,
                  int32_t next_offspring_index,
                  Tables * tables,
                  dict lookup):
    cdef unsigned nmut = gsl_ran_poisson(tables.rng, mu)
    cdef unsigned i = 0
    cdef double pos
    cdef int rv = 0
    for i in range(nmut):
        pos = gsl_rng_uniform(tables.rng)
        while pos in lookup:
            pos = gsl_rng_uniform(tables.rng)
        rv = add_mutation(pos,
                         generation,
                         next_offspring_index,
                         &tables.mutations)
        if rv != 0:
            return rv
        lookup[pos] = True
    return rv

cdef int poisson_recombination(double r,
                               size_t pg1, size_t pg2,
                                int32_t next_offspring_id,
                               Tables * tables):
    cdef unsigned nbreaks = gsl_ran_poisson(tables.rng, r)
    cdef list b = []
    cdef size_t lenb
    cdef unsigned i = 0
    cdef double x
    # cdef list pgams
    cdef int rv = 0
    cdef double left,right
    cdef int32_t p
    if nbreaks == 0:
        # The parent passes the entire region onto the child
        rv = add_edge(0.0,1.0,pg1,
                      next_offspring_id,&tables.edges)
        if rv != 0:
            return rv
    else:
        while i < nbreaks:
            x = gsl_rng_uniform(tables.rng)
            while x in b:
                x = gsl_rng_uniform(tables.rng)
            b.append(x)
            i += 1
        b.sort()
        b.append(1.0)

        if b[0] != 0.0:
            b.insert(0,0.0)
        else:
            pg1,pg2 = pg2,pg1
            #parent_indexes = (parent_indexes[1], parent_indexes[0])
        
        #pgams = list([*tuple(i for i in parent_indexes)]*int(len(b)/2))

        #for left,right in zip(b[0:len(b)-1],b[1:len(b)]):
        lenb = len(b)-1
        for i in range(lenb):
            left=b[i]
            right=b[i+1]
            # print(pg1,pg2,left,right)
            rv = add_edge(left,right,pg1,
                          next_offspring_id,&tables.edges)
            pg1,pg2 = pg2,pg1
            if rv != 0:
                return rv
    # print("//")
    return 0

MutationMetadata = namedtuple('MutationMetadata',['pos','origin'])

cdef int simplify(Tables * tables, 
            double g,
            double dt,
            object nodes,
            object edges,
            object sites,
            object mutations):
    if tables.nodes.next_node > 0:
        assert(g == tables.nodes.time[tables.nodes.next_node-1])
    cdef int rv = 0
    cdef size_t i,nedges=len(edges)
    cdef np.ndarray[double,ndim=1] dview,lview,rview
    cdef np.ndarray[int32_t,ndim=1] pview,cview
    # Reverse time for our new nodes
    for i in range(tables.nodes.next_node):
        tables.nodes.time[i] = -1.0*(tables.nodes.time[i]-g)
    
    
    nodes.set_columns(time=nodes.time+dt,flags=nodes.flags)
    dview = np.asarray(<double[:tables.nodes.next_node]>tables.nodes.time)
    
    nodes.append_columns(time=dview,
                        flags=np.ones(tables.nodes.next_node,dtype=np.uint32))
    
    
    lview = np.asarray(<double[:tables.edges.next_edge]>tables.edges.left)
    rview = np.asarray(<double[:tables.edges.next_edge]>tables.edges.right)
    pview = np.asarray(<int32_t[:tables.edges.next_edge]>tables.edges.parent)
    cview = np.asarray(<int32_t[:tables.edges.next_edge]>tables.edges.child)
    edges.append_columns(left=lview,
                        right=rview,
                        parent=pview,
                        child=cview)
    
    # Disclaimer: we will use .add_row here 
    # to keep the code as simple as possible.
    # However, things will go MUCH faster
    # if we use .append_columns, whose syntax
    # is not going to be pleasant here :)
    try:
        for i in range(tables.mutations.next_mutation):
            sites.add_row(position=tables.mutations.pos[i],
                         ancestral_state='0')
            # md = MutationMetadata(float(tables.mutations.pos[i]),
            #                       int(tables.mutations.time[i]))
            # print(md)
            #mdp = pickle.dumps(md)
            #print(type(mdp))
            mutations.add_row(site=<uint32_t>(len(sites)-1),
                              node=<uint32_t>tables.mutations.node[i],
                              derived_state='1')
                              # metadata=md)
    except:
        return -1
    msprime.sort_tables(nodes=nodes,edges=edges,
                       sites=sites,mutations=mutations)
                       #edge_start=nedges)
    
    samples = np.where(nodes.time == 0)[0]
    print(len(nodes),len(edges),len(sites),len(mutations),len(samples))
    print(samples) 
    msprime.simplify_tables(samples=samples.tolist(),
                            nodes=nodes,
                           edges=edges,
                           sites=sites,
                           mutations=mutations)
    
    # "clear" our temp containers
    tables.nodes.next_node = 0
    tables.mutations.next_mutation = 0
    tables.edges.next_edge = 0
                          
    return rv

def evolve(int N, int ngens, double theta, double rho, int gc, int seed):
    nodes = msprime.NodeTable()
    edges = msprime.EdgeTable()
    sites = msprime.SiteTable()
    mutations = msprime.MutationTable()
    
    cdef double mu = theta/<double>(4*N)
    cdef double r = rho/<double>(4*N)
    
    cdef int rv
    cdef size_t i, generation
    cdef Tables tables
    rv = init_Tables(&tables, seed)
    if rv != 0:
        free_Tables(&tables)
        raise RuntimeError("could not initialize tables")
        
    for i in range(2*<size_t>N):
        nodes.add_row(time=0.0,
                      flags=msprime.NODE_IS_SAMPLE)
        
    
    cdef int32_t next_offspring_index, first_parental_index
    next_offspring_index = len(nodes)
    first_parental_index = 0
    PARENT_DTYPE = np.int32
    cdef np.ndarray[int32_t,ndim=1] parents
    cdef size_t parent1, parent2,pindex
    cdef int32_t p1g1, p1g2, p2g1, p2g2
    cdef dict lookup = {}
    cdef size_t last_gen_gc = 0
    for generation in range(1,<size_t>(ngens+1)):
        if generation%gc == 0.0:
            if tables.nodes.next_node > 0:
                print(generation-1,
                     tables.nodes.time[tables.nodes.next_node-1])
            rv = simplify(&tables,
                         generation-1,
                         generation-last_gen_gc,
                         nodes,edges,sites,mutations)
            last_gen_gc=generation
            next_osffpring_index = len(nodes)
            first_parental_index = 0
            if rv != 0:
                raise RuntimeError("simplification error")
        else:
            first_parental_index = next_offspring_index - 2*N
        for pindex in range(0,2*N,2):
            parent1=<size_t>gsl_ran_flat(tables.rng,0.0,<double>N)
            parent2=<size_t>gsl_ran_flat(tables.rng,0.0,<double>N)
            p1g1 = first_parental_index + 2*parent1
            p1g2 = p1g1 + 1
            p2g1 = first_parental_index + 2*parent2
            p2g2 = p2g1 + 1
            
            if gsl_rng_uniform(tables.rng) < 0.5:
                p1g1, p1g2 = p1g2, p1g1
            if gsl_rng_uniform(tables.rng) < 0.5:
                p2g1, p2g2 = p2g2, p2g1
                
            rv = poisson_recombination(r,p1g1,p1g2,
                                      next_offspring_index,
                                      &tables)
            if rv != 0:
                free_Tables(&tables)
                raise RuntimeError("error during recombination")
                
            rv = infsites(mu,generation,
                          next_offspring_index,
                         &tables,lookup)
            
            if rv != 0:
                free_Tables(&tables)
                raise RuntimeError("error during mutation")
                
            rv = add_node(<double>generation, &tables.nodes)
            if rv != 0:
                free_Tables(&tables)
                raise RuntimeError("error during adding nodes")
                
            next_offspring_index += 1
                
            rv = poisson_recombination(r,p2g1,p2g2,
                                      next_offspring_index,
                                      &tables)
            if rv != 0:
                free_Tables(&tables)
                raise RuntimeError("error during recombination")
                
            rv = infsites(mu,generation,
                          next_offspring_index,
                         &tables,lookup)
            
            if rv != 0:
                free_Tables(&tables)
                raise RuntimeError("error during mutation")
                
            rv = add_node(<double>generation, &tables.nodes)
            if rv != 0:
                free_Tables(&tables)
                raise RuntimeError("error during adding nodes")
                
            next_offspring_index += 1
        # first_parental_index += 2*N
        
    if tables.nodes.next_node > 0:
        rv=simplify(&tables,
                   generation,
                   generation-last_gen_gc,
                    nodes,edges,sites,mutations
                   )
    
    # Push first nodes times further back
    # nodes.set_columns(time=nodes.time + ngens +1,
    #                  flags=nodes.flags)
    #         
    # # Add our data to the tables
    # cdef double[:] timeview = <double[:tables.nodes.next_node]>tables.nodes.time
    # time=np.asarray(timeview,np.float)
    # time-=time.max()
    # time*=-1.0
    # nodes.append_columns(time=time,
    #                      flags=np.ones(tables.nodes.next_node,np.uint32))
    # edges.append_columns(left=np.asarray(<double[:tables.edges.next_edge]>tables.edges.left),
    #                     right=np.asarray(<double[:tables.edges.next_edge]>tables.edges.right),
    #                     parent=np.asarray(<int32_t[:tables.edges.next_edge]>tables.edges.parent),
    #                     child=np.asarray(<int32_t[:tables.edges.next_edge]>tables.edges.child))
    # 
    # print(nodes.time.min(),nodes.time.max())
# 
    # msprime.sort_tables(nodes=nodes,edges=edges)
    # 
    # samples = np.where(nodes.time==0)[0]
    # 
    # print(samples)
    # 
    # msprime.simplify_tables(samples=samples.tolist(),
    #                        nodes=nodes,
    #                        edges=edges)
    #             
    # free_Tables(&tables)
    return msprime.load_tables(nodes=nodes,edges=edges)

In [120]:
%%time
evolve(1000, 11, 500.0, 500.0, 10, 42)

9 9.0
20000 20270 2259 2259 2000
[18000 18001 18002 ... 19997 19998 19999]


IndexError: Array index out of bounds

Exception ignored in: '_cython_magic_1b669dc1c001826c4bcf4f68dd6a6deb.simplify'
Traceback (most recent call last):
  File "/Users/kevin/anaconda3/lib/python3.5/site-packages/msprime/tables.py", line 1229, in sort_tables
    return _msprime.sort_tables(**kwargs)
IndexError: Array index out of bounds


LibraryError: Edges must be listed in (time[parent], child, left) order; time[parent] order violated