In [None]:
# This notebook illustrates handling the July 2024 Demo of the 6mo Hackathon Scenario 2, described at: https://github.com/DARPA-ASKEM/program-milestones/issues/74

# Import funman related code
import os
from funman.api.run import Runner
from funman_demo import summarize_results
from funman import FunmanWorkRequest, EncodingSchedule 
import json
from funman.representation.constraint import LinearConstraint, ParameterConstraint, StateVariableConstraint
from funman.representation import Interval
import pandas as pd

RESOURCES = "../../resources"
SAVED_RESULTS_DIR = "./out"

EXAMPLE_DIR = os.path.join(RESOURCES, "amr", "petrinet","monthly-demo", "2024-07")
MODEL_PATH = os.path.join(
    EXAMPLE_DIR, "SEIRHD_vacc_petrinet.json"
)
REQUEST_PATH = os.path.join(
    EXAMPLE_DIR, "SEIRHD_vacc_petrinet_request.json"
)

request_params = {}

# %load_ext autoreload
# %autoreload 2

In [None]:
# Constants for the scenario
STATES = ["S_u", "I_u", "E_u", "S_v", "E_v", "I_v", "H", "R", "D"]

MAX_TIME=28
STEP_SIZE=28
timepoints = list(range(0, MAX_TIME+STEP_SIZE, STEP_SIZE))

In [None]:
# Helper functions to setup FUNMAN for different steps of the scenario

def get_request():
    with open(REQUEST_PATH, "r") as request:
        funman_request = FunmanWorkRequest.model_validate(json.load(request))
        return funman_request

def set_timepoints(funman_request, timepoints):
    funman_request.structure_parameters[0].schedules = [EncodingSchedule(timepoints=timepoints)]

def unset_all_labels(funman_request):
    for p in funman_request.parameters:
        p.label = "any"
    
def set_config_options(funman_request, debug=False):
    # Overrides for configuration
    #
    funman_request.config.substitute_subformulas = True
    # funman_request.config.use_transition_symbols = True
    funman_request.config.use_compartmental_constraints=False
    if debug:
        funman_request.config.save_smtlib="./out"
    funman_request.config.tolerance = 0.01
    funman_request.config.dreal_precision = 1000
    # funman_request.config.verbosity = 10
    # funman_request.config.dreal_log_level = "debug"
    # funman_request.config.dreal_prefer_parameters = ["beta","NPI_mult","r_Sv","r_EI","r_IH_u","r_IH_v","r_HR","r_HD","r_IR_u","r_IR_v"]

def get_synthesized_vars(funman_request):
    return [p.name for p in funman_request.parameters if p.label == "all"]

def run(funman_request, plot=False):
    to_synthesize = get_synthesized_vars(funman_request)
    return Runner().run(
        MODEL_PATH,
        funman_request,
        description="SEIRHD Hackathon 6mo",
        case_out_dir=SAVED_RESULTS_DIR,
        dump_plot=plot,
        print_last_time=True,
        parameters_to_plot=to_synthesize
    )

def setup_common(funman_request, synthesize=False, debug=False):
    set_timepoints(funman_request, timepoints)
    if not synthesize:
        unset_all_labels(funman_request)
    set_config_options(funman_request, debug=debug)
    

def set_compartment_bounds(funman_request, upper_bound=9830000.0, error=0.01):
    # Add bounds to compartments
    for var in STATES:
        funman_request.constraints.append(StateVariableConstraint(name=f"{var}_bounds", variable=var, interval=Interval(lb=0, ub=upper_bound, closed_upper_bound=True),soft=False))

    # Add sum of compartments
    funman_request.constraints.append(LinearConstraint(name=f"compartment_bounds", variables=STATES, additive_bounds=Interval(lb=upper_bound-error, ub=upper_bound+error, closed_upper_bound=False), soft=True))

def relax_parameter_bounds(funman_request, factor = 0.1):
    # Relax parameter bounds
    parameters = funman_request.parameters
    for p in parameters:
        interval = p.interval
        width = float(interval.width())
        interval.lb = interval.lb - (factor/2 * width)
        interval.ub = interval.ub + (factor/2 * width)

def plot_last_point(results):
    pts = results.parameter_space.points() 
    print(f"{len(pts)} points")

    if len(pts) > 0:
        # Get a plot for last point
        df = results.dataframe(points=pts[-1:])
        ax = df[STATES].plot()
        ax.set_yscale("log")

def get_last_point_parameters(results):
    pts = results.parameter_space.points()
    if len(pts) > 0:
        pt = pts[-1]
        parameters = results.model._parameter_names()
        param_values = {k:v for k, v in pt.values.items() if k in parameters }
        return param_values

def pretty_print_request_params(params):
    # print(json.dump(params, indent=4))
    if len(params)>0:

        df = pd.DataFrame(params)
        print(df.T)


def report(results, name):
    plot_last_point(results)
    param_values = get_last_point_parameters(results)
    print(f"Point parameters: {param_values}")
    if param_values is not None:
        request_params[name] = param_values
    pretty_print_request_params(request_params)

In [None]:
# Find a single parameterization of the model that allows H <= 3000 for all timepoints.
# Do not assume the default compartmental constraints because model has vital dynamics.
# The chosen parameters result in an impossibly large value for some compartments.
# Further runs need to eliminate these impossible values.

funman_request = get_request()
setup_common(funman_request)
# funman_request.model.to_dot()
results = run(funman_request)
results.model.to_dot().write("SEIRHD-vacc-strat.dot")
# report(results, "unconstrained")

In [None]:
results.model.to_dot().render("SEIRHD-vacc-strat")


In [None]:
# Add bounds [0, N] to the STATE compartments.  
# Add bounds sum(STATE) in [N-e, N+e], for a small e.

funman_request = get_request()
setup_common(funman_request, debug=True)
set_compartment_bounds(funman_request)
results = run(funman_request)
report(results, "compartmental_constrained")

In [None]:
# Relax the bounds on the parameters to allow additional parameterizations

funman_request = get_request()
setup_common(funman_request)
set_compartment_bounds(funman_request)
relax_parameter_bounds(funman_request, factor = 0.75)
results = run(funman_request)
report(results, "relaxed_bounds")

In [None]:
funman_request = get_request()
setup_common(funman_request, synthesize=True)
set_compartment_bounds(funman_request)
# relax_parameter_bounds(funman_request, factor=0.75)
# funman_request.config.verbosity=10
results = run(funman_request, plot=True)
report(results, "synthesis")

In [None]:
# import pandas as pd
# df1 = pd.DataFrame({"ltp": ltp, "gtp": gtp})

# # df2 = 
# # df1.ltp.N == df1.gtp.N
# # df2.loc[df2].sort_index()[0:60]

# #(= H_10 (+ H_8 (* 2 (+ (* r_HD (- 1) H_8) (* r_HR (- 1) H_8) (* r_IH_v I_v_8) (* r_IH_u I_u_8)))))

# df1['same'] = df1['ltp'] == df1["gtp"]
# df1[0:20]
# # df1.loc[df1.index.str.endswith("_6")].sort_values(by="same")

In [None]:
# # Get points (trajectories generated)
# pts = results.parameter_space.points() 
# print(f"{len(pts)} points")

# # Get a plot for last point
# df = results.dataframe(points=pts[-1:])
# ax = df[STATES].plot()
# ax.set_yscale("log")


# # Get the values of the point
# gtp=pts[-1].values


# # Output the model diagram
# #
# # results_unconstrained_point.model.to_dot()
# # gtp