# make a class object

In [None]:
#!/usr/bin/env python

"""
Run drift simulations on quartet trees
"""

# imports for py3 compatibility
from __future__ import print_function
from builtins import range, input

import numpy as np
from numba import jit

class QuartetDrifter:
    """
    A coalescent model for returning ms simulations.
    """
    def __init__(
        self, 
        dem_dict,
        num_genes=1000,
        seed = None,
        Ne = None
        ):
        """
        demography should be an object des
        """
        # init random seed
        if seed:
            np.random.seed(seed)

        # params
        self.demography = {"rootNe": Ne,
                             "br4":None,
                             "br123":None,
                             "br3":None,
                             "br12":None,
                             "br1":None,
                             "br2":None,
                             "node12":Ne,
                             "node123":Ne*2,
                             "tips":Ne*4, # node height is 0
                             "flow":None
                            }
        # set demography to user specification
        for key in demography.keys():
            self.demography[key] = demography[key]
        
        self.num_genes = num_genes
        self.seed = seed

        # save starting allele freqs
        

    @property
    def starting_variants(self):
        "calculate starting population variants"
        freqs = np.random.uniform(0,1,size=self.num_genes).reshape(self.num_genes,1,1)
        starting_variants = (freqs*self.demography["rootNe"]).astype(int)
        return starting_variants
    
    def _add_one_gen(population_obj, anc_pop_size, num_genes):
        # we just want the latest gen from our population obj
        obj = np.zeros((num_genes),dtype = int)
        for gene_idx in range(len(population_obj)):
            obj[gene_idx] = population_obj[gene_idx][0]

        # now convert the number of variants to frequencies
        obj_freqs = np.divide(obj.astype(float), anc_pop_size)

        # now sample a new variant number for each gene based on these frequencies
        newgen = np.random.binomial(n=anc_pop_size,p=obj_freqs).reshape(num_genes,1,1)

        # return the full array with the additional generation
        #return(np.hstack([newgen,population_obj]))

        # return just a single row, replacing the previous
        return(newgen)
    
    def _run_a_bunch(population_obj, 
                    ngens,
                   anc_pop_size):
        '''
        uses the add_generation function to resample across a bunch of generations.
        '''
        p = deepcopy(population_obj)
        num_genes = population_obj.shape[0]
        for _ in xrange(ngens):
            p = _add_one_gen(p,anc_pop_size,num_genes)
        return p

    def run(self):
        total_tree_height = self.demography["tips"]
        n_segments = 6+


class PopSizeChange:
    def __init__(
        self,
        time,
        Ne
        ):
        '''
        These can be arranged in a list for the dict
        size: the new population size
        time: time (in gens) of the population size change. The tips are 0, root is total height.
        '''
        self.Ne = Ne
        self.time = time


class GeneFlow:
    def __init__(
        self,
        time,
        source,
        dest,
        prop
        ):
        '''
        time: time (in gens) of the population size change. The tips are 0, root is total height.
        source: the branch sending genes
        dest: the branch getting genes
        prop: the proportion of the destination branch that is to be sampled from the source branch
        '''
        self.time = time
        self.source = source
        self.dest = dest
        self.prop = prop

In [26]:
temp.items()

[('rootNe', 1000),
 ('node12', 1000),
 ('br12', None),
 ('node123', 2000),
 ('br123', None),
 ('tips', 4000),
 ('br4', None),
 ('br3', None),
 ('br2', None),
 ('br1', None)]

In [None]:
def set_demography(self,demography):
    for key in test.keys():
        temp[key] = test[key]

In [22]:
Ne=1000
temp = {"rootNe": Ne,
                     "br4":None,
                     "br123":None,
                     "br3":None,
                     "br12":None,
                     "br1":None,
                     "br2":None,
                     "node12":Ne,
                     "node123":Ne*2,
                     "tips":Ne*4, # node height is 0
                    }

In [18]:
test={"rootNe": 1000,
 "br4":[PopSizeChange(time=1000,Ne=500)],
 "node123":500,
 "node12":1200,
 "tips":1500, # node height is 0
}

In [23]:
for key in test.keys():
    temp[key] = test[key]

['rootNe', 'tips', 'br4', 'node12', 'node123']

In [27]:
test=PopSizeChange(1000,2000)