In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint, solve_ivp
from scipy.stats import truncnorm

# Function to get truncated normal random values
def get_truncated_normal(mean=0, sd=1, low=0, upp=10):
    return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)
#
#simplify the Bla simulation to better reflect experimental data from clonal populations. Measured under three conditions.
# A = 0, inh=0
# A = 50, inh = 0
# A = 50, inh = 25
# For simplicity, assume that inhibition of Bla is max when inh is present.
#
#defining growth dynamics of cells producing Bla at different levels, including sensitive cells.
#
def growth(y, t, p, inh):
    n, b, a = np.maximum(y, 0) # ensure values are non-negative
    alpha, Ks, theta, Ln, kappab, phimax, gamma, betamin, db, c = p

    #alpha: max growth rate
    #Ks, theta: tune the shape of growth curves.
    #Nm: carrying capacity.
    #phimax: the maximum antibiotic degradation rate constant by cells (n). 
    #gamma: the maximum normalized lysis rate.
    #betamin (0,1): inversely correlates with Bla-mediated intracellular protection. Larger betamin corresponds to more sensitive cells. 
    #db: the maximum degradation rate constant of Bla mediated by the inhibitor.
    #c(0,1): the ability of the inhibitor to inhibit intracellular Bla. c=0 means that the inhibitor is unable to penetrate and inhibit intracellar Bla.
    
    db0 = 0.001       #basal level degration of Bla
    da0 = 0.001      #basal level degration of antibiotic
    ha = 3          #Hill coefficient of antibiotic-mediated killing.
    Nm = 3.0

    if not hasattr(growth, "initiate_lysis"):
        growth.initiate_lysis = False

    beta = betamin + c * (1 - betamin) if inh > 0 else betamin
    phi = phimax * (1 - c) if inh > 0 else phimax
    
    g = (Ks**theta/(Ks**theta + (n/Nm)**theta)) * (1-(n/Nm))  if Ks>0 else 0
    #g = 1-(n/Nm)**theta
    # 
    # this implementation allows the crash and recovery dynamics.
    #
    l = 0
    if a > 0: 
        if not growth.initiate_lysis and n > Ln:        #initiate lysis only when n > Ln. 
           growth.initiate_lysis = True
        if growth.initiate_lysis:
            l = gamma * g * a**ha/(1 + a**ha)            # antibiotic effect is assumed to be binary. Max lysis if a > MIC (=1)
    
    growth_rate = alpha * g * n
    lysis_rate = beta * l * n

    dndt = growth_rate - lysis_rate
    dbdt = lysis_rate - db * b if inh > 0 else lysis_rate - db0 * b  
    dadt = -(kappab * b  + phi * n + da0) * a                                 #sensitive cells will have kappab and phimax being 0.

    return [dndt, dbdt, dadt]

# Simulation function
# Here the simulation window is chosen to align with the time span of the experimental data associated with the 311 isolates.
# This way, when evaluating performanace (in the subsequent objective function), we only need to consider cell density, instead of time axis.
def simulate(p, n0, a0, inh):
    b0 = 0
    y0 = [n0, b0, a0]
    t = np.linspace(0, 24, 145)
    growth.initiate_lysis=False
    sol = odeint(growth, y0, t, args=(p, inh))
    return sol.T

# Objective function
# Here, simulated[0] and experimental[i, 0, :] select the "n" values from the simulated and experimental results, respectively. 
# The sum of squared differences is then calculated only for these "n" values. 
# Please note that we assume that the simulate function returns a 2D array with shape (3, time_points), 
# corresponding to the values of the four species at each time point.
#
def objective(p, experimental):
    error = 0
    initial_conditions = [(0, 0), (10, 0), (10, 10)]
    for i, (a0, inh) in enumerate(initial_conditions):
        n0 = experimental[i, 0]
        simulated = simulate(p, n0, a0, inh)
        error += np.sum((simulated[0] - experimental[i, :])**2)
    return error

#
# Simplified objective function to fit growth parameters only. Only evaluation drug-free growth.
#
def objective_drug_free(p, experimental):
    error = 0
    n0 = experimental[0, 0]
    simulated = simulate(p, n0, 0, 0)
    error += np.sum((simulated[0] - experimental[0, :])**2)
    return error
