In [100]:
import numpy as np
import msprime
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import resample
from statistics import mean, stdev
from sklearn.model_selection import KFold
from math import sqrt
import seaborn as sns
import allel

In [2]:
seq_len = 1e8
rec_rate = 1e-8
mut_rate = 1e-8
split_time = 50 # in generations

In [3]:
demography = msprime.Demography()
demography.add_population(name="A", initial_size=1000) 
demography.add_population(name="B", initial_size=1000)
demography.add_population(name="C", initial_size=1000)
demography.add_population_split(time=split_time, derived=["A", "B"], ancestral="C")

PopulationSplit(time=50, derived=['A', 'B'], ancestral='C')

In [4]:
 ts = msprime.sim_ancestry(
        samples={'A':1000, 'B':1000}, # diploid samples
        demography=demography,
        ploidy=2,
        sequence_length=seq_len,
        discrete_genome=False,
        recombination_rate=rec_rate, 
        model='dtwf',
    )

ts = msprime.sim_mutations(
    ts, 
    rate=mut_rate, 
    discrete_genome=False,
    start_time=split_time,
    )

In [13]:
ts

Tree Sequence,Unnamed: 1
Trees,46535
Sequence Length,100000000.0
Sample Nodes,4000
Total Size,7.7 MiB
Metadata,No Metadata

Table,Rows,Size,Has Metadata
Edges,174688,4.7 MiB,
Individuals,2000,31.3 KiB,
Migrations,0,4 Bytes,
Mutations,22564,639.0 KiB,
Nodes,31076,728.3 KiB,
Populations,3,278 Bytes,✅
Provenances,2,2.5 KiB,
Sites,22564,374.6 KiB,


In [22]:
for site in ts.sites():
    print(site)
    break

Site(id=0, position=8473.858337816917, ancestral_state='A', mutations=[Mutation(id=0, site=0, node=23066, derived_state='T', parent=-1, metadata=b'', time=831.8728790942114)], metadata=b'')


22564

## Functions

In [6]:
def get_fst(dA, dB, dAB):
    mean_within = (dA + dB) / 2
    between = dAB 
    Fst = 1 - mean_within.sum() / between.sum()
    return Fst

In [7]:
def get_fst_general(ts, pop1_samples, pop2_samples, sites_index):
    """return fst
    ts = tree seqeunce
    pop1_samples = the samples from the first popualtion to be used
    pop2_samples = the samples from the second population to be used
    sites_index = the indexs of the sites to be used.
   
    all of  (pop1_samples, pop2_samples, and sites_index) may have duplicates.
    """
   
    ga = allel.GenotypeArray(
        ts.genotype_matrix().reshape(
            ts.num_sites, ts.num_samples, 1),
        dtype='i1')
    
    ac1 = ga[sites_index][:, pop1_samples, :].count_alleles()
    ac2 = ga[sites_index][:, pop2_samples].count_alleles()
    num, denom = allel.hudson_fst(ac1, ac2)
    fst = np.sum(num) / np.sum(denom)
    return fst

In [8]:
def get_CI_normal(data):
    est_mean = mean(data)
    
    length = len(data)

    upper = est_fst + 1.96 * (stdev(data) / sqrt(length))
    lower = est_fst - 1.96 * (stdev(data) / sqrt(length))

    return lower, upper


def get_CI_quantile(data):

    lower = np.quantile(data, 0.05)
    upper = np.quantile(data, 0.95)

    return lower, upper

#### Bootstrap Resampling

In [9]:
def bootstrap(ts, popA, popB, bootstrap_time, bootstrap_size, size, fst, sites_index, method):
    '''
    Bootstrap Resampling over the given sample and return the coverage rate
    @method = 0 over individuals otherwise over sites
    ''' 
    
    popA_resample = popA
    popB_resample = popB
    sites_resample = sites_index
    
    within = []
    for i in range(bootstrap_time):
        fst_bt = []
        for j in range(bootstrap_size):
            
            if method == 0:
                popA_resample = np.random.choice(popA, size, replace=True)
                popB_resample = np.random.choice(popB, size, replace=True)
            else:
                sites_resample = np.random.choice(sites_index, size, replace=True)
                
            fst_bt.append(get_fst_general(ts, popA_resample, popB_resample, sites_resample))
            
        lower, upper = get_CI_quantile(fst_bt)
        
        print((lower, upper))

        if fst < upper and fst > lower:
            within.append(1)
        else:
            within.append(0)
            
            
    print('Coverage rate is: ', mean(within))

    return mean(within)

#### Jackknife Resampling 
see https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold 
for reference of kfold split

In [92]:
def jackknife_ind(kf, ts, popA, popB, sites_index):
    index = np.arange(len(popA))
    fst_block = []
    
    for index,_ in kf.split(index):
        popA_block = popA[index]
        popB_block = popB[index]
        fst_block.append(get_fst_general(ts, popA_block, popB_block, sites_index))
        
    return mean(fst_block)

def jackknife_sites(kf, ts, popA, popB, sites_index):
    fst_block = []
    
    for index,_ in kf.split(sites_index):
        sites_block = sites_index[index]
        fst_block.append(get_fst_general(ts, popA, popB, sites_block))
        
    return mean(fst_block)

def jackknife(ts, popA, popB, bootstrap_time, bootstrap_size, size, fst, sites_index, method):
     
    within = []
    kf = KFold(n_splits=10, shuffle=True)
    
    for i in range(bootstrap_time):
        fst_jk = []
        for j in range(bootstrap_size):

            if method == 0:
                func = jackknife_ind 
            else:
                func = jackknife_sites
                
            fst_jk.append(func(kf, ts, popA, popB, sites_index))
            
        lower, upper = get_CI_quantile(fst_jk)
        
        print((lower, upper))

        if fst < upper and fst > lower:
            within.append(1)
        else:
            within.append(0)
                
    print(f'Coverage rate is: {mean(within)} \n')

    return mean(within)
    

In [90]:
def replicate(replicate_time, ts, popA, popB, bootstrap_time, 
                  bootstrap_size, size, fst, sites_index, func, method):
    
    popA_sample = popA
    popB_sample = popB
    sites_sample = sites_index
    
    rate_list = []
            
    for time in range(replicate_time):
        
        print(f'Replicate time: {time}')
        
        if method == 0:
            popA_sample = np.random.choice(popA, size, replace=False)
            popB_sample = np.random.choice(popB, size, replace=False)
        else:
            sites_sample = np.random.choice(sites_index, size, replace=False)

        rate = func(ts, popA_sample, popB_sample, bootstrap_time, 
                            bootstrap_size, size, fst, sites_sample, method)

        rate_list.append(rate)
        
    return rate_list

In [94]:
def simulation(bootstrap_time, bootstrap_size, sample_size, replicate_time, func, method):
    
    '''
    @replicate_times: times of replication of simulation
    @bootstrap_time: times of resampling over each simulation  
    @bootstrap_size: size of each resampling 
    @sample_size: a list of sample size
    @method: 0: resample with individuals, 1: resample with sites
    @func: func to use for resampling: bootstrap or jackknife block
    '''
    
    popA = ts.samples(population = 0)
    popB = ts.samples(population = 1)

    sites_index = np.arange(len(ts.sites()))

    fst = get_fst_general(ts, popA, popB, sites_index)
    print(f'Fst value is: {fst}')
    func_name = func.__name__
    
    if method == 0:
        print(f'{func_name.upper()} OVER INDIVIDUALS')
    else:
        print(f'{func_name.upper()} OVER SITES')
    
    coverage_rate = {}

    for size in sample_size:
        
        print(f'Sample size: {size}')
        
        coverage_rate[size] = replicate(replicate_time, ts, popA, popB, 
                                        bootstrap_time, bootstrap_size, size, 
                                        fst, sites_index, func, method) 
        
        print(f'Coverage rate for size {size} is: {coverage_rate[size]}: \n')

    return coverage_rate


In [99]:
ind_size = np.linspace(100, 1000, 10)
sites_size = np.linspace(1000, 10000, 10)

In [87]:
coverage_bt_ind = simulation(bootstrap_time=5, bootstrap_size=10, 
                          sample_size=ind_size, replicate_time=3, func=bootstrap, method = 0)

Fst value is: 0.026293477935966713
BOOTSTRAP OVER INDIVIDUALS
Sample size: 100
Replicate time: 0
(0.033919379328343546, 0.03886779740563828)
(0.03360105358982893, 0.0382605836227519)
(0.03445600175989992, 0.03745124527441967)
(0.0328447944283908, 0.037655550261859676)
(0.03403031320414443, 0.039816001075653605)
Coverage rate is:  0
Replicate time: 1
(0.03464580819780551, 0.041173764105206935)
(0.03498394746745327, 0.039412129123958)
(0.034812254135339206, 0.0413452691775905)
(0.03488300033148257, 0.03910047948433612)
(0.034110410404255456, 0.03905641810342584)
Coverage rate is:  0
Replicate time: 2
(0.03361931007640105, 0.040312111229052844)
(0.03278620032800271, 0.038786715796565036)
(0.031642011338138334, 0.0400632342860376)
(0.03256754356870734, 0.03728536071736181)
(0.035326917752235926, 0.03892552574936073)
Coverage rate is:  0
Coverage rate is: [0, 0, 0]


In [91]:
coverage_bt_sites = simulation(bootstrap_time=5, bootstrap_size=10, 
                          sample_size=sites_size, replicate_time=3, func=bootstrap, method = 1)

Fst value is: 0.026293477935966713
BOOTSTRAP OVER SITES
Sample size: 1000
Replicate time: 0
(0.026205112262437512, 0.030971963397772962)
(0.027068200212294106, 0.030443263546365784)
(0.027489093162296398, 0.030608283153173328)
(0.026583164012769132, 0.029980024185312436)
(0.026180745627258525, 0.030063523762967476)
Coverage rate is:  0.4
Replicate time: 1
(0.02445224861135497, 0.027362992779976246)
(0.024264453278104364, 0.028306466002544314)
(0.0246235002165971, 0.02785159368636903)
(0.022497749491481367, 0.027528083314013713)
(0.024005988857915824, 0.02718086126511402)
Coverage rate is:  1
Replicate time: 2
(0.0246228975340426, 0.028006411782524845)
(0.025561638514586732, 0.030779994516356286)
(0.024927942396050676, 0.02852912435720769)
(0.025281129256680815, 0.0300615214412814)
(0.025077824454483436, 0.030282822753862626)
Coverage rate is:  1
Coverage rate is: [0.4, 1, 1]


In [95]:
coverage_jk_ind = simulation(bootstrap_time=5, bootstrap_size=10, 
                          sample_size=ind_size, replicate_time=3, func=jackknife, method = 0)

Fst value is: 0.026293477935966713
JACKKNIFE OVER INDIVIDUALS
Sample size: 100
Replicate time: 0
(0.030104503889665547, 0.030163580318473365)
(0.03009373136660798, 0.03016832339107155)
(0.03012332179925057, 0.030170016317229563)
(0.0301115815620328, 0.030163587693588952)
(0.0300986304867933, 0.030175352034594795)
Coverage rate is: 0 

Replicate time: 1
(0.027544325244095183, 0.027596251778819005)
(0.02753792831717515, 0.027580880963223954)
(0.027541392847634807, 0.027591824761844012)
(0.02752610344405089, 0.02759148410175454)
(0.027551027310279412, 0.02760460078089473)
Coverage rate is: 0 

Replicate time: 2
(0.02518182574749256, 0.025233224361170908)
(0.025154683698643834, 0.025234480317407992)
(0.025161213630035414, 0.025242530547476577)
(0.025178987742050442, 0.025235581419799608)
(0.025155522706480013, 0.025219494043056283)
Coverage rate is: 0 

Coverage rate for 100 is: [0, 0, 0]: 



In [97]:
coverage_jk_sites = simulation(bootstrap_time=5, bootstrap_size=10, 
                          sample_size=sites_size, replicate_time=3, func=jackknife, method = 1)

Fst value is: 0.026293477935966713
JACKKNIFE OVER SITES
Sample size: 1000
Replicate time: 0
(0.02823807376275367, 0.02824253704477888)
(0.02823694579714109, 0.02824194903624372)
(0.028238493547759813, 0.028243994596409097)
(0.02823648839252952, 0.028242071042664484)
(0.028236367775076703, 0.028241348229662837)
Coverage rate is: 0 

Replicate time: 1
(0.02733669401905544, 0.027341938868323853)
(0.027333205703725726, 0.027340333137927324)
(0.02733635451570037, 0.027340131372190067)
(0.02733691120652617, 0.027340501942564807)
(0.027333507182971332, 0.02734082653914683)
Coverage rate is: 0 

Replicate time: 2
(0.025199382176720797, 0.025204368470762642)
(0.02519924839205187, 0.025204372476698992)
(0.025199390475653165, 0.02520392194533611)
(0.025200370000232855, 0.025204210945852124)
(0.025200573116200917, 0.025205158332127324)
Coverage rate is: 0 

Coverage rate for 1000 is: [0, 0, 0]: 



In [110]:
def build_df(data):
    df = pd.DataFrame.from_dict(coverage_ind, orient='index')
    df['mean'] = df.mean(axis=1)
    return df

In [111]:
build_df(coverage_bt_ind)

Unnamed: 0,0,1,2,mean
100,0,0,0,0.0
