In [1]:
import torch
from torch import nn
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi, SNLE, MNLE, SNRE, SNRE_A
from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble
from sbi.utils import BoxUniform
from sbi.utils import MultipleIndependent
from sbi.neural_nets.embedding_nets import PermutationInvariantEmbedding, FCEmbedding
from sbi.utils.user_input_checks import process_prior, process_simulator
from sbi.utils import get_density_thresholder, RestrictedPrior
from sbi.utils.get_nn_models import posterior_nn
import numpy as np
import moments
from matplotlib import pyplot as plt
import pickle
import os
import seaborn as sns
import datetime
import pandas as pd
import logging
import atexit
import torch.nn.functional as F
import subprocess
import sparselinear as sl
from sortedcontainers import SortedDict
from scipy.spatial import KDTree
import os
import re
from monarch_linear import MonarchLinear
import pdb
logging.getLogger('matplotlib').setLevel(logging.ERROR) # See: https://github.com/matplotlib/matplotlib/issues/14523
from collections import defaultdict
from sbi.analysis import pairplot
from tqdm import tqdm




In [2]:
sample_size=85
the_device='cuda'
moments_loss_lof = -237.79096099886482
moments_loss_mis = -1065.0798831623356

In [3]:
def get_sim_datafrom_hdf5(path_to_sim_file: str):
    """_summary_

    Args:
        path_to_sim_file (str): _description_
    """    
    #TODO probably will be better to use https://github.com/quantopian/warp_prism for faster look-up tables
    global loaded_file 
    global loaded_file_keys
    global loaded_tree
    import h5py
    loaded_file = h5py.File(path_to_sim_file, 'r')
    loaded_file_keys = list(loaded_file.keys())
    loaded_tree = KDTree(np.asarray(loaded_file_keys)[:,None]) # needs to have a column dimension

In [4]:
def generate_sim_data(prior: float) -> torch.float32:

    data = np.zeros((sample_size*2-1))
    #theprior = prior[:-1] # last dim is misidentification
    theprior=prior
    #mis_id = prior[-1].cpu().numpy()
    mis_id=0.0021
    for a_prior in theprior:
        _, idx = loaded_tree.query(a_prior.cpu().numpy(), k=(1,)) # the k sets number of neighbors, while we only want 1, we need to make sure it returns an array that can be indexed
        fs = loaded_file[loaded_file_keys[idx[0]]][:]*1164.3148344084038 #15583.437265450002  # lof scaling parameter
        fs = (1 - mis_id)*fs + mis_id * fs[::-1]
        data += fs 
    data /= theprior.shape[0]
    return torch.nn.functional.relu(torch.tensor(data, device=the_device)).type(torch.float32)

In [None]:
def generate_sim_data(prior: float) -> torch.float32:

    data = np.zeros((sample_size*2-1))
    theprior = prior[:-1] # last dim is misidentification
    #theprior=prior
    mis_id = prior[-1].cpu().numpy()
    #mis_id=0.0021
    for a_prior in theprior:
        _, idx = loaded_tree.query(a_prior.cpu().numpy(), k=(1,)) # the k sets number of neighbors, while we only want 1, we need to make sure it returns an array that can be indexed
        fs = loaded_file[loaded_file_keys[idx[0]]][:]*1164.3148344084038 #15583.437265450002  # lof scaling parameter
        fs = (1 - mis_id)*fs + mis_id * fs[::-1]
        data += fs 
    data /= theprior.shape[0]
    return torch.nn.functional.relu(torch.tensor(data, device=the_device)).type(torch.float32)

In [232]:
def generate_moments_sim_data2(prior: float) -> torch.float32:
    
    global sample_size
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    #s_prior, weights = prior[:6], prior[6:]
    #s_prior, weights = prior[:5], prior[5:]
    #s_prior, p_misid, weights = prior[:7], prior[7], prior[7:]
    fs_aggregate = None
    p_misid = 0.0021 #.0021 # lof missid # 0.0147 missense
    gammas = -1*10**(prior.cpu().numpy().squeeze())
    nu_func = lambda t: [opt_params[0] * np.exp(
                np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
    for j, gamma in enumerate(gammas):
        while rerun:
            #print(gamma, j)
            ns_sim = 2 * ns_sim
            fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
            fs = moments.Spectrum(fs)
            fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
            
            fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
            if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
                # large gamma-values can require large sample sizes for stability
                rerun = True
                del fs
            else:
                rerun = False
        if j == 0:
            fs_aggregate = fs.project([projected_sample_size]).compressed()*theta_lof
            #check_test_2.append(fs_aggregate)
        else:
            fs2 = fs.project([projected_sample_size]).compressed()*theta_lof
            fs_aggregate += fs2
            #check_test_2.append(fs2)
            del fs2
        rerun = True
        ns_sim = 100
    
    #check_fs_aggregate = np.copy(fs_aggregate)            
    fs_aggregate /= gammas.shape[0]
    #fs_aggregate = torch.poisson(torch.nn.functional.relu(torch.tensor(fs_aggregate))).type(torch.float32) 
    fs_aggregate = torch.nn.functional.relu(torch.tensor(fs_aggregate)).type(torch.float32) 
    #return gammas, check_test_2, check_fs_aggregate, fs_aggregate
    return fs_aggregate

In [28]:
def generate_momments(prior: float, sample_size) -> torch.float32:
    
    opt_params = [2.21531687, 5.29769918, 0.55450117, 0.04088086]
    theta_mis = 15583.437265450002
    theta_lof = 1164.3148344084038
    rerun = True
    ns_sim = 100
    h=0.5
    projected_sample_size = sample_size*2
    gamma = -1*10**(prior)
    p_misid = 0.0 #.0021 # lof missid

    while rerun:
        ns_sim = 2 * ns_sim
        fs = moments.LinearSystem_1D.steady_state_1D(ns_sim, gamma=gamma, h=h)
        fs = moments.Spectrum(fs)
        fs.integrate([opt_params[0]], opt_params[2], gamma=gamma, h=h)
        nu_func = lambda t: [opt_params[0] * np.exp(
            np.log(opt_params[1] / opt_params[0]) * t / opt_params[3])]
        fs.integrate(nu_func, opt_params[3], gamma=gamma, h=h)
        if abs(np.max(fs)) > 10 or np.any(np.isnan(fs)):
            # large gamma-values can require large sample sizes for stability
            rerun = True
            print("rerunning")
        else:
            rerun = False
        fs2 = fs.project([projected_sample_size]).compressed()*theta_lof
        fs2 = (1 - p_misid) * fs2 + p_misid * fs2[::-1]
        
    return fs2

In [7]:
# optimal parameters from moments, gamma for lof
# shape: 0.3589
# scale: 7830.5
gdist = torch.distributions.gamma.Gamma(torch.tensor([0.3589]),torch.tensor([1/7830.5]))
gamma_samples = gdist.sample((1000,)).type(torch.int32)

In [6]:
#optimal parameters (missense):
#shape: 0.4448
#scale: 82.4
gdist2 = torch.distributions.gamma.Gamma(torch.tensor([0.4448]),torch.tensor([1/82.4]))
gamma_samples = gdist.sample((1000,)).type(torch.int32)

In [None]:
udist = torch.distributions.uniform.Uniform(-6.0*torch.ones(50),4.0*torch.ones(50))
uniform_samples = udist.sample((1000,))

In [8]:
#last_posterior = torch.load('Experiments/saved_posteriors_msl_mis_scf_sinkhorn_36_sel_blur_2_2023-04-11_14-06/posterior_observed_round_2.pkl')
last_posterior=torch.load('Experiments/saved_posteriors_msl_lof_scf_sinkhorn_12_pmsid_and_optimizer_2023-04-17_11-50/posterior_observed_round_2.pkl')


In [9]:
#true_x = np.load('emperical_missense_sfs_msl.npy')
true_x = np.load('emperical_lof_sfs_msl.npy')


In [29]:
check_aggregate = generate_momments(2.2, 85)


rerunning
rerunning


In [30]:
print(f"Single Selection Coefficient: {check_aggregate[:20].astype(int)}")
print("True X: {}".format(true_x[:10]))

Single Selection Coefficient: [564  27   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0]
True X: [685. 159.  82.  59.  37.  23.  23.  13.  12.  15.]


In [10]:
accept_reject_fn = get_density_thresholder(last_posterior, quantile=1e-5, num_samples_to_estimate_support=100000)
proposal = RestrictedPrior(last_posterior._prior, accept_reject_fn, last_posterior, sample_with="sir", device='cuda')

In [11]:
predicted_fs=[]
predicted_fs2=[]

In [12]:
obs_samples = proposal.sample((1000,), oversampling_factor=1024)[:,:-1]


In [328]:
#with p_misid
obs_samples2 = proposal.sample((1000,), oversampling_factor=1024)


In [13]:
torch.mean(obs_samples,dim=0)


tensor([-5.1886, -2.9545, -0.9737, -5.9108,  3.8404, -5.8858, -1.5907, -3.3094,
        -2.3612, -5.9933, -2.2722], device='cuda:0')

In [14]:
print("Shape of sampled selection coefficients: {}".format(obs_samples.shape))

Shape of sampled selection coefficients: torch.Size([1000, 11])


In [None]:
# IF using non-cached data
for obs_sample in tqdm(obs_samples):
    fs = generate_moments_sim_data2(obs_sample)
    predicted_fs.append(fs.unsqueeze(0).cpu().numpy())


In [16]:
get_sim_datafrom_hdf5('moments_msl_sfs_lof_hdf5_data.h5')

In [17]:
predicted_fs=[]

In [18]:
for obs_sample in tqdm(obs_samples):
    fs = generate_sim_data(obs_sample)
    predicted_fs.append(fs.unsqueeze(0).cpu().numpy())


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 208.53it/s]


In [330]:
for obs_sample in tqdm(obs_samples2):
    fs = generate_sim_data2(obs_sample)
    predicted_fs2.append(fs.unsqueeze(0).cpu().numpy())


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 113.33it/s]


In [20]:
new_predicted_fs = np.asarray(predicted_fs).squeeze(1)
print("Shape of frequency spectrum containing all bins {}".format(new_predicted_fs.shape))

Shape of frequency spectrum containing all bins (1000, 169)


In [304]:
new_predicted_fs2 = np.asarray(predicted_fs2).squeeze(1)
print("Shape of frequency spectrum containing all bins {}".format(new_predicted_fs.shape))

Shape of frequency spectrum containing all bins (1000, 169)


In [None]:
smaller_true_x = np.log10(true_x[1:169:10])
idx = np.arange(1,169,10)
print("Shape of indices (bin) that were chosen to plot {}".format(idx.shape[0]))
print("Showing idx for sanity check\n")
print(idx)
print("\n")

In [196]:
plt.ioff()
plt.tight_layout()

fig = plt.figure(figsize=(30,30))
fig.subplots_adjust(hspace=0.6, wspace=0.6)
for i in range(1, 25):
    ax = fig.add_subplot(5, 5, i)
    sns.histplot(np.log10(new_predicted_fs[:,i-1]))
    plt.axvline(x=np.mean(np.log10(new_predicted_fs[:,i-1])), color='m', label="mean")
    plt.axvline(x=np.median(np.log10(new_predicted_fs[:,i-1])), color='k', label="median")
    ax.axline((np.log10(true_x[i-1]), 1), (np.log10(true_x[i-1]),100), marker='+', c='r', label="Emperical SFS")
    plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(np.log10(true_x[i-1]),i))
fig.legend(["mean", "median", "Emperical SFS"], loc="lower center", ncol=4)
plt.savefig('ppc_check_hist_mis_36_coefficients_round2_blur_2_hs.png')
plt.close()



In [197]:
plt.ioff()
plt.tight_layout()

fig = plt.figure(figsize=(30,30))
fig.subplots_adjust(hspace=0.6, wspace=0.6)
for i in range(1, 25):
    ax = fig.add_subplot(5, 5, i)
    sns.histplot(np.log10(new_predicted_fs2[:,i-1]))
    plt.axvline(x=np.mean(np.log10(new_predicted_fs2[:,i-1])), color='m', label="mean")
    plt.axvline(x=np.median(np.log10(new_predicted_fs2[:,i-1])), color='k', label="median")
    ax.axline((np.log10(true_x[i-1]), 1), (np.log10(true_x[i-1]),100), marker='+', c='r', label="Emperical SFS")
    plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(np.log10(true_x[i-1]),i))
fig.legend(["mean", "median", "Emperical SFS"], loc="lower center", ncol=4)
plt.savefig('ppc_check_hist_lof_36_coefficients_round2_blur_2_s.png')
plt.close()



In [21]:
mean_predicted = np.mean(new_predicted_fs[:],axis=0)
print(f"Without pmsid: {mean_predicted[:10]}")
print(true_x[:10])

Without pmsid: [2935.0466  1176.6456   713.8925   509.01065  393.62485  319.29532
  267.25986  228.74918  199.09727  175.5803 ]
[685. 159.  82.  59.  37.  23.  23.  13.  12.  15.]


In [336]:
mean_predicted2 = np.mean(new_predicted_fs2[:],axis=0)
print(f"With pmsid: {mean_predicted2[:10]}")
print(true_x[:10])

With pmsid: [1872.4374   743.6381   450.16882  320.6129   247.7485   200.84918
  168.03575  143.76195  125.07932  110.26723]
[685. 159.  82.  59.  37.  23.  23.  13.  12.  15.]


In [None]:
# Test to create a normalized batch tensor of the predicted frequency spectrum
temp = torch.tensor(predicted_fs)
temp = temp.squeeze(1)

In [None]:
norm_predicted_fs = temp/temp.sum(dim=1).view(temp.shape[0],1)

In [None]:
print(temp[0,:10])
print(temp.sum(dim=1)[0])
print(norm_predicted_fs[0,:10])
# Test to create a normalized batch tensor of the predicted frequency spectrum

In [23]:
x = np.arange(0,mean_predicted.shape[0])

In [24]:
fig = plt.figure(figsize=(8,8))
sns.scatterplot(x=x,y=np.log10(mean_predicted+1), label="Predicted")
sns.scatterplot(x=x,y=np.log10(true_x+1), label="Emperical")
plt.title("Mean of posterior predicted SFS vs True SFS")
plt.ylabel("Log Scaled Allele Frequency")
plt.xlabel("Frequency Bin")
#fig.legend(["Predicted", "Emperical"], loc="lower center", ncol=2)
plt.savefig('ppc_scatter_lof_12_coefficients_round5.png')
plt.close()


In [338]:
fig = plt.figure(figsize=(8,8))
sns.scatterplot(x=x,y=np.log10(mean_predicted2+1), label="Predicted")
sns.scatterplot(x=x,y=np.log10(true_x+1), label="Emperical")
plt.title("Mean of posterior predicted SFS vs True SFS")
plt.ylabel("Log Scaled Allele Frequency")
plt.xlabel("Frequency Bin")
#fig.legend(["Predicted", "Emperical"], loc="lower center", ncol=2)
plt.savefig('ppc_scatter_mis_36_coefficients_round1.png')
plt.close()


In [25]:
print(obs_samples.shape)
print(obs_samples.mean(dim=0))
obs_samples2 = obs_samples.reshape(-1)

samps = obs_samples2.cpu().numpy()

torch.Size([1000, 11])
tensor([-5.1886, -2.9545, -0.9737, -5.9108,  3.8404, -5.8858, -1.5907, -3.3094,
        -2.3612, -5.9933, -2.2722], device='cuda:0')


In [26]:
prior_samps =last_posterior._prior.sample((2000,)).view(-1).cpu().numpy()
gamma_samples2 = gdist.sample((1000,))

In [27]:
plt.close()
fig = plt.figure(figsize=(8,8))
sns.kdeplot(samps, label="DFE", c='r')
sns.kdeplot(prior_samps, label="Initial Proposal", c='g')
sns.kdeplot(torch.log10(gamma_samples2), label="Moments Proposal")
plt.title("Kernel Density Estimation Inferred Scaled Selection")
plt.ylabel("Density")
plt.xlabel("Log of Absolute Scaled Selection Coefficient")
fig.legend(["DFE", "Initial Proposal", "Moments Proposal"], loc="lower center", ncol=3)
plt.savefig('ppc_mmd_kde_selection_lof_12_coefficients_round5.png')
plt.close()

    

In [342]:
plt.close()
fig = plt.figure(figsize=(8,8))
sns.kdeplot(samps, label="DFE", c='r')
sns.kdeplot(prior_samps, label="Initial Proposal", c='g')
sns.kdeplot(torch.log10(gamma_samples2), label="Moments Proposal")
plt.title("Kernel Density Estimation Inferred Scaled Selection")
plt.ylabel("Density")
plt.xlabel("Log of Absolute Scaled Selection Coefficient")
fig.legend(["DFE", "Initial Proposal", "Moments Proposal"], loc="lower center", ncol=3)
plt.savefig('ppc_mmd_kde_selection_mis_36_coefficients_round2_blur_1_s.png')
plt.close()

In [133]:
dfe2= samps.squeeze()

In [134]:
bins = [-6, -5, -4, -3, -2, -1, 0, 4.0]
for s0, s1 in zip(bins[:-1], bins[1:]):
    the_dat=np.extract((s0 <= dfe2) & (dfe2 < s1), dfe2)
    prop = the_dat.shape[0]/obs_samples2.shape[0]
    print(f"{10**s0} <= |s| < {10**s1}: {prop:.7f}")
    if s1 == bins[-1]:
        the_dat=np.extract(dfe2 > s1, dfe2)
        prop = the_dat.shape[0]/500000.0
        print(f"|s| > {10**s1}: {prop:.7f}")

1e-06 <= |s| < 1e-05: 0.0000000
1e-05 <= |s| < 0.0001: 0.0000000
0.0001 <= |s| < 0.001: 0.0000000
0.001 <= |s| < 0.01: 0.3749375
0.01 <= |s| < 0.1: 0.0915938
0.1 <= |s| < 1: 0.0851250
1 <= |s| < 10000.0: 0.4483437
|s| > 10000.0: 0.0000000


In [135]:
# Loss between true and predicted in poisson log-liklihood
loss = -1*torch.nn.functional.poisson_nll_loss(torch.log(torch.tensor(mean_predicted2+1)), torch.log(torch.tensor(true_x+1)),log_input=True, full=False, reduction='sum' )

In [248]:
log_predicted = torch.log(torch.tensor(new_predicted_fs).unsqueeze(1))

In [249]:
log_target = torch.log(torch.tensor(true_x+1).repeat(log_predicted.shape[0],1).unsqueeze(1))

In [250]:
loss2 = -1*torch.nn.functional.poisson_nll_loss(log_predicted, log_target, reduction='mean')

In [213]:
print("poisson log-liklihood loss {} vs moments log-liklihood loss for lof {}".format(loss, moments_loss_lof))

poisson log-liklihood loss -4795.529159783006 vs moments log-liklihood loss for lof -237.79096099886482


In [252]:
print("poisson log-liklihood loss over batch {} vs moments log-liklihood loss for lof {}".format(loss2, moments_loss_lof))

poisson log-liklihood loss over batch -47.59079092120809 vs moments log-liklihood loss for lof -237.79096099886482


In [58]:
resid = (mean_predicted - true_x)/np.sqrt(mean_predicted)

In [59]:
fig = plt.figure(figsize=(8,8))
sns.scatterplot(x=x,y=resid, label="Residual")
plt.title("Poisson Residuals")
plt.ylabel("Residuals")
plt.xlabel("Frequency Bin")
#fig.legend(["Predicted", "Emperical"], loc="lower center", ncol=2)
plt.savefig('resid_lof_32_coefficients_round10.png')
plt.close()

In [138]:
log_target[0]

tensor([[6.5309, 5.0752, 4.4188, 4.0943, 3.6376, 3.1781, 3.1781, 2.6391, 2.5649,
         2.7726, 2.7081, 2.7081, 2.7081, 2.7726, 2.3026, 1.0986, 2.3026, 2.3979,
         1.6094, 1.9459, 1.3863, 1.6094, 1.3863, 1.9459, 1.9459, 1.7918, 1.3863,
         1.6094, 1.0986, 1.0986, 1.6094, 1.7918, 2.0794, 1.6094, 1.0986, 1.6094,
         2.0794, 0.6931, 1.0986, 1.3863, 1.3863, 0.6931, 1.0986, 0.6931, 0.0000,
         1.0986, 0.6931, 1.0986, 0.6931, 0.6931, 0.0000, 1.0986, 1.0986, 0.6931,
         1.0986, 1.0986, 0.6931, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6931,
         0.0000, 0.6931, 0.6931, 0.6931, 0.0000, 0.0000, 0.6931, 0.6931, 0.0000,
         0.6931, 0.6931, 1.3863, 0.6931, 0.6931, 0.0000, 0.6931, 0.6931, 0.0000,
         0.0000, 0.0000, 1.3863, 0.6931, 0.6931, 0.6931, 0.0000, 0.6931, 0.6931,
         1.0986, 0.0000, 0.6931, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0986,
         0.0000, 0.0000, 0.0000, 0.6931, 0.0000, 0.6931, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0