In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint, solve_ivp
from scipy.stats import truncnorm

# Function to get truncated normal random values
def get_truncated_normal(mean=0, sd=1, low=0, upp=10):
    return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)
#
# This version of the model accounts for dose responses in terms of both [A] and [I]
# this might be necessary to describe the dose response curves
#
#defining growth dynamics of cells producing Bla at different levels, including sensitive cells.
#
def growth(y, t, p, inh):  # Changed the order of arguments for odeint
#def growth(t, y, p, inh):  # the order is for solve_ivp
    n, b, a = np.maximum(y, 0) # ensure values are non-negative
    mumax, Ks, theta, Ln, kappab, phimax, gamma, betamin, db, c = p

    #mumax: max growth rate
    #Ks, theta: tune the shape of growth curves.
    #Nm: carrying capacity.
    #phimax: the maximum antibiotic degradation rate constant by cells (n). 
    #gamma: the maximum normalized lysis rate.
    #betamin (0,1): inversely correlates with Bla-mediated intracellular protection. Larger betamin corresponds to more sensitive cells. 
    #db: the maximum degradation rate constant of Bla mediated by the inhibitor.
    #c(0,1): the ability of the inhibitor to inhibit intracellular Bla. c=0 means that the inhibitor is unable to penetrate and inhibit intracellar Bla.
    #c = 0.146 (for clav), 0.0357 for sul, 0.0176 (Taz)

    db0 = 0.001       #basal level degration of Bla
    da0 = 0.001      #basal level degration of antibiotic
    ha = 3          #Hill coefficient of antibiotic-mediated killing.
    hi = 3          #Hill coefficient of inhibitor effect
    Ka = 1          #Michaelis-Menton constant for antibiotic degradation
    Ki = 15         #half activation concentratrion for inhibitor-mediated killing. 
                    #Assumed to be 10x higher than that required to inhibit Bla
    Nm = 3.0

    
    if not hasattr(growth, "initiate_lysis"):
        growth.initiate_lysis = False

    iota = (inh**hi) / (1 + inh**hi) if inh > 0 else 0
    beta = betamin + c * (1 - betamin) * iota 
    phi = phimax * (1 - c * iota) 
    
    g = (1/(1 + (n/(Nm * Ks))**theta)) * (1-(n/Nm))  if Ks>0 else 0
    # 
    # this implementation allows the crash and recovery dynamics.
    #
    l = 0
    if a > 0 or inh > 0:                #lysis can happen if either a or inh is high enough
        if not growth.initiate_lysis and n > Ln:        #initiate lysis only when n > Ln. 
           growth.initiate_lysis = True
        if growth.initiate_lysis:
            l = gamma * g * (a**ha + (inh/Ki)**hi)/(1 + a**ha + (inh/Ki)**hi)
    
    growth_rate = mumax * g * n
    lysis_rate = beta * l * n

    dndt = growth_rate - lysis_rate
    dbdt = lysis_rate - (db * iota + db0) * b  
    dadt = -(kappab * b  + phi * n ) * a/(Ka + a) - da0 * a                    #sensitive cells will have kappab and phimax being small
    return [dndt, dbdt, dadt]

# Simulation function
# Here the simulation window is chosen to align with the time span of the experimental data associated with the 311 isolates.
# This way, when evaluating performanace (in the subsequent objective function), we only need to consider cell density, instead of time axis.
def simulate_ivp(p, n0, a0, inh):
    b0 = 0
    y0 = [n0, b0, a0]
    t = np.linspace(0, 24, 145)
    growth.initiate_lysis=False
    sol = solve_ivp(growth, [t[0], t[-1]], y0, t_eval=t, args=(p, inh), method='LSODA')  #this appears to be 3-4 times faster than "BDF"
    return sol.y


# Simulation function. The following implementation uses odeint, which is faster than solve_ivp. 
# the limitation of odeint is that it uses only one method 'LSODA'.
def simulate(p, n0, a0, inh):
    b0 = 0
    y0 = [n0, b0, a0]
    t = np.linspace(0, 24, 145)
    growth.initiate_lysis = False
    sol = odeint(growth, y0, t, args=(p, inh))  # Using odeint here, which is about 20% faster than solve_ivp
    return sol.T  # Transpose because odeint returns shape (time_points, species)


# Objective function
# Here, simulated[0] and experimental[i, 0, :] select the "n" values from the simulated and experimental results, respectively. 
# The sum of squared differences is then calculated only for these "n" values. 
# Please note that we assume that the simulate function returns a 2D array with shape (3, time_points), 
# corresponding to the values of the four species at each time point.
#
def objective(p, experimental):
    error = 0
    initial_conditions = [(0, 0), (10, 0), (10, 10)]
    for i, (a0, inh) in enumerate(initial_conditions):
        n0 = experimental[i, 0]
        simulated = simulate(p, n0, a0, inh)
        error += np.sum((simulated[0] - experimental[i, :])**2)
    return error