In [None]:
%matplotlib widget

import gzip
from dataclasses import dataclass
from functools import partial
from os import listdir, path
from typing import Callable, Optional, Tuple, Union

import msgpack
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.collections import PathCollection
from matplotlib.figure import Figure
from scipy import interpolate
from scipy.optimize import curve_fit


In [None]:
STICKING_PROBABILITIES = [1.0, 0.8, 0.6, 0.4, 0.2, 0.1]
DATA_DIR = "/run/media/life/barry/ViennaTools/spherical"


# Function Setup

## Data Loading Routines

In [None]:
# Some handy type definitions
DataLoadingFunction = Callable[[str], dict]
NumpyOrFloat = Union[float, np.ndarray]


@dataclass(frozen=True, slots=True)
class ThicknessData:
    """Container class for storing extracted layer thickness information"""

    pos: np.ndarray
    thickness: list[list[np.ndarray]]
    sticking_probabilities: list[float]
    max_times: list[int]
    num_points: int


def map_configs(fnames: list[str]) -> dict[str, list[str]]:
    """Takes a list of filenames and splits those into the geometry part as well as the
    sticking probability part. For each geometry all thickness data resulting from
    different sticking probabilities are stored in a list."""
    configs = {}
    for fname in fnames:
        sticking, geom = fname[:5], fname[6:]
        if not geom in configs:
            configs[geom] = []
        configs[geom].append(sticking)
    return configs


def load_data(data_dir: str, filename: str) -> dict:
    """Loads nodal data of .msgpack.gz files storing graph data."""
    with gzip.open(path.join(data_dir, filename)) as gz:
        data = msgpack.unpack(gz, use_list=False, raw=False)
    return data["nodes"]


def extract_data(
    loader: DataLoadingFunction,
    name: str,
    sticking: list[str],
    prefix: str = "extracted_",
) -> ThicknessData:
    """Uses the provided data loader to get the nodal data of a geometry and
    subsequently extracts all fields starting with prefix."""
    data_list: list[list[np.ndarray]] = []
    sticking_probabilities: list[float] = []
    max_times: list[int] = []
    num_points: list[int] = []
    pos = []
    for ps in sticking:
        node_data = loader(f"{ps}_{name}")
        labels = [s for s in node_data if s.startswith(prefix)]
        data_list.append([np.array(node_data[l]) for l in labels])
        sticking_probabilities.append(float(ps[1:]) / 1000)
        max_times.append(len(labels))
        num_points.extend([d.shape[0] for d in data_list[-1]])
        if len(pos) == 0:
            pos = np.array(node_data["pos"])

    

    # Ensure that the number of nodes is the same for all data
    assert np.all(np.array(num_points) == num_points[0])

    # Construct the data container instance
    return ThicknessData(
        thickness=data_list,
        sticking_probabilities=sticking_probabilities,
        max_times=max_times,
        pos=pos,
        num_points=num_points[0],
    )


## Data Transformation Routines

In [None]:
def to_data_tensor(
    data: ThicknessData,
    sticking_probabilities: list[float],
) -> np.ndarray:
    """Converts the `ThicknessData` instance to a numpy tensor with missing data
    indicated by np.nan. First axis `point_index`, second axis `sticking_probability` and
    third axis `time`."""
    max_time = 50
    mapped = np.full(
        (data.num_points, len(sticking_probabilities), max_time),
        np.nan,
    )
    for i, s in enumerate(sticking_probabilities):
        # First, determine the index
        data_index = 0
        for sp in data.sticking_probabilities:
            if sp == s:
                break
            data_index = data_index + 1
        else:
            print(f"Couldn't find data for sticking probability {s}")
            continue
        mapped[:, i, 0] = 0
        mapped[:, i, 1 : np.min([data.max_times[data_index] + 1, max_time])] = np.array(
            data.thickness[data_index]
        ).T
    return mapped


def to_grid_data(
    z: np.ndarray,
    sticking_probabilities: list[float],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Converts the provided data to a format compatible with matplotlib's meshgrid used
    by surface plots."""
    y = sticking_probabilities
    x = np.arange(z.shape[1])
    xx, yy = np.meshgrid(x, y)
    # Ensure that all values are >= 0
    zz = np.where(z >= 0, z, np.nan)
    return xx, yy, zz


def fill_along_axis(
    x: np.ndarray,
    mask: np.ndarray,
    axis: int = 0,
) -> np.ndarray:
    """Fills values of multidimensional array where mask is True using
    (linear) interpolation along one axis."""
    x = x.copy()
    for (sequence, ma) in zip(np.moveaxis(x, 0, axis), np.moveaxis(mask, 0, axis)):
        if ma.sum() != 0 and (~ma).sum():
            sequence[ma] = np.interp(
                np.flatnonzero(ma), np.flatnonzero(~ma), sequence[~ma]
            )
    return x


def check_monotonic(
    x: np.ndarray,
    axis: int = 0,
    eps: float = 0.1,
    increasing: bool = True,
):
    """Checks if the provided array is monotonically increasing or decreasing along the
    provided axis"""
    assert axis < len(x.shape)
    if increasing:
        monotonic = np.diff(x, 1, axis=axis, prepend=0) > -eps
    else:
        monotonic = np.diff(x, 1, axis=axis, append=0) > eps
    return monotonic


## Data Visualization Routines

In [None]:
def plot_data_2d(
    mapped: np.ndarray,
    sticking_probabilities: list[float],
    point_index: int,
    normalize: bool = False,
    log: bool = False,
) -> Tuple[Figure, Axes]:
    """Creates two 2D plots which show the relation of the deposition radius with
    respect to the time dimension as well as the sticking probability"""
    fig, ax = plt.subplots(nrows=1, ncols=2, squeeze=False, figsize=(10, 4))
    cmap = plt.get_cmap("viridis", mapped.shape[2])
    for i, r in enumerate(np.moveaxis(mapped[point_index, :, :], 0, 1)):
        if normalize:
            r = np.where(r != 0, r, np.nan)
            ax[0, 0].plot(
                sticking_probabilities,
                r / (i + 1),
                marker="x",
                color=cmap(i),
            )
        else:
            ax[0, 0].plot(
                sticking_probabilities,
                r,
                marker="x",
                color=cmap(i),
            )

    # Invert the x-axis so that we start with a sticking probability of 1 at the left
    ax[0, 0].invert_xaxis()

    for i, r in enumerate(np.moveaxis(mapped[point_index, :, :], 0, 0)):
        if normalize:
            r = np.where(r != 0, r, np.nan)
            ax[0, 1].plot(
                r / (1 + np.arange(49)),
                marker="x",
                label=sticking_probabilities[i],
            )
        else:
            ax[0, 1].plot(
                r,
                marker="x",
                label=sticking_probabilities[i],
            )

    ax[0, 0].grid(which="both")
    ax[0, 0].set_xlabel("sticking probability")
    ax[0, 0].set_ylabel("radius")
    ax[0, 1].set_xlabel("time")
    ax[0, 1].set_ylabel("radius")
    ax[0, 1].grid(which="both")

    if log:
        ax[0, 0].set_yscale("log")
        ax[0, 1].set_yscale("log")
    ax[0, 1].legend()
    return fig, ax


def draw_3d_plot(
    ax: Axes,
    data: np.ndarray,
    sticking_probabilities: list[float],
    interp: bool = False,
    fill: bool = False,
    monotonic: bool = False,
) -> None:
    """Creates a 3D surface plot of the data in the provided plot axes object."""
    ax.clear()
    xx, yy, zz = to_grid_data(data, sticking_probabilities)
    cmap = plt.get_cmap("coolwarm")
    if fill:
        zz = fill_along_axis(zz, np.isnan(zz), 0)
        if monotonic:
            mono_mask = ~check_monotonic(zz, axis=0)
            zz = fill_along_axis(zz, mono_mask, axis=1)

    if interp:
        xnew = np.linspace(0, zz.shape[1], 100)
        ynew = np.linspace(
            np.min(sticking_probabilities), np.max(sticking_probabilities), 100
        )
        xxnew, yynew = np.meshgrid(xnew, ynew)

        zznew = interpolate.griddata(
            (xx.ravel(), yy.ravel()),
            zz.ravel(),
            (xxnew.ravel(), yynew.ravel()),
            method="linear",
        )
        ax.plot_surface(
            xxnew,
            yynew,
            zznew.reshape(xxnew.shape),
            linewidth=2,
            cmap=cmap,
            rstride=1,
            cstride=1,
            antialiased=True,
        )
    else:
        ax.plot_surface(xx, yy, zz, linewidth=1, cmap=cmap, antialiased=True)

    ax.set_xlabel("time")
    ax.set_ylabel("sticking probability")
    ax.set_zlabel("radius")


def plot_data_3d(
    pos: np.ndarray,
    mapped: np.ndarray,
    sticking_probabilities: list[float],
    interp: bool = False,
    fill: bool = False,
    monotonic: bool = False,
    figsize: Tuple[float, float] = (10, 5),
    selection_callback: Optional[Callable[[int, Axes], None]] = None,
) -> Tuple[Figure, Axes]:
    """Plots the geometry and also creates a 3D surface plot of the data corresponding
    to the selected surface point."""

    point_index = 0

    fig = plt.figure(figsize=figsize)
    ax_selector = fig.add_subplot(121)
    ax_3d = fig.add_subplot(122, projection="3d")
    fig.subplots_adjust(hspace=0.4)

    draw_3d_plot(
        ax_3d,
        mapped[point_index, :, :],
        sticking_probabilities,
        interp,
        fill,
        monotonic,
    )

    if selection_callback:
        selection_callback(point_index, ax_3d)

    ax_selector.set_title("Geometry\n(selected: 0)")
    ax_selector.scatter(
        pos[:, 0],
        pos[:, 1],
        color="black",
        linewidths=0,
        marker=".",
        picker=True,
        pickradius=0.2,
    )

    scat = ax_selector.scatter(
        pos[point_index, 0], pos[point_index, 1], color="red", marker="o"
    )

    def onpick(event):
        if isinstance(event.artist, PathCollection):
            point_index = event.ind[0]
            colors = np.zeros((len(pos), 3))
            colors[point_index, 0] = 1
            scat.set_offsets((pos[point_index, 0], pos[point_index, 1]))
            ax_selector.set_title(f"Geometry\n(selected: {point_index})")

            # Now redraw 3d plot
            draw_3d_plot(
                ax_3d,
                mapped[point_index, :, :],
                sticking_probabilities,
                interp,
                fill,
                monotonic,
                # fit_function,
            )

            if selection_callback:
                selection_callback(point_index, ax_3d)

            fig.canvas.draw_idle()

    ax_selector.axis("scaled")
    fig.canvas.mpl_connect("pick_event", onpick)
    return fig, ax_3d


# Data Exploration

## Loading and Transforming

In [None]:
# Get all available configurations
fnames = listdir(DATA_DIR)
configs = map_configs(fnames)

config_list = list(configs.items())

# Select a particular geometry
geometry_index = 57
name, sticking = config_list[geometry_index]

print(name)
trench_diameter = float(name.split("_")[0][1:4])

# Load the data of this geometry
data_loader = partial(load_data, DATA_DIR)
extracted = extract_data(data_loader, name, sticking)
data = to_data_tensor(extracted, STICKING_PROBABILITIES)
viewfactor_data = to_data_tensor(
    extract_data(data_loader, name, ["s1000"], prefix="viewfactor_"),
    STICKING_PROBABILITIES,
)


In [None]:
deviation_data = np.zeros_like(data)
for i in range(data.shape[1]):
    deviation_data[:, i, :] = data[:, i, :] - viewfactor_data[:, 0, :]


### Extend the dataset with the expected conformal thickness data

In [None]:
expanded_data = np.zeros((data.shape[0], data.shape[1] + 1, data.shape[2]))
expanded_data[:, :-1, :] = data.copy()
y = np.where(extracted.pos[:, 1] > 0, extracted.pos[:, 1], 0)
time = np.arange(data.shape[2])

expanded_data[:, -1, :] = np.outer(
    np.ones(data.shape[0]),
    time,
)

expanded_data[:, -1, :] = expanded_data[:, -1, :] - np.outer(y, np.ones(data.shape[2]))

expanded_data[:, -1, :] = np.where(
    expanded_data[:, -1, :] <= trench_diameter / 2,
    expanded_data[:, -1, :],
    trench_diameter / 2,
)

expanded_sticking_probabilities = STICKING_PROBABILITIES + [0.0]


## 2D Plots

In [None]:
point_index = 300
fig, _ = plot_data_2d(
    data,
    STICKING_PROBABILITIES,
    point_index=point_index,
    normalize=False,
    log=False,
)
fig, _ = plot_data_2d(
    viewfactor_data,
    STICKING_PROBABILITIES,
    point_index=point_index,
    normalize=False,
    log=False,
)
fig, _ = plot_data_2d(
    deviation_data,
    STICKING_PROBABILITIES,
    point_index=point_index,
    normalize=False,
    log=False,
)

fig, _ = plot_data_2d(
    expanded_data,
    expanded_sticking_probabilities,
    point_index=point_index,
    normalize=False,
    log=False,
)


## 3D Plots

In [None]:
fig, ax = plot_data_3d(
    extracted.pos,
    data,
    STICKING_PROBABILITIES,
    interp=False,
    fill=True,
    monotonic=True,
)


### Extended dataset (+conformal deposition thickness)

In [None]:
fig, ax = plot_data_3d(
    extracted.pos,
    expanded_data,
    expanded_sticking_probabilities,
    interp=False,
    fill=True,
    monotonic=True,
)


## Curve Fitting

In [None]:
def optimize(
    data: np.ndarray,
    sticking_probabilities: list[float],
) -> None:
    def fun(
        X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
        a: float,
        b: float,
        c: float,
        d: float,
        e: float,
        f: float,
        g: float,
    ) -> np.ndarray:
        # Normalize and reshape the inputs
        x = X[0]  # / 50
        y = 1 - X[1]

        return a * np.tanh(b * x * (c + d * y)) * (e + f * y**2 + g * y**4)

    def fit_surface(
        x: np.ndarray,
        y: np.ndarray,
        a: float,
        b: float,
        c: float,
        d: float,
        e: float,
        f: float,
        g: float,
    ) -> np.ndarray:
        """x: time, y: sp"""
        t = fun((x, y), a, b, c, d, e, f, g)
        return t

    def callback(point_index: int, ax: Axes) -> None:
        x, y, z = to_grid_data(data[point_index, :, :], sticking_probabilities)
        z = fill_along_axis(z, np.isnan(z))
        mono_mask = ~check_monotonic(z, axis=0)
        z = fill_along_axis(z, mono_mask, axis=1)

        popt, pcov = curve_fit(fun, (x.ravel(), y.ravel()), z.ravel())
        print(popt)

        xnew = np.linspace(0, z.shape[1], 100)
        ynew = np.linspace(
            np.min(sticking_probabilities), np.max(sticking_probabilities), 100
        )
        xxnew, yynew = np.meshgrid(xnew, ynew)
        zznew = fit_surface(xxnew.ravel(), yynew.ravel(), *popt)
        ax.plot_surface(
            xxnew,
            yynew,
            zznew.reshape(xxnew.shape),
            linewidth=1,
            antialiased=True,
        )

    plot_data_3d(
        extracted.pos,
        data,
        sticking_probabilities,
        interp=False,
        fill=True,
        monotonic=True,
        selection_callback=callback,
    )


optimize(data, STICKING_PROBABILITIES)
