# Test Fusion on BigStitcher Example Data

Here, we test the calm-utils image fusion code using the (already aligned) 3D dataset from BigStitcher: https://imagej.net/plugins/bigstitcher/#example-datasets

In [1]:
from xml.etree import ElementTree as et
from pathlib import Path

from h5py import File
import numpy as np

## Parse transforms from XML dataset definition

In [2]:
# path of XML dataset definition
dataset = Path('/Users/david/Desktop/scratch_data/grid-3d-stitched-h5/dataset.xml')

# wheter to preserve anisotropy (un-do scaling to isotropic coordinates)
preserve_anisotropy = True

# parse XML
dataset_xml = et.parse(dataset)

# get path of H5 data file (NOTE: we assume it is relative to XML)
h5file = dataset.parent / dataset_xml.find('SequenceDescription/ImageLoader/hdf5').text

transforms = {}

for vr in dataset_xml.iterfind('./ViewRegistrations/'):

    # (tp, setup) view id tuple
    view_id = tuple(map(int, vr.attrib.values()))

    # load and accumulate transforms
    tr = np.eye(4)
    for vt in vr.iterfind('ViewTransform/affine'):
        tr_i = np.eye(4)
        tr_i[:3] = np.fromstring(vt.text, sep=' ').reshape((3,4))
        tr = tr @ tr_i
    
    # undo calibration scale to relative pixel sizes
    # (we pre-concatenate inverse transform of the last transform in list := calibration)
    if preserve_anisotropy:
        tr = np.linalg.inv(tr_i) @ tr

    # shuffle to zyx
    tr = tr[:, [2,1,0,3]][[2,1,0,3]]

    transforms[view_id] = tr

## Load images from H5

Here, we load image data (at full resolution) for all view ids we found in the XML dataset before:

In [4]:
from calmutils.misc.string_utils import pad

images = {}

with File(h5file) as reader:
    for (tp, setup) in transforms.keys():
        img = reader[f'/t{pad(tp, 5)}/s{pad(setup, 2)}/0/cells'][...]
        images[(tp, setup)] = img


### Visualize one image in napari

In [5]:
import napari

if napari.current_viewer() is not None:
    napari.current_viewer().close()

napari.view_image(images[(0,0)])

Viewer(camera=Camera(center=(0.0, np.float64(255.5), np.float64(255.5)), zoom=np.float64(1.88330078125), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(np.float64(42.0), 1.0, 0.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=3, ndisplay=2, order=(0, 1, 2), axis_labels=('0', '1', '2'), rollable=(True, True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(85.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(511.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(511.0), step=np.float64(1.0))), margin_left=(0.0, 0.0, 0.0), margin_right=(0.0, 0.0, 0.0), point=(np.float64(42.0), np.float64(255.0), np.float64(255.0)), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Image layer 'Image' at 0x15716b7a0>], help='use <2> for transform', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dark', ti

## Do Fusion

In [16]:
from calmutils.stitching.fusion import fuse_image_blockwise, fuse_image
from calmutils.stitching.phase_correlation import get_axes_aligned_bbox

# which view ids to fuse
# 0-6: 6 tiles of first channel
view_ids_to_fuse = [(0,s) for s in range(6)]

# # estimate bounding box based on transforms
# bbox = get_axes_aligned_bbox([images[k].shape for k in view_ids_to_fuse], [transforms[k] for k in view_ids_to_fuse])

# # to int and to list of (min, max) tuples
# bbox = (b.astype(int) for b in bbox)
# list(zip(*bbox))

# fuse, set block size to something other than None to do multi-threaded
fused = fuse_image_blockwise([images[k] for k in view_ids_to_fuse], [transforms[k] for k in view_ids_to_fuse], block_size=(128,128,128), interpolation_mode='linear')
# fused = fuse_image([images[k] for k in view_ids_to_fuse], [transforms[k] for k in view_ids_to_fuse], interpolation_mode='linear')

100%|██████████| 96/96 [00:18<00:00,  5.21it/s]


### Visualize fusion results

In [10]:
if napari.current_viewer() is not None:
    napari.current_viewer().close()

napari.view_image(fused)

Viewer(camera=Camera(center=(0.0, np.float64(710.5), np.float64(488.0)), zoom=np.float64(0.9041256446319736), angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(np.float64(43.0), 1.0, 0.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=3, ndisplay=2, order=(0, 1, 2), axis_labels=('0', '1', '2'), rollable=(True, True, True), range=(RangeTuple(start=np.float64(0.0), stop=np.float64(87.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(1421.0), step=np.float64(1.0)), RangeTuple(start=np.float64(0.0), stop=np.float64(976.0), step=np.float64(1.0))), margin_left=(0.0, 0.0, 0.0), margin_right=(0.0, 0.0, 0.0), point=(np.float64(43.0), np.float64(710.0), np.float64(488.0)), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Image layer 'fused' at 0x32f2e11c0>], help='use <2> for transform', status='Ready', tooltip=Tooltip(visible=False, text=''), theme='dar