In [None]:
%load_ext autoreload
%autoreload 2

import scipy
from scipy.integrate import odeint
import numpy as np
import matplotlib.pyplot as plt
import random
import sir_model
import json
from random import randint

In [None]:
# initialize recording of parameter choices and true/false


# USER: set bounds
beta_search_bounds = [0.1, 1] #[0.008,0.01] # S, D -> I
lambda_search_bounds = [0.000, 0.1] #[0.36,0.55] # I -> H

# USER: list how many points for each parameter you'd like to synthesize
beta_values_to_synthesize = 10
lambda_values_to_synthesize = 10

# USER: set values for all parameters that you are not synthesizing 
alpha_val = 0.1 # I -> I
delta_val = 0.00011 # S, R -> I

epsilon_val = 0.03 # I -> D
theta_val = 0.571 # A -> R

zeta_val = 0.0125 # I -> A
eta_val = 0.0125 # D -> R

mu_val = 0.017 # A -> T
nu_val = 0.027 # R -> T
gamma_val = 0.0456 # S, A -> I
rho_val = 0.034 # D -> H

kappa_val = 0.017 # A -> H
xi_val = 0.017 # R -> H
sigma_val = 0.017 # T -> H
 
tau_val = 0.01 # T -> E

# USER: set initial conditions
I0, D0, A0, R0, T0, H0, E0 = 0.01, 0, 0, 0, 0, 0, 0
S0 = 1-I0-D0-A0-R0-T0-H0-E0

# USER: set simulation parameters
dt = 1
tstart = 0
tend = 100

# USER: set parameter synthesis method: "all" or "any"
method = "any"



In [None]:
def query_1(sim, tstart, tend, dt):
    S, I, D, A, R, T, H, E = sim.T

    # USER: write query condition.
    # query_condition =  0.15 <= max(I) <= 0.3  #((0.15 <= I[10]) and I[10] <= 1.0)
    # query_condition = (0.45 <= max(I)) and (max(I) <= 0.55) and (45 <= np.argmax(I)*dt) and (np.argmax(I)*dt <= 55)
    query_values = [0.45 <= I[int(t/dt)] <= 0.55 for t in range(tstart, tend+1, dt) if 45 <= t < 55]
    query_condition = all(query_values) 
    # print(I[9:11])
    return query_condition

def query_2(sim, tstart, tend, dt):
    S, I, D, A, R, T, H, E = sim.T

    # USER: write query condition.
    # query_condition =  0.15 <= max(I) <= 0.3  #((0.15 <= I[10]) and I[10] <= 1.0)
    # query_condition = (0.45 <= max(I)) and (max(I) <= 0.55) and (45 <= np.argmax(I)*dt) and (np.argmax(I)*dt <= 55)
    query_values = [0.25 <= I[int(t/dt)] <= 0.55 for t in range(tstart, tend+1, dt) if 45 <= t < 55]
    query_condition = all(query_values) 
    # print(I[9:11])
    return query_condition

def eval_point(beta_val, lamb_val, query_condition=query_1, plot=False, rtol=1, atol=1, mxstep=10, mxordn = 1, mxords=1, hmin=1):    
    # set parameters


    # parameters
    # set parameter values
    def alpha(t): return np.piecewise(t, [t>=0], [alpha_val])
    def beta(t): return np.piecewise(t, [t>=0], [beta_val])
    def delta(t): return np.piecewise(t, [t>=0], [delta_val])
    def gamma(t): return np.piecewise(t, [t>=0], [gamma_val])

    def epsilon(t): return np.piecewise(t, [t>=0], [epsilon_val])
    def theta(t): return np.piecewise(t, [t>=0], [theta_val])

    def zeta(t): return np.piecewise(t, [t>=0], [zeta_val])
    def eta(t): return np.piecewise(t, [t>=0], [eta_val])

    def mu(t): return np.piecewise(t, [t>=0], [mu_val])
    def nu(t): return np.piecewise(t, [t>=0], [nu_val])
    def lamb(t): return np.piecewise(t, [t>=0], [lamb_val])
    def rho(t): return np.piecewise(t, [t>=0], [rho_val])

    def kappa(t): return np.piecewise(t, [t>=0], [kappa_val])
    def xi(t): return np.piecewise(t, [t>=0], [xi_val])
    def sigma(t): return np.piecewise(t, [t>=0], [sigma_val])

    def tau(t): return np.piecewise(t, [t>=0], [tau_val])

    y0 = S0, I0, D0, A0, R0, T0, H0, E0 # Initial conditions vector

    tvect = np.arange(tstart, tend, dt)
    # simulate/solve ODEs
    sim = odeint(sir_model.SIDARTHE_model, y0, tvect, args=(alpha, beta, gamma, delta, epsilon, mu, zeta, lamb, eta, rho, theta, kappa, nu, xi, sigma, tau))
    S, I, D, A, R, T, H, E = sim.T


    # write query condition.
    # query_condition = (0.999 <= (S+I+D+A+R+T+H+E).all() <= 1.001)  # compartmental constraint
    #query_condition =  (0.32 <= max(I) <= 0.33) and (12 <= np.argmax(I)*dt <= 14) # max location and size
    query = '1' if query_condition(sim, tstart, tend, dt) else '0'
    # plot results - uncomment next line to plot time series.  not recommended for large numbers of points
    # if query == '1':
    #     print('beta:', beta_val, 'gamma:', gamma_val)
    if plot:
        sir_model.plotSIDARTHE(tvect, S, I, D, A, R, T, H, E)
    ## end plot
    param_assignments = {'beta': beta_val, 'lambda': lamb_val, 'assignment': query} # for "all", go through every option. for "any", only need one good parameter choice.
    return param_assignments

def ps(param_synth_method, search_points_beta, search_points_gamma, query_condition=query_1, plot=False):
    param_choices_true_false = []
    for i in range(len(search_points_beta)):
        beta_val = search_points_beta[i]
        for j in range(len(search_points_lambda)):
            lambda_val = search_points_lambda[j]
            param_assignments = eval_point(beta_val, lambda_val, query_condition=query_condition, plot=plot, rtol=1, atol=1, mxstep=10, mxordn = 1, mxords=1, hmin=1)
            param_choices_true_false.append(param_assignments)
            if 0:
                return param_choices_true_false
    # if param_synth_method == "any" and query == '1':
    # return param_choices_true_false
    return param_choices_true_false
    

search_points_beta = np.linspace(beta_search_bounds[0], beta_search_bounds[1], beta_values_to_synthesize)
search_points_lambda = np.linspace(lambda_search_bounds[0], lambda_search_bounds[1], lambda_values_to_synthesize)

# search_points_beta = [0.4]
# search_points_lambda = [0.001]
method = "any"

param_choices_true_false = ps(method, search_points_beta, search_points_lambda, query_condition=query_2, plot=False)

# Plot "true/false" points.  
sir_model.plot_two_params("beta", "lambda", param_choices_true_false)

In [None]:
# Save results

id_number = randint(10**5, 10**6 - 1)

with open(f'sidarthe_query_auto_2_param_{id_number}_{method}.json', 'w', encoding='utf-8') as f:
    json.dump(param_choices_true_false, f, ensure_ascii=False, indent=4)

In [None]:
### For reference: old values

# alpha_val = 0.57
# beta_val = 0.011
# delta_val = 0.011
# gamma_val = 0.456

# epsilon_val = 0.05 #0.171
# theta_val = 0.371

# zeta_val = 0.125
# eta_val = 0.125

# mu_val = 0.017
# nu_val = 0.027
# lamb_val = 0.034
# rho_val = 0.034

# kappa_val = 0.017
# xi_val = 0.017
# sigma_val = 0.017

# tau_val = 0.01