In [1]:
%matplotlib notebook

import json
import numpy as np
import numpy.random as npr
import scipy.stats as scs
import matplotlib.pyplot as plt
import matplotlib.ticker as mpt

from ast import literal_eval

In [2]:
BURN_IN = 20

base_parameters = {
    'k_a': 0.002,
    'k_d': 0.1,
    'mu': 3.0,
    'kappa': 1.0,
    'gamma': 0.04,
    'diffusion': 0.6,
    'time_step': 0.1,
    'cell_radius': 6.0,
    'nucleus_radius': 2.5,
}

def log_value(p, name=None):
    """Convert log2 into log10 and adjust for base value"""
    return (np.log10(base_parameters[name]) if name else 0) + p * np.log10(2)

def exp_value(p, name=None):
    """Convert log10 into log2 and adjust for base value"""
    return (p - (np.log10(base_parameters[name]) if name else 0))/np.log10(2)


def get_data(filename):
    with open(filename, 'r') as f:
        data = {model: {literal_eval(k): v for k, v in expvalues.items()}
                for model, expvalues in json.load(f).items()}
    return data

In [6]:
files = [
    'kg/logexpvalue_kg_all_hh.json',
    'kg/logexpvalue_kg_all_hmh.json',
    'kg/logexpvalue_kg_all_hml.json',
    'kg/logexpvalue_kg_all_hl.json',
    'kg/logexpvalue_kg_all_ll.json',
    'kg/logexpvalue_kg_P_hh.json',
    'kg/logexpvalue_kg_P_hl.json',
    'kg/logexpvalue_kg_P_lh.json',
    'kg/logexpvalue_kg_P_ll.json',
    'kg/logexpvalue_kg_RNA_hh.json',
    'ss/logexpvalue_ss_all_hh.json',
    'ss/logexpvalue_ss_all_ll.json',
    'ss/logexpvalue_ss_P_hh.json',
    'ss/logexpvalue_ss_P_hl.json',
    'ss/logexpvalue_ss_P_lh.json',
    'ss/logexpvalue_ss_P_ll.json',
    'ss/logexpvalue_ss_RNA_hh.json',
    'ssAdv/logexpvalue_ssAdv_all_hh.json',
    'ssAdv/logexpvalue_ssAdv_all_ll.json',
    'ssAdv/logexpvalue_ssAdv_P_hh.json',
    'ssAdv/logexpvalue_ssAdv_P_hl.json',
    'ssAdv/logexpvalue_ssAdv_P_lh.json',
    'ssAdv/logexpvalue_ssAdv_P_ll.json',
    'ssAdv/logexpvalue_ssAdv_RNA_hh.json',
    #'sskg/logexpvalue_sskg_all_hh.json',
]

In [10]:
params = {'legend.fontsize': 20,
          #'figure.figsize': (6, 6),
         'figure.figsize': (8, 3),
         'figure.titlesize': 28,
         'axes.labelsize': 24,
         'axes.titlesize':24,
         'xtick.labelsize':20,
         'ytick.labelsize':20,
         'text.usetex':True,
         'figure.autolayout':False,
         }

plt.rcParams.update(params)

import scipy.interpolate as sci
import matplotlib.colors as colors


for filename in files:
    data = get_data(filename)

    fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(16,6))
    #fig.tight_layout(w_pad=0)

    for i, solver in enumerate(['WMM', 'CBM', 'smoldyn']):
        min_x = log_value(min(data[solver].keys(), key=lambda t: t[1])[1])
        max_x = log_value(max(data[solver].keys(), key=lambda t: t[1])[1])
        min_D = log_value(min(data[solver].keys(), key=lambda t: t[0])[0], 'diffusion')
        max_D = log_value(max(data[solver].keys(), key=lambda t: t[0])[0], 'diffusion')

        raw_grid = np.array([[
                log_value(chi), log_value(D, 'diffusion'),
                expvalue
            ]
            for ((D, chi, k_d), expvalue) in data[solver].items()
            if k_d == 0])

        XY = raw_grid[:, :2]
        Z = raw_grid[:, 2]

        xi, yi = np.mgrid[min_x:max_x:100j, min_D:max_D:100j]
        grid_z2 = sci.griddata(XY, Z, (xi, yi) , method='nearest')


        pcm = ax[i].imshow(grid_z2.T, extent=(min_x, max_x, min_D, max_D),
                   vmin=0, vmax=10,
                   #norm=colors.LogNorm(vmin=0.1, vmax=10),
                   origin='lower', cmap='RdBu_r')

        if not i:
            ax[i].set_ylabel('Diffusion ($\mu m^2$ min$^{-1}$)')

        ax[i].set_xlabel('$\chi$')

        ax[i].set_title(f"{solver} ({np.mean(Z):.2f})")

        xticks = list(range(int(np.ceil(min_x)), int(np.floor(max_x))+1))
        ax[i].set_xticks(xticks)
        ax[i].set_xticklabels([f'$10^{{{t}}}$' for t in xticks])
        ax[i].xaxis.set_minor_locator(mpt.FixedLocator(sum([[np.log10(i*c) for i in range(2,10)] for c in [0.001, 0.01, 0.1, 1, 10]],[])))

        yticks = list(range(int(np.ceil(min_D)), int(np.floor(max_D))+1))
        ax[i].set_yticks(yticks)
        ax[i].set_yticklabels([f'$10^{{{t}}}$' for t in yticks])
        ax[i].yaxis.set_minor_locator(mpt.FixedLocator(sum([[np.log10(i*c) for i in range(2,10)] for c in [0.001, 0.01, 0.1, 1, 10]],[])))

    cb = fig.colorbar(pcm, ax=ax[:], location='right', shrink=1)
    cb.set_label('Expected relative log-error')
    fig.suptitle(f"{filename.replace('_', ' ')}", y=1)
    plt.savefig(f"{filename}.pdf")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

  fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(16,6))


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# All vs P only vs mRNA only

In [22]:
params = {
    'legend.fontsize': 16,
    'legend.title_fontsize': 18,
    'figure.figsize': (10, 16),
    'figure.titlesize': 28,
    'axes.labelsize': 24,
    'axes.titlesize':24,
    'xtick.labelsize':20,
    'ytick.labelsize':20,
    'text.usetex':True,
    'figure.autolayout':False,
}

import seaborn as sns
import pandas as pd

sns.set_theme(style="whitegrid")
plt.style.use('tableau-colorblind10')

plt.rcParams.update(params)

solvers_groups = [['WMM', 'CBM'], ['smoldyn']]

for solvers in solvers_groups:
    fig, axes = plt.subplots(len(solvers), 1, sharex=False, sharey=True, figsize=(10, 5*len(solvers)))

    if len(solvers) == 1:
        axes = [axes]

    for i, (ax, solver) in enumerate(zip(axes, solvers)):
        traj = 'h'
        tsamp = 'h'

        labels = [
            (species, dist)
            for species in ['all', 'P', 'RNA']
            for dist in ['ss', 'ssAdv', 'kg']
        ]

        expvalues = [
                (v, species, dist)
            for species, dist in labels
            for (_, _, k_d), v in get_data(
                    '{dist}/logexpvalue_{dist}_{species}_{traj}{tsamp}.json'.format(
                        dist=dist, species=species, traj=traj, tsamp=tsamp)
                )[solver].items() if k_d == 0
        ]

        df_expvalues = pd.DataFrame(expvalues, columns=['expvalue', 'species', 'dist'])

        sns.boxenplot(ax=ax, data=df_expvalues, y='expvalue', hue='species', x='dist')

        #ax.set_yscale('log')

        ax.set_ylabel("Exp. Log-Err.")

        if len(solvers) > 1:
            ax.set_title(solver)
        else:
            ax.set_title('Expected Log-Error')
        if i != 0:
            ax.legend().set_visible(False)
        if i != len(solvers) - 1:
            ax.set_xlabel(None)

    fig.tight_layout()

    #plt.savefig(f"RNAvsP{'-'.join(solvers)}_error.pdf", dpi=300)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>