In [None]:
!pip install numpyro

Collecting numpyro
  Downloading numpyro-0.18.0-py3-none-any.whl.metadata (37 kB)
Collecting multipledispatch (from numpyro)
  Downloading multipledispatch-1.0.0-py3-none-any.whl.metadata (3.8 kB)
Downloading numpyro-0.18.0-py3-none-any.whl (365 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m365.8/365.8 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multipledispatch-1.0.0-py3-none-any.whl (12 kB)
Installing collected packages: multipledispatch, numpyro
Successfully installed multipledispatch-1.0.0 numpyro-0.18.0


In [None]:
import argparse
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
import pickle
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax.random import PRNGKey
from jax import jit
import jax
import time
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, autoguide, SVI, Trace_ELBO, TraceMeanField_ELBO
from numpyro.handlers import mask
import sys
import numpy as np
from copy import deepcopy
numpyro.enable_x64(False)
numpyro.set_platform("tpu")

matplotlib.use("Agg")

In [None]:
def dz_dt(state, xs, parameters):
    """
    Diff. Eq. of glycemia.
    State variables:
        SG: Subcutaneous glucose
        G: Blood glucose
        X: Insulin effect
        Ip: Plasma insulin
        S1,2: Subcutaneous insulin 1,2
        E1,2: Energy expenditure 1,2

    Parameters:
        EGP: Endogenous glucose production, varies between days.
        GEZI: Defines glucose steady state at zero plasma insulin.
        SI: Insulin sensitivity, varies during and between days.
        hypo: Increased insulin effect in hypoglycemia.
        ht: Threshold for increased insulin effect in hypoglycemia approx. 60-70 mg/dL.
        EI: Sensitivity of elevated heart rate on energy expenditure.
        Sdelay: Time constant between plasma and subcutaneous glucose.
        ihr: Increased absorption due to elevated heart rate, might be very arbitrary.

    Inputs:
        basal_insulin: Long acting insulin injections are converted to constant infusion rates assuming 24 hours of total action.
        Sa: Short-acting insulin injections in the subcutanous space modeled by the sum of first-order impulse responses.
        Ra: Rate of glucose appearance from meal intakes, modeled by the sum of third-order impulse responses.
        hr: Excess heart rate, measured from 70 BPM resting heart rate.
    """
    SG, X1, Ip, S1, G, S2, E1, E2 = state
    G = jax.nn.softplus(G) # Constraint on G > 0
    SG = jax.nn.softplus(SG) # Constraint on SG > 0
    renal_threshold = jax.nn.softplus(G-9*18.0)

    # Parameters
    EGP, GEZI, SI, tau1, p2, tau2, taup, tau0, hypo, ht, taue, EI, tauG, ihr = parameters
    CI = 1200.0 # Insulin clearance fixed for identifiability
    sampling_time = 5.0

    # Inputs
    basal_insulin, bolus_insulin, Ra, SIt, Sa, EGPt, hr = xs

    # Differential equations
    dE1dt = (-E1 + hr*EI)/taue
    dE2dt = (-E2 + E1)/taue
    dS1dt = (-S1 + basal_insulin / CI) / tau1
    dS2dt = (-S2 + S1 + Sa) / tau2
    dIpdt = (-Ip + S2) / (taup/(1.0+hr/ihr))
    dX1dt = (-X1 + SIt * Ip)/p2
    dGdt = (- (GEZI+ (1.0+hypo*jnp.heaviside(ht-G,0.0))*(X1)) * G
            + EGPt + Ra - E2/taue - 0.003*renal_threshold)
    dSGdt = (G - SG)/tauG

    next_state = state + sampling_time*jnp.stack((dSGdt,dX1dt,dIpdt,dS1dt,dGdt,dS2dt,dE1dt,dE2dt))

    return next_state, next_state

In [None]:
def Ra(time, meal_info, Vg, tau0):

    carb, start_times, durations, taud = meal_info # 4 by Number of meals

    sampled_start_times = start_times
    sampled_tauds = taud[:,None]
    sampled_carbs = carb[:,None]

    rate_of_appearance_time_array_full = jnp.transpose(jnp.linspace(time[0] - sampled_start_times, time[-1] - sampled_start_times, time.size))
    binmap = jnp.heaviside(rate_of_appearance_time_array_full, 1.0)
    rate_of_appearance_time_array_full = jnp.clip(rate_of_appearance_time_array_full, min=0.0)

    expbt = jnp.exp(jnp.divide(-rate_of_appearance_time_array_full, sampled_tauds))
    expat = jnp.exp(jnp.divide(-rate_of_appearance_time_array_full, tau0 * jnp.ones(sampled_tauds.shape)))

    term1 = jnp.divide(jnp.multiply(expbt, binmap*rate_of_appearance_time_array_full), jnp.abs(tau0 - sampled_tauds)*sampled_tauds)
    term2 = jnp.divide(jnp.multiply(expat, binmap*tau0), jnp.power(tau0 - sampled_tauds, 2))
    term3 = jnp.divide(jnp.multiply(expbt, binmap*tau0), jnp.power(tau0 - sampled_tauds, 2))

    Ra = jnp.sum(jnp.multiply(sampled_carbs/Vg*800, term1 + term2 - term3), 0)

    return Ra

In [None]:
def Sa(time, insulin_info, CI=1200.0):

    bolus, start_times, durations, taud = insulin_info

    sampled_start_times = start_times
    sampled_tauds = taud[:,None]
    sampled_carbs = bolus[:,None]*1E6

    rate_of_appearance_time_array_full = jnp.transpose(
        jnp.linspace(time[0] - sampled_start_times, time[-1] - sampled_start_times, time.size))
    binmap = jnp.heaviside(rate_of_appearance_time_array_full,1.0)
    rate_of_appearance_time_array_full = jnp.clip(rate_of_appearance_time_array_full, min=0.0)

    d_temp = jnp.divide(-rate_of_appearance_time_array_full, sampled_tauds)
    m2_temp = jnp.multiply(sampled_carbs / CI, jnp.divide(jnp.multiply(jnp.exp(d_temp), binmap), sampled_tauds))
    Ra = jnp.sum(m2_temp, 0)

    return Ra

In [None]:
def Pa(time, exercise_info, gain=1.0):

    exercise, start_times, durations, taud = exercise_info

    sampled_start_times = start_times
    sampled_tauds = taud[:,None]
    sampled_exercise = exercise[:,None]

    rate_of_appearance_time_array_full = jnp.transpose(
        jnp.linspace(time[0] - sampled_start_times, time[-1] - sampled_start_times, time.size))
    binmap = jnp.heaviside(rate_of_appearance_time_array_full,1.0)
    rate_of_appearance_time_array_full = jnp.clip(rate_of_appearance_time_array_full, min=0.0)

    d_temp = jnp.divide(-rate_of_appearance_time_array_full, sampled_tauds)
    m2_temp = jnp.multiply(gain*sampled_exercise, jnp.divide(jnp.multiply(jnp.exp(d_temp), binmap), sampled_tauds))
    Ra = jnp.sum(m2_temp, 0)

    return Ra

In [None]:
def model_cohort(sample_data, cohort_meals, cohort_boluses, all_roll, cohort_pa, y=None):
    """
    """
    all_bg, all_basal, all_bolus, all_ts_grid, all_ts_norm, all_hr = sample_data
    full_shape = jnp.ones((all_bg.shape[0],all_bg.shape[1]))
    patient_shape = jnp.ones((all_bg.shape[0],))

    # Init conditions on the cohort level
    Ieff0_l3 = numpyro.sample("Ieff0_l3",dist.TruncatedNormal(low=0.0, high=0.2, loc=0.0, scale=2E-3))
    Ip0_l3 = numpyro.sample("Ip0_l3",dist.TruncatedNormal(low=0.0, high=200.0, loc=0.0, scale=400*0.2))
    Isc0_l3 = numpyro.sample("Isc0_l3",dist.TruncatedNormal(low=0.0, high=100.0, loc=0.0, scale=400*0.2))

    # Init conditions on the patient level
    Ieff0_l2 = numpyro.sample("Ieff0_l2",dist.TruncatedNormal(low=0.0*jnp.ones((all_bg.shape[0])), high=0.2*jnp.ones((all_bg.shape[0])), loc=Ieff0_l3*jnp.ones((all_bg.shape[0])), scale=2E-3*jnp.ones((all_bg.shape[0]))))
    Ip0_l2 = numpyro.sample("Ip0_l2",dist.TruncatedNormal(low=0.0*jnp.ones((all_bg.shape[0])), high=200.0*jnp.ones((all_bg.shape[0])), loc=Ip0_l3*jnp.ones((all_bg.shape[0])), scale=400*0.2*jnp.ones((all_bg.shape[0]))))
    Isc0_l2 = numpyro.sample("Isc0_l2",dist.TruncatedNormal(low=0.0*jnp.ones((all_bg.shape[0])), high=100.0*jnp.ones((all_bg.shape[0])), loc=Isc0_l3*jnp.ones((all_bg.shape[0])), scale=400*0.2*jnp.ones((all_bg.shape[0]))))

    # Init conditions on the sample level
    Ieff0_l2S = numpyro.sample("Ieff0_l2S",dist.HalfNormal(scale=2E-3*jnp.ones((all_bg.shape[0]))))
    Ip0_l2S = numpyro.sample("Ip0_l2S",dist.HalfNormal(scale=400*0.2*jnp.ones((all_bg.shape[0]))))
    Isc0_l2S = numpyro.sample("Isc0_l2S",dist.HalfNormal( scale=400*0.2*jnp.ones((all_bg.shape[0]))))
    Ieff0 = numpyro.sample("Ieff0_l1",dist.TruncatedNormal(low=0.0*full_shape, high=0.2*full_shape, loc=Ieff0_l2[:,None], scale=Ieff0_l2S[:,None]))
    Ip0 = numpyro.sample("Ip0_l1",dist.TruncatedNormal(low=0.0*full_shape, high=200.0*full_shape, loc=Ip0_l2[:,None], scale=Ip0_l2S[:,None]))
    Isc0 = numpyro.sample("Isc0_l1",dist.TruncatedNormal(low=0.0*full_shape, high=100.0*full_shape, loc=Isc0_l2[:,None], scale=Isc0_l2S[:,None]))

    z_init = jnp.stack([all_bg[:,:,0],Ieff0,Ip0,jnp.zeros(Ip0.shape),all_bg[:,:,0],Isc0,jnp.zeros(Ip0.shape),jnp.zeros(Ip0.shape)],axis=-1)

    # Lower and upper bounds for node C
    # EGP, GEZI, SI, tau1, p2, tau2, taup, tau0 ,hypo, ht, taue ,EI, tauG, ihr
    mp_low = jnp.array([0.25,1E-3,2.5E-5,5.0,20.0,20.0,20.0,1.0,0.0,50.0,10.0,0.0,5.0,25.0])
    mp_high = jnp.array([2.5,3E-3,80E-5,80.0,80.0,80.0,80.0,10.0,1.0,100.0,40.0,2.0,30.0,300.0])

    # Glucose distribution volume on the patient level
    VG = numpyro.sample("VG",dist.TruncatedNormal(low=40.0*patient_shape, high=300.0*patient_shape, loc=120.0*patient_shape, scale=100.0*patient_shape))

    # Node C
    param_prior = 0.5*jnp.ones(mp_low.shape)
    param_prior = param_prior.at[2].set(0.25)
    model_parameters_l3 = numpyro.sample("model_parameters_l3",
                               dist.TruncatedNormal(low=jnp.zeros(mp_low.shape),
                                                    high=jnp.ones(mp_low.shape),
                                                    loc=param_prior,
                                                    scale=0.4*jnp.ones(mp_low.shape)))

    # Node CP
    model_parameters_l2 = numpyro.sample("model_parameters_l2",
        dist.TruncatedNormal(
            low=jnp.zeros((all_bg.shape[0],mp_low.shape[0])),
            high=jnp.ones((all_bg.shape[0],mp_low.shape[0])),
            loc=jnp.ones((all_bg.shape[0],mp_low.shape[0]))*model_parameters_l3,
            scale=0.2*jnp.ones((all_bg.shape[0],mp_low.shape[0]))))
    model_parameters = mp_low[None,:] + (mp_high[None,:]-mp_low[None,:])*model_parameters_l2


    time_range = jnp.linspace(0,1,288,endpoint=False)
    # Intraday insulin variability on the cohort level
    phase3 = numpyro.sample("SIp_l3",dist.Uniform(low=-1.0,high=1.0))
    amp3 = numpyro.sample("SIA_l3",dist.TruncatedNormal(loc=0.5,scale=1.0,low=0.0,high=1.0))
    # Intraday insulin variability on the cohort level
    phase2 = numpyro.sample("SIp_l2",dist.TruncatedNormal(low=-1.0*patient_shape,high=1.0*patient_shape,loc=phase3*patient_shape,scale=0.5*patient_shape))
    amp2 = numpyro.sample("SIA_l2",dist.TruncatedNormal(low=0.0*patient_shape,high=1.0*patient_shape,loc=amp3*patient_shape,scale=0.5*patient_shape))
    # Standard deviation on the sample level
    EGPsS = numpyro.sample("EGPsS",dist.Exponential(rate=(1.0/(0.01))*patient_shape))
    SIsS = numpyro.sample("SIsS",dist.Exponential(rate=(1.0/(0.01))*patient_shape))
    phaseS = numpyro.sample("SIpS",dist.Exponential(rate=1.0/0.15*patient_shape))
    ampS = numpyro.sample("SIAS",dist.Exponential(rate=1.0/0.25*patient_shape))


    EGPsRi = model_parameters_l2[:,0] # Intercept of the normalized EGP on the patient level
    # Normalized insulin sensitivity on the sample level
    SIsRu = numpyro.sample("SIsR",dist.TruncatedNormal(low=0.0*full_shape, high=1.0*full_shape, loc=model_parameters_l2[:,[2]], scale=SIsS[:,None]*model_parameters_l2[:,[2]]))
    # Scaled insulin sensitivity on the sample level
    SIs = numpyro.deterministic("SIs",mp_low[2] + (mp_high[2]-mp_low[2])*SIsRu)

    # Correlation between SI and EGP, slope between 0 and 1
    SI_EGP = numpyro.sample("SI_EGP",dist.TruncatedNormal(low=0.0*patient_shape, high=1.0*patient_shape, loc=0.1*patient_shape, scale=1.0*patient_shape))
    EGPsRu = EGPsRi[:,None] + SI_EGP[:,None]*SIsRu
    EGPex = mp_low[0] + (mp_high[0]-mp_low[0])*EGPsRu
    # Scaled EGP on sample level
    EGPs = numpyro.sample("EGPs",dist.TruncatedNormal(low=0.25*full_shape, high=2*2.5*full_shape, loc=EGPex, scale=EGPsS[:,None]*EGPex))

    # Intraday insulin variability on the sample level
    phase = numpyro.sample("SIp_l1",dist.TruncatedNormal(low=-1.0*full_shape,high=1.0*full_shape,loc=phase2[:,None]*full_shape,scale=phaseS[:,None]))
    amp = numpyro.sample("SIA_l1",dist.TruncatedNormal(low=0.0*full_shape,high=1.0*full_shape,loc=amp2[:,None]*full_shape,scale=ampS[:,None]))

    # Sinusoidal intraday insulin sensitivity variation on the sample level
    SIintra = numpyro.deterministic("SIintra",amp[:,:,None]*jnp.sin(2*jnp.pi*(time_range[None,None,:]+phase[:,:,None])))
    # Time varying insulin sensitivity on the sample level, adjusted for the starting time of the current sample
    SIt =  numpyro.deterministic("SIt",jax.vmap(jnp.roll,(0,0,None))(SIs[:,:,None]*(1.0+0.25*SIintra),-all_roll,1))
    EGPt = EGPs[:,:,None]*jnp.ones(SIt.shape)

    # Meal absorption time constant on the patient level
    tauds_prior = numpyro.sample(name="taud_m_prior", fn=dist.TruncatedNormal(low=11.0,high=90.0,loc=35.0,scale=30.0))
    taudsS = numpyro.sample(name="taud_mS", fn=dist.Exponential(rate=1.0/25.0*patient_shape))

    meal_shape = jnp.ones((cohort_meals.shape[0],cohort_meals.shape[1]))
    # Excluding padded zero meals from the likelihood calculation
    with mask(mask=cohort_meals[:,:,1]>0):
        # Offset of meal start time on the meal level
        start_offset = numpyro.sample(name="start_m", fn=dist.TruncatedNormal(low=-90.0*meal_shape,high=90.0*meal_shape,loc=10.0*meal_shape,scale=15.0))
        # Meal absorption time constant on the meal level
        tauds = numpyro.sample(name="taud_m", fn=dist.TruncatedNormal(low=11.0*meal_shape,high=90.0*meal_shape,loc=tauds_prior*meal_shape,scale=taudsS[:,None]))
        # Carb content coefficient on the meal level
        meal_coeff = numpyro.sample(name="meal_coeff", fn=dist.TruncatedNormal(low=0.0*meal_shape,high=3.0*meal_shape,loc=1.0*meal_shape,scale=0.2*meal_shape))


    meal_info = jnp.asarray([cohort_meals[:,:,1]*meal_coeff,
                             cohort_meals[:,:,0]+start_offset,20.0*jnp.ones_like(cohort_meals[:,:,0]),tauds]).transpose(1,0,2)
    # Calculates rate of appearance as 3rd order impulse response
    ra_arr = jax.vmap(jax.vmap(Ra,(0,None,None,None)),(0,0,0,0))(all_ts_grid,meal_info,VG,model_parameters.at[:,7].get())#

    # Insulin absorption time constant 1 on the patient level
    taudbi_shape = jnp.ones((cohort_boluses.shape[0],cohort_boluses.shape[1]))
    taudbiS = numpyro.sample(name="taud_biS", fn=dist.HalfNormal(scale=15.0*patient_shape))
    # Excluding padded zero boluses from the likelihood calculation
    with mask(mask=cohort_boluses[:,:,1]>0):
        # Offset of bolus start time on the bolus level
        start_offset_b = numpyro.sample(name="start_b", fn=dist.TruncatedNormal(low=-45.0*taudbi_shape,high=45.0*taudbi_shape, loc=0*taudbi_shape,scale=5*taudbi_shape))
        # Insulin absorption time constant on the bolus level
        taudbi = numpyro.sample(name="taud_bi", fn=dist.TruncatedNormal(low=5.0*taudbi_shape,high=80.0*taudbi_shape,loc=model_parameters[:,[3]]*taudbi_shape,scale=taudbiS[:,None]))
        # Injection amount coefficient on the bolus level
        bolus_coeff = numpyro.sample(name="bolus_coeff", fn=dist.TruncatedNormal(low=0.75*taudbi_shape,high=1.5*taudbi_shape,loc=1.0*taudbi_shape,scale=0.025*taudbi_shape))


    insulin_info = jnp.asarray([cohort_boluses[:,:,1]*bolus_coeff,cohort_boluses[:,:,0]+start_offset_b,20.0*jnp.ones_like(cohort_boluses[:,:,1]),taudbi]).transpose(1,0,2)#theta.at[3].get()*jnp.ones_like(all_boluses[:,1])
    # Calculates insulin appearance in the subcutaneous compartment 2 as a 1st order impulse response
    sa_arr = jax.vmap(jax.vmap(Sa,(0,None)),(0,0))(all_ts_grid,insulin_info)

    tudpa_shape = jnp.ones((cohort_pa.shape[0],cohort_pa.shape[1]))

    # Time constant of the effect of physical activity on the patient level
    taudpa_prior = numpyro.sample(name="taudpa_prior", fn=dist.TruncatedNormal(low=120.0*patient_shape,high=2500.0*patient_shape,loc=600.0*patient_shape,scale=1000.0*patient_shape))
    taudpaS = numpyro.sample("taudpaS",dist.Exponential(rate=1.0/1000.0*patient_shape))

    # Exlcuding padded zero physical activity events from the likelihood calculation
    with mask(mask=cohort_pa[:,:,1]>0):
        # Offset in the start time of the self-reported physical activity
        start_offset_pa = numpyro.sample(name="start_pa", fn=dist.TruncatedNormal(low=-60.0*tudpa_shape,high=60.0*tudpa_shape,loc=0.0*tudpa_shape,scale=15.0))
        # Time constants for the short and long effect of physical activity on the event level
        taudpa = numpyro.sample(name="taudpa", fn=dist.TruncatedNormal(low=30.0*tudpa_shape,high=2500.0*tudpa_shape,loc=taudpa_prior[:,None],scale=taudpaS[:,None]))
        taudpa_short = numpyro.sample(name="taudpa_short", fn=dist.TruncatedNormal(low=5.0*tudpa_shape,high=120.0*tudpa_shape,loc=30.0*tudpa_shape,scale=120.0*tudpa_shape))
        # Gain for the short and long effect of physical activity on the event level
        pa_coeff = numpyro.sample(name="pa_coeff", fn=dist.TruncatedNormal(low=0.0*tudpa_shape,high=2500.0*tudpa_shape,loc=taudpa/5.0,scale=taudpa/5.0))
        pa_coeff_short = numpyro.sample(name="pa_coeff_short", fn=dist.TruncatedNormal(low=-120.0*tudpa_shape,high=120.0*tudpa_shape,loc=1.0,scale=taudpa_short/5.0))
    # Calculates the effect of physical activity as a 1st order impulse response
    pa_info = jnp.asarray([cohort_pa[:,:,1]*pa_coeff,cohort_pa[:,:,0]+start_offset_pa,20.0*jnp.ones_like(cohort_pa[:,:,1]),taudpa]).transpose(1,0,2)
    pa_arr = jax.vmap(jax.vmap(Pa,(0,None)),(0,0))(all_ts_grid,pa_info)
    pa_info = jnp.asarray([cohort_pa[:,:,1]*pa_coeff_short,cohort_pa[:,:,0]+start_offset_pa,20.0*jnp.ones_like(cohort_pa[:,:,1]),taudpa_short]).transpose(1,0,2)
    pa_arr_short = jax.vmap(jax.vmap(Pa,(0,None)),(0,0))(all_ts_grid,pa_info)
    # Physical activity mediated insulin sensitivity
    SItPA = jax.nn.softplus(100.0*(1.0+pa_arr+pa_arr_short))/100.0*SIt

    inputs = jnp.stack((all_basal,all_bolus,ra_arr,SItPA,sa_arr, EGPt, jax.nn.softplus(all_hr-70.0))).transpose(1,2,3,0)
    # Vectorized numerical integration using 1st order Euler with a sampling time of 5 minutes
    _,states = jax.vmap(jax.vmap(lambda z_initp, inputsp, paramp: jax.lax.scan(jax.tree_util.Partial(dz_dt, parameters=paramp),z_initp, inputsp),(0,0,None)),(0,0,0))(z_init, inputs,model_parameters)

    # Subcutenous glucose concentration
    glucose = states[:,:,:,0]

    if y is None:
        numpyro.deterministic("y_pred", glucose)
        numpyro.deterministic("SIt_pred", SIt)
        numpyro.deterministic("Ra_pred", ra_arr)
        numpyro.deterministic("Sa_pred", sa_arr)
        numpyro.deterministic("Pa_pred", pa_arr+pa_arr_short)
        numpyro.deterministic("Ieff_pred",states[:,:,:,1])
        numpyro.deterministic("E2_pred",states[:,:,:,7])
    else:
        # Measurement noise on the sample level
        with numpyro.plate(name="patients", size=all_bg.shape[0], dim=-2) as ind:
            with numpyro.plate(name="samples", size=all_bg.shape[1], dim=-1):
                sigma = numpyro.sample('sigma_i', dist.LogNormal(3.3,0.5))
        numpyro.sample("y", dist.Normal(glucose.reshape(-1,glucose.shape[-1]), sigma.reshape((sigma.shape[0]*sigma.shape[1],1))), obs=all_bg.reshape(-1,all_bg.shape[-1]))#



    return

In [None]:
with open("dataset.pkl","rb") as f:
  dataset = pickle.load(f)

In [None]:
if __name__ == "__main__":
    num_steps = 80000
    guide = autoguide.AutoLowRankMultivariateNormal(model_cohort)
    optimizer = numpyro.optim.Adam(step_size=0.001)
    svi = SVI(model_cohort, guide, optimizer, loss=Trace_ELBO(15))
    svi_result = svi.run(PRNGKey(1),num_steps,sample_data=dataset["samples"],
                         cohort_meals=dataset["meals"], cohort_boluses=dataset["boluses"],
                         all_roll=dataset["sample_times"],cohort_pa=dataset["exercises"],
                         y=True)
    output_dict = {}
    output_dict['model']=model_cohort
    output_dict['guide']=guide
    output_dict['params']=svi_result.params
    output_dict["state"] = svi_result.state
    with open('svi_result_LRMv101rf3.pkl', 'wb') as handle:
        pickle.dump(output_dict, handle)

100%|██████████| 80000/80000 [2:25:40<00:00,  9.15it/s, init loss: 9516162048.0000, avg. loss [76001-80000]: 527598.4375]
