In [36]:
import functions
from dismal import blocking
import pyranges as pr
from collections import Counter
import numpy as np
import scipy
from joblib import Parallel, delayed
from sklearn.ensemble import RandomForestRegressor
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

## Real (simulated) data

In [2]:
blocks = pr.read_bed("../simulations/stdpopsim_simulations/dismal_ponabe/PonAbe_blocks_1932.bed")
blocks_df = blocks.df.rename(columns={"ThickStart": "Sample1", "ThickEnd": "Sample2", "ItemRGB": "NumSegSites"}).iloc[:, [0,1,2,6,7,8]].sample(50000)

  return {k: v for k, v in df.groupby(grpby_key)}


In [3]:
pop1_samples = ["Bornean_0", "Bornean_1", "Bornean_2"]
pop2_samples = ["Sumatran_3", "Sumatran_4", "Sumatran_5"]

s1_counter = Counter(blocks_df["NumSegSites"][(blocks_df["Sample1"].isin(pop1_samples)) & (blocks_df["Sample2"].isin(pop1_samples))])
s2_counter = Counter(blocks_df["NumSegSites"][(blocks_df["Sample1"].isin(pop2_samples)) & (blocks_df["Sample2"].isin(pop2_samples))])
s3_counter = Counter(blocks_df["NumSegSites"][(blocks_df["Sample1"].isin(pop1_samples)) & (blocks_df["Sample2"].isin(pop2_samples))])

S = np.array([functions.counter_to_arr(counter, 1932) for counter in [s1_counter, s2_counter, s3_counter]], dtype="int_")

In [4]:
num_blocks_per_state = S.sum(axis=1)
num_blocks_per_state

array([10019,  9986, 29995])

In [5]:
X_test = np.array([np.concatenate(S), ])

## Generate training set

In [12]:
n = 10_000

X_train, y_train = functions.generate_training_set(blocklen=1932,
                                                   mutation_rate=2e-8,
                                                   recombination_rate=1.5e-8,
                                                   num_blocks_per_state=num_blocks_per_state,
                                                   n=n, n_cpus=-1,
                                                   saveto=f"pongo_trainset_{n}.npz")

Generating training data of length 10000 of 50000 blocks each on 7 cores


## Fit & test RF

In [13]:
rf = RandomForestRegressor()
rf.fit(X_train, y_train)

In [14]:
y_pred = rf.predict(X_test)

In [15]:
y_pred

array([[18.44045902, 20.59966805,  4.41326209, 11.06900345, 19.19128785,
        20.56019107]])

In [16]:
functions.reparameterise(y_pred[0,0], y_pred[0,1], y_pred[0,2], y_pred[0,3], y_pred[0,4], y_pred[0, 5], 
                         blocklen=1939,
                         mutation_rate=2e-8)

(118878.66823283379,
 132798.27262463045,
 28450.63234594308,
 2631736.7773714797,
 0.000161435925661952,
 0.00015482272969187847)

## Check whole test set

In [21]:
test = np.load("pongo_trainset_1000.npz")
X_test = test["X_train"]
y_test = test["y_train"]

In [24]:
y_pred = rf.predict(X_test)

In [31]:
res_df = pd.DataFrame(y_pred/y_test)
res_df.columns = ["theta1", "theta2", "theta3", "tau", "M12", "M21"]
res_df

Unnamed: 0,theta1,theta2,theta3,tau,M12,M21
0,0.923495,0.803485,1.005440,0.876453,1.583530,0.580069
1,1.613518,0.570036,0.989028,0.515694,1.165076,0.565776
2,0.666090,0.730622,1.004018,1.559654,2.523739,2.293896
3,1.709185,4.115651,1.001698,0.635404,2.377720,39.526331
4,0.696642,0.903353,1.000757,5.353509,0.808529,16.293516
...,...,...,...,...,...,...
995,0.871774,1.405761,0.987056,8.287942,2.854298,1.947258
996,1.566191,0.731255,0.992323,1.876877,3.870714,3.091778
997,0.695997,0.615536,1.006314,3.360232,1.292154,0.606472
998,0.632655,0.559205,1.000736,8.504169,0.687229,0.615480


In [39]:
res_df.describe()

Unnamed: 0,theta1,theta2,theta3,tau,M12,M21
count,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
mean,3.347896,3.36527,0.999861,3.185849,40.194619,5.578272
std,14.052405,10.605588,0.008052,12.911109,1161.984438,52.880078
min,0.35773,0.388533,0.968197,0.301838,0.2517,0.402354
25%,0.665105,0.667886,0.995346,0.667775,0.645657,0.682335
50%,1.024031,1.036287,1.000112,1.007169,0.95472,1.038461
75%,2.007626,2.02821,1.004258,1.877448,2.002638,2.281014
max,270.53757,148.030739,1.047498,266.674395,36746.039662,1587.776739


In [38]:
sns.catplot(data=res_df, log_scale=True, alpha=0.1)
plt.axhline(y=1)



<matplotlib.lines.Line2D at 0x133678b90>



Error in callback <function flush_figures at 0x130f277e0> (for post_execute):


KeyboardInterrupt: 

If no good:
* Consider alternative parameterisation (Nes, generations)
* Consider if there are ways to narrow down prior - dxy estimate of split time?