# Reformat Model

Reformat the "best" model to intake matrices of beta, b, gamma. 

In [None]:
# # mount to drive and change directory
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/scBIVI_mc/scBIVI/scBIVI/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/scBIVI_mc/scBIVI/scBIVI


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
sys.path.append('../')
import matplotlib.pyplot as plt
from scipy import stats

## Define model 

In [None]:
class MLP(nn.Module):

    def __init__(self, input_dim, npdf, h1_dim, h2_dim):
        super().__init__()

        self.input = nn.Linear(input_dim, h1_dim)
        self.hidden = nn.Linear(h1_dim, h2_dim)
        self.output = nn.Linear(h2_dim, npdf)

        self.hyp = nn.Linear(h1_dim,1)

        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = torch.sigmoid
        

    def forward(self, inputs):

        # pass inputs to first layer, apply sigmoid
        l_1 = self.sigmoid(self.input(inputs))

        # pass to second layer, apply sigmoid
        l_2 = self.sigmoid(self.hidden(l_1))
        
        # pass to output layer 
        w_un = (self.output(l_2))
        
        # pass out hyperparameter, sigmoid so it is between 0 and 1, then scale between 1 and 6
        hyp = self.sigmoid(self.hyp(l_2))
    
        # apply softmax
        w_pred = self.softmax(w_un)

        return w_pred,hyp
        

model_path = './models/best_model_MODEL'       
npdf = 10

# load in model
model = MLP(7,10,256,256)
model.load_state_dict(torch.load(model_path))
model.eval() 
model.to(torch.device('cuda'))

RuntimeError: ignored

In [None]:
def get_NORM(npdf,quantiles='cheb'):
    '''' Returns quantiles based on the number of kernel functions npdf. 
    Chebyshev or linear, with chebyshev as default.
    '''
    if quantiles == 'lin':
        q = np.linspace(0,1,npdf+2)[1:-1]
        norm = stats.norm.ppf(q)
        norm = torch.tensor(norm)
        return norm
    if quantiles == 'cheb':
        n = np.arange(npdf)
        q = np.flip((np.cos((2*(n+1)-1)/(2*npdf)*np.pi)+1)/2)

        norm = stats.norm.ppf(q)
        norm = torch.tensor(norm)
        return norm

NORM = get_NORM(10)

In [None]:
def generate_grid(logmean_cond,logstd_cond,NORM):
    ''' Generate grid of kernel means based on the log mean and log standard devation of a conditional distribution.
    Generates the grid of quantile values in NORM, scaled by conditional moments.
    '''
    logmean_cond = torch.reshape(logmean_cond,(-1,1))
    logstd_cond = torch.reshape(logstd_cond,(-1,1))
    translin = torch.exp(torch.add(logmean_cond,logstd_cond*NORM))
    
    return translin

def get_ypred_at_RT(p,w,hyp,n,m,NORM,eps=1e-8):
    '''Given a parameter vector (tensor) and weights (tensor), and hyperparameter,
    calculates ypred (Y), or approximate probability. Calculates over array of nascent (n) and mature (m) values.
    '''
        
    p_vec = 10**p[:,0:3]
    logmean_cond = p[:,3]
    logstd_cond = p[:,4]
    
    hyp = hyp*5+1
        
    grid = generate_grid(logmean_cond,logstd_cond,NORM)
    s = torch.zeros((len(n),10))
    s[:,:-1] = torch.diff(grid,axis=1)
    s *= hyp
    s[:,-1] = torch.sqrt(grid[:,-1])
  
    
    v = s**2
    r = grid**2/(v-grid)
    p_nb = 1-grid/v
    
    Y = torch.zeros((len(n),1))

    # grid_i = grid[:,i].reshape((-1,1))

    # r = r[:,i]
    # w = w[:,i].reshape((-1,1))
    # p_nb = p_nb[:,i]


    y_ = m * torch.log(grid + eps) - grid - torch.lgamma(m+1) 

    if (p_nb > 1e-10).any():
      index = [p_nb > 1e-10]
      y_[index] += torch.special.gammaln(grid[index]+r[index]) - torch.special.gammaln(r[index]) \
                - grid[index]*torch.log(r[index] + grid[index]) + grid[index] \
                + r[index]*torch.log(r[index]/(r[index]+grid[index]))

    y_ = torch.exp(y_)
    y_weighted = w*y_
    Y = y_weighted.sum(axis=1)

    EPS = 1e-40
    Y[Y<EPS]=EPS
    return Y

In [None]:
def log_prob_nnNB(x: torch.Tensor, mu1: torch.Tensor, mu2: torch.Tensor,
                       theta: torch.Tensor,  THETA_IS, eps=1e-8, **kwargs):
    ''' Calculates probability for bursty model given our most accurate model.
      -----------------------------------
      x
        data
     mu1,mu2
        mean of the negative binomial (has to be positive support) (shape: minibatch x vars/2)
      theta
        params (has to be positive support) (shape: minibatch x vars)
      eps
        numerical stability constant
    '''
    # Divide the original data x into spliced (x) and unspliced (y)
    n,m = torch.chunk(x,2,dim=-1)

    if THETA_IS == 'MAT_SHAPE':
        gamma = 1/theta
        b = mu2*gamma
        beta = b/mu1
    elif THETA_IS == 'B':
        print('hasdakdhakjsd')
        b = theta
        beta = b/mu1
        gamma = b/mu2
    elif THETA_IS == 'NAS_SHAPE':
        beta = 1/theta
        b = mu1*beta
        gamma = b/mu2
    
    # calculate nascent marginal negative binomial P(n) 
    n_nb = 1/beta
    p_nb = 1/(b+1)
    prob_nascent = torch.tensor(stats.nbinom.pmf(k=n, n=n_nb, p=p_nb))
 
  
    # get moments
    var1 = mu1 * (1+b)
    var2 = mu2 * (1+b*beta/(beta+gamma))
    cov = b**2/(beta+gamma)
    
    # calculate conditional moments
    logvar1 = torch.log((var1/mu1**2)+1)
    logvar2 = torch.log((var2/mu2**2)+1)
    logstd1 = torch.sqrt(logvar1)
    logstd2 = torch.sqrt(logvar2)

    logmean1 = torch.log(mu1**2/torch.sqrt(var1+mu1**2))
    logmean2 = torch.log(mu2**2/torch.sqrt(var2+mu2**2))

    val = (logmean1 + logmean2 + (logvar1 + logvar2)/2)
    val[val<-88] = -88
    logcov = torch.log(cov * torch.exp(-(val)) +1 )
    logcorr = logcov/torch.sqrt(logvar1 * logvar2)

    logmean_cond = logmean2 + logcorr * logstd2/logstd1 * (torch.log(n+1) - logmean1)
    logvar_cond = logvar2 * (1-logcorr**2)  
    logstd_cond = logstd2 * torch.sqrt(1-logcorr**2)  

    xmax_m = torch.ceil(torch.ceil(mu2) + 4*torch.sqrt(var2))
    xmax_m = torch.clip(xmax_m,30,np.inf).int()

    # reshape and stack
    pv = torch.column_stack((torch.log10(b).reshape(-1),
                             torch.log10(beta).reshape(-1),
                             torch.log10(gamma).reshape(-1),
                             logmean_cond.reshape(-1),
                             logstd_cond.reshape(-1),
                             xmax_m.reshape(-1),
                             n.reshape(-1)
                             ))
    # run through model
    w_,hyp_= model(pv)

    n = n.reshape(-1,1)
    m = m.reshape(-1,1)
    # get conditional probabilites
    ypred_cond = get_ypred_at_RT(pv,w_,hyp_,n,m,NORM)
    
    # multiply conditionals P(m|n) by P(n)
    predicted = prob_nascent * ypred_cond.reshape((prob_nascent.shape))
    log_P = torch.log(predicted)
    
    return(predicted)

NameError: ignored