In [5]:
import torch
from torch import nn
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi, SNLE, MNLE, SNRE, SNRE_A
from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble
from sbi.utils import BoxUniform
from sbi.utils import MultipleIndependent
from sbi.neural_nets.embedding_nets import PermutationInvariantEmbedding, FCEmbedding
from sbi.utils.user_input_checks import process_prior, process_simulator
from sbi.utils import get_density_thresholder, RestrictedPrior
from sbi.utils.get_nn_models import posterior_nn
import numpy as np
import moments
from matplotlib import pyplot as plt
import pickle
import os
import seaborn as sns
import datetime
import pandas as pd
import logging
import atexit
import torch.nn.functional as F
import subprocess
import sparselinear as sl
from sortedcontainers import SortedDict
from scipy.spatial import KDTree
import os
import re
from monarch_linear import MonarchLinear
import pdb
logging.getLogger('matplotlib').setLevel(logging.ERROR) # See: https://github.com/matplotlib/matplotlib/issues/14523
from collections import defaultdict
from sbi.analysis import pairplot
from tqdm import tqdm


In [11]:
sample_size=85

In [12]:
def generate_moments_sim_data2(prior: float) -> torch.float32:
    
    global sample_size
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    #s_prior, weights = prior[:6], prior[6:]
    #s_prior, weights = prior[:5], prior[5:]
    #s_prior, p_misid, weights = prior[:7], prior[7], prior[7:]
    fs_aggregate = None
    p_misid = 0.0021 #.0021 # lof missid # 0.0147 missense
    gammas = -1*10**(prior.cpu().numpy().squeeze())
    nu_func = lambda t: [opt_params[0] * np.exp(
                np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
    for j, gamma in enumerate(gammas):
        while rerun:
            #print(gamma, j)
            ns_sim = 2 * ns_sim
            fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
            fs = moments.Spectrum(fs)
            fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
            
            fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
            if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
                # large gamma-values can require large sample sizes for stability
                rerun = True
                del fs
            else:
                rerun = False
        if j == 0:
            fs_aggregate = fs.project([projected_sample_size]).compressed()*theta_lof
            #check_test_2.append(fs_aggregate)
        else:
            fs2 = fs.project([projected_sample_size]).compressed()*theta_lof
            fs_aggregate += fs2
            #check_test_2.append(fs2)
            del fs2
        rerun = True
        ns_sim = 100
    
    #check_fs_aggregate = np.copy(fs_aggregate)            
    fs_aggregate /= gammas.shape[0]
    #fs_aggregate = torch.poisson(torch.nn.functional.relu(torch.tensor(fs_aggregate))).type(torch.float32) 
    fs_aggregate = torch.nn.functional.relu(torch.tensor(fs_aggregate)).type(torch.float32) 
    #return gammas, check_test_2, check_fs_aggregate, fs_aggregate
    return fs_aggregate

In [7]:
def generate_momments(prior: float, sample_size) -> torch.float32:
    
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    gamma = -1*10**(prior)
    p_misid = 0 #.0021 # lof missid

    while rerun:
        ns_sim = 2 * ns_sim
        fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
        fs = moments.Spectrum(fs)
        fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
        nu_func = lambda t: [opt_params[0] * np.exp(
            np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
        fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
        if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
            # large gamma-values can require large sample sizes for stability
            rerun = True
            print("rerunning")
        else:
            rerun = False
        fs2 = fs.project([projected_sample_size]).compressed()*theta_lof
        fs2 = (1 - p_misid) * fs2 + p_misid * fs2[::-1]
        
    return fs2

In [8]:
# optimal parameters from moments, gamma
# shape: 0.3589
# scale: 7830.5
gdist = torch.distributions.gamma.Gamma(torch.tensor([0.3589]),torch.tensor([1/7830.5]))
gamma_samples = gdist.sample((1000,)).type(torch.int32)

In [9]:
udist = torch.distributions.uniform.Uniform(-6.0*torch.ones(50),4.0*torch.ones(50))
uniform_samples = udist.sample((1000,))

In [67]:
check_aggregate = generate_momments(4.8, 85)
print(check_aggregate)

rerunning
rerunning
[ 1.56807761e+000  1.92047862e-004  3.09283739e-008  5.49554045e-012
  1.01775767e-015  1.91188976e-019  3.58466579e-023  6.63316621e-027
  1.20074054e-030  2.11021947e-034  3.57503841e-038  5.79793904e-042
  8.93652579e-046  1.29896117e-049  1.76529847e-053  2.22108648e-057
  2.55753885e-061  2.65778165e-065  2.44945132e-069  1.95708051e-073
  1.31424473e-077  7.08975219e-082  2.85672160e-086  7.48897812e-091
  8.78183348e-096 -2.77437259e-101  5.10761874e-106 -1.73773963e-110
  8.68585657e-115 -5.75698542e-119  4.77243795e-123 -4.76409225e-127
  5.57626255e-131 -7.50304099e-135  1.14292966e-138 -1.94711855e-142
  3.67305161e-146 -7.60877730e-150  1.71869203e-153 -4.20774350e-157
  1.11066774e-160 -3.14627244e-164  9.52593434e-168 -3.07135209e-171
  1.05107499e-174 -3.80650779e-178  1.45489483e-181 -5.85425267e-185
  2.47432372e-188 -1.09615801e-191  5.08013000e-195 -2.45852271e-198
  1.24032867e-201 -6.51290841e-205  3.55421303e-208 -2.01295866e-211
  1.18161523e-

In [15]:
#last_posterior = torch.load('Experiments/saved_posteriors_msl_nsf_and_wass_params_2023-04-07_16-29/posterior_observed_round_10.pkl')
last_posterior=torch.load('Experiments/saved_posteriors_msl_lof_scf_sinkhorn_2023-04-10_17-07/posterior_observed_round_5.pkl')


In [16]:
#true_x = np.load('emperical_missense_sfs_msl.npy')
true_x = np.load('emperical_lof_sfs_msl.npy')


In [17]:
accept_reject_fn = get_density_thresholder(last_posterior, quantile=1e-5, num_samples_to_estimate_support=100000)
proposal = RestrictedPrior(last_posterior._prior, accept_reject_fn, last_posterior, sample_with="sir", device='cuda')

In [28]:
predicted_fs=[]
predicted_fs2=[]
sample_size = 85

In [29]:
obs_samples = proposal.sample((100,), oversampling_factor=1024)[:10,:-1]


In [344]:
obs_samples3 = obs_samples[:,:10]

In [30]:
torch.mean(obs_samples,dim=0)


tensor([ 3.9468, -3.1502,  0.4001, -0.5060, -5.3432,  3.9965,  3.7823,  3.9954,
         3.9926,  3.9964], device='cuda:0')

In [31]:
print("Shape of sampled selection coefficients: {}".format(obs_samples.shape))

Shape of sampled selection coefficients: torch.Size([10, 10])


In [32]:
for obs_sample in tqdm(obs_samples):
    fs = generate_moments_sim_data2(obs_sample)
    predicted_fs.append(fs.unsqueeze(0).cpu().numpy())


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [09:03<00:00, 54.35s/it]


In [33]:
new_predicted_fs = np.asarray(predicted_fs).squeeze(1)
print("Shape of frequency spectrum containing all bins {}".format(new_predicted_fs.shape))

Shape of frequency spectrum containing all bins (10, 169)


In [55]:
smaller_true_x = np.log10(true_x[1:169:10])


In [None]:
idx = np.arange(1,169,10)
print("Shape of indices (bin) that were chosen to plot {}".format(idx.shape[0]))
print("Showing idx for sanity check\n")
print(idx)
print("\n")

In [34]:
plt.ioff()
plt.tight_layout()

fig = plt.figure(figsize=(30,30))
fig.subplots_adjust(hspace=0.6, wspace=0.6)
for i in range(1, 25):
    ax = fig.add_subplot(5, 5, i)
    sns.histplot(np.log10(new_predicted_fs[:,i-1]))
    plt.axvline(x=np.mean(np.log10(new_predicted_fs[:,i-1])), color='m', label="mean")
    plt.axvline(x=np.median(np.log10(new_predicted_fs[:,i-1])), color='k', label="median")
    ax.axline((np.log10(true_x[i-1]), 1), (np.log10(true_x[i-1]),100), marker='+', c='r', label="Emperical SFS")
    plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(np.log10(true_x[i-1]),i))
fig.legend(["mean", "median", "Emperical SFS"], loc="lower center", ncol=4)
plt.savefig('ppc_check_hist_miss.png')
plt.close()



<Figure size 640x480 with 0 Axes>

In [39]:
mean_predicted = np.mean(new_predicted_fs[:],axis=0)

In [40]:
new_predicted_fs[:,0]

array([1250.1741 , 1029.977  , 1054.1991 , 1029.2208 , 1222.707  ,
       1118.6437 , 1030.3286 , 1029.6094 ,  730.56055, 1036.2322 ],
      dtype=float32)

In [337]:
# Test to create a normalized batch tensor of the predicted frequency spectrum
temp = torch.tensor(predicted_fs)
temp = temp.squeeze(1)

In [338]:
norm_predicted_fs = temp/temp.sum(dim=1).view(temp.shape[0],1)

In [339]:
print(temp[0,:10])
print(temp.sum(dim=1)[0])
print(norm_predicted_fs[0,:10])
# Test to create a normalized batch tensor of the predicted frequency spectrum

tensor([3370.9434, 1358.4443,  822.8662,  584.6456,  450.1974,  363.5256,
         302.8641,  258.0094,  223.5210,  196.2169])
tensor(11438.5957)
tensor([0.2947, 0.1188, 0.0719, 0.0511, 0.0394, 0.0318, 0.0265, 0.0226, 0.0195,
        0.0172])


In [340]:
mean_predicted

array([3374.2642   , 1361.4269   ,  825.75073  ,  587.5043   ,
        453.05035  ,  366.37555  ,  305.70734  ,  260.84024  ,
        226.33376  ,  199.00636  ,  176.86497  ,  158.59363  ,
        143.2868   ,  130.30013  ,  119.16239  ,  109.52099  ,
        101.10663  ,   93.71038  ,   87.1675   ,   81.34653  ,
         76.14109  ,   71.46456  ,   67.245316 ,   63.42396  ,
         59.950645 ,   56.783337 ,   53.886322 ,   51.229076 ,
         48.785355 ,   46.532536 ,   44.450882 ,   42.52328  ,
         40.73469  ,   39.07191  ,   37.52327  ,   36.07848  ,
         34.72836  ,   33.464733 ,   32.28032  ,   31.16855  ,
         30.123589 ,   29.140114 ,   28.213364 ,   27.339005 ,
         26.513128 ,   25.732172 ,   24.992893 ,   24.292343 ,
         23.627832 ,   22.996864 ,   22.39721  ,   21.826763 ,
         21.28363  ,   20.766052 ,   20.272406 ,   19.801207 ,
         19.35106  ,   18.920698 ,   18.508928 ,   18.114658 ,
         17.73687  ,   17.374615 ,   17.027006 ,   16.6

In [327]:
true_x

array([2.6296e+04, 8.8700e+03, 4.9600e+03, 3.3690e+03, 2.5990e+03,
       2.0210e+03, 1.5630e+03, 1.3470e+03, 1.1680e+03, 1.0570e+03,
       9.3100e+02, 8.1200e+02, 6.7700e+02, 7.0300e+02, 5.9200e+02,
       5.4800e+02, 5.3900e+02, 4.4700e+02, 4.5500e+02, 4.2300e+02,
       4.0100e+02, 3.7700e+02, 3.2300e+02, 2.9000e+02, 2.9800e+02,
       2.9500e+02, 2.6700e+02, 2.2500e+02, 2.2600e+02, 2.0100e+02,
       2.0800e+02, 2.2200e+02, 1.9500e+02, 2.0800e+02, 1.8600e+02,
       1.6900e+02, 1.9800e+02, 1.5100e+02, 1.6200e+02, 1.7700e+02,
       1.5000e+02, 1.4600e+02, 1.3900e+02, 1.5100e+02, 1.3100e+02,
       1.3600e+02, 1.2100e+02, 1.3700e+02, 1.1300e+02, 1.0800e+02,
       1.1000e+02, 1.1100e+02, 1.0100e+02, 8.8000e+01, 9.8000e+01,
       9.2000e+01, 9.7000e+01, 1.0700e+02, 9.9000e+01, 8.3000e+01,
       8.5000e+01, 8.8000e+01, 7.6000e+01, 7.4000e+01, 7.7000e+01,
       6.9000e+01, 7.1000e+01, 7.8000e+01, 6.6000e+01, 9.2000e+01,
       6.6000e+01, 7.4000e+01, 8.3000e+01, 6.6000e+01, 6.4000e

In [306]:
true_x2 = true_x/true_x.sum()
true_x2

array([3.63134201e-01, 1.22490126e-01, 6.84950424e-02, 4.65241528e-02,
       3.58908498e-02, 2.79089679e-02, 2.15842240e-02, 1.86013754e-02,
       1.61294777e-02, 1.45966250e-02, 1.28566299e-02, 1.12133013e-02,
       9.34902091e-03, 9.70806750e-03, 8.17521474e-03, 7.56759743e-03,
       7.44331207e-03, 6.17283951e-03, 6.28331538e-03, 5.84141188e-03,
       5.53760323e-03, 5.20617560e-03, 4.46046345e-03, 4.00475046e-03,
       4.11522634e-03, 4.07379788e-03, 3.68713232e-03, 3.10713398e-03,
       3.12094346e-03, 2.77570636e-03, 2.87237275e-03, 3.06570553e-03,
       2.69284945e-03, 2.87237275e-03, 2.56856409e-03, 2.33380286e-03,
       2.73427790e-03, 2.08523214e-03, 2.23713647e-03, 2.44427873e-03,
       2.07142265e-03, 2.01618472e-03, 1.91951833e-03, 2.08523214e-03,
       1.80904245e-03, 1.87808987e-03, 1.67094761e-03, 1.89189936e-03,
       1.56047173e-03, 1.49142431e-03, 1.51904328e-03, 1.53285276e-03,
       1.39475792e-03, 1.21523462e-03, 1.35332947e-03, 1.27047256e-03,
      

In [44]:
x = np.arange(0,mean_predicted.shape[0])

In [60]:
fig = plt.figure(figsize=(8,8))
sns.scatterplot(x=x,y=np.log10(mean_predicted+1), label="Predicted")
sns.scatterplot(x=x,y=np.log10(true_x+1), label="Emperical")
plt.title("Mean of posterior predicted SFS vs True SFS")
plt.ylabel("Log Scaled Allele Frequency")
plt.xlabel("Frequency Bin")
#fig.legend(["Predicted", "Emperical"], loc="lower center", ncol=2)
plt.savefig('ppc_scatter_lof.png')
plt.close()


In [53]:
obs_samples2 = obs_samples.reshape(-1)
obs_samples2.shape
samps = obs_samples2.cpu().numpy()


In [61]:
prior_samps =last_posterior._prior.sample((2000,)).view(-1).cpu().numpy()
gamma_samples2 = gdist.sample((1000,))

In [66]:
plt.close()
fig = plt.figure(figsize=(8,8))
sns.kdeplot(samps, label="DFE", c='r')
sns.kdeplot(prior_samps, label="Initial Proposal", c='g')
sns.kdeplot(torch.log10(gamma_samples2), label="Moments Proposal")
plt.title("Kernel Density Estimation Inferred Scaled Selection")
plt.ylabel("Density")
plt.xlabel("Log of Absolute Scaled Selection Coefficient")
fig.legend(["DFE", "Initial Proposal", "Moments Proposal"], loc="lower center", ncol=3)
plt.savefig('ppc_mmd_kde_selection.png')
plt.close()

    

In [58]:
dfe2= samps.squeeze()
dfe2

array([ 3.9935045, -3.4411573, -3.6136167,  1.3261223, -5.8830423,
        3.996296 ,  3.8651047,  3.9957   ,  3.9964046,  3.9964104,
        3.9939442, -5.0167375,  3.3824282, -4.251507 , -3.4186954,
        3.9966106,  3.6330109,  3.9960814,  3.9956932,  3.9964314,
        3.8496761, -2.8735144,  2.5188694, -0.6839266, -5.819459 ,
        3.9965496,  3.6210136,  3.9954052,  3.9890232,  3.9962568,
        3.9948769, -3.0523849,  3.3872833, -2.216776 , -3.1356847,
        3.9966087,  3.7913303,  3.9959002,  3.989542 ,  3.9963455,
        3.9962912, -5.0614786,  1.4560118, -2.8558784, -5.8096375,
        3.996582 ,  3.7620373,  3.9959173,  3.996502 ,  3.9964218,
        3.90096  , -2.3870935,  1.9460859, -1.6362019, -5.8646245,
        3.9965744,  3.7686205,  3.9951096,  3.9898424,  3.996314 ,
        3.7577372, -0.7037678,  3.2789488, -3.3734376, -5.840725 ,
        3.9965773,  3.816288 ,  3.9954224,  3.9873314,  3.9962091,
        3.9932766, -2.8975136, -5.788995 ,  3.3305435, -5.9017

In [59]:
bins = [-5, -4, -3, -2, -1, 0, 5]
for s0, s1 in zip(bins[:-1], bins[1:]):
    the_dat=np.extract((s0 <= dfe2) & (dfe2 < s1), dfe2)
    prop = the_dat.shape[0]/obs_samples2.shape[0]
    print(f"{10**s0} <= |s| < {10**s1}: {prop:.7f}")
    if s1 == bins[-1]:
        the_dat=np.extract(dfe2 > s1, dfe2)
        prop = the_dat.shape[0]/500000.0
        print(f"|s| > {10**s1}: {prop:.7f}")

1e-05 <= |s| < 0.0001: 0.0100000
0.0001 <= |s| < 0.001: 0.0700000
0.001 <= |s| < 0.01: 0.0600000
0.01 <= |s| < 0.1: 0.0100000
0.1 <= |s| < 1: 0.0200000
1 <= |s| < 100000: 0.7100000
|s| > 100000: 0.0000000


In [41]:
# Loss between true and predicted in poisson log-liklihood
loss = torch.nn.functional.poisson_nll_loss(torch.tensor(mean_predicted+1), torch.tensor(true_x+1),log_input=False, full=True)

In [42]:
print("poisson negative log-liklihood loss {}".format(loss))

poisson negative log-liklihood loss 7.422133771241665


In [326]:
print("poisson negative log-liklihood loss {}".format(loss))

poisson negative log-liklihood loss 3.217984279400709
