In [None]:
# 1:54:49.48 to get a functioning encoding with just compartmental constraints.  Needs more time on probing the parameters

from pysmt.shortcuts import And, Or, Plus, Minus, Times, Div, REAL, LE, LT, GE, GT, Equals, Symbol, Real, Solver
from pysmt.fnode import  FNode
from typing import Dict
import pandas as pd
from decimal import Decimal

# Started at 48:00 # Finished at 1:27:15
def dataframe(assignment: Dict[Symbol, float], state_variables, parameters, timepoints) -> pd.DataFrame:
    timeseries = {sv: [None]*len(timepoints) for sv in state_variables}
    
    

    for k, v in assignment.items():
        sym = k.symbol_name()
        if "_" in sym:
            sv = sym.split("_")[0]
            t = sym.split("_")[1]
            value =Decimal(v.numerator) / Decimal(v.denominator)
            timeseries[sv][timepoints.index(int(t))] = value
        else:
            timeseries[sym] = [v]*len(timepoints)

    df = pd.DataFrame(timeseries, index=timepoints).astype(float)

    return df




In [None]:
beta_bounds = (0.0, 0.01)
gamma_bounds = (0.0, 0.2)

In [None]:
time_format = lambda t: f"{t:03d}"

# Initial States
S_0 = Symbol(f"S_{time_format(0)}", REAL)
I_0 = Symbol(f"I_{time_format(0)}", REAL)
R_0 = Symbol(f"R_{time_format(0)}", REAL)

S_0_value = 1000
I_0_value = 1
R_0_value = 0

population_size = S_0_value + I_0_value + R_0_value

# SIR Model Initial State
initial_state = And([
    Equals(S_0, Real(S_0_value)),
    Equals(I_0, Real(I_0_value)),
    Equals(R_0, Real(R_0_value))
])

# Parameters
beta = Symbol("beta", REAL)

gamma = Symbol("gamma", REAL)


parameters = And([
    And(LE(Real(beta_bounds[0]), beta), LT(beta, Real(beta_bounds[1]))),
    And(LE(Real(gamma_bounds[0]), gamma), LT(gamma, Real(gamma_bounds[1])))
])

# Timepoints
step_size = 10
timepoints = list(range(0, 110, step_size))

# Transitions

S_next = lambda t: Symbol(f"S_{time_format(t+step_size)}", REAL)
S_now = lambda t: Symbol(f"S_{time_format(t)}", REAL)
I_next = lambda t: Symbol(f"I_{time_format(t+step_size)}", REAL)
I_now = lambda t: Symbol(f"I_{time_format(t)}", REAL)
R_next = lambda t: Symbol(f"R_{time_format(t+step_size)}", REAL)
R_now = lambda t: Symbol(f"R_{time_format(t)}", REAL)
dt = Real(float(step_size))

S_Trans = lambda t: Equals(S_next(t), 
                                Minus(
                                    S_now(t), 
                                    Times([beta, S_now(t), I_now(t), dt])))

I_Trans = lambda t: Equals(I_next(t), 
                                Plus(
                                    I_now(t), 
                                    Times(
                                        Minus(
                                            Times([beta, S_now(t), I_now(t)]), 
                                            Times(gamma, I_now(t))), dt)))

R_Trans = lambda t: Equals(R_next(t), 
                                Plus(
                                    R_now(t), 
                                    Times(
                                        Times(gamma, I_now(t)), 
                                        dt)))

Trans = lambda t: And(S_Trans(t), I_Trans(t), R_Trans(t))

All_Trans = And([Trans(t) for t in timepoints[:-1]])


compartmental_constraint = And([
    And( LE(Real(0.0), S_now(t)),
    LE(Real(0.0), I_now(t)),
    LE(Real(0.0), R_now(t)),
    Equals(Plus([S_now(t), I_now(t), R_now(t)]), Real(population_size)))
 for t in timepoints])

#Stopped at 32:00.68

consistency = And([
    initial_state,
    parameters,
    All_Trans,
    compartmental_constraint
    ])
# consistency.serialize()

In [None]:
# Solve encoding
with Solver() as solver:
    solver.add_assertion(consistency)
    result = solver.solve()
    if result:
        model = solver.get_model()
        variables = consistency.get_free_variables()
        values = {}
        for var in variables:
            try:
                values[var] = model.get_value(var).constant_value()
            except Exception as e:
                pass
    else:
        print("Unsat")
        values = None

if values:
    results: pd.DataFrame = dataframe(values, ["S", "I", "R"], ["beta", "gamma"], list(timepoints))
    print(f"beta = {results['beta'][0]}, gamma = {results['gamma'][0]}")
    ax = results[["S", "I", "R"]].plot()
    ax.set_xlabel="Time"
