# Timepix Images (Fe Sample and Open Beam)

In [None]:
%matplotlib widget

In [None]:
from ess.imaging.data import get_fe_timepix_image_path, get_ob_timepix_image_path
from ess.imaging.io import load_tiff

sample_image_path = get_fe_timepix_image_path()
openbeam_image_path = get_ob_timepix_image_path()

## Define Workflow

In [None]:
import pathlib

import scipp as sc
import sciline as sl

from typing import NewType, TypeVar

Sample = NewType('Sample', int)
OpenBeam = NewType('OpenBeam', int)
RunType = TypeVar('RunType', Sample, OpenBeam)


class FilePath(sl.Scope[RunType, pathlib.Path], pathlib.Path): ...


class TiffImage(sl.Scope[RunType, sc.DataArray], sc.DataArray): ...


NormalizedImage = NewType('NormalizedImage', sc.DataArray)


def load_image(path: FilePath[RunType]) -> TiffImage[RunType]:
    easier_dims = {'dim_0': 'x', 'dim_1': 'y', 'dim_2': 't'}
    img = load_tiff(path).rename_dims(easier_dims)
    # Hard-code the time coordinate
    img.coords['t'] = sc.linspace(
        dim='t', start=0.0, stop=40.0, num=img.sizes['t'] + 1, unit='ns'
    )
    # Hard-code the positional coordinates
    img.coords['x'] = sc.linspace(
        dim='x', start=0, stop=10.0, num=img.sizes['x'] + 1, unit='cm'
    )
    img.coords['y'] = sc.linspace(
        dim='y', start=0, stop=10.0, num=img.sizes['y'] + 1, unit='cm'
    )
    return img


def normalize_image(
    sample_img: TiffImage[Sample], openbeam_img: TiffImage[OpenBeam]
) -> NormalizedImage:
    """
    Normalize the sample image with open beam image /pixel and /time(wavelength)

    ``nan`` values can occur if any pixels of `open beam` has `0` counts.
    ``nan`` values are masked.
    """
    normalized = sample_img / openbeam_img
    normalized.masks['no-ob'] = sc.isnan(normalized.data)
    normalized.data = sc.nan_to_num(
        normalized.data, posinf=sc.scalar(0.0, unit='dimensionless')
    )
    return NormalizedImage(normalized)


wf = sl.Pipeline(
    providers=[load_image, normalize_image],
    params={
        FilePath[Sample]: sample_image_path,
        FilePath[OpenBeam]: openbeam_image_path,
    },
)
wf.visualize(NormalizedImage, graph_attr={'rankdir': 'LR'})

## Compute Images

In [None]:
images = wf.compute((TiffImage[Sample], TiffImage[OpenBeam], NormalizedImage))
sample_img = images[TiffImage[Sample]]
openbeam_img = images[TiffImage[OpenBeam]]
normalized_img = images[NormalizedImage]

## Visualize Images
### Raw Images (Sample and OpenBeam)

In [None]:
sample_2d = sample_img.sum('t').plot(title='Fe Image (Accumulated)', grid=True, cbar=False)
openbeam_2d = openbeam_img.sum('t').plot(
    title='Open Beam Image (Accumulated)', grid=True
)

sample_1d = sample_img.sum('x').sum('y').plot(title='Fe Spectrum', grid=True)
openbeam_1d = openbeam_img.sum('x').sum('y').plot(title='Open Beam Spectrum', grid=True)

(sample_2d + openbeam_2d) / (sample_1d + openbeam_1d)

### Normalized Images

In [None]:
# Handle NaNs
normalized_img.masks['no-ob'] = sc.isnan(normalized_img.data)
normalized_img.data = sc.nan_to_num(
    normalized_img.data, posinf=sc.scalar(0.0, unit='dimensionless')
)
# 2D
normed_2d_da = normalized_img.sum('t')
normed_2d_da.masks['no-ob'] = normalized_img.masks['no-ob'].all('t')
normed_2d = normed_2d_da.plot(title='Normalized Image (Accumulated)', grid=True, cbar=False)
# 1D
normed_1d_da = normalized_img.sum('x').sum('y')
normed_1d = normed_1d_da.plot(title='Normalized Spectrum', grid=True)
normed_2d + normed_1d

## ROI Selection Tool

In [None]:
import plopp as pp
from plopp.widgets.tools import ToggleTool
from mpltoolbox import Rectangles


class RectInfo(sc.DataGroup):
    @property
    def min_x(self) -> sc.Variable:
        return self['min_x']

    @property
    def max_x(self) -> sc.Variable:
        return self['max_x']

    @property
    def min_y(self) -> sc.Variable:
        return self['min_y']

    @property
    def max_y(self) -> sc.Variable:
        return self['max_y']

    def __init__(
        self,
        min_x: sc.Variable,
        max_x: sc.Variable,
        min_y: sc.Variable,
        max_y: sc.Variable,
    ):
        super().__init__(
            {'min_x': min_x, 'max_x': max_x, 'min_y': min_y, 'max_y': max_y}
        )


def _get_rect_info(artist, figure) -> RectInfo:
    """
    Convert the raw rectangle info to a ``RectInfo`` object.
    each axis, and values with units.
    """
    x_range = sc.array(
        dims=['x'],
        values=[artist.xy[1], artist.xy[1] + artist.height],
        unit=figure.canvas.units['x'],
    )
    y_range = sc.array(
        dims=['y'],
        values=[artist.xy[0], artist.xy[0] + artist.width],
        unit=figure.canvas.units['y'],
    )
    return RectInfo(x_range.min(), x_range.max(), y_range.min(), y_range.max())


class MergedRectanglesTool(ToggleTool):
    def __init__(self, figure, destination, value: bool = False, **kwargs):
        super().__init__(callback=self.start_stop, value=value, **kwargs)

        self._figure = figure
        self._tool = Rectangles(
            ax=self._figure.ax, autostart=kwargs.get('autostart', False)
        )
        self.rectangles = []
        self._draw_node = pp.Node(lambda: self.rectangles)  # Empty rectangle info.
        self._destination = destination
        self._destination.add_parents(self._draw_node)
        self._tool.on_create(self.update_node)
        self._tool.on_vertex_release(self.update_node)
        self._tool.on_drag_release(self.update_node)
        self._tool.on_remove(self.update_node)

    def update_node(self, _):
        self.rectangles = [
            _get_rect_info(artist=artist, figure=self._figure)
            for artist in self._tool.children
        ]
        self._draw_node.func = lambda: self.rectangles
        self._draw_node.notify_children("")  # Empty message.

    def start_stop(self):
        """
        Toggle start or stop of the tool.
        """
        if self.value:
            self._tool.start()
        else:
            self._tool.stop()


In [None]:
# Prepare the merge method
from functools import reduce


def _show_histogram(da: sc.DataArray, rect_infos: list[RectInfo]) -> sc.DataArray:
    if len(rect_infos) > 0:
        masks = [
            (normalized_img.coords['x'] > rect_info.max_x)
            | (normalized_img.coords['x'] < rect_info.min_x)
            | (normalized_img.coords['y'] > rect_info.max_y)
            | (normalized_img.coords['y'] < rect_info.min_y)
            for rect_info in rect_infos
        ]
        mask = reduce(lambda x, y: x & y, masks)
        selected_sample_region = da.copy(deep=False)
        selected_sample_region.masks['roi'] = mask['x', :-1]['y', :-1]
        return selected_sample_region.sum('x').sum('y')
    else:
        return da.sum('x').sum('y')


original_data_node = pp.Node(normalized_img)
merge_node = pp.Node(_show_histogram)
merge_node.add_parents(original_data_node)

In [None]:
data_node = pp.Node(normalized_img.sum('t'))
f2d = pp.imagefigure(
    data_node, norm='log', title="Region of Interest Selection", cbar=True
)

roi_merge_tool = MergedRectanglesTool(
    figure=f2d, destination=merge_node, icon='vector-square'
)
f1d = pp.linefigure(merge_node, title="Wavelength Spectrum in ROI", grid=True, vmin=0)

In [None]:
from plopp.widgets import HBar

f2d.toolbar['roi-merged'] = roi_merge_tool
box = HBar([f2d, f1d])
box  # TODO: These ROI tools should be reusable easily with other data.