In [1]:
import torch
import torch.nn as nn
from typing_extensions import Literal
import torch.nn.functional as F
from typing import Optional
import numpy as np
from scipy.spatial.distance import cdist
from scvi.nn import Encoder, FCLayers
import scvelo as scv
import scanpy as sc
import pandas as pd
import sctour as rgv

  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: '/Users/weixu.wang/miniconda3/envs/RegVelo/lib/python3.10/site-packages/torchvision/image.so'
  warn(
Global seed set to 0


In [2]:
scv.set_figure_params()

In [3]:
adata = sc.read_h5ad("dataset_branch_v2.h5ad")

In [4]:
adata

AnnData object with n_obs × n_vars = 1000 × 280
    obs: 'step_ix', 'simulation_i', 'sim_time'
    var: 'module_id', 'basal', 'burn', 'independence', 'color', 'is_tf', 'is_hk', 'transcription_rate', 'splicing_rate', 'translation_rate', 'mrna_halflife', 'protein_halflife', 'mrna_decay_rate', 'protein_decay_rate', 'max_premrna', 'max_mrna', 'max_protein', 'mol_premrna', 'mol_mrna', 'mol_protein'
    uns: 'network', 'regulators', 'skeleton', 'targets', 'traj_dimred_segments', 'traj_milestone_network', 'traj_progressions'
    obsm: 'dimred'
    layers: 'counts_protein', 'counts_spliced', 'counts_unspliced', 'logcounts', 'rna_velocity'

In [5]:
def sanity_check(adata):
    reg_index = [i in adata.var.index.values for i in adata.uns["regulators"]]
    tar_index = [i in adata.var.index.values for i in adata.uns["targets"]]
    adata.uns["regulators"] = adata.uns["regulators"][reg_index]
    adata.uns["targets"] = adata.uns["targets"][tar_index]
    W = adata.uns["skeleton"]
    W = W[reg_index,:]
    W = W[:,tar_index]
    adata.uns["skeleton"] = W
    W = adata.uns["network"]
    W = W[reg_index,:]
    W = W[:,tar_index]
    adata.uns["network"] = W
    
    regulators = adata.uns["regulators"][adata.uns["skeleton"].sum(1) > 0]
    targets = adata.uns["targets"][adata.uns["skeleton"].sum(0) > 0]

    W = pd.DataFrame(adata.uns["skeleton"],index = adata.uns["regulators"],columns = adata.uns["targets"])
    W = W.loc[regulators,targets]
    adata.uns["skeleton"] = W
    W = pd.DataFrame(adata.uns["network"],index = adata.uns["regulators"],columns = adata.uns["targets"])
    W = W.loc[regulators,targets]
    adata.uns["network"] = W
    
    adata.uns["regulators"] = regulators
    adata.uns["targets"] = targets
    
    adata = adata[:,np.unique(adata.uns["regulators"].tolist()+adata.uns["targets"].tolist())].copy()
    
    return adata

In [6]:
adata.uns["skeleton"].shape

(280, 280)

In [7]:
adata.X = adata.X.copy()
adata.layers["spliced"] = adata.layers["counts_spliced"].copy()
adata.layers["unspliced"] = adata.layers["counts_unspliced"].copy()

In [8]:
scv.pp.filter_and_normalize(adata, min_shared_counts=5, n_top_genes=280)
scv.pp.moments(adata, n_pcs=30, n_neighbors=30)

Filtered out 14 genes that are detected 5 counts (shared).
Normalized count data: X, spliced, unspliced.
Skip filtering by dispersion since number of variables are less than `n_top_genes`.
computing neighbors
    finished (0:00:02) --> added 
    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)


In [9]:
adata.layers["Mu"].max()

1086.6066

In [10]:
adata.X = np.log1p(adata.X.copy())

In [11]:
adata = sanity_check(adata)

In [12]:
adata.uns["skeleton"].shape

(135, 250)

In [13]:
adata.layers["Ms"].max()

39747.324

In [14]:
W = adata.uns["skeleton"].copy()

In [15]:
import torch
W = torch.tensor(np.array(W)).int()

In [16]:
W.shape

torch.Size([135, 250])

In [17]:
rgv_m = rgv.train.Trainer(adata, W=W.T,early_stopping = False, nepoch = 200, solver = "AdaBelief", lr = 0.01,wt_decay=0.01,T_max=300,batch_size=128,grad_clip = 5,alpha_recon_reg = 1)

In [210]:
rgv_m.train()

[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m
Weight decoupling enabled in AdaBelief


Epoch 200: 100%|████████████████████████████████| 200/200 [12:19<00:00,  3.70s/epoch, train_loss=1.83e+7, val_loss=1.62e+7]


In [18]:
class alpha_encoder(nn.Module):
    """ 
    encode the time dependent alpha (f)
    time dependent transcription rate is determined by upstream emulator

    """                 
    def __init__(
        self,
        n_int: int = 5,
        alpha_unconstr_init: torch.Tensor = None,
        W: torch.Tensor = (torch.FloatTensor(5, 5).uniform_() > 0.5).int(),
        W_int: torch.Tensor = None,
        log_h_int: torch.Tensor = None,
        global_time: bool = False,
    ):
        device = W.device
        super().__init__()
        self.n_int = n_int
        if global_time:
            self.log_h = torch.nn.Parameter(torch.randn(W.shape[1]))
            self.log_phi = torch.nn.Parameter(torch.randn(W.shape[1]))
            self.tau = torch.nn.Parameter(torch.randn(W.shape[1]))
            self.o = torch.nn.Parameter(torch.randn(W.shape[1]))
        else:
            self.log_h = torch.nn.Parameter(log_h_int.repeat(W.shape[0],1)*W)
            self.log_phi = torch.nn.Parameter(torch.ones(W.shape).to(device)*W)
            self.tau = torch.nn.Parameter(torch.ones(W.shape).to(device)*W*10)
            self.o = torch.nn.Parameter(torch.ones(W.shape).to(device)*W)

        self.mask_m = W
        self.global_time = global_time

        ## initialize grn
        self.grn = torch.nn.Parameter(W_int*self.mask_m)
        
        ## initilize gamma and beta
        self.beta_mean_unconstr = torch.nn.Parameter(0.5 * torch.ones(n_int))
        self.gamma_mean_unconstr = torch.nn.Parameter(-1 * torch.ones(n_int))
        self.alpha_unconstr_bias = torch.nn.Parameter(torch.zeros(n_int))
        self.alpha_unconstr_max = torch.nn.Parameter(torch.randn(n_int))
        # calculating emulating matrix
    
    ### define hook to froze the parameters
    def _set_mask_grad(self):
        self.hooks_grn = []
        if not self.global_time:
            self.hooks_log_h = []
            self.hooks_log_phi = []
            self.hooks_tau = []
            self.hooks_o = []
        #mask_m = self.mask_m
        
        def _hook_mask_no_regulator(grad):
            return grad * self.mask_m

        w_grn = self.grn.register_hook(_hook_mask_no_regulator)
        self.hooks_grn.append(w_grn)
        if not self.global_time:
            w_log_h = self.log_h.register_hook(_hook_mask_no_regulator)
            w_log_phi = self.log_phi.register_hook(_hook_mask_no_regulator)
            w_tau = self.tau.register_hook(_hook_mask_no_regulator)
            w_o = self.o.register_hook(_hook_mask_no_regulator)

            self.hooks_log_h.append(w_log_h)
            self.hooks_log_phi.append(w_log_phi)
            self.hooks_tau.append(w_tau)
            self.hooks_o.append(w_o)

    def emulator(t,log_h_v,log_phi_v,tau_v,o_v):
        pre = torch.exp(log_h_v)*torch.exp(-torch.exp(phi_v)*(t-tau_v)**2)+o_v
        
        return pre

    def emulation_all(self,t: torch.Tensor = None):
        if self.global_time:
            # broadcasting the time t
            t = t.repeat((self.mask_m.shape[0],1))

        emulate_m = torch.zeros([self.mask_m.shape[0], self.mask_m.shape[1], t.shape[1]])

        h = torch.exp(self.log_h)
        phi = torch.exp(self.log_phi)
        for i in range(t.shape[1]):
            # for each time stamps, predict the emulator predict value
            tt = t[:,i]
            emu = h * torch.exp(-phi*(tt.reshape((len(tt),1))-self.tau)**2) + self.o
            emulate_m[:,:,i] = emu

        return emulate_m


    def forward(self,t,g):
        ## Encode 

        if self.global_time:
            u = u[locate]
            s = s[locate]
            ## when use global time, t is a single value
            T = t.repeat((dim,1))

            ## calculate emulator vector
            h = torch.exp(self.log_h)
            phi = torch.exp(self.log_phi)
            emu = h[locate,:] * torch.exp(-phi[locate,:]*(T.reshape((dim,1))-self.tau[locate,:])**2) + self.o[locate,:]
            
            ## Use the Emulator matrix to predict alpha
            emu = emu * self.grn[locate,:]
            alpha_unconstr = emu.sum(dim=1)
            alpha_unconstr = alpha_unconstr + self.alpha_unconstr[locate]
            
            ## Generate kinetic rate
            beta = torch.clamp(F.softplus(self.beta_mean_unconstr[locate]), 0, 50)
            gamma = torch.clamp(F.softplus(self.gamma_mean_unconstr[locate]), 0, 50)
            alpha = torch.clamp(F.softplus(alpha_unconstr),0,50)

            ## Predict velocity
            du = alpha - beta*u
            ds = beta*u - gamma*s

            du = du.reshape((dim,1))
            ds = ds.reshape((dim,1))

            v = torch.concatenate([du,ds],axis = 1)

        else:
            ## calculate emulator value
            ## output the f_g(t)
            
            ## Build Emulator matrix for gene g
            
            h = torch.exp(self.log_h)[g,:].view(-1)
            phi = torch.exp(self.log_phi)[g,:].view(-1)
            tau = self.tau[g,:].view(-1)
            o = self.o[g,:].view(-1)
            w = self.grn[g,:].view(-1)
            bias = self.alpha_unconstr_bias[g]

            #emu = h[locate,:] * torch.exp(-phi[locate,:]*(T.reshape((dim,1))-self.tau[locate,:])**2) + self.o[locate,:]
            emu = h * torch.exp(-phi*(t - tau)**2) + o

            ## Use the Emulator matrix to predict alpha
            #emu = emu * self.grn[locate,:]
            emu = emu * w
            
            alpha_unconstr = emu.sum()
            #alpha_unconstr = alpha_unconstr + self.alpha_unconstr_bias[locate]
            alpha_unconstr = alpha_unconstr + bias

            ## Generate transcription kinetic rate for time t
            alpha = torch.clamp(alpha_unconstr,0,)
            alpha = F.softsign(alpha)

        return alpha
    
def SolveInitialValueProblem(f_t, x0, t0, t_eval):
    
    ## generate the prediction of unspliced/spliced readout at time t for every gene
    ## use torchquad integral, different with torchode, the t_eval no longer need to be ordered
    
    ## get the kinetic parameters
    beta = torch.clamp(F.softplus(f_t.beta_mean_unconstr), 0, 50)
    gamma = torch.clamp(F.softplus(f_t.gamma_mean_unconstr), 0, 50)
    alpha_max = torch.clamp(F.softplus(f_t.alpha_unconstr_max),0,50)

    ## define integral function
    def integral_alpha_beta(f_t, tt, t0, beta, g):
        f_i = lambda t: f_t(t,g)*torch.exp(beta[g]*t)
        integration_domain = [[t0[g],tt]]
        result = simp.integrate(f_i, dim=1, N=101, integration_domain=integration_domain)
        return result
        
    def integral_alpha_gamma(f_t, tt,t0, gamma, g):
        f_i = lambda t: f_t(t,g)*torch.exp(gamma[g]*t)
        integration_domain = [[t0[g],tt]]
        result = simp.integrate(f_i, dim=1, N=101, integration_domain=integration_domain)
        return result

    ## get the initial condition (i.e. u,s = 0,0)
    u0 = x0[:,0].view(-1)
    s0 = x0[:,1].view(-1)
    pre_u = torch.zeros(t_eval.shape)
    pre_s = torch.zeros(t_eval.shape)
    
    ## build for loop to generate readout for each targets
    for g, t in enumerate(t_eval):
        u0g = u0[g]
        s0g = s0[g]
        
        ## calculate integral for gene g
        integral_tensor_alpha_beta = torch.tensor(list(map(lambda tt: integral_alpha_beta(tt=tt,f_t=f_t,t0=t0,beta = beta,g = g), t)))
        integral_tensor_alpha_gamma = torch.tensor(list(map(lambda tt: integral_alpha_gamma(tt=tt,f_t=f_t,t0=t0,gamma = gamma,g = g), t)))
        
        ug = u0g*torch.exp(-beta[g]*t) + alpha_max[g]*torch.exp(-beta[g]*t)*integral_tensor_alpha_beta
        sg = s0g*torch.exp(-gamma[g]*t) + \
            ( (alpha_max[g]*beta[g])/(gamma[g] - beta[g]) )*(torch.exp(-beta[g]*t)*integral_tensor_alpha_beta - torch.exp(-gamma[g]*t)*integral_tensor_alpha_gamma) + \
            ( (beta[g]*u0g)/(gamma[g] - beta[g]) )*(torch.exp(-beta[g]*t) - torch.exp(-gamma[g]*t))
        
        pre_u[g,:] = ug
        pre_s[g,:] = sg
        #print(g)
    return pre_u, pre_s

In [71]:
def SolveInitialValueProblem(t_eval,f_t, x0, t0):
    
    ## generate the prediction of unspliced/spliced readout at time t for every gene
    ## use torchquad integral, different with torchode, the t_eval no longer need to be ordered
    
    ## get the kinetic parameters
    g = int(t_eval[0])
    t = t_eval[1:]
    beta = torch.clamp(F.softplus(f_t.beta_mean_unconstr), 0, 50)
    gamma = torch.clamp(F.softplus(f_t.gamma_mean_unconstr), 0, 50)
    alpha_max = torch.clamp(F.softplus(f_t.alpha_unconstr_max),0,50)

    ## define integral function
    def integral_alpha_beta(f_t, tt, t0, beta, g):
        f_i = lambda t: f_t(t,g)*torch.exp(beta[g]*t)
        integration_domain = [[t0[g],tt]]
        result = simp.integrate(f_i, dim=1, N=101, integration_domain=integration_domain)
        return result
        
    def integral_alpha_gamma(f_t, tt,t0, gamma, g):
        f_i = lambda t: f_t(t,g)*torch.exp(gamma[g]*t)
        integration_domain = [[t0[g],tt]]
        result = simp.integrate(f_i, dim=1, N=101, integration_domain=integration_domain)
        return result

    ## get the initial condition (i.e. u,s = 0,0)
    u0 = x0[:,0].view(-1)
    s0 = x0[:,1].view(-1)
    pre_u = torch.zeros(t_eval.shape)
    pre_s = torch.zeros(t_eval.shape)
    
    ## build for loop to generate readout for each targets
    ## calculate integral for gene g
    u0g = u0[g]
    s0g = s0[g]
    
    integral_tensor_alpha_beta = torch.tensor(list(map(lambda tt: integral_alpha_beta(tt=tt,f_t=f_t,t0=t0,beta = beta,g = g), t)))
    integral_tensor_alpha_gamma = torch.tensor(list(map(lambda tt: integral_alpha_gamma(tt=tt,f_t=f_t,t0=t0,gamma = gamma,g = g), t)))

    ug = u0g*torch.exp(-beta[g]*t) + alpha_max[g]*torch.exp(-beta[g]*t)*integral_tensor_alpha_beta
    sg = s0g*torch.exp(-gamma[g]*t) + \
        ( (alpha_max[g]*beta[g])/(gamma[g] - beta[g]) )*(torch.exp(-beta[g]*t)*integral_tensor_alpha_beta - torch.exp(-gamma[g]*t)*integral_tensor_alpha_gamma) + \
        ( (beta[g]*u0g)/(gamma[g] - beta[g]) )*(torch.exp(-beta[g]*t) - torch.exp(-gamma[g]*t))
    
    pre = torch.cat([ug,sg])
    #print(g)
    return pre

In [35]:
from torchquad import set_up_backend  # Necessary to enable GPU support
from torchquad import Trapezoid, Simpson, Boole, MonteCarlo, VEGAS # The available integrators
from torchquad.utils.set_precision import set_precision
import torchquad
simp = Simpson()

indices = torch.arange(t.shape[0])
indices = indices.repeat_interleave(t.shape[1])
t_all = torch.cat((indices.reshape(-1,1),t.reshape(t.shape[0]*t.shape[1],1)),1)

## define integral function
def integral_alpha_beta(v,f_t, t0, beta):
    g = int(v[0])
    tt = v[1]
    f_i = lambda t: f_t(t,g)*torch.exp(beta[g]*t)
    integration_domain = [[t0[g],tt]]
    result = simp.integrate(f_i, dim=1, N=101, integration_domain=integration_domain)
    return result

def integral_alpha_gamma(v,f_t,t0, gamma):
    g = int(v[0])
    tt = v[1]
    f_i = lambda t: f_t(t,g)*torch.exp(gamma[g]*t)
    integration_domain = [[t0[g],tt]]
    result = simp.integrate(f_i, dim=1, N=101, integration_domain=integration_domain)
    return result

def sum_all(v):
    return int(v.min())

In [72]:
indices = torch.arange(t.shape[0])
t_all = torch.cat((indices.reshape(-1,1),t),1)
partial_func = functools.partial(SolveInitialValueProblem, f_t=f_t,x0 = x0, t0 = t0)

In [77]:
len(partial_func(t_all[0,:]))

256

In [75]:
a = list(map(lambda t_eval: partial_func(t_eval), t_all))

In [76]:
t_all.shape

torch.Size([250, 129])

In [78]:
import torch
from concurrent.futures import ThreadPoolExecutor

starttime = datetime.datetime.now()
output_tensor = torch.randn((250,256))

with ThreadPoolExecutor(max_workers=3) as executor:
    for i, output_row in enumerate(executor.map(partial_func, t_all)):
        output_tensor[i] = output_row
        
endtime = datetime.datetime.now()
print(endtime - starttime)

0:00:23.494697


In [28]:
import functools
partial_func = functools.partial(integral_alpha_beta, f_t=f_t, t0 = t0, beta = beta)

In [29]:
import datetime
starttime = datetime.datetime.now()
datetime.datetime.now()

datetime.datetime(2023, 3, 23, 18, 21, 13, 822388)

In [43]:
import torch
from concurrent.futures import ThreadPoolExecutor

starttime = datetime.datetime.now()
input_tensor = torch.randn((t_all.shape[0],1))
output_tensor = torch.empty_like(input_tensor)
partial_func = functools.partial(integral_alpha_beta, f_t=f_t, t0 = t0, beta = beta)
with ThreadPoolExecutor(max_workers=3) as executor:
    for i, output_row in enumerate(executor.map(partial_func, t_all)):
        output_tensor[i] = output_row
        
endtime = datetime.datetime.now()
print(endtime - starttime)

0:00:13.365531


In [37]:
starttime = datetime.datetime.now()
a = torch.tensor(list(map(lambda v: integral_alpha_beta(v=v,f_t = f_t,t0=t0,beta = beta), t_all)))
b = torch.tensor(list(map(lambda v: integral_alpha_beta(v=v,f_t = f_t,t0=t0,beta = beta), t_all)))
endtime = datetime.datetime.now()
print(endtime - starttime)

0:00:33.851643


In [289]:
integral_alpha_beta(t_all[0,:],f_t = f_t,t0 = t0, beta = beta)

tensor(467.4549, grad_fn=<SumBackward1>)

In [259]:
t_all[0,:].sum()

tensor([ 6.2917, 13.5292, 11.5073,  ..., 13.7434, 14.8046, 14.8780])

In [229]:
x = torch.arange(t.shape[0])
x = x.repeat_interleave(t.shape[1])
torch.cat((x.reshape(-1,1),t.reshape(t.shape[0]*t.shape[1],1)),1)

tensor([[  0.0000,   6.2917],
        [  0.0000,  13.5292],
        [  0.0000,  11.5073],
        ...,
        [249.0000,  13.7434],
        [249.0000,  14.8046],
        [249.0000,  14.8780]])

In [186]:
result2

tensor([[ 0.3027,  0.1420,  0.7881, -0.9004, -0.3324],
        [-0.8600, -0.1008,  1.7802, -0.9399,  0.1204],
        [ 1.2766,  0.2640, -1.9085,  0.5039, -0.1361]])

In [22]:
f_t = alpha_encoder(n_int = rgv_m.model.n_targets, alpha_unconstr_init = rgv_m.model.v_encoder.alpha_unconstr_bias,log_h_int = rgv_m.model.v_encoder.log_h[0,:],
                    W = W.T, W_int = rgv_m.model.v_encoder.grn)

In [20]:
m = nn.Sigmoid()
t = m(torch.randn(250,128))

In [21]:
t = t*20

In [23]:
x0 = torch.zeros((250,2))
t0 = torch.zeros((250))

In [109]:
t0

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [125]:
t_pre = t[0,:]
t_pre

tensor([ 6.2917, 13.5292, 11.5073,  6.2158, 11.7081,  7.6167, 10.6023, 14.8237,
         2.7619,  9.0510, 10.1462,  7.9613, 18.2336,  5.5493,  3.8115,  7.9946,
        11.6737,  8.7470, 14.3674,  4.1293, 12.6021, 11.3531,  9.9593, 13.2177,
        11.7522,  5.0506, 14.9015,  5.1413, 13.6931,  8.8733,  9.1672,  4.3102,
        10.7343, 13.2622, 17.3496,  6.5893,  8.6659,  4.7248, 14.5547, 10.4318,
         5.8495, 16.8982,  8.1507,  7.0760,  9.2485, 11.7672, 10.9677, 15.8301,
        15.8142,  2.8028,  8.7850,  5.1704, 14.3562,  5.5317, 10.6229,  4.0864,
         2.2072,  8.5267,  2.2765, 10.4272, 16.4707,  9.5146, 16.3261,  8.7235,
         2.5453, 11.2223, 15.9432,  3.7466,  5.5299, 12.8501, 15.6438,  6.8605,
         5.8711,  3.9872,  2.9348,  3.4938,  3.8725, 10.9450,  7.5372,  9.4815,
        16.3129, 17.6023,  6.9619, 16.5306, 15.2938,  8.3910,  9.2520, 10.5136,
        14.5773, 10.6170, 12.6587, 11.8950,  9.0369, 12.1820, 13.9930,  7.3579,
         1.0173,  1.9698,  2.9729, 14.06

In [38]:
starttime = datetime.datetime.now()
a,b = SolveInitialValueProblem(f_t = f_t, x0 = x0,t0 = t0, t_eval = t)
endtime = datetime.datetime.now()
print(endtime - starttime)

0:00:29.881662


In [39]:
a

tensor([[0.4227, 0.4186, 0.4227,  ..., 0.4227, 0.4227, 0.2320],
        [1.9704, 1.9611, 1.9658,  ..., 1.9831, 1.9704, 1.8495],
        [0.7740, 0.7740, 0.7740,  ..., 0.7504, 0.7503, 0.7740],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.6349, 0.6529, 0.6465,  ..., 0.6463, 0.6527, 0.6512],
        [0.4001, 0.3839, 0.3989,  ..., 0.4006, 0.3969, 0.4000]],
       grad_fn=<CopySlices>)

In [40]:
b

tensor([[1.2948, 0.9760, 1.2915,  ..., 1.2651, 1.3062, 0.0964],
        [5.4182, 4.4900, 4.7838,  ..., 5.8561, 5.4047, 2.6192],
        [2.3793, 2.3892, 2.3487,  ..., 1.3490, 1.3474, 2.3738],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [1.2124, 1.9342, 1.8324,  ..., 1.7767, 1.9468, 2.0018],
        [1.2203, 0.6343, 1.1521,  ..., 1.1820, 0.9574, 1.2271]],
       grad_fn=<CopySlices>)

In [26]:
beta = torch.clamp(F.softplus(f_t.beta_mean_unconstr), 0, 50)

In [27]:
beta

tensor([0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741, 0.9741,
        0.9741, 0.9741, 0.9741, 0.9741, 