In [1]:
import torch
from torch import nn
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi, SNLE, MNLE, SNRE, SNRE_A
from sbi.utils.posterior_ensemble import NeuralPosteriorEnsemble
from sbi.utils import BoxUniform
from sbi.utils import MultipleIndependent
from sbi.neural_nets.embedding_nets import PermutationInvariantEmbedding, FCEmbedding
from sbi.utils.user_input_checks import process_prior, process_simulator
from sbi.utils import get_density_thresholder, RestrictedPrior
from sbi.utils.get_nn_models import posterior_nn
import numpy as np
import moments
from matplotlib import pyplot as plt
import pickle
import os
import seaborn as sns
import datetime
import pandas as pd
import logging
import atexit
import torch.nn.functional as F
import subprocess
import sparselinear as sl
from sortedcontainers import SortedDict
from scipy.spatial import KDTree
import os
import re
from monarch_linear import MonarchLinear
import pdb
logging.getLogger('matplotlib').setLevel(logging.ERROR) # See: https://github.com/matplotlib/matplotlib/issues/14523
from collections import defaultdict
from sbi.analysis import pairplot
from tqdm import tqdm




In [2]:
sample_size=360388
the_device='cuda'


In [22]:
min_freq = int(sample_size*2*(0.1/100))# cut_off frequency

In [3]:
class SummaryNet(nn.Module):
    def __init__(self, sample_size, block_sizes, dropout_rate=0.0):
        super().__init__()
        self.sample_size = sample_size # For monarch this needs to be divisible by the block size
        self.block_size = block_sizes
        self.linear4 = MonarchLinear(sample_size, int(sample_size / 10), nblocks=self.block_size[0]) # 11171
        self.linear5 = MonarchLinear(int(self.sample_size / 10), int(self.sample_size / 10) , nblocks=self.block_size[1]) # 11171
        self.linear6 = MonarchLinear(int(self.sample_size / 10), int(self.sample_size / 10), nblocks=self.block_size[2]) # 11171

        self.model = nn.Sequential(self.linear4, nn.Dropout(dropout_rate), nn.GELU(),
                                   self.linear5, nn.Dropout(dropout_rate), nn.GELU(),
                                   self.linear6) 
    def forward(self, x):
        
        x=self.model(x)
        return x

In [4]:
def get_sim_datafrom_hdf5(path_to_sim_file: str):
    """_summary_

    Args:
        path_to_sim_file (str): _description_
    """    
    #TODO probably will be better to use https://github.com/quantopian/warp_prism for faster look-up tables
    global loaded_file 
    global loaded_file_keys
    global loaded_tree
    import h5py
    loaded_file = h5py.File(path_to_sim_file, 'r')
    loaded_file_keys = list(loaded_file.keys())
    loaded_tree = KDTree(np.asarray(loaded_file_keys)[:,None]) # needs to have a column dimension

In [28]:
def aggregated_generate_sim_data(prior: float) -> torch.float32:

    data = np.zeros((sample_size*2-1))
    theprior = prior[:-1] # last dim is misidentification
    gammas = 10**(theprior.cpu().numpy().squeeze())

    scaling_theta=prior[-1].cpu().numpy()
    for a_prior in gammas:
        _, idx = loaded_tree.query(a_prior, k=(1,)) # the k sets number of neighbors, while we only want 1, we need to make sure it returns an array that can be indexed
        fs = loaded_file[loaded_file_keys[idx[0]]][:]
        fs = fs*(10**scaling_theta) # scale to gwas theta rate
        fs = (fs[:sample_size*2-1])
        data += fs
    data = data /(theprior.shape[0]-1)
    data = data.astype(int)
    return torch.log(torch.nn.functional.relu(torch.tensor(data)+1).type(torch.float32))

In [6]:
def load_true_data(a_path: str, type: int) -> torch.float32:
    """Loads a true SFS, note that the sample size must be consistent with the passed parameters

    Args:
        path (str): Where the true-SFS is located, must be a numpy array
        type (int): is data stored in numpy pickle (0) or torch pickle (1)

    Returns:
        Returns the SFS of the true data-set
    """
    if type == 0:
        sfs = np.load(a_path)
        sfs = torch.tensor(sfs, device=the_device).type(torch.float32)
    else:
        sfs = torch.load(a_path)
        sfs.to(the_device)
    assert sfs.shape[0] == sample_size*2-1, "Sample Size must be the same dimensions as the Site Frequency Spectrum, SFS shape: {} and sample shape (2*N-1): {}".format(sfs.shape[0], sample_size*2-1)

    return sfs 

In [7]:
get_sim_datafrom_hdf5('chr10_sim_genome_wide_mut_sfs_gwas.h5')

In [8]:
udist = torch.distributions.uniform.Uniform(-6.0*torch.ones(50),4.0*torch.ones(50))
uniform_samples = udist.sample((1000,))

In [9]:
last_posterior=torch.load('nfe_restriction_classifier_gwas_embedding_final.pkl')


In [10]:
true_x = (load_true_data('emperical_standiing_height_gwas.npy', 0)[:-1]).unsqueeze(0)


In [17]:
print("Shape of true data: {}".format(true_x.shape[1]))

Shape of true data: 720774


In [12]:
proposal = last_posterior.restrict_prior()

In [19]:
predicted_fs=[]

In [29]:
obs_samples = proposal.sample((100,), oversampling_factor=1024)[:,:-1]


In [30]:
torch.mean(obs_samples,dim=0)


tensor([-3.7436, -3.3859, -3.6393, -3.7399, -3.7507, -3.6475, -3.7763, -3.8871,
        -3.4344, -3.3926, -3.6163, -3.5240, -3.2402, -3.5826, -3.7637, -3.5524,
        -3.5065, -3.4306, -3.6372, -3.2210, -3.5641, -3.2965, -3.3962, -3.4016,
        -3.8603, -3.5941, -3.8820, -3.6374, -3.4875, -3.7401, -3.6473, -3.6426,
        -3.5212, -3.1718, -3.6219, -3.8104, -3.6296, -3.5811, -3.2859, -3.8620,
        -3.6544, -3.3300, -3.3908, -3.9823, -3.4256, -3.4680, -3.7673, -3.5709,
        -3.3899, -3.5671, -3.5878, -3.3960, -3.1534, -3.3696, -3.4476, -3.7206,
        -3.6791, -3.5400, -3.8363, -3.4467], device='cuda:0')

In [34]:
obs_samples[0,:-1]

tensor([-6.4785, -0.6062, -2.4141, -5.4499, -1.3839, -1.0286, -1.7420, -2.1192,
        -6.6727, -6.2917, -2.2872, -2.0298, -0.2443, -5.7555, -6.4988, -4.5858,
        -6.7222, -5.8016, -6.9031, -0.8002, -1.5555, -4.2450, -0.5596, -0.3094,
        -2.8221, -1.8723, -5.1864, -1.7985, -5.7647, -6.6136, -1.8481, -3.0345,
        -0.1278, -2.9326, -0.2739, -6.5146, -2.7501, -2.5380, -4.7748, -2.1977,
        -3.0237, -4.5051, -1.7391, -6.0884, -2.3353, -5.0147, -6.7540, -5.1095,
        -5.3367, -1.1296, -2.7863, -1.3614, -3.3874, -3.3689, -1.2930, -3.8117,
        -3.8608, -0.3258, -1.6060], device='cuda:0')

In [42]:
_, idx = loaded_tree.query(obs_samples[0,1].cpu().numpy(), k=(1,)) # the k sets number of neighbors, while we only want 1, we need to make sure it returns an array that can be indexed
fs = loaded_file[loaded_file_keys[idx[0]]][:]

In [43]:
fs[min_freq]

1060.08

In [31]:
print("Shape of sampled selection coefficients: {}".format(obs_samples.shape))

Shape of sampled selection coefficients: torch.Size([100, 60])


In [32]:
predicted_fs=[]

In [38]:
for obs_sample in tqdm(obs_samples):
    fs = aggregated_generate_sim_data(obs_sample)
    predicted_fs.append(fs.unsqueeze(0).cpu().numpy())


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00,  5.56it/s]


In [39]:
new_predicted_fs = np.asarray(predicted_fs).squeeze(1)
print("Shape of frequency spectrum containing all bins {}".format(new_predicted_fs.shape))

Shape of frequency spectrum containing all bins (200, 720775)


In [26]:
new_predicted_fs[0]

array([13.058207, 12.203549, 11.606325, ...,  0.      ,  0.      ,
        0.      ], dtype=float32)

In [40]:
mean_predicted = np.mean(new_predicted_fs[:],axis=0)
print(f"Mean of predicted SFS: {mean_predicted[min_freq:min_freq+10]}")
print(true_x[0, min_freq:min_freq+10])

Mean of predicted SFS: [1.302827  1.3019902 1.2986102 1.297986  1.297057  1.2960618 1.2949531
 1.2939458 1.2934831 1.2928555]
tensor([619., 674., 718., 692., 804., 772., 847., 868., 905., 903.],
       device='cuda:0')


In [196]:
plt.ioff()
plt.tight_layout()

fig = plt.figure(figsize=(30,30))
fig.subplots_adjust(hspace=0.6, wspace=0.6)
for i in range(1, 25):
    ax = fig.add_subplot(5, 5, i)
    sns.histplot(np.log10(new_predicted_fs[:,i-1]))
    plt.axvline(x=np.mean(np.log10(new_predicted_fs[:,i-1])), color='m', label="mean")
    plt.axvline(x=np.median(np.log10(new_predicted_fs[:,i-1])), color='k', label="median")
    ax.axline((np.log10(true_x[i-1]), 1), (np.log10(true_x[i-1]),100), marker='+', c='r', label="Emperical SFS")
    plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(np.log10(true_x[i-1]),i))
fig.legend(["mean", "median", "Emperical SFS"], loc="lower center", ncol=4)
plt.savefig('ppc_check_hist_mis_36_coefficients_round2_blur_2_hs.png')
plt.close()



In [197]:
plt.ioff()
plt.tight_layout()

fig = plt.figure(figsize=(30,30))
fig.subplots_adjust(hspace=0.6, wspace=0.6)
for i in range(1, 25):
    ax = fig.add_subplot(5, 5, i)
    sns.histplot(np.log10(new_predicted_fs2[:,i-1]))
    plt.axvline(x=np.mean(np.log10(new_predicted_fs2[:,i-1])), color='m', label="mean")
    plt.axvline(x=np.median(np.log10(new_predicted_fs2[:,i-1])), color='k', label="median")
    ax.axline((np.log10(true_x[i-1]), 1), (np.log10(true_x[i-1]),100), marker='+', c='r', label="Emperical SFS")
    plt.title("Emperical SFS: 10^{:.3f} at bin: {}".format(np.log10(true_x[i-1]),i))
fig.legend(["mean", "median", "Emperical SFS"], loc="lower center", ncol=4)
plt.savefig('ppc_check_hist_lof_36_coefficients_round2_blur_2_s.png')
plt.close()



In [20]:
true_x.shape

torch.Size([1, 111708])

In [None]:
# Test to create a normalized batch tensor of the predicted frequency spectrum
temp = torch.tensor(predicted_fs)
temp = temp.squeeze(1)

In [None]:
norm_predicted_fs = temp/temp.sum(dim=1).view(temp.shape[0],1)

In [None]:
print(temp[0,:10])
print(temp.sum(dim=1)[0])
print(norm_predicted_fs[0,:10])
# Test to create a normalized batch tensor of the predicted frequency spectrum

In [27]:
x = np.arange(0,mean_predicted[:200].shape[0])
print("Shape of frequency bins on x-axis for plotting {}".format(x.shape[0]))

Shape of frequency bins on x-axis for plotting 200


In [31]:
fig = plt.figure(figsize=(8,8))
sns.scatterplot(x=x,y=np.log10(mean_predicted[:200]+1), label="Predicted")
sns.scatterplot(x=x,y=np.log10(true_x.cpu().numpy()[0,:200]+1), label="Emperical")
plt.title("Mean of posterior predicted SFS vs True SFS")
plt.ylabel("Log Scaled Allele Frequency")
plt.xlabel("Frequency Bin")
#fig.legend(["Predicted", "Emperical"], loc="lower center", ncol=2)
plt.savefig('ppc_scatter_lof_nfe_round_3.png')
plt.close()


In [338]:
fig = plt.figure(figsize=(8,8))
sns.scatterplot(x=x,y=np.log10(mean_predicted2+1), label="Predicted")
sns.scatterplot(x=x,y=np.log10(true_x+1), label="Emperical")
plt.title("Mean of posterior predicted SFS vs True SFS")
plt.ylabel("Log Scaled Allele Frequency")
plt.xlabel("Frequency Bin")
#fig.legend(["Predicted", "Emperical"], loc="lower center", ncol=2)
plt.savefig('ppc_scatter_mis_36_coefficients_round1.png')
plt.close()


In [27]:
print(obs_samples.shape)
print(obs_samples.mean(dim=0))
obs_samples2 = obs_samples.reshape(-1)

samps = obs_samples2.cpu().numpy()

torch.Size([1000, 60])
tensor([-3.6042, -3.4862, -3.4878, -3.5170, -3.5543, -3.5059, -3.5199, -3.5383,
        -3.7101, -3.5415, -3.4439, -3.6128, -3.5215, -3.5429, -3.6729, -3.5535,
        -3.6120, -3.4460, -3.5710, -3.5401, -3.6356, -3.5594, -3.4464, -3.5448,
        -3.5348, -3.4697, -3.6557, -3.5755, -3.4994, -3.5083, -3.4280, -3.5558,
        -3.6247, -3.5190, -3.5291, -3.4545, -3.6503, -3.4407, -3.5591, -3.5766,
        -3.5873, -3.5375, -3.4949, -3.5990, -3.5165, -3.5074, -3.5106, -3.5948,
        -3.5323, -3.5583, -3.5479, -3.5093, -3.6229, -3.5799, -3.5042, -3.5697,
        -3.4173, -3.6066, -3.6029, -3.5630], device='cuda:0')


In [37]:
prior_samps =last_posterior._prior.sample((2000,)).view(-1).cpu().numpy()


In [39]:
plt.close()
fig = plt.figure(figsize=(8,8))
sns.kdeplot(samps, label="DFE", c='r')
sns.kdeplot(prior_samps, label="Initial Proposal", c='g')
plt.title("Kernel Density Estimation Inferred Scaled Selection")
plt.ylabel("Density")
plt.xlabel("Log of Absolute Scaled Selection Coefficient")
fig.legend(["DFE", "Initial Proposal"], loc="lower center", ncol=2)
plt.savefig('ppc_nfe_selectiion_round3.png')
plt.close()

    

In [133]:
dfe2= samps.squeeze()

In [134]:
bins = [-6, -5, -4, -3, -2, -1, 0, 4.0]
for s0, s1 in zip(bins[:-1], bins[1:]):
    the_dat=np.extract((s0 <= dfe2) & (dfe2 < s1), dfe2)
    prop = the_dat.shape[0]/obs_samples2.shape[0]
    print(f"{10**s0} <= |s| < {10**s1}: {prop:.7f}")
    if s1 == bins[-1]:
        the_dat=np.extract(dfe2 > s1, dfe2)
        prop = the_dat.shape[0]/500000.0
        print(f"|s| > {10**s1}: {prop:.7f}")

1e-06 <= |s| < 1e-05: 0.0000000
1e-05 <= |s| < 0.0001: 0.0000000
0.0001 <= |s| < 0.001: 0.0000000
0.001 <= |s| < 0.01: 0.3749375
0.01 <= |s| < 0.1: 0.0915938
0.1 <= |s| < 1: 0.0851250
1 <= |s| < 10000.0: 0.4483437
|s| > 10000.0: 0.0000000


In [135]:
# Loss between true and predicted in poisson log-liklihood
loss = -1*torch.nn.functional.poisson_nll_loss(torch.log(torch.tensor(mean_predicted2+1)), torch.log(torch.tensor(true_x+1)),log_input=True, full=False, reduction='sum' )

In [248]:
log_predicted = torch.log(torch.tensor(new_predicted_fs).unsqueeze(1))

In [249]:
log_target = torch.log(torch.tensor(true_x+1).repeat(log_predicted.shape[0],1).unsqueeze(1))

In [250]:
loss2 = -1*torch.nn.functional.poisson_nll_loss(log_predicted, log_target, reduction='mean')

In [213]:
print("poisson log-liklihood loss {} vs moments log-liklihood loss for lof {}".format(loss, moments_loss_lof))

poisson log-liklihood loss -4795.529159783006 vs moments log-liklihood loss for lof -237.79096099886482


In [252]:
print("poisson log-liklihood loss over batch {} vs moments log-liklihood loss for lof {}".format(loss2, moments_loss_lof))

poisson log-liklihood loss over batch -47.59079092120809 vs moments log-liklihood loss for lof -237.79096099886482


In [58]:
resid = (mean_predicted - true_x)/np.sqrt(mean_predicted)

In [59]:
fig = plt.figure(figsize=(8,8))
sns.scatterplot(x=x,y=resid, label="Residual")
plt.title("Poisson Residuals")
plt.ylabel("Residuals")
plt.xlabel("Frequency Bin")
#fig.legend(["Predicted", "Emperical"], loc="lower center", ncol=2)
plt.savefig('resid_lof_32_coefficients_round10.png')
plt.close()

In [138]:
log_target[0]

tensor([[6.5309, 5.0752, 4.4188, 4.0943, 3.6376, 3.1781, 3.1781, 2.6391, 2.5649,
         2.7726, 2.7081, 2.7081, 2.7081, 2.7726, 2.3026, 1.0986, 2.3026, 2.3979,
         1.6094, 1.9459, 1.3863, 1.6094, 1.3863, 1.9459, 1.9459, 1.7918, 1.3863,
         1.6094, 1.0986, 1.0986, 1.6094, 1.7918, 2.0794, 1.6094, 1.0986, 1.6094,
         2.0794, 0.6931, 1.0986, 1.3863, 1.3863, 0.6931, 1.0986, 0.6931, 0.0000,
         1.0986, 0.6931, 1.0986, 0.6931, 0.6931, 0.0000, 1.0986, 1.0986, 0.6931,
         1.0986, 1.0986, 0.6931, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6931,
         0.0000, 0.6931, 0.6931, 0.6931, 0.0000, 0.0000, 0.6931, 0.6931, 0.0000,
         0.6931, 0.6931, 1.3863, 0.6931, 0.6931, 0.0000, 0.6931, 0.6931, 0.0000,
         0.0000, 0.0000, 1.3863, 0.6931, 0.6931, 0.6931, 0.0000, 0.6931, 0.6931,
         1.0986, 0.0000, 0.6931, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0986,
         0.0000, 0.0000, 0.0000, 0.6931, 0.0000, 0.6931, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0