In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
from functools import partial
import itertools
import os
import json
from decimal import Decimal
import time
from scipy.integrate import solve_ivp
import pandas as pd
import traceback
import copy
import concurrent.futures
import multiprocessing as mp
import sys
import uuid

from scipy.stats import skewnorm
from scipy.signal import chirp, find_peaks, peak_widths
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import InsetPosition
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)


%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams.update({'font.size': 13})

In [None]:
def sci_notation(number, sig_fig=2):
    form_string = "{0:.{1:d}e}".format(number, sig_fig)
    a,b = form_string.split("e")
    b = int(b)
    c = r"$10^{"+str(b)+r"}$"
    if float(number).is_integer() and number <= 10:
        return str(int(number))
    elif a == '1.00':
        return c
    elif float(a).is_integer():
        return str(int(float(a))) + r"$\cdot$" + c
    elif b == 0:
        return a
    else:
        return a + r"$\cdot$" + c

In [None]:
def axes_tick_formatter(self, MajorTicks, MinorTicks, axis = 'x'):
    if axis == 'x':
        self.xaxis.set_major_locator(MultipleLocator(MajorTicks))
        self.xaxis.set_minor_locator(MultipleLocator(MinorTicks))
    elif axis == 'y':
        self.yaxis.set_major_locator(MultipleLocator(MajorTicks))
        self.yaxis.set_minor_locator(MultipleLocator(MinorTicks))
    else:
        print("Algo mal hay")

In [None]:
def remove_zeros(a):
    
    if a[-1] != '0' and a[-1] != '.':
        return str(a)   
        
    else:
        c = str(a[:-1])
        return(remove_zeros(c))

def param_folder_notation(number, sig_fig=2):
    if number == 0.0:
        c = '0'
    
    else:    
        form_string = "{0:.{1:d}e}".format(number, sig_fig)
        a,b = form_string.split("e")
        a = remove_zeros(a)
        b = int(b)

        if float(a).is_integer():
            if b == 0:
                c = str(int(float(a)))
            else:
                c = str(int(float(a))) +'e'+ str(b)
        else:
            if b == 0:
                c = a
            else:
                c = a +'e'+ str(b)
        c = c.replace('-','m').replace('.','p')
    
    return c

In [None]:
def m_notch_epsin_coop(t, x, params):
    bl = params["bl"]
    dl = params["dl"]
    
    kr = params["kr"]
    krl = params["krl"]
    
    br = params["br"]
    dr = params["dr"]
    
    kb = params["kb"]
    ku = params["ku"]
    
    km = params["km"]
    n = params["n"]
    m = params["m"]
    
    
    ks2 = params["ks2"]
    kp = params["kp"]
    
    ki = params["ki"]
    ks3 = params["ks3"]
    ds = params["ds"]
    
    KG = params["KG"]
    TN = params["TN"]
    CN = params["CN"]
    
    dG = params["dG"]
    dm = params["dm"]

    Ecte = params["Ecte"]
    
    R = x[0]
    L = x[1]
    E = x[2]
    C = x[3]
    CL = x[4]
    CLi = x[5]
    CR = x[6]
    S = x[7]
    G4 = x[8]
    Cit = x[9]

    
    Dx = np.array([
        br - dr * R - kb * (R * L) + ku * C, #R
        bl - dl * L - kb * (R * L) + ku * C + n * krl * kr * CLi , #L
        Ecte * (m * kr * CLi - m * ks2 * km ** (1-m) * C ** n * E ** m) , #E
        kb * R * L - ku * C - kp * C - n * ks2 * km ** (1-m) * C ** n * E ** m, #C
        ks2 * km ** (1-m) * C ** n * E ** m - ki * CL, #CL
        ki * CL - kr * CLi, #CLi
        ks2 * km ** (1-m) * C ** n * E ** m - ks3 * CR, #CR
        n * ks3 * CR - ds * S, #S
        n * ks3 * CR - dG * G4, #G4
        CN * TN * G4 / (G4 + KG) - dm * Cit #Cit
        
    ])
    
    return Dx

In [None]:
def x_interpolation(r, q, y):
    if r[0] >= q[0]:
        p_0 = q
        p_1 = r
    else:
        p_0 = r
        p_1 = q

    x = p_0[0] + (y - p_0[1]) / (p_1[1] - p_0[1]) * (p_1[0] - p_0[0])

    return x


def peak_values_old (np_input_x, np_input_y, eps = 0.1):
    max_value = max(np_input_y)
    max_value_list = np.where(np_input_y == max_value)
    max_value_index = int(np.round(np.mean(max_value_list[0])))

    end_value = np_input_y[-1]

    t2 = 0
    t1 = 0

    if end_value != max_value and ((max_value - end_value)/max_value) > eps:

        half_max = end_value + 0.5 * (max_value - end_value)

        mvi_left = max_value_index - 1
        mvi_right = max_value_index + 1

        while np_input_y[mvi_left] > half_max:
            mvi_left -= 1

        while np_input_y[mvi_right] > half_max:
            mvi_right += 1

        r1 = [np_input_x[mvi_left], np_input_y[mvi_left]]
        q1 = [np_input_x[mvi_left + 1], np_input_y[mvi_left + 1]]

        t1 = x_interpolation(r1, q1, half_max)

        r2 = [np_input_x[mvi_right], np_input_y[mvi_right]]
        q2 = [np_input_x[mvi_right - 1], np_input_y[mvi_right - 1]]

        t2 = x_interpolation(r2, q2, half_max)
        
    max_12 = 0.5 * max_value
    
    k = 0
    
    while np_input_y[k] < max_12:
        k += 1
        
    r12 = [np_input_x[k], np_input_y[k]]
    q12 = [np_input_x[k - 1], np_input_y[k - 1]]
    
    t12 = x_interpolation(r12, q12, max_12)
    
    t_max = np_input_x[max_value_index]

    return (t1, t2, max_value, end_value, t12, t_max)


def peak_values (np_input_x, np_input_y, eps = 0.1):
    max_value = max(np_input_y)
    max_value_list = np.where(np_input_y == max_value)
    max_value_index = int(np.round(np.mean(max_value_list[0])))

    end_value = np_input_y[-1]

    t2 = 0
    t1 = 0

    if end_value < 0.5 * max_value:

        half_max = 0.5 * max_value

        mvi_left = max_value_index - 1
        mvi_right = max_value_index + 1

        while np_input_y[mvi_left] > half_max:
            mvi_left -= 1

        while np_input_y[mvi_right] > half_max:
            mvi_right += 1

        r1 = [np_input_x[mvi_left], np_input_y[mvi_left]]
        q1 = [np_input_x[mvi_left + 1], np_input_y[mvi_left + 1]]

        t1 = x_interpolation(r1, q1, half_max)

        r2 = [np_input_x[mvi_right], np_input_y[mvi_right]]
        q2 = [np_input_x[mvi_right - 1], np_input_y[mvi_right - 1]]

        t2 = x_interpolation(r2, q2, half_max)
        
    max_12 = 0.5 * max_value
    
    k = 0
    
    while np_input_y[k] < max_12:
        k += 1
        
    r12 = [np_input_x[k], np_input_y[k]]
    q12 = [np_input_x[k - 1], np_input_y[k - 1]]
    
    t12 = x_interpolation(r12, q12, max_12)
    
    t_max = np_input_x[max_value_index]

    return (t1, t2, max_value, end_value, t12, t_max)

def HS_f (a,b,eps):
    if a/b>eps:
        return 1.
    else:
        return 0.

In [None]:
def f_RS (peak):
    max_val = peak[2]
    end_val = peak[3]
    
    H = max_val - end_val
    
    return (H/end_val)

def f_RSM (peak):
    max_val = peak[2]
    end_val = peak[3]
    
    H = max_val - end_val
    
    return (H/max_val)

def f_NRS (peak, params):
    br = params['br']
    ds = params['ds']
    
    max_val = peak[2]
    end_val = peak[3]
    
    H = max_val - end_val
    S0 = br/ds
    
    return (H/S0)

def f_DT (peak, eps):

    t1 = peak[0]
    t2 = peak[1]
    
    max_val = peak[2]
    end_val = peak[3]
    
    HS = HS_f(max_val-end_val, max_val, eps)
    
    return  (t2 - t1)

def f_RT (peak, params, eps):
    ds = params['ds']

    t1 = peak[0]
    t2 = peak[1]
    
    max_val = peak[2]
    end_val = peak[3]
    
    HS = HS_f(max_val-end_val, max_val, eps)
    
    return (HS * (t2 - t1) * ds)

def f_GT(peak, mu, sd):
    t1 = peak[0]
    t2 = peak[1]
    
    dt = t2-t1
    
    return np.exp(-((mu-dt)/sd)**2)

def f_GS(rs, mu, sd):
    return np.exp(-((mu-rs)/sd)**2)

def f_GT12(peak, mu, sd):
    t12 = peak[4]
    
    return np.exp(-((mu-t12)/sd)**2)

In [None]:
def solve_resolution(KI, KJ, PPNAMES, PPVALUES, X0, PARAMS, variable, p_gauss, model = m_notch_epsin_coop, param_iter = [],
                    time_l = [0,25*60], method = 'Radau'):
    k_i = KI
    k_j = KJ
    p_p_names = PPNAMES
    p_p_values = PPVALUES
    x_0 = X0
    params = PARAMS
    
    p_p_values_iter = list(itertools.product(*p_p_values))
    
    if param_iter == []:
        param_iter = []
        param_iter_dic = []
        for  e in p_p_values_iter:
            param_i = dict(params)
            param_i.update(zip(p_p_names, e))
            param_iter.append(tuple([dict(param_i)]))
            param_iter_dic.append(dict(param_i))
            
    else:
        param_iter_dic = []
        for  e in param_iter:
            param_iter_dic.append(dict(e[0]))
            
    if type(x_0[0])!=list:
        partial_iter = partial(par_solve_ivp, model, time_l, x_0, method)
        sol = [partial_iter(e) for e in param_iter]
        
    else:
        partial_iter = partial(par_solve_ivp, model, time_l)
        sol = [partial_iter(f,method,e) for e,f in zip(param_iter,x_0)]
    
    T_H = [e.t / 60. for e in sol]
    
    res = [peak_values(T_H[e], sol[e].y[variable]) for e in range(len(sol))]
    s_max = [max(sol[e].y[variable]) for e in range(len(sol))]
    s_stat = [sol[e].y[variable][-1] for e in range(len(sol))]
    table_params = [[p_p_values_iter[j] for j in range(i*len(k_j),(i+1)*len(k_j))]for i in range(len(k_i))]
    table_params_dict = [[param_iter_dic[j] for j in range(i*len(k_j),(i+1)*len(k_j))]for i in range(len(k_i))]
    table_values =  [[res[j] for j in range(i*len(k_j),(i+1)*len(k_j))]for i in range(len(k_i))]
    SMax =  [[s_max[j] for j in range(i*len(k_j),(i+1)*len(k_j))]for i in range(len(k_i))]
    Sstat =  [[s_stat[j] for j in range(i*len(k_j),(i+1)*len(k_j))]for i in range(len(k_i))]
    

    len_i = len(table_values)
    len_j = len(table_values[0])
    
    mut = p_gauss["mut"]
    sdt = p_gauss["sdt"]
    
    mut12 = p_gauss["mut12"]
    sdt12 = p_gauss["sdt12"]

    mus = p_gauss["mus"]
    sds = p_gauss["sds"]

    eps = p_gauss["eps"]

    RSM = [[f_RSM(tvij) for tvij in tvi] for tvi in table_values]

    RT = [[f_RT(table_values[i][j], table_params_dict[i][j], eps) for j in
          range(len_j)] for i in range(len_i)]
    
    RM = [[RT[i][j] * RSM[i][j] for j in range(len_j)] for i in range(len_i)]

    GSM = [[f_GS(RSM[i][j], mus, sds) for j in range(len_j)] for i in
          range(len_i)]

    GT = [[f_GT(tvij, mut, sdt) for tvij in tvi] for tvi in table_values]

    GM = [[GT[i][j] * GSM[i][j] for j in range(len_j)] for i in range(len_i)]
    
    RT12 = [[tvij[4] for tvij in tvi] for tvi in table_values]
    
    GT12 = [[f_GT12(tvij, mut12, sdt12) for tvij in tvi] for tvi in table_values]
    
    DT = [[f_DT(tvij, eps) for tvij in tvi] for tvi in table_values]
    
    TMAX = [[tvij[5] for tvij in tvi] for tvi in table_values]
    
    return [RSM, RT, RM, GSM, GT, GM, RT12, GT12, SMax, Sstat, DT, TMAX]

In [None]:
def is_zero(a):
    return(1 if int(a) == int(0) else 0)
is_zero_vec = np.vectorize(is_zero)

contour_values_index=[0]
contour_levels=[[0.7,0.8]]
contour_colours=[["#36454F","#808080"]]


In [None]:
def rk4(f,x,t,dt,params):
    
    k1 = f(t, x, params)
    k2 = f(t, x + k1*dt/2, params)
    k3 = f(t, x + k2*dt/2, params)    
    k4 = f(t, x + k3*dt, params)
    
    k = dt * (k1 + 2*k2 + 2*k3 + k4)/6
    
    return (x + k)

In [None]:
def par_solve_ivp(model, time_l, x_0, meth, arg):
    return solve_ivp(model, time_l, x_0, method = meth, args=arg)

In [None]:
mmss=['o','^','+','s','d','+']
llss = ["-", ":", "-.", "--"]

variables = ["R","L","E","C",r"$C_L$",r"$C_{Li}$",r"$C_R$","S","Gal4","Cit"]
variables_f = ["R","L","E","C",r"C_L",r"C_Li",r"C_R","S","Gal4","Cit"]
ligands = ["Dll1", r"$Dll4_{ECD}$"+"\n"+r"$Dll1_{ICD}$", "Dll4", r"$Dll1_{ECD}$"+"\n"+r"$Dll4_{ICD}$"]
ligands_f = ["Dll1", "1I4E", "Dll4", "4I1E"]
cm_variables = ["#194569","#800020","#ff7f0e","#9467bd","#c20e35","#fa8fad","#1f77b4","#80bce6","#c8f1c8","#77dd77"]
cm_ligands = ["#194569","#ff7f0e","#800020","#9467bd"]
color_bars = [cm_variables[5],cm_variables[4],cm_variables[1]]


p_gauss_S = {"mut": 0.8, "sdt": 0.2, "mut12": 0.1, "sdt12":0.02, "mus": 0.9, "sds": 0.15, "eps": 0.1}
p_gauss_Citrine = {"mut": 9.5, "sdt": 4.5, "mut12": 1.5, "sdt12":0.5, "mus": 0.8, "sds": 0.15, "eps": 0.1}

range_vals = [[0.,1.], 1, [0.,1.], [0.,1.], [0.,1.], [0.,1.], 1, [0.,1.], 0, 0, 0]


In [None]:
t12_G = 3.9
t12_m = 3.4

params_base = {"bl": 2e0, 
    "br": 6e0, 
    "dl": 2e-3, 
    "dr": 2e-3,
    "kb": 2e-4,
    "ku": 2e2, 
    "kp": 5.,
    "ks2": 1.,
    "kr": 1e-4,
#     "kr": 2e-3,
    'krl':0.,
    'km':1e3,
    "ki": 1.5e1, 
    "ks3": 1e-1, 
    "ds": 8e-3,
    "n":1,
    "m":1,
    "KG": 6.6 * 0.1 * 1200., 
    "t12_G": t12_G,
    "t12_m": t12_m,
    "dG": np.log(2) / (60. * t12_G),
    "dm": np.log(2) / (60. * t12_m),
    "E0": 7e2,
    "TN": 1/60.,
    "CN": 500.,
    "Ecte":1} 

t0 = 0.
tf = 25. * 60.
time_l = [t0,tf]

c_max_11 = 0.68
c_max_41 = 0.59

c_max_44 = 1.4
c_max_14 = 3.6



In [None]:
main_dir = os.getcwd() 
os.chdir(main_dir)


# Fig 1

In [None]:
filename0 = 'Fig1'
dir = main_dir + '/' + filename0
if os.path.isdir(dir) == False:
    os.makedirs(dir)
os.chdir(dir)


## Panel B

In [None]:

fig_transient=plt.figure()
ax_transient=fig_transient.add_subplot(111, label="1")
ax_sustained=fig_transient.add_subplot(111, label="2", frame_on=False)


a = 7
mean, var, skew, kurt = skewnorm.stats(a, moments='mvsk')
x = np.linspace(skewnorm.ppf(0.0001, a),
                skewnorm.ppf(0.999, a), 100)

peaks, _ = find_peaks(skewnorm.pdf(x, a))
results_half = peak_widths(skewnorm.pdf(x, a), peaks, rel_height=0.5)

X = np.logspace(-3.,5.,int(1e3))
Y1 = 1*X/(0.01+1*X)
ax_sustained.plot(X,Y1, label = "Dll4", lw = 3, color="#386261", zorder = 1)
ax_sustained.set_xscale('log')

ax_transient.arrow(0.0,0.37616097,1.10,0, head_width=0.03, head_length=0.1, color="orange", zorder = 1)
ax_transient.arrow(1.23,0.37616097,-1.12,0, head_width=0.03, head_length=0.1, color="orange", zorder = 1)
ax_transient.plot(x, skewnorm.pdf(x, a),label = "Dll1", lw = 3,color = '#C3B1E1', zorder = 5)

ax_sustained.set_yticks([])
ax_sustained.set_xticks([])

lines_1, labels_1 = ax_transient.get_legend_handles_labels()
lines_2, labels_2 = ax_sustained.get_legend_handles_labels()

lines = lines_1 + lines_2
labels = labels_1 + labels_2
fig_transient.legend(lines, labels, loc = (0.65,0.68), fontsize = 15)

ax_transient.set_yticks([])
ax_transient.set_xticks([])
ax_transient.text(0.22,0.32,r"$\Delta t \approx 13$ h", fontsize = 15)
ax_transient.set_ylabel('Activity', fontsize = 15)
ax_transient.set_xlabel('Time', fontsize = 15)

ax_sustained.set_xticks([])
ax_sustained.set_xticks([], minor = True)

fig_transient.savefig(filename0 + "_B.png", bbox_inches="tight", dpi = 300, transparent = True)


## Panel C

In [None]:
fig_lebon = plt.figure()
ax_lebon = fig_lebon.add_subplot(111, label="1")

X = np.logspace(-3.,5.,int(1e3))
Y1 = 0.8*0.8*X**2/(1+0.8*X**2)
Y2 = 0.2*0.8*X**2/(100+0.8*X**2)

ax_lebon.plot(X,Y1,lw = 3, color = '#800020')
ax_lebon.plot(X,Y2,lw = 3, color = '#72BEB7')
ax_lebon.set_xscale('log')


ax_lebon.set_ylim([-0.05,1])

ax_lebon.set_ylabel('Activity', fontsize = 15)
ax_lebon.set_xlabel('Ligand Concentration', fontsize = 15)

# ax_lebon.legend(['Ligand 1', 'Ligand 2'], loc =(0.6,0.45), fontsize = 15)
ax_lebon.legend(['Ligand 1', 'Ligand 2'], loc =(0.55,0.45), fontsize = 15)


X = np.logspace(-3.,5.,int(1e3))
Y1 = 0.8*X**2/(1+0.8*X**2)
Y2 = 0.8*X**2/(100+0.8*X**2)

axins = inset_axes(ax_lebon, width=1.7, height=1.)

ip = InsetPosition(ax_lebon, [0.08, 0.6,0.3,0.3]) #posx, posy, width, height
axins.set_axes_locator(ip)

axins.hlines(0.5, 0.001, 100000., color = 'orange', lw = 1, linestyles = 'dashed')
axins.vlines(1.15, 0.0, 1., color = '#800020', lw = 1, linestyles = 'dashed')
axins.vlines(12.1, 0.0, 1., color = '#72BEB7', lw = 1, linestyles = 'dashed')
axins.plot(X,Y1,lw = 2, color = '#800020')
axins.plot(X,Y2,lw = 2, color = '#72BEB7')
axins.set_xscale('log')

axins.set_yticks([])
axins.set_xticks([])

axins.set(ylabel='Normalised Activity')
axins.set(xlabel='Ligand Concentration')

ax_lebon.set_yticks([])
ax_lebon.set_xticks([])
ax_lebon.set_xticks([], minor = True)

fig_lebon.savefig(filename0 + "_C.png", bbox_inches="tight", dpi = 300, transparent = True)


## Panel D

In [None]:

param_iter = []
list_dicts = []

params = dict(params_base)
params.update({'m':5,'n':5, 'E0':3e3})
params.update({'Label': r'$k_m = $' + sci_notation(params['km']) 
               + '\n' + r'$E_0 = $' + sci_notation(params['E0'])})
list_dicts.append(params)


k_i = np.logspace(-3,2.5,300)

p_p_names = ["bl"]
p_p_values = [k_i]

p_p_values_iter = list(itertools.product(*p_p_values))

p_p_values_iter.sort(key=lambda a: a[0])

x_0 = [] 

for j in range(len(list_dicts)):
    params = dict(list_dicts[j])
    for  e in p_p_values_iter:
        param_i = dict(params)
        param_i.update(zip(p_p_names, e))
        L0 = round(param_i["bl"]/param_i["dl"],2)
        if j == 1:
            L0 = round(param_i["bl"]/param_i["dl"],2)
        
        x_0_i = [round(param_i["br"]/param_i["dr"],2), 
           L0,
           round(param_i["E0"],2),
           0., 0., 0., 0., 0., 0., 0.]
        param_i.update({'x0': x_0_i})
        param_iter.append(tuple([dict(param_i)]))
        x_0.append(x_0_i)


In [None]:
    
partial_iter = partial(par_solve_ivp, m_notch_epsin_coop, time_l)

ncpus = None
# ncpus = mp.cpu_count() - 2

# def partial_iter_mp(param_iter):
#     X_0 = param_iter[0]['x0']
#     return partial_iter(X_0,'Radau',param_iter)

# with mp.Pool(ncpus) as pool:
#     sol = pool.map(partial_iter_mp, param_iter)

    
partial_iter = partial(par_solve_ivp, m_notch_epsin_coop, time_l)
sol = [partial_iter(f,'Radau',e) for e,f in zip(param_iter,x_0)]


In [None]:
vv = 9
cit_stat = [e.y[vv][-1] for e in sol]
cit_max = [max(e.y[vv]) for e in sol]


In [None]:

fig,ax = plt.subplots()

i = 0
start = 90
end = 205
ax.plot(k_i[start:end], cit_stat[i*len(k_i):(i+1)*len(k_i)][start:end],color = cm_variables[0], label = 'Dll1')

ax.legend()

ax.set_xscale('log')
ax.set_yticks([])
ax.set_xticks([])
ax.set_xticks([], minor = True)

ax.set_ylabel('Activity', fontsize = 15)
ax.set_xlabel('Ligand Concentration', fontsize = 15)

fig.savefig('Fig1_D.png', bbox_inches="tight", dpi = 300, transparent = True)


# Fig4

In [None]:
filename0 = 'Fig4'
dir = main_dir + '/' + filename0
if os.path.isdir(dir) == False:
    os.makedirs(dir)
os.chdir(dir)


# Panels A & B

In [None]:

param_iter = []
list_dicts = []

params_T1 = dict(params_base)
params_T1.update({"E0":7e2, 'n':5, 'm':5, "kr":1e-4, 'Case': 'T1_n5m5'})
params_T1_km1 = dict(params_base)
params_T1_km1.update({"E0":7e2, 'n':5, 'm':5, "kr":1e-4,'km':1, 'Case': 'T1_n5m5_km1'})
params_S1 = dict(params_base)
params_S1.update({"E0":7e2, 'n':5, 'm':5, "kr":2e-3, 'Case': 'S1_n5m5'})
params_S1_km1 = dict(params_base)
params_S1_km1.update({"E0":7e2, 'n':5, 'm':5, "kr":2e-3,'km':1, 'Case': 'S1_n5m5_km1'})
params_S3 = dict(params_base)
params_S3.update({"E0":3e3, 'n':5, 'm':5, "kr":2e-3, 'Case': 'S3_n5m5'})

list_dicts.extend([params_T1_km1,params_T1])
figs = ['A', 'B']

k_i = [1,2,4]

p_p_names = ["bl"]
p_p_values = [k_i]

p_p_values_iter = list(itertools.product(*p_p_values))

p_p_values_iter.sort(key=lambda a: a[0])

for j in range(len(list_dicts)):
    params = dict(list_dicts[j])
    for  e in p_p_values_iter:
        param_i = dict(params)
        param_i.update(zip(p_p_names, e))
        param_iter.append(tuple([dict(param_i)]))

x_0 = [[round(e[0]["br"]/e[0]["dr"],2), 
       round(e[0]["bl"]/e[0]["dl"],2),
       round(e[0]["E0"],2),
       0., 0., 0., 0., 0., 0., 0.] for e in param_iter]


In [None]:
partial_iter = partial(par_solve_ivp, m_notch_epsin_coop, time_l)
sol = [partial_iter(f,'Radau',e) for e,f in zip(param_iter,x_0)]

T_H = [e.t/60. for e in sol]


In [None]:
for j in range(len(list_dicts)):
    fig, ax = plt.subplots()
    for i in range(len(k_i)):
        ax.plot(T_H[j*len(k_i)+i], sol[j*len(k_i)+i].y[9],c = color_bars[i])
        ax.plot(T_H[j*len(k_i)+i], sol[j*len(k_i)+i].y[7], ls = '--',c = color_bars[i])

    llss = ['-','--']
    labels2 = [r"$Cit$",r"$S$"]
    h2 = [plt.plot([],[], color='#161616', ls=llss[j])[0] for j in range(len(llss))]
    leg2 = ax.legend(handles=h2, labels=labels2, fontsize = 12, facecolor='white', framealpha=1, loc = (0.8,0.83))
    
    labels1 = [r'$\beta_L$ = ' + str(int(k_i[i])) for i in range(len(k_i))]
    h1 = [plt.plot([],[], color=color_bars[j], ls=llss[0])[0] for j in range(len(k_i))]
    leg1= ax.legend(handles=h1, labels=labels1, fontsize = 12, facecolor='white', framealpha=1, loc = (0.743,0.58))

    ax.add_artist(leg1)
    ax.add_artist(leg2)

    ax.set_ylabel("Cit, S [a.u.]", fontsize=15)
    
    ax.set_xlabel("t [h]", fontsize=15)
    
    axes_tick_formatter(ax, 5, 2.5)
    ax.yaxis.set_minor_locator(AutoMinorLocator(n=2))

    filename = filename0 + figs[j]
    
    fig.savefig(filename + '.png', bbox_inches="tight", dpi = 300, transparent = True)


# Panel C

In [None]:
params = dict(params_base)
params.update({'n':6,'m':6})

k_i = np.logspace(-5,-2,50)
k_j = np.logspace(0,4,50)

p_p_names = ["kr","E0"]
p_p_values = [k_i, k_j]

p_p_values_iter = list(itertools.product(*p_p_values))

p_p_values_iter.sort(key=lambda a: a[0])

XLabel = r"$k_r\ [min^{-1}]$"
YLabel = r"$E_0$ [a.u.]"
save=1

param_iter = []

for  e in p_p_values_iter:
    param_i = dict(params)
    param_i.update(zip(p_p_names, e))
    param_iter.append(tuple([dict(param_i)]))


x_0 = [[round(params["br"]/params["dr"],2), 
       round(params["bl"]/params["dl"],2),
       round(e[1],2),
       0., 0., 0., 0., 0., 0., 0.] for e in p_p_values_iter]


In [None]:

maps_Citrine = solve_resolution(k_i, k_j, p_p_names, p_p_values, x_0, params, 9, p_gauss_Citrine)


In [None]:
fig, ax = plt.subplots()

cm = copy.copy(mpl.cm.get_cmap('Blues'))
cm.set_bad(color= '#d3d3d3')

xmin = min(k_i)
ymin = min(k_j)
xmax = max(k_i)+1
ymax = max(k_j)+1

dx = 1
dy = 1

X, Y = np.meshgrid(k_i,k_j, indexing = 'ij')

Z = maps_Citrine[0]

VMIN = round(min([min([Z[i][j] for j in range(len(k_j)) if isinstance(Z[0][j], float) ]) for i in range(len(k_i))]),1)
VMAX = round(max([max([Z[i][j] for j in range(len(k_j)) if isinstance(Z[0][j], float) ]) for i in range(len(k_i))]),1)

TICKS =  np.around(np.linspace(VMIN,VMAX, 5),2)

c = ax.pcolormesh(X, Y ,Z, cmap=cm, vmin = VMIN, vmax = VMAX)

CS =ax.contour(X, Y, Z, levels = [.7,.8], colors=["#36454F","#808080"])

cbar = fig.colorbar(c, ticks = TICKS)
cbar.set_label(r"$\rho$", fontsize = 13)

# cbar.ax.set_yticklabels(TICKS)

ax.set_yscale('log')
ax.set_xscale('log')

ax.set_xlabel(XLabel, fontsize = 13)
ax.set_ylabel(YLabel, fontsize = 13)


ax.set_box_aspect(1)

fig.tight_layout()

fig.savefig(filename0 + '_C.png', bbox_inches="tight", dpi = 300, transparent = True)


# Panel D

In [None]:

params = dict(params_base)
params.update({'kr':2e-3})

k_i = [i+1 for i in range(15)]

k_j = np.logspace(-2,2,50)
p_p_names = ["n","kp"]

p_p_values = [k_i, k_j]

p_p_values_iter = list(itertools.product(*p_p_values))

p_p_values_iter.sort(key=lambda a: a[0])

XLabel = r"n=m"
YLabel = r"$k_{P}\ [min^{-1}]$"

param_iter = []

for  e in p_p_values_iter:
    param_i = dict(params)
    param_i.update(zip(p_p_names, e))
    param_iter.append(tuple([dict(param_i)]))

save = 1

for e in param_iter:
    N = e[0]["n"]
    e[0].update({"m": N})
    
x_0 = [round(params["br"]/params["dr"],2), 
       round(params["bl"]/params["dl"],2),
       round(params["E0"],2),
       0., 0., 0., 0., 0., 0., 0.]


In [None]:

maps_Citrine = solve_resolution(k_i, k_j, p_p_names, p_p_values, x_0, params, 9, p_gauss_Citrine, param_iter = param_iter)


In [None]:
fig, ax = plt.subplots()

cm = copy.copy(mpl.cm.get_cmap('Blues'))
cm.set_bad(color= '#d3d3d3')

xmin = min(k_i)
ymin = min(k_j)
xmax = max(k_i)+1
ymax = max(k_j)+1

dx = 1
dy = 1

X, Y = np.meshgrid(k_i,k_j, indexing = 'ij')

Z = maps_Citrine[0]

VMIN = round(min([min([Z[i][j] for j in range(len(k_j)) if isinstance(Z[0][j], float) ]) for i in range(len(k_i))]),1)
VMAX = round(max([max([Z[i][j] for j in range(len(k_j)) if isinstance(Z[0][j], float) ]) for i in range(len(k_i))]),1)

TICKS =  np.around(np.linspace(VMIN,VMAX, 5),2)

c = ax.pcolormesh(X, Y ,Z, cmap=cm, vmin = VMIN, vmax = VMAX)

CS =ax.contour(X, Y, Z, levels = [.7,.8], colors=["#36454F","#808080"])

cbar = fig.colorbar(c, ticks = TICKS)
cbar.set_label(r"$\rho$", fontsize = 13)

# cbar.ax.set_yticklabels(TICKS)

ax.set_yscale('log')

ax.set_xlabel(XLabel, fontsize = 15)
ax.set_ylabel(YLabel, fontsize = 15)

ax.text(1.5,2e1,r'$C_1$',color='#DC582A')
ax.scatter(1,2e1, marker = '+',color='#DC582A')
ax.text(1.5,2e-1,r'$C_2$',color='#DC582A')
ax.scatter(1,2e-1, marker = '+',color='#DC582A', zorder = 3)
ax.text(6.5,2e1, r'$C_3$',color='#DC582A')
ax.scatter(6,2e1, marker = '+',color='#DC582A', zorder = 3)
ax.text(6.5,2e-1,r'$C_4$',color='#DC582A')
ax.scatter(6,2e-1, marker = '+',color='#DC582A', zorder = 3)


ax.set_xticks(np.arange(xmin,xmax,dx))

ax.set_xlim([0,16])
axes_tick_formatter(ax, 5, 1)

ax.set_box_aspect(1)

fig.tight_layout()

fig.savefig(filename0 + '_D.png', bbox_inches="tight", dpi = 300, transparent = True)


# Panel E

In [None]:
params = dict(params_base)

k_i = [1,2,4]

param_iter_C = []
list_dics_C = []

list_dics_C.append({'kr':2e-3,'kp': 2e1, 'n': 1, 'm': 1, 'Case':r'$C_1$'})
list_dics_C.append({'kr':2e-3,'kp': 2e-1, 'n': 1, 'm': 1, 'Case':r'$C_2$'})
list_dics_C.append({'kr':2e-3,'kp': 2e1, 'n': 6, 'm': 6, 'Case':r'$C_3$'})
list_dics_C.append({'kr':2e-3,'kp': 2e-1, 'n': 6, 'm': 6, 'Case':r'$C_4$'})
        
for e in list_dics_C:
    for i in k_i:
        param_i = dict(params)
        param_i.update(e)
        param_i.update({'bl':i})
        param_iter_C.append(tuple([dict(param_i)]))
    
labels_C = [e[0]["Case"] for e in param_iter_C]

x_0_C = [[round(params_base["br"]/params_base["dr"],2), 
       round(e[0]["bl"]/params_base["dl"],2),
       round(e[0]["E0"],2),
       0., 0., 0., 0., 0., 0., 0.] for e in param_iter_C]


In [None]:
partial_iter = partial(par_solve_ivp, m_notch_epsin_coop, time_l)

sol_C = [partial_iter(f,'Radau',e) for e,f in zip(param_iter_C,x_0_C)]
T_H_C = [e.t/60. for e in sol_C]

In [None]:

Cit_Max_C = [max(e.y[9]) for e in sol_C]

data_C = [[param_iter_C[i][0]['bl'], param_iter_C[i][0]['E0'], param_iter_C[i][0]['kr'], param_iter_C[i][0]['Case'],Cit_Max_C[i]] for i in range(len(Cit_Max_C))]
columns_C = ['bl','E0','kr', 'Case','Cit_Max']
df_C = pd.DataFrame(data_C,columns = columns_C)


In [None]:

fig,ax = plt.subplots()

df_C.pivot(index='Case',columns='bl',values='Cit_Max').plot(kind='bar', color=color_bars, rot = 0, ax = ax)
ax.legend([r"$\beta_L = $" +str(i) for i in k_i ], loc = 4, facecolor = 'white', framealpha=1)
ax.set_ylabel(r'$Cit_{Max}$ [a.u]')
ax.set_xlabel(r'')
axes_tick_formatter(ax, 200, 100, axis = 'y')

fig.tight_layout()

fig.savefig(filename0 + '_E.png', bbox_inches="tight", dpi = 300, transparent = True)


# Panel F

In [None]:

param_iter = []
list_dicts = []

params_0 = dict(params_base)
params_0.update({'kr':2e-3,'kp': 2e1, 'n': 6, 'm': 6})

params = dict(params_0)
params.update({'Label': r'$E = E(t)$'})
list_dicts.append(params)

params = dict(params_0)
params.update({'Ecte':0,'Label': r'$E = E_0$'})
list_dicts.append(params)


k_i = np.logspace(-2,3,300)

p_p_names = ["bl"]
p_p_values = [k_i]

p_p_values_iter = list(itertools.product(*p_p_values))

p_p_values_iter.sort(key=lambda a: a[0])

x_0 = [] 

for j in range(len(list_dicts)):
    params = dict(list_dicts[j])
    for  e in p_p_values_iter:
        param_i = dict(params)
        param_i.update(zip(p_p_names, e))
        L0 = round(param_i["bl"]/param_i["dl"],2)
        if j == 1:
            L0 = round(param_i["bl"]/param_i["dl"],2)
        
        x_0_i = [round(param_i["br"]/param_i["dr"],2), 
           L0,
           round(param_i["E0"],2),
           0., 0., 0., 0., 0., 0., 0.]
        param_i.update({'x0': x_0_i})
        param_iter.append(tuple([dict(param_i)]))
        x_0.append(x_0_i)


In [None]:

partial_iter = partial(par_solve_ivp, m_notch_epsin_coop, time_l)

ncpus = None
# ncpus = mp.cpu_count() - 2

# def partial_iter_mp(param_iter):
#     X_0 = param_iter[0]['x0']
#     return partial_iter(X_0,'Radau',param_iter)

# with mp.Pool(ncpus) as pool:
#     sol = pool.map(partial_iter_mp, param_iter)

    
partial_iter = partial(par_solve_ivp, m_notch_epsin_coop, time_l)
sol = [partial_iter(f,'Radau',e) for e,f in zip(param_iter,x_0)]


In [None]:
vv = 9
cit_stat = [e.y[vv][-1] for e in sol]
cit_max = [max(e.y[vv]) for e in sol]


In [None]:

fig,ax = plt.subplots()

for i in range(len(list_dicts)):
    ax.plot(k_i, cit_stat[i*len(k_i):(i+1)*len(k_i)],color = cm_variables[i])
    ax.plot(k_i, cit_max[i*len(k_i):(i+1)*len(k_i)],color = cm_variables[i], ls = '--')
       
llss = ['-','--']
labels2 = [variables[vv] + r"$_{Stat}$", variables[vv] + r"$_{Max}$"]
h2 = [plt.plot([],[], color='#161616', ls=llss[j])[0] for j in range(len(llss))]
leg2 = ax.legend(handles=h2, labels=labels2, fontsize = 12, facecolor='white', framealpha=1, loc = (0.036,0.79))

labels1 = [list_dicts[i]['Label'] for i in range(len(list_dicts))]
h1 = [plt.plot([],[], color=cm_variables[j], ls=llss[0])[0] for j in range(len(list_dicts))]
leg1= ax.legend(handles=h1, labels=labels1, fontsize = 12, facecolor='white', framealpha=1, loc = (0.036,0.6))   

ax.add_artist(leg1)
ax.add_artist(leg2)

ax.set_box_aspect(0.7)
ax.set_xscale('log')
ax.set_xlabel(r'$\beta_L\ [min^{-1}]$')
ax.set_ylabel(variables[vv] + ' [a.u.]')

axes_tick_formatter(ax, 300, 150, axis = 'y')

fig.savefig(filename0 + '_F.png', bbox_inches="tight", dpi = 300, transparent = True)
