In [1]:
import pyranges as pr
import numpy as np
from collections import Counter
import functions
from dismal.models import im, gim, iso_two_epoch
from dismal import model_instance
from scipy import stats
import tqdm
from joblib import Parallel, delayed
import os
import itertools

In [2]:
# simulate simple IM model
modinst = model_instance.ModelInstance(param_vals=(3, 2, 4, 2, 0.1, 0.2), epochs=im().epochs)
true_mod = modinst.simulate(mutation_rate=1e-9, blocklen=500, recombination_rate=0, blocks_per_state=10_000)

In [3]:
true_mod.demography

id,name,description,initial_size,growth_rate,default_sampling_time,extra_metadata
0,pop1_be760db4f5,,1500000.0,0,0.0,{}
1,pop2_8204835a35,,1000000.0,0,0.0,{}
2,pop1_63258ea307,,2000000.0,0,6000000.0,{}

Unnamed: 0,pop1_be760db4f5,pop2_8204835a35,pop1_63258ea307
pop1_be760db4f5,0.0,1.667e-08,0
pop2_8204835a35,5e-08,0.0,0
pop1_63258ea307,0.0,0.0,0

time,type,parameters,effect
6000000.0,Population Split,"derived=[pop1_be760db4f5], ancestral=pop1_63258ea307","Moves all lineages from the 'pop1_be760db4f5' derived population to the ancestral 'pop1_63258ea307' population. Also set 'pop1_be760db4f5' to inactive, and all migration rates to and from the derived population to zero."
6000000.0,Population Split,"derived=[pop2_8204835a35], ancestral=pop1_63258ea307","Moves all lineages from the 'pop2_8204835a35' derived population to the ancestral 'pop1_63258ea307' population. Also set 'pop2_8204835a35' to inactive, and all migration rates to and from the derived population to zero."


In [4]:
def theta_prior(n):
    return stats.uniform.rvs(loc=1, scale=10, size=n)

def tau_prior(n):
    return stats.uniform.rvs(loc=1, scale=10, size=n)

def M_prior(n):
    return stats.uniform.rvs(loc=0, scale=2, size=n)

In [5]:
def prior_param_set(migration=True):
    prior = np.concatenate([theta_prior(3), tau_prior(1)])

    if migration is True:
        prior = np.concatenate([prior, M_prior(2)])
    else:
        prior = np.concatenate([prior, np.array([0,0])])

    return prior

In [9]:
def generate_embedding(): 


    mod = im()
    blocklen = 500

    params = prior_param_set(migration=True)

    sim = model_instance.ModelInstance(params, epochs=mod.epochs).simulate(mutation_rate=1e-9, 
                                                                        blocklen=blocklen, 
                                                                        blocks_per_state=10000, 
                                                                        recombination_rate=0)


    s1, s2, s3 = [np.pad(s_arr, (0, blocklen-len(s_arr))) for s_arr in [sim.s1, sim.s2, sim.s3]]
    
    try:
        mod.fit_model(s1, s2, s3, optimisers=["Nelder-Mead"])
        dismal_pred = mod.inferred_params
    except RuntimeError:
        dismal_pred = np.empty((6,))
        dismal_pred[:] = np.nan

    embedding = np.concatenate([s1, s2, s3])

    return params, dismal_pred, embedding

In [12]:
def make_training_data(n, save_as=None):
    num_sims_per_mod = n
    threads = 7

    res = []
    res.append(
        list(tqdm.tqdm(Parallel(n_jobs=threads, return_as="generator")(
                delayed(generate_embedding)() for _ in range(num_sims_per_mod)), total=num_sims_per_mod)))

    res = list(itertools.chain(*res))
            
    y_params = np.array([entry[0] for entry in res])
    X_dismal_pred = np.array([entry[1] for entry in res])
    X_embeddings = np.array([entry[2] for entry in res])

    if save_as is not None:
        np.savez(save_as, X_dismal_pred=X_dismal_pred, X_embeddings=X_embeddings, y_params=y_params)
            
    return X_embeddings, X_dismal_pred, y_params
    

In [14]:
make_training_data(1)

100%|██████████| 1/1 [01:40<00:00, 100.82s/it]


(array([[1148.,  909.,  829., ...,    0.,    0.,    0.]]),
 array([[9.65757543, 4.97556568, 5.34230084, 2.5819923 , 0.00005221,
         0.00017226]]),
 array([[1.1932046 , 7.47021684, 5.16639757, 4.37178454, 1.36229273,
         0.86275693]]))