In [None]:
import torch
from torch import nn
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi, SNLE, MNLE, SNRE, SNRE_A
from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble
from sbi.utils import BoxUniform
from sbi.utils import MultipleIndependent
from sbi.neural_nets.embedding_nets import PermutationInvariantEmbedding, FCEmbedding
from sbi.utils.user_input_checks import process_prior, process_simulator
from sbi.utils import get_density_thresholder, RestrictedPrior
from sbi.utils.get_nn_models import posterior_nn
import numpy as np
import moments
from matplotlib import pyplot as plt
import pickle
import os
import seaborn as sns
import datetime
import pandas as pd
import logging
import atexit
import torch.nn.functional as F
import subprocess
import sparselinear as sl
from sortedcontainers import SortedDict
from scipy.spatial import KDTree
import os
import re
from monarch_linear import MonarchLinear
import pdb
logging.getLogger('matplotlib').setLevel(logging.ERROR) # See: https://github.com/matplotlib/matplotlib/issues/14523
from collections import defaultdict
from sbi.analysis import pairplot

In [65]:
def ImportanceSamplingEstimator(sample, threshold, target=None, num_particles=None):
    """_summary_

    Args:
        sample (_type_): _description_
        target (_type_): _description_
        threshold (_type_): _description_

    Returns:
        _type_: _description_
    """    
    cdftest = target.q.transforms
    low_samples = sample - .0025
    high_samples = sample + .0025 
    num_particles = 1000 if num_particles is None else num_particles
    if target is not None:
        with torch.no_grad():
            for transform in cdftest[::-1]:
                value = transform.inv(high_samples)
            if target.q._validate_args:
                target.q.base_dist._validate_sample(value)
            value = target.q.base_dist.base_dist.cdf(value)
            #value = target.q._monotonize_cdf(value)
        with torch.no_grad():
            for transform in cdftest[::-1]:
                value2 = transform.inv(low_samples)
            if target.q._validate_args:
                target.q.base_dist._validate_sample(value2)
            value2 = target.q.base_dist.base_dist.cdf(value2)
            #value2 = target.q.base_dist._monotonize_cdf(value2)
        
        return value - value2
    else:
        sample_low = sample-threshold
        sample_high = sample+threshold
        proposal = torch.distributions.uniform.Uniform(sample_low, sample_high)
        prop_samps = proposal.sample((num_particles,))
        target_logprobs = target.log_prob(prop_samps)
        proposal_logprobs = proposal.log_prob(prop_samps)
        log_importance_weights = target_logprobs - proposal_logprobs

    ret = torch.sum(torch.exp(log_importance_weights))/num_particles


    return ret 

In [66]:
def generate_moments_sim_data(prior: float) -> torch.float32:
    
    global sample_size
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    #s_prior, weights = prior[:6], prior[6:]
    #s_prior, weights = prior[:5], prior[5:]
    s_prior, p_misid, weights = prior[:7], prior[7], prior[7:]
    fs_aggregate = None
    gammas = s_prior.cpu().numpy().squeeze()
    weights = weights.cpu().numpy().squeeze()
    p_misid = p_misid.cpu().numpy()
    for j, (gamma, weight) in enumerate(zip(gammas, weights)):
        while rerun:
            ns_sim = 2 * ns_sim
            fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
            fs = moments.Spectrum(fs)
            fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
            nu_func = lambda t: [opt_params[0] * np.exp(
                np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
            fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
            if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
                # large gamma-values can require large sample sizes for stability
                rerun = True
            else:
                rerun = False
        if j == 0:
            fs_aggregate = fs.project([projected_sample_size]).compressed()*theta_mis * weight
            fs_aggregate = (1 - p_misid)*fs_aggregate + p_misid * fs_aggregate[::-1]
        else:
            fs_aggregate += fs.project([projected_sample_size]).compressed()*theta_mis * weight
            fs_aggregate = (1 - p_misid)*fs_aggregate + p_misid * fs_aggregate[::-1]

    fs_aggregate = torch.tensor(fs_aggregate).type(torch.float32) 
    return fs_aggregate


In [205]:
def generate_moments_sim_data2(prior: float) -> torch.float32:
    
    global sample_size
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    fs_aggregate = None
    gammas = prior.cpu().numpy().squeeze()
   
    for j, gamma in enumerate(gammas):
        while rerun:
            ns_sim = 2 * ns_sim
            fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
            fs = moments.Spectrum(fs)
            fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
            nu_func = lambda t: [opt_params[0] * np.exp(
                np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
            fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
            if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
                # large gamma-values can require large sample sizes for stability
                rerun = True
            else:
                rerun = False
        if j == 0:
            fs_aggregate = fs.project([projected_sample_size]).compressed()*theta_mis
        else:
            fs2 = fs.project([projected_sample_size]).compressed()*theta_mis
            fs_aggregate += fs2
            del fs2
            
    fs_aggregate /= gammas.shape[0]
    fs_aggregate = torch.poisson(torch.tensor(fs_aggregate)).type(torch.float32) 
    return fs_aggregate


In [198]:
last_posterior = torch.load('Experiments/saved_posteriors_msl_mcf_6_and_psmid_params_2023-04-05_20-19/posterior_observed_last_round.pkl')
true_x = np.load('emperical_missense_sfs_msl.npy')



In [201]:
accept_reject_fn = get_density_thresholder(last_posterior, quantile=1e-5, num_samples_to_estimate_support=10000)
proposal = RestrictedPrior(last_posterior._prior, accept_reject_fn, last_posterior, sample_with="sir", device='cuda')

In [169]:
predicted_fs=[]
predicted_fs2=[]
sample_size = 85

In [202]:
obs_samples = proposal.sample((2000,))


In [203]:
obs_samples[:50,-1]

tensor([-2.3912,  6.7486, -5.8919, -7.5726,  0.8809, -7.7729,  2.0690,  1.5029,
         6.4421,  5.5723,  5.2281,  0.1243,  6.3237,  6.1314, -8.6591,  5.1889,
        -3.6221,  5.0881,  2.9969, -0.8614,  3.8105,  0.7566,  1.2305, -8.4798,
        -4.2378, -3.2068, -0.6270, -4.6329, -7.0010, -5.5889,  5.3062,  5.7948,
         5.4552, -5.6743,  4.7063, -4.5431, -0.7692,  4.1205,  2.2638,  5.6893,
         0.6888, -0.5392, -8.3251,  5.2313,  1.1143, -5.4134,  5.9483, -0.6425,
         2.4292, -3.7556], device='cuda:0')

In [204]:
print("Shape of sampled selection coefficients: {}".format(obs_samples.shape))

Shape of sampled selection coefficients: torch.Size([2000, 20])


In [206]:
for obs_sample in obs_samples:
    
    fs = generate_moments_sim_data2(obs_sample)
    predicted_fs.append(fs.unsqueeze(0).cpu().numpy())
    predicted_fs2.append(np.log10(fs[1:169:10].unsqueeze(0).cpu().numpy()))

In [210]:
new_predicted_fs = np.asarray(predicted_fs[2000:]).squeeze(1)
print("Shape of frequency spectrum containing all bins {}".format(new_predicted_fs.shape))

Shape of frequency spectrum containing all bins (2000, 169)


In [227]:
smaller_true_x = np.log10(true_x[:17])
#new_predicted_fs2 = np.asarray(predicted_fs2[2000:]).squeeze(1)
new_predicted_fs2 = np.log10(new_predicted_fs[:,:17])


In [212]:
idx = np.arange(1,169,10)
print("Shape of indices (bin) that were chosen to plot {}".format(idx.shape[0]))

Shape of indices (bin) that were chosen to plot 17


In [187]:
idx

array([  1,  11,  21,  31,  41,  51,  61,  71,  81,  91, 101, 111, 121,
       131, 141, 151, 161])

In [225]:
print("How many subplots to create per bin: {}".format(smaller_true_x.shape))
print("Shape of predicted SFS: {}".format(new_predicted_fs2.shape))

How many subplots to create per bin: (17,)
Shape of predicted SFS: (2000, 17)


In [235]:
plt.ioff()
plt.tight_layout()

fig = plt.figure(figsize=(14,10))
x_points = np.arange(1, smaller_true_x.shape[0])
fig.subplots_adjust(hspace=0.6, wspace=0.6)
for i in range(1, 8):
    ax = fig.add_subplot(3, 3, i)
    #sns.scatterplot(ax=ax, x=predicted_fs2[:,i-1], y=x_points[i-1])
    sns.histplot(new_predicted_fs2[:,i-1])
    plt.axvline(x=np.mean(new_predicted_fs2[:,i-1]), color='m', label="mean")
    plt.axvline(x=np.median(new_predicted_fs2[:,i-1]), color='k', label="median")
    ax.axline((smaller_true_x[i-1], 1), (smaller_true_x[i-1],100), marker='+', c='r', label="Emperical SFS")
    true_diff_predicted = 10**(np.mean(new_predicted_fs2[:,i-1])) - 10**(smaller_true_x[i-1])
    plt.text(x=((np.mean(new_predicted_fs2[:,i-1])+smaller_true_x[i-1])/2), y=50, s="{:.2f}.".format(true_diff_predicted))
    #ax.plot(smaller_true_x[i-1], x_points[i-1], markersize=7.0, c='r', marker='+')
    #plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(smaller_true_x[i-1], idx[i-1] ))
    plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(smaller_true_x[i-1],i))
fig.legend(["mean", "median", "Emperical SFS"], loc="lower center", ncol=4)
plt.savefig('ppc_check_hist_averaged_samples_is_20.png')
plt.close()



In [223]:
true_x[:10]

array([26296.,  8870.,  4960.,  3369.,  2599.,  2021.,  1563.,  1347.,
        1168.,  1057.])

In [222]:
new_predicted_fs[0,:10]

array([38460., 12776.,  6312.,  3560.,  2270.,  1564.,  1015.,   728.,
         538.,   348.], dtype=float32)