# Multiple model parameter fitting

## Starting point
- $N$ segmented geometry images $G_i$, each with a target concentration image $T_i$ at some time $t_i$
- a model using one of these geometry images with reaction parameters $k^{(j)}$ to be fitted

## Strategy
- make N identical copies $M_i$ of the initial sme model and set the parameters to the same input values for all models
- update the geometry image for model $M_i$ to $G_i$
- simulate each model $M_i$ for time $t_i$
- calculate the rms difference between the simulated concentrations of model $M_i$ and the target concentration $T_i$
- sum this difference over all models $M_i$ to return a total cost function value for this set of parameters

Then feed this cost function into an optimisation algorithm to fit the parameters $k^{(j)}$

## Navigation

- Press `Space` to show the next page
- Press `Shift+Space` to show the previous page
- Press `Escape` to zoom out

## Utility functions

In [None]:
from itertools import cycle

import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.colors as mcolors
from mpl_toolkits.mplot3d import Axes3D

from IPython.display import Image, display, HTML, Video

import numpy as np
import imageio.v3 as iio
import tifffile
import skimage
from scipy import ndimage as ndi
import sme
import pyvista as pv

pv.global_theme.axes.show = True
pv.global_theme.interactive = True
plt.rcParams["figure.figsize"] = (8, 8)

In [None]:
pv.set_jupyter_backend("static")

# pv.set_jupyter_backend("trame") # for interactive plots

In [None]:
def make_discrete_colormap(cmap: str = "tab10", values: np.array = []) -> list[int]:
    """Create a discrete colormap of potentially repeating colors of the same size as the `values` array.

    Args:
        cmap (str, optional): matplotlib colormap name. Defaults to "tab10".
        values (np.array, optional): values to be mapped to colors. Defaults to [].

    Returns:
        list[int]: list of color in rgba format.
    """
    cm = [(0, 0, 0, 1)]
    i = 0
    for c in cycle(plt.get_cmap(cmap).colors):
        cm.append(mcolors.to_rgba(c))
        if len(cm) >= len(values):
            break
        i += 1
    return cm

In [None]:
def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
    """Convert an array of RGB values to scalar values.
        This function is necessary because pyvista does not support RGB values directly as mesh data

    Args:
        img (np.ndarray): data to be converted, of shape (n, m, 3)

    Returns:
        np.ndarray: data converted to scalar values, of shape (n, m)
    """
    reshaped = img.reshape(-1, 3, copy=True)
    unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)

    values = np.arange(len(unique_rgb))
    return values[ridx].reshape(img.shape[:-1])

In [None]:
def plot3D(
    data: np.ndarray,
    title: str | list[str],
    threshold_value: int | list[int] = [1, 0],
    cmap: str | list[str] = "tab10",
    with_swap: bool = True,
    with_aux: bool = True,
    with_cbar: bool = False,
    mesh_kwargs: dict = {},
) -> pv.Plotter:
    """Plot a 3D image with optional auxilary image that can show a differently thresholded version of the same mesh.

    Args:
        data (np.ndarray): Data to plot
        title (str | list[str]): Title for each plot
        threshold_value (int | list[int], optional): Treshold values of reach plot. Values below the threshold will not be shown Defaults to [1, 0].
        cmap (str | list[str], optional): Name of a matplotlib colormap or a list of colors in RGBA or hex format. Defaults to "tab10".
        with_swap (bool, optional): Whether axes 0 and 2 should be swapped. Defaults to True.
        with_aux (bool, optional): Enable second plot. Defaults to True.
        with_cbar (bool, optional): Show colorbar. Defaults to False.
        mesh_kwargs (dict, optional): Other keywor arguments for the pyvista plotter.add_mesh function. Defaults to {}.

    Raises:
        ValueError: When the input data is not 3D or 4D (for RGB values)
        ValueError: When the title is not a list of two strings when with_aux is True
        ValueError: When the threshold_value is not a list of two integers when with_aux is True

    Returns:
        pv.Plotter: pyvista plotter object. Call plotter.show() to display the plot
    """
    if data.ndim not in [3, 4]:
        raise ValueError("Image must be 3D or 4D (for rgb values)")

    _data = data

    plotter = pv.Plotter(shape=(1, 2 if with_aux else 1), border=False, notebook=True)

    if with_aux and (len(title) != 2 or isinstance(title, str)):
        raise ValueError("Two title must be provided for the two subplots")

    if with_aux and (len(threshold_value) != 2 or isinstance(threshold_value, int)):
        raise ValueError("Two threshold values must be provided for the two subplots")

    if with_swap:
        _data = np.swapaxes(data, 0, 2).copy()

    if len(_data.shape) == 4:
        _data = rgb_to_scalar(_data)

    if isinstance(threshold_value, int):
        threshold_value = [threshold_value, threshold_value]

    img_data = pv.ImageData(dimensions=_data.shape, **mesh_kwargs)
    img_data.point_data["Data"] = _data.flatten()
    img_data = img_data.points_to_cells(scalars="Data")
    plotter.subplot(0, 0)
    plotter.add_text(title[0] if isinstance(title, list) else title)
    plotter.add_mesh(
        img_data.threshold(threshold_value[0]),
        show_edges=True,
        show_scalar_bar=with_cbar,
        cmap=cmap,
    )

    if with_aux:
        plotter.subplot(0, 1)
        plotter.add_text(title[1] if isinstance(title, list) else title)
        plotter.add_mesh(
            img_data.threshold(threshold_value[1]),
            show_edges=True,
            show_scalar_bar=with_cbar,
            cmap=cmap,
        )
        plotter.link_views()
    return plotter

In [None]:
def plot_indexed(img_indexed):
    values = np.unique(img_indexed)
    print(values)
    lt = pv.LookupTable(
        values=np.array(make_discrete_colormap(cmap="tab10", values=values)) * 255,
        scalar_range=(0, len(values)),
        n_values=len(values),
    )
    print(lt)
    plotter = plot3D(
        img_indexed,
        ["Segmented image (voxel)", "Segmented image (surface)"],
        threshold_value=[1, 0],
        cmap=lt,
    )
    print(lt)
    plotter.show()

Select the 'trame' jupyter backend below to have run the notebook locally and be able to interact with the plots. See documentation of pyvista for other backends

## Geometry images

In [None]:
def sphere_mask(grid_shape, center, radius, deformation):
    # generate a boolean mask for a sphere with given center, radius and deformation
    Z, Y, X = grid_shape
    z0, y0, x0 = center
    dz, dy, dx = deformation
    z, y, x = np.ogrid[:Z, :Y, :X]
    return dx * (x - x0) ** 2 + dy * (y - y0) ** 2 + dz * (z - z0) ** 2 <= radius**2

In [None]:
def geometry_image(n_pixels):
    # generate a segmented image containing one randomly distributed, sized and deformed sphere
    max_radius = n_pixels / 3
    max_deform = 1.2
    voxels = np.zeros((n_pixels, n_pixels, n_pixels), dtype=np.uint16)
    center = np.random.randint(2, n_pixels - 2, 3)
    nuclear_radius = np.random.randint(1, max_radius / 2)
    cell_radius = np.random.randint(1.5 * nuclear_radius, max_radius)
    deformation = np.random.uniform(1 / max_deform, max_deform, 3)
    voxels[sphere_mask(voxels.shape, center, cell_radius, deformation)] = 2
    voxels[sphere_mask(voxels.shape, center, nuclear_radius, deformation)] = 1
    return voxels

In [None]:
np.sum(geometry_image(30) == 2)

In [None]:
img_indexed = geometry_image(30)
values = np.unique(img_indexed)
lt = pv.LookupTable(
    values=np.array(make_discrete_colormap(cmap="tab10", values=values)) * 255,
    scalar_range=(0, len(values)),
    n_values=len(values),
)

In [None]:
lt

In [None]:
plotter = plot3D(
    img_indexed,
    ["Segmented image (voxel)", "Segmented image (surface)"],
    threshold_value=[1, 0],
    cmap=lt,
)

In [None]:
plotter.show()

In [None]:
plot_indexed(geometry_image(30))

## SME model

In [None]:
import sme

In [None]:
# do example simulation, plot geometry image

## Geometry images

In [None]:
# load geometry images
# make copies of model
# assign geometry to each
# do example simulation for each

## Target concentration images

In [None]:
# plot target conc, simulated conc for each model
# calculate cost function for each, total cost

## Particle swarm parameter optimization

In [None]:
# plot target conc, simulated conc for each model
# calculate cost function for each, total cost

## Final results

In [None]:
# plot final simulation concs vs target concs for all models