In [1]:
import msprime
import functions
import numpy as np
import scipy
import tqdm
from joblib import Parallel, delayed
import itertools
import os
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
import prototype
from collections import Counter
import pyranges as pr

In [2]:
# prototype.generate_reference_embeddings(blocklen=1932,
#                                         mutation_rate=2e-8, 
#                                         recombination_rate=1.5e-8,
#                                         num_blocks_per_state=[5000, 5000, 15000],
#                                         num_sims_per_mod=10_000,
#                                         threads=-1,
#                                         save_as="pongo_reference_25000")

In [3]:
blocks = pr.read_bed("../simulations/stdpopsim_simulations/dismal_ponabe/PonAbe_blocks_1932.bed")
blocks_df = blocks.df.rename(columns={"ThickStart": "Sample1", "ThickEnd": "Sample2", "ItemRGB": "NumSegSites"}).iloc[:, [0,1,2,6,7,8]]
pop1_samples = ["Bornean_0", "Bornean_1", "Bornean_2"]
pop2_samples = ["Sumatran_3", "Sumatran_4", "Sumatran_5"]


  return {k: v for k, v in df.groupby(grpby_key)}


In [4]:
X = []
for _ in range(100):    
    s1_subset = blocks_df["NumSegSites"][(blocks_df["Sample1"].isin(pop1_samples)) & (blocks_df["Sample2"].isin(pop1_samples))].sample(5000)
    s2_subset = blocks_df["NumSegSites"][(blocks_df["Sample1"].isin(pop2_samples)) & (blocks_df["Sample2"].isin(pop2_samples))].sample(5000)
    s3_subset = blocks_df["NumSegSites"][(blocks_df["Sample1"].isin(pop1_samples)) & (blocks_df["Sample2"].isin(pop2_samples))].sample(15000)

    s1_counter = Counter(s1_subset)
    s2_counter = Counter(s2_subset)
    s3_counter = Counter(s3_subset)

    S = np.array([functions.counter_to_arr(counter, 1932) for counter in [s1_counter, s2_counter, s3_counter]], dtype="int_")
    X.append(np.concatenate(S))

X_true = np.array(X)

In [5]:
X_true

array([[2476,  922,  552, ...,    0,    0,    0],
       [2520,  953,  572, ...,    0,    0,    0],
       [2529,  886,  557, ...,    0,    0,    0],
       ...,
       [2476,  970,  519, ...,    0,    0,    0],
       [2477,  955,  540, ...,    0,    0,    0],
       [2466,  923,  545, ...,    0,    0,    0]])

## Model classifier

In [6]:
npz = np.load("pongo_reference_25000.npz", allow_pickle=True)
X_sparse = npz["X"]
y_model = npz["y_model"]

In [7]:
X = np.vstack([arr.todense() for arr in X_sparse])

In [8]:
model_classifier = RandomForestClassifier(n_estimators=500)

In [9]:
model_classifier.fit(X, y_model)

In [14]:
model_classifier.predict(X_true)

array(['im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im', 'im',
       'im'], dtype='<U10')

### Parameter regressor

In [11]:
param_regressor = RandomForestRegressor(n_estimators=500)

In [12]:
y_params = npz["y_params"][npz["y_model"] == "im", :]
X_im = X[npz["y_model"] == "im"]
param_regressor.fit(X_im, y_params)

In [15]:
param_regressor.predict(X_true)

array([[2.99876972e+05, 6.10996161e+05, 9.95756048e+05, 9.11574776e+05,
        1.17508936e+01, 1.22523892e+01],
       [2.80417150e+05, 6.20504069e+05, 9.98037922e+05, 9.50977385e+05,
        1.20527555e+01, 1.21034243e+01],
       [3.54772938e+05, 5.74199063e+05, 9.16259216e+05, 9.72686314e+05,
        1.17624499e+01, 1.21382369e+01],
       [3.16549278e+05, 6.17539303e+05, 1.00780229e+06, 9.72592653e+05,
        1.18076325e+01, 1.21047718e+01],
       [3.63831923e+05, 6.20363283e+05, 9.81786393e+05, 1.00855178e+06,
        1.16570544e+01, 1.21242154e+01],
       [2.86094008e+05, 6.20542633e+05, 9.89259680e+05, 9.47261196e+05,
        1.18558183e+01, 1.20266461e+01],
       [2.78240536e+05, 6.14386457e+05, 9.85866940e+05, 9.23583939e+05,
        1.18339686e+01, 1.21760628e+01],
       [3.65896985e+05, 5.73917015e+05, 1.01072886e+06, 9.53162383e+05,
        1.16751597e+01, 1.19591333e+01],
       [3.49039177e+05, 5.88853842e+05, 9.75056225e+05, 1.00176223e+06,
        1.17373917e+01, 