In [1]:
import torch
from torch import nn
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi, SNLE, MNLE, SNRE, SNRE_A
from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble
from sbi.utils import BoxUniform
from sbi.utils import MultipleIndependent
from sbi.neural_nets.embedding_nets import PermutationInvariantEmbedding, FCEmbedding
from sbi.utils.user_input_checks import process_prior, process_simulator
from sbi.utils import get_density_thresholder, RestrictedPrior
from sbi.utils.get_nn_models import posterior_nn
import numpy as np
import moments
from matplotlib import pyplot as plt
import pickle
import os
import seaborn as sns
import datetime
import pandas as pd
import logging
import atexit
import torch.nn.functional as F
import subprocess
import sparselinear as sl
from sortedcontainers import SortedDict
from scipy.spatial import KDTree
import os
import re
from monarch_linear import MonarchLinear
import pdb
logging.getLogger('matplotlib').setLevel(logging.ERROR) # See: https://github.com/matplotlib/matplotlib/issues/14523
from collections import defaultdict
from sbi.analysis import pairplot



In [65]:
def ImportanceSamplingEstimator(sample, threshold, target=None, num_particles=None):
    """_summary_

    Args:
        sample (_type_): _description_
        target (_type_): _description_
        threshold (_type_): _description_

    Returns:
        _type_: _description_
    """    
    cdftest = target.q.transforms
    low_samples = sample - .0025
    high_samples = sample + .0025 
    num_particles = 1000 if num_particles is None else num_particles
    if target is not None:
        with torch.no_grad():
            for transform in cdftest[::-1]:
                value = transform.inv(high_samples)
            if target.q._validate_args:
                target.q.base_dist._validate_sample(value)
            value = target.q.base_dist.base_dist.cdf(value)
            #value = target.q._monotonize_cdf(value)
        with torch.no_grad():
            for transform in cdftest[::-1]:
                value2 = transform.inv(low_samples)
            if target.q._validate_args:
                target.q.base_dist._validate_sample(value2)
            value2 = target.q.base_dist.base_dist.cdf(value2)
            #value2 = target.q.base_dist._monotonize_cdf(value2)
        
        return value - value2
    else:
        sample_low = sample-threshold
        sample_high = sample+threshold
        proposal = torch.distributions.uniform.Uniform(sample_low, sample_high)
        prop_samps = proposal.sample((num_particles,))
        target_logprobs = target.log_prob(prop_samps)
        proposal_logprobs = proposal.log_prob(prop_samps)
        log_importance_weights = target_logprobs - proposal_logprobs

    ret = torch.sum(torch.exp(log_importance_weights))/num_particles


    return ret 

In [66]:
def generate_moments_sim_data(prior: float) -> torch.float32:
    
    global sample_size
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    #s_prior, weights = prior[:6], prior[6:]
    #s_prior, weights = prior[:5], prior[5:]
    s_prior, p_misid, weights = prior[:7], prior[7], prior[7:]
    fs_aggregate = None
    gammas = s_prior.cpu().numpy().squeeze()
    weights = weights.cpu().numpy().squeeze()
    p_misid = p_misid.cpu().numpy()
    for j, (gamma, weight) in enumerate(zip(gammas, weights)):
        while rerun:
            ns_sim = 2 * ns_sim
            fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
            fs = moments.Spectrum(fs)
            fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
            nu_func = lambda t: [opt_params[0] * np.exp(
                np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
            fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
            if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
                # large gamma-values can require large sample sizes for stability
                rerun = True
            else:
                rerun = False
        if j == 0:
            fs_aggregate = fs.project([projected_sample_size]).compressed()*theta_mis * weight
            fs_aggregate = (1 - p_misid)*fs_aggregate + p_misid * fs_aggregate[::-1]
        else:
            fs_aggregate += fs.project([projected_sample_size]).compressed()*theta_mis * weight
            fs_aggregate = (1 - p_misid)*fs_aggregate + p_misid * fs_aggregate[::-1]

    fs_aggregate = torch.tensor(fs_aggregate).type(torch.float32) 
    return fs_aggregate


In [2]:
def generate_moments_sim_data2(prior: float) -> torch.float32:
    
    global sample_size
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    #s_prior, weights = prior[:6], prior[6:]
    #s_prior, weights = prior[:5], prior[5:]
    #s_prior, p_misid, weights = prior[:7], prior[7], prior[7:]
    fs_aggregate = None
    p_misid = 0.0137
    gammas = -1*10**(prior.cpu().numpy().squeeze())
   
    for j, gamma in enumerate(gammas):
        while rerun:
            ns_sim = 2 * ns_sim
            fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
            fs = moments.Spectrum(fs)
            fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
            nu_func = lambda t: [opt_params[0] * np.exp(
                np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
            fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
            if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
                # large gamma-values can require large sample sizes for stability
                rerun = True
            else:
                rerun = False
        if j == 0:
            fs2 = fs.project([projected_sample_size]).compressed()*theta_mis
            fs2 = (1 - p_misid) * fs2 + p_misid * fs2[::-1]
            fs_aggregate = fs2
        else:
            fs2 = fs.project([projected_sample_size]).compressed()*theta_mis
            fs2 = (1 - p_misid) * fs2 + p_misid * fs2[::-1]
            fs_aggregate += fs2
            del fs2
            
    fs_aggregate /= gammas.shape[0]
    fs_aggregate = torch.poisson(torch.tensor(fs_aggregate)).type(torch.float32) 
    return fs_aggregate

In [317]:
def generate_momments(prior: float, sample_size) -> torch.float32:
    
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    gammas = -1*10**(prior)

    while rerun:
        ns_sim = 2 * ns_sim
        fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
        fs = moments.Spectrum(fs)
        fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
        nu_func = lambda t: [opt_params[0] * np.exp(
            np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
        fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
        if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
            # large gamma-values can require large sample sizes for stability
            rerun = True
        else:
            rerun = False
        fs2 = fs.project([projected_sample_size]).compressed()*theta_mis


   
    return fs2

In [320]:
b = torch.poisson(torch.abs(torch.tensor(generate_momments(-3, 85)*15583)))

In [321]:
b

tensor([6.7659e+08, 2.6023e+08, 1.5035e+08, 1.0196e+08, 7.4956e+07, 5.7785e+07,
        4.5954e+07, 3.7372e+07, 3.0901e+07, 2.5886e+07, 2.1907e+07, 1.8718e+07,
        1.6108e+07, 1.3943e+07, 1.2142e+07, 1.0625e+07, 9.3428e+06, 8.2343e+06,
        7.2927e+06, 6.4832e+06, 5.7749e+06, 5.1637e+06, 4.6275e+06, 4.1499e+06,
        3.7357e+06, 3.3714e+06, 3.0467e+06, 2.7600e+06, 2.5021e+06, 2.2734e+06,
        2.0701e+06, 1.8896e+06, 1.7251e+06, 1.5772e+06, 1.4456e+06, 1.3229e+06,
        1.2177e+06, 1.1203e+06, 1.0316e+06, 9.5172e+05, 8.7733e+05, 8.1359e+05,
        7.5097e+05, 6.9594e+05, 6.4496e+05, 5.9884e+05, 5.5723e+05, 5.1825e+05,
        4.8411e+05, 4.5120e+05, 4.2000e+05, 3.9252e+05, 3.6800e+05, 3.4562e+05,
        3.2281e+05, 3.0260e+05, 2.8444e+05, 2.6742e+05, 2.5100e+05, 2.3663e+05,
        2.2238e+05, 2.1040e+05, 1.9849e+05, 1.8694e+05, 1.7688e+05, 1.6800e+05,
        1.5847e+05, 1.5046e+05, 1.4256e+05, 1.3514e+05, 1.2841e+05, 1.2195e+05,
        1.1569e+05, 1.1075e+05, 1.0537e+

In [295]:
c = torch.poisson(torch.abs(torch.tensor(generate_momments(-1, 85)*15583)))

In [296]:
c

tensor([4.4567e+04, 1.7949e+04, 1.0958e+04, 7.6070e+03, 5.8640e+03, 4.6710e+03,
        4.0230e+03, 3.3090e+03, 2.8310e+03, 2.5110e+03, 2.1770e+03, 1.9510e+03,
        1.7350e+03, 1.6050e+03, 1.4240e+03, 1.2950e+03, 1.1680e+03, 1.1090e+03,
        1.0190e+03, 9.4700e+02, 8.8800e+02, 8.1100e+02, 7.7900e+02, 7.1000e+02,
        6.5900e+02, 6.1300e+02, 6.3600e+02, 5.3400e+02, 5.0900e+02, 5.0100e+02,
        4.5300e+02, 4.4300e+02, 4.0300e+02, 3.8200e+02, 3.4400e+02, 3.6200e+02,
        3.4300e+02, 3.1700e+02, 3.4100e+02, 2.8800e+02, 3.1900e+02, 2.6900e+02,
        2.5500e+02, 2.6200e+02, 2.5100e+02, 2.1800e+02, 2.3800e+02, 2.3300e+02,
        2.0900e+02, 1.9300e+02, 1.9700e+02, 1.7000e+02, 1.9100e+02, 1.6700e+02,
        1.6700e+02, 1.5900e+02, 1.5600e+02, 1.4600e+02, 1.5300e+02, 1.0100e+02,
        1.1700e+02, 1.4200e+02, 1.1900e+02, 1.4600e+02, 1.2800e+02, 1.3000e+02,
        1.2000e+02, 1.2400e+02, 1.1000e+02, 1.3300e+02, 1.0600e+02, 1.0400e+02,
        1.1000e+02, 9.8000e+01, 1.0600e+

In [3]:
last_posterior = torch.load('Experiments/saved_posteriors_msl_mcf_6_and_psmid_params_2023-04-06_14-14/posterior_observed_round_15.pkl')



In [4]:
true_x = np.load('emperical_missense_sfs_msl.npy')


In [5]:
accept_reject_fn = get_density_thresholder(last_posterior, quantile=1e-5, num_samples_to_estimate_support=30000)
proposal = RestrictedPrior(last_posterior._prior, accept_reject_fn, last_posterior, sample_with="sir", device='cuda')

OutOfMemoryError: CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 11.90 GiB total capacity; 296.54 MiB already allocated; 78.12 MiB free; 382.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [350]:
predicted_fs=[]
predicted_fs2=[]
sample_size = 85

In [329]:
obs_samples = proposal.sample((2000,), oversampling_factor=1024)


In [330]:
torch.mean(obs_samples,dim=0)


tensor([-1.4829, -2.1209, -2.2635, -2.4465, -2.9487, -2.0098, -2.0398, -1.7250,
        -2.0853, -3.0351, -2.1298, -2.4436, -2.1012, -2.0673, -2.4945, -2.3914,
        -2.1178, -1.9946, -2.1168, -1.8758, -2.0457, -2.3790, -2.3992, -1.8551,
        -2.3266, -1.9339, -1.9661, -2.3549, -1.9196, -2.6118, -2.2146, -2.1941,
        -2.6707, -2.2191, -1.5458, -2.0106, -2.1676, -1.8533, -2.0701, -2.4434],
       device='cuda:0')

In [334]:
plt.ioff()
obs_samples2 = obs_samples.view(-1)
obs_samples2.shape

torch.Size([80000])

In [335]:
samps = obs_samples2.cpu().numpy()
sns.histplot(samps)
plt.savefig('posterior_selection_coef.png')
plt.close()

In [204]:
print("Shape of sampled selection coefficients: {}".format(obs_samples.shape))

Shape of sampled selection coefficients: torch.Size([2000, 20])


In [353]:
for obs_sample in obs_samples:
    
    fs = generate_moments_sim_data2(obs_sample)
    predicted_fs.append(fs.unsqueeze(0).cpu().numpy())
    predicted_fs2.append(np.log10(fs[1:169:10].unsqueeze(0).cpu().numpy()))

In [354]:
new_predicted_fs = np.asarray(predicted_fs).squeeze(1)
print("Shape of frequency spectrum containing all bins {}".format(new_predicted_fs.shape))

Shape of frequency spectrum containing all bins (2000, 169)


In [359]:
smaller_true_x = np.log10(true_x[1:169:10])


In [356]:
new_predicted_fs2 = np.log10(new_predicted_fs[:,1:169:10])


In [345]:
idx = np.arange(1,169,10)
print("Shape of indices (bin) that were chosen to plot {}".format(idx.shape[0]))

Shape of indices (bin) that were chosen to plot 17


In [187]:
idx

array([  1,  11,  21,  31,  41,  51,  61,  71,  81,  91, 101, 111, 121,
       131, 141, 151, 161])

In [357]:
print("How many subplots to create per bin: {}".format(smaller_true_x.shape))
print("Shape of predicted SFS: {}".format(new_predicted_fs2.shape))

How many subplots to create per bin: (17,)
Shape of predicted SFS: (2000, 17)


In [402]:
plt.ioff()
plt.tight_layout()

fig = plt.figure(figsize=(30,30))
x_points = np.arange(1, smaller_true_x.shape[0])
fig.subplots_adjust(hspace=0.6, wspace=0.6)
for i in range(1, 25):
    ax = fig.add_subplot(5, 5, i)
    #sns.scatterplot(ax=ax, x=predicted_fs2[:,i-1], y=x_points[i-1])
    sns.histplot(np.log10(new_predicted_fs[:,i-1]))
    plt.axvline(x=np.mean(np.log10(new_predicted_fs[:,i-1])), color='m', label="mean")
    plt.axvline(x=np.median(np.log10(new_predicted_fs[:,i-1])), color='k', label="median")
    ax.axline((np.log10(true_x[i-1]), 1), (np.log10(true_x[i-1]),100), marker='+', c='r', label="Emperical SFS")
    #true_diff_predicted = 10**(np.mean(new_predicted_fs2[:,i-1])) - 10**(smaller_true_x[i-1])
    #plt.text(x=((np.mean(new_predicted_fs2[:,i-1])+smaller_true_x[i-1])/2), y=50, s="{:.2f}.".format(true_diff_predicted))
    #ax.plot(smaller_true_x[i-1], x_points[i-1], markersize=7.0, c='r', marker='+')
    #plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(smaller_true_x[i-1], idx[i-1] ))
    plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(np.log10(true_x[i-1]),i))
fig.legend(["mean", "median", "Emperical SFS"], loc="lower center", ncol=4)
plt.savefig('ppc_check_hist_averaged_samples_is_40_different_bins.png')
plt.close()



In [380]:
mean_predicted = np.mean(new_predicted_fs[:],axis=0)

In [381]:
mean_predicted

array([44606.355 , 18036.723 , 10953.882 ,  7805.9033,  6024.93  ,
        4882.157 ,  4076.6145,  3481.797 ,  3023.556 ,  2660.1836,
        2366.924 ,  2124.216 ,  1919.974 ,  1746.943 ,  1598.359 ,
        1471.364 ,  1359.3855,  1260.101 ,  1171.34  ,  1093.393 ,
        1024.5284,   961.585 ,   905.8385,   853.722 ,   806.275 ,
         766.1625,   725.4145,   690.0025,   656.746 ,   627.951 ,
         600.7235,   572.4925,   549.3955,   527.0005,   506.0025,
         485.702 ,   468.388 ,   450.8835,   435.0095,   420.4755,
         406.683 ,   393.7765,   381.103 ,   368.8925,   356.7455,
         347.734 ,   336.9785,   327.3815,   318.447 ,   309.545 ,
         301.9875,   294.3455,   287.2055,   279.658 ,   273.996 ,
         266.9535,   260.7385,   256.083 ,   249.4475,   244.08  ,
         238.929 ,   234.353 ,   229.043 ,   224.8635,   220.698 ,
         216.174 ,   212.639 ,   208.8525,   205.0715,   201.8095,
         197.828 ,   194.6575,   191.204 ,   188.152 ,   185.1

In [376]:
x = np.arange(0,mean_predicted.shape[0])

In [397]:
fig = plt.figure(figsize=(4,4))
sns.scatterplot(x=x,y=np.log10(mean_predicted))
sns.scatterplot(x=x,y=np.log10(true_x))
plt.title("Mean of posterior predicted SFS vs True SFS")
plt.ylabel("Log Scaled Allele Frequency")
plt.xlabel("Frequency Bin")
plt.savefig('ppc_scatter.png')
plt.close()


In [390]:
obs_samples2 = obs_samples.view(-1)
obs_samples2.shape
samps = obs_samples2.cpu().numpy()


In [395]:
prior_samps =last_posterior._prior.sample((2000,)).view(-1).cpu().numpy()

In [396]:
plt.close()
fig = plt.figure(figsize=(8,8))
sns.kdeplot(samps)
sns.kdeplot(prior_samps)
plt.title("Kernel Density Estimation Inferred Scaled Selection")
plt.ylabel("Density")
plt.xlabel("Log of Absolute Scaled Selection Coefficient")
plt.savefig('ppc_kde_selection.png')
plt.close()

    