In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import itertools
import json
import numpy as np

from matplotlib.gridspec import GridSpec
from numba import njit
from scipy.optimize import minimize, Bounds
from scipy.stats.qmc import Sobol
from sympy import symbols, solve, lambdify
from prototypefour import *

In [2]:
with open("data.json") as f:
    POSITION, LENGTH, BES1, CPD, ROT3, AUXIN = map(np.array, json.load(f).values())

# Define time to position and position to time functions
time_to_position = np.poly1d(np.polyfit(POSITION[:, 0], POSITION[:, 1], 4))
position_to_time = np.poly1d(np.polyfit(POSITION[:, 1], POSITION[:, 0], 4))

# Define the time and vectors
STEP = 0.01
vT = np.arange(0, 18 + STEP, STEP)
vP = time_to_position(vT)

# Remove outliers from the BES1 data
idx = np.where(BES1[:, 1] < 100)
BES1 = BES1[idx]

# Filter the data to only include positions between 150 and 600um
# Transform the data to be in terms of time
# Also divide by 100 so that the numerics of the model are less erratic
def filter_transform(data):
    idx = np.where((data[:, 0] > 150) & (data[:, 0] < 600))
    data = data[idx]
    data[:, 0] = position_to_time(data[:, 0])
    data[:, 1] = data[:, 1] / 100
    return data

BES1 = filter_transform(BES1)
LENGTH = filter_transform(LENGTH)

In [3]:
def setup(config):

    # Define parameters and variables
    params = symbols('c_max p q g_B g_C')
    c_max, p, q, g_B, g_C = params

    # Define fixed parameters from prototype-4a
    s_0, s_in, s_out, s_C = 0.0517, 0.0157, 0, 0

    # Define the symbols representing hormone levels
    B, S, C, L = symbols('B S C L')

    # Construct some expressions based on the configuration
    scaled, tfbs = config
    L_1 = L if scaled else 1
    L_2 = 1 + (g_C * C / L) if tfbs else 1

    # Write down the equations
    signal_equation = s_in * (1 + s_C * C) * B - s_out * S
    clasp_equation = c_max - p * (1 + q * S) * C
    length_equation = (g_B * S * L_1) / L_2

    # Lambdify the equations
    dS = njit(lambdify([*params, B, S, C], signal_equation))
    dC = njit(lambdify([*params, S, C], clasp_equation))
    dL = njit(lambdify([*params, S, C, L], length_equation))

    # Return the list of lambda functions
    return [dS, dC, dL]

# Create the configuration tuples and descriptions
config_space = [(False, False), (True, False), (False, True), (True, True)]
descriptions = ["Base-Model", "L-Scaled", "TFBs", "TFBs, L-Scaled"]

# Create the simulation tuples (description, funcs)
simulations = [(config, description, setup(config)) for config, description in zip(config_space, descriptions)]

In [4]:
# Compute the RSS of a trial run given a vector of predictions (result) and observed (data)    
@njit
def RSS(data, result):
    predicted = np.interp(data[:, 0], vT, result)
    observed = data[:, 1]
    residuals = predicted - observed
    rss = np.sum(np.square(residuals))
    return rss / residuals.shape[0]

# Run a complete simulation
@njit
def simulate_cell(params, dS, dC, dL):

    c_max, p, q, g_B, g_C = params
    s_0, s_in, s_out, s_C = 0.0517, 0.0157, 0, 0
    
    vB = get_br(vP, CPD, ROT3)
    vS, s = np.array([np.float64(x) for x in range(0)]), s_0
    vC, c = np.array([np.float64(x) for x in range(0)]), 1
    vL, l = np.array([np.float64(x) for x in range(0)]), LENGTH[0][1]
    
    for b in vB:
        s = s + dS(c_max, p, q, g_B, g_C, b, s, c) * STEP
        c = c + dC(c_max, p, q, g_B, g_C, s, c) * STEP
        l = l + dL(c_max, p, q, g_B, g_C, s, c, l) * STEP
        
        vS = np.append(vS, s)
        vC = np.append(vC, c)
        vL = np.append(vL, l)

    return [vB, vS, vC, vL], RSS(LENGTH, vL)

## Model Fitting

In [15]:
bounds = np.array([
    [0, 1],    # c_max
    [0, 1],    # p
    [0, 1],    # q
    [0, 1],    # g_B
    [0, 1],    # g_C
])

params = np.array([0.5, 0.5, 0.5, 0.5, 0.5])

results = []
for sim in simulations:

    config, description, (dS, dC, dL) = sim

    # Define the cost function
    @njit
    def cost(params):
        data, error = simulate_cell(params, dS, dC, dL)
        return error

    # Find the parameters of best fit
    fit = minimize(
        cost, 
        params, 
        method = "trust-constr",
        bounds = bounds, 
        options = {"maxiter": 50000}
    )
    
    # Run a simulation with the optimal parameters
    data, error = simulate_cell(fit.x, dS, dC, dL)
    results.append((description, data, error))

    # Run a sensitivity analysis
    fos, tei = quasi_monte_carlo(8, 5, bounds, cost)
        
    # Log the simulation
    print(description)
    print("Success: ", fit.success, fit.message)
    print("Params: ", [round(n, 4) for n in fit.x])
    print("Length Error: ", round(error * 10000, 4))
    print("First Order Sensitivities: ", [round(n, 4) for n in fos])

Base-Model
Success:  True `gtol` termination condition is satisfied.
Params:  [0.5, 0.5, 0.5, 0.1119, 0.5]
Length Error:  54.6697
First Order Sensitivities:  [0.0, 0.0, 0.0, 1.2873, 0.0]
L-Scaled
Success:  True `gtol` termination condition is satisfied.
Params:  [0.5, 0.5, 0.5, 0.4988, 0.5]
Length Error:  11.8079
First Order Sensitivities:  [0.0, 0.0, 0.0, 1.2205, 0.0]


  self.H.update(self.x - self.x_prev, self.g - self.g_prev)


TFBs
Success:  True `gtol` termination condition is satisfied.
Params:  [0.0001, 0.1288, 0.9869, 0.2418, 0.9953]
Length Error:  2.3781
First Order Sensitivities:  [0.0048, 0.1453, 0.0269, 0.6148, 0.5055]


  self.H.update(self.x - self.x_prev, self.g - self.g_prev)


TFBs, L-Scaled
Success:  True `gtol` termination condition is satisfied.
Params:  [0.0215, 0.8214, 0.513, 0.4988, 0.0]
Length Error:  11.808
First Order Sensitivities:  [0.0007, 0.0146, 0.0028, 0.2993, 0.0477]


## Mutant Roots

In [16]:
# Simulate a mutant
def simulate_mutant(params, dS, dC, dL, mutant):

    c_max, p, q, g_B, g_C = params
    s_0, s_in, s_out, s_C, c_0 = 0.0517, 0.0157, 0, 0, 1

    match mutant:
        case "CLASP":
            c_0 = 0
            dC = lambda c_max, p, q, g_B, g_C, s, c : 0
        case "BRIN":
            q = 0
        case "TORIN2":
            s_in = 0.01
            dS = lambda c_max, p, q, g_B, g_C, b, s, c : s_in * b - s_out * s
    
    vB = get_br(vP, CPD, ROT3)
    vS, s = np.array([np.float64(x) for x in range(0)]), s_0
    vC, c = np.array([np.float64(x) for x in range(0)]), c_0
    vL, l = np.array([np.float64(x) for x in range(0)]), LENGTH[0][1]
    
    for b in vB:
        s = s + dS(c_max, p, q, g_B, g_C, b, s, c) * STEP
        c = c + dC(c_max, p, q, g_B, g_C, s, c) * STEP
        l = l + dL(c_max, p, q, g_B, g_C, s, c, l) * STEP
        vS = np.append(vS, s)
        vC = np.append(vC, c)
        vL = np.append(vL, l)

    return [vB, vS, vC, vL]

# Simulate all four root types
mutants = []
for root_type in ["Wild", "CLASP", "BRIN", "TORIN2"]:
    data = simulate_mutant(params, dS, dC, dL, root_type)
    mutants.append((root_type, data))

## Results

In [None]:
mpl.rcParams['font.size'] = 22
mpl.rcParams['figure.figsize'] = (10, 10)

In [None]:
def plot_model(data, desc):
    
    vB, vS, vC, vL = data
    fig, (a1, a2, a3) = plt.subplots(nrows = 3, ncols = 1, sharex = True)
    
    a1.set_title(desc)
    a1.plot(vT, vS, color = "orange", label = "Predicted BES1")
    a1.scatter(BES1[:, 0], BES1[:, 1], label = "Observed BES1")
    a1.set_ylabel("BES1 Signalling (au)")
    a1.legend()

    a2.plot(vT, vC, color = "orange", label = "Predicted CLASP")
    a2.set_ylabel("CLASP (au)")
    a2.legend()
    
    a3.plot(vT, vL, color = "orange", label = "Prediced Lengths")
    a3.scatter(LENGTH[:, 0], LENGTH[:, 1], label = "Mean Observed Lengths")
    a3.set_xlabel("Time (h)")
    a3.set_ylabel(r"Length ($\mu$m)")
    a3.legend()
    
    fig.savefig(f"img/prototype-4c-{desc}.png", bbox_inches = "tight")

for desc, data, errors in results:
    plot_model(data, desc)

In [None]:
fig, ((a1, a2), (a3, a4)) = plt.subplots(nrows = 2, ncols = 2, sharex = True)
fig.tight_layout()

for (desc, data, errors), ax in zip(results, (a1, a2, a3, a4)):
    vB, vS, vC, vL = data
    
    ax.set_title(desc)
    ax.plot(vT, vL, color = "orange", label = "Prediced Lengths")
    ax.scatter(LENGTH[:, 0], LENGTH[:, 1], label = "Mean Observed Lengths")
    ax.set_ylabel("Length (100um)")
    ax.set_ylim((0, 1))

a3.set_xlabel("Time (h)")
a4.set_xlabel("Time (h)")
fig.savefig("img/prototype-4c-comparison", bbox_inches = "tight")

In [None]:
# Plot the mutant data
fig, ((a1, a2), (a3, a4)) = plt.subplots(nrows = 2, ncols = 2, sharex = True)
fig.tight_layout()

for (desc, data), ax in zip(mutants, (a1, a2, a3, a4)):
    vB, vS, vC, vL = data
    
    ax.set_title(desc)
    ax.plot(vT, vL, color = "orange", label = "Prediced Lengths")
    ax.scatter(LENGTH[:, 0], LENGTH[:, 1], label = "Mean Observed Lengths")
    ax.set_ylabel("Length (100um)")
    ax.set_ylim((0, 1))

a3.set_xlabel("Time (h)")
a4.set_xlabel("Time (h)")
fig.savefig("img/prototype-4c-mutants", bbox_inches = "tight")

## Other Data Visualizations

In [None]:
# Plot the transformed and filtered length data
plt.scatter(LENGTH[:, 0], LENGTH[:, 1], label = "Mean Lengths")
plt.title("Time vs. Mean Cell Length")
plt.xlabel("Time (h)")
plt.ylabel("Mean Length (um")
plt.legend()
plt.savefig("img/prototype-4b-lengths.png", bbox_inches = "tight")