In [5]:
# -*- coding: utf-8 -*-
"""
Landmark-based co-registration of multi-scale HiP-CT data.
Using ORB features, evaluate a 4x4 spatial transformation between overview and zoom datasets
and export to Excel sheets all transformation matrices (initial, landmark, and final combined).
Transformation coefficients can be inputed directly into Neuroglancer.

Not for clinical use.
SPDX-FileCopyrightText: 2025 University College London, UK
SPDX-FileCopyrightText: 2025 Thierry L. Lefebvre
SPDX-License-Identifier: MIT
"""

def register_zoom_to_overview_orb_matrix_out(
    down_level: int,
    overview_name: str,
    zoom_name: str,
    num_slices: int = None,
    search_range: int = None,
    verbose: bool = True,
    use_xyz = False,
    random_seed: int = 42,
):
    import time
    import numpy as np
    import pandas as pd
    import cv2 as cv
    import SimpleITK as sitk
    import hoa_tools.dataset
    import hoa_tools.voi
    import hoa_tools.registration
    import random
    import matplotlib.pyplot as plt
    import numpy as np
    import xarray as xr
    import SimpleITK as sitk
    from hoa_tools.registration import Inventory as RegInventory
    from typing import Any


    start_time = time.time()
    np.random.seed(random_seed)
    random.seed(random_seed)
    
    base_slices = 8
    base_range = 20
    scale = max(1, 2 - down_level) if down_level < 2 else 1 / (down_level - 1 + 1e-5)
    if num_slices is None:
        num_slices = max(5, int(round(base_slices * scale)))
    if search_range is None:
        search_range = max(10, int(round(base_range * scale)))
        


    def get_matching_slice(
        zoom_voi,
        overview_voi,
        zoom_slice_index: int,
        axis: int = 0,  # 0=z, 1=y, 2=x
        interpolator: Any = sitk.sitkLinear,
    ) -> np.ndarray:
        # Get transform
        tfm = RegInventory.get_registration(
            source_dataset=zoom_voi.dataset,
            target_dataset=overview_voi.dataset,
        )

        # Compute physical point in zoom space
        zoom_spacing = zoom_voi.voxel_size_um
        zoom_origin = [
            zoom_voi.lower_corner.z * zoom_spacing,
            zoom_voi.lower_corner.y * zoom_spacing,
            zoom_voi.lower_corner.x * zoom_spacing,
        ]
        zoom_index = [
            zoom_voi.size.z // 2,
            zoom_voi.size.y // 2,
            zoom_voi.size.x // 2,
        ]
        zoom_index[axis] = zoom_slice_index
        zoom_phys = [
            zoom_origin[i] + zoom_spacing * zoom_index[i]
            for i in range(3)
        ]

        # Transform to overview physical point
        overview_phys = tfm.TransformPoint(zoom_phys)

        # Convert to index in overview array
        overview_spacing = overview_voi.voxel_size_um
        overview_origin = [
            overview_voi.lower_corner.z * overview_spacing,
            overview_voi.lower_corner.y * overview_spacing,
            overview_voi.lower_corner.x * overview_spacing,
        ]
        overview_index = [
            int(round((overview_phys[i] - overview_origin[i]) / overview_spacing))
            for i in range(3)
        ]

        # Safely extract 2D slice from overview VOI’s xarray
        da = overview_voi.get_data_array()
        dim_names = ['z', 'y', 'x']
        sel_kwargs = {dim_names[axis]: overview_index[axis]}
        slice_2d = da.isel(**sel_kwargs).values  # returns a 2D numpy array

        return slice_2d


    def get_slice_from_voi(voi, axis, index):
        da = voi.dataset.data_array(downsample_level=voi.downsample_level)
        slicers = {
            "x": slice(voi.lower_corner.x, voi.upper_corner.x),
            "y": slice(voi.lower_corner.y, voi.upper_corner.y),
            "z": slice(voi.lower_corner.z, voi.upper_corner.z),
        }
        dims = ['x', 'y', 'z']
        
        # Replace one of the slicers with an int to get a 2D slice
        dim = dims[axis]
        offset = getattr(voi.lower_corner, dim)
        slicers[dim] = offset + index

        return da.isel(**slicers).values




    def log(msg):
        if verbose:
            print(f"[{(time.time() - start_time)/60:.2f}min] {msg}")
 
    def normalize_to_uint8(xr_data, sample_frac=1e-1, bins=int(1e7), clip_z=2.0, out_range=(0, 255)):
        def sample_volume(array, frac):
            total_voxels = np.prod(array.shape)
            sample_size = int(total_voxels * frac)
            stride = int((total_voxels / sample_size) ** (1/3)) + 1
            return array[::stride, ::stride, ::stride]
        def percentile(p):
            return np.interp(p / 100.0, cdf, bin_edges[1:])        

        sampled = sample_volume(xr_data.data, sample_frac)
        flat_sample = sampled.ravel()
        if hasattr(flat_sample, "compute"):
            flat_sample = flat_sample.compute()


        hist, bin_edges = np.histogram(flat_sample, bins=bins)
        cdf = np.cumsum(hist) / np.sum(hist)

        p05 = percentile(0.05)
        p995 = percentile(99.95)
        clipped = xr_data.clip(p05, p995)

        mean = clipped.mean().compute()
        std = clipped.std().compute()
        zscore = (clipped - mean) / std
        zscore = zscore.clip(-clip_z, clip_z)

        norm = ((zscore + clip_z) / (2 * clip_z)) * (out_range[1] - out_range[0])
        return norm.clip(*out_range).astype(np.uint8)
    
    
        
    def normalize_slice_to_uint8(slice2d: np.ndarray, clip_z: float = 3.0, out_range=(0, 255)) -> np.ndarray:
        """
        Normalize a 2D image slice to uint8 using z-score clipping and linear scaling.

        Parameters
        ----------
        slice2d : np.ndarray
            2D grayscale input image.
        clip_z : float
            Z-score clipping range (e.g., 2.0 means clip to [-2, 2]).
        out_range : tuple
            Output intensity range, usually (0, 255) for uint8.

        Returns
        -------
        np.ndarray
            Normalized 2D uint8 image.
        """
        if slice2d.dtype != np.float32:
            slice2d = slice2d.astype(np.float32)

        mean = np.mean(slice2d)
        std = np.std(slice2d)
        if std == 0:
            std = 1e-5  # avoid division by zero

        zscore = (slice2d - mean) / std
        zscore = np.clip(zscore, -clip_z, clip_z)

        norm = ((zscore + clip_z) / (2 * clip_z)) * (out_range[1] - out_range[0])
        return np.clip(norm, *out_range).astype(np.uint8)
    
    def match_slices_along_axis(img_zoom, img_parent, axis, slice_indices, search_range=10):
        orb = cv.ORB_create(nfeatures=2000)
        zoom_matches_3d = []
        parent_matches_3d = []

        for slice_idx in slice_indices:
            if axis == 0:
                img2 = get_matching_slice(img_zoom, img_parent, slice_idx, axis)
                #img2 = img_parent.dataset.isel(z=slice_idx).values
                #img2 = get_slice_from_voi(img_parent, axis, slice_idx+100)
                plt.figure()
                plt.imshow(img2)
            elif axis == 1:
                img2 = img_parent.isel(y=slice_idx).values
            elif axis == 2:
                img2 = img_parent.isel(x=slice_idx).values
                
            img2 = normalize_slice_to_uint8(img2)

            kp2, des2 = orb.detectAndCompute(img2, None)
            if des2 is not None and len(kp2) > 1:
                kp2, des2 = zip(*sorted(zip(kp2, des2), key=lambda x: x[0].response, reverse=True))
                des2 = np.array(des2)
            else:
                continue

            max_good_matches = []
            best_kp1 = None

            for zoom_slice_idx in range(slice_idx - search_range, slice_idx + search_range + 1):
                if zoom_slice_idx < 0 or zoom_slice_idx >= [img_zoom.size.z, img_zoom.size.y, img_zoom.size.x][axis]:
                    continue
                if axis == 0:
                    #img1 = img_zoom.isel(z=zoom_slice_idx).values
                    img1 = get_slice_from_voi(img_zoom, axis, slice_idx)
                    plt.figure()
                    plt.imshow(img1)
                elif axis == 1:
                    img1 = img_zoom.isel(y=zoom_slice_idx).values
                elif axis == 2:
                    img1 = img_zoom.isel(x=zoom_slice_idx).values

                img1 = normalize_slice_to_uint8(img1)
                kp1, des1 = orb.detectAndCompute(img1, None)
                if des1 is not None and len(kp1) > 1:
                    kp1, des1 = zip(*sorted(zip(kp1, des1), key=lambda x: x[0].response, reverse=True))
                    des1 = np.array(des1)
                else:
                    continue


                bf = cv.BFMatcher(cv.NORM_HAMMING, crossCheck=False)
                matches = bf.knnMatch(des1, des2, k=2)

                good_matches = []
                for m_n in matches:
                    if len(m_n) == 2:
                        m, n = m_n
                        if m.distance < 0.7 * n.distance:
                            good_matches.append(m)

                if len(good_matches) > len(max_good_matches):
                    max_good_matches = good_matches
                    best_kp1 = kp1
                    best_zoom_slice_idx = zoom_slice_idx

            if max_good_matches:
                matched_coords_kp1 = np.array([best_kp1[m.queryIdx].pt for m in max_good_matches], dtype=np.float32)
                matched_coords_kp2 = np.array([kp2[m.trainIdx].pt for m in max_good_matches], dtype=np.float32)

                for pt1, pt2 in zip(matched_coords_kp1, matched_coords_kp2):
                    if axis == 0:
                        zoom_matches_3d.append([pt1[0], pt1[1], best_zoom_slice_idx])
                        parent_matches_3d.append([pt2[0], pt2[1], slice_idx])
                    elif axis == 1:
                        zoom_matches_3d.append([pt1[0], best_zoom_slice_idx, pt1[1]])
                        parent_matches_3d.append([pt2[0], slice_idx, pt2[1]])
                    elif axis == 2:
                        zoom_matches_3d.append([best_zoom_slice_idx, pt1[0], pt1[1]])
                        parent_matches_3d.append([slice_idx, pt2[0], pt2[1]])

        return np.array(zoom_matches_3d), np.array(parent_matches_3d)

    def convert_voxel_to_physical(points_3d, sitk_image):
        return [sitk_image.TransformIndexToPhysicalPoint([int(round(pt[2])), int(round(pt[1])), int(round(pt[0]))])
                for pt in points_3d]

    def sitk_transform_to_matrix_4x4(tfm: sitk.Transform) -> np.ndarray:
        if not isinstance(tfm, (sitk.AffineTransform, sitk.Similarity3DTransform)):
            raise TypeError(f"Unsupported transform type: {type(tfm)}")
        matrix = np.array(tfm.GetMatrix()).reshape(3, 3)
        translation = np.array(tfm.GetTranslation())
        mat4x4 = np.eye(4)
        mat4x4[:3, :3] = matrix
        mat4x4[:3, 3] = translation
        return mat4x4

    # Step 1: Load VOIs and data arrays
    log("Loading datasets")
    overview_dataset = hoa_tools.dataset.get_dataset(overview_name)
    zoom_dataset = hoa_tools.dataset.get_dataset(zoom_name)
    
    size_xyz = [round(v / (2**down_level)) for v in zoom_dataset.data.shape]

    zoom_voi = hoa_tools.voi.VOI(
        dataset=zoom_dataset,
        downsample_level=down_level,
        lower_corner={"x": 0, "y": 0, "z": 0},
        size={k: round(v / (2**down_level)) for k, v in zip(["x", "y", "z"], zoom_dataset.data.shape)}
    )
    overview_voi = zoom_voi.transform_to(overview_dataset)
    #zoom_array = zoom_voi.get_data_array()
    #overview_array = overview_voi.get_data_array()
    #resampled_overview = overview_voi.get_data_array_on_voi(zoom_voi, interpolator=sitk.sitkNearestNeighbor)
    #log("Normalizing")
    #img_zoom = normalize_to_uint8(zoom_array)
    #img_parent = normalize_to_uint8(resampled_overview)

    # Step 2: Match features
    log("Finding feature matches")
    zoom_matches_all, parent_matches_all = [], []
    for axis in range(3):
        #slice_indices = np.linspace(search_range, size_xyz[axis] - search_range, num_slices, dtype=int)
        base_margin = 20
        scale = 2 - down_level
        edge_margin = max(int(base_margin * (2 ** scale)), 20)

        start = int(edge_margin)
        end = max(int(size_xyz[axis] - edge_margin), start + 1)
        slice_indices = np.linspace(start, end - 1, num_slices, dtype=int)

        z, p = match_slices_along_axis(zoom_voi, overview_voi, axis, slice_indices, search_range)
        zoom_matches_all.append(z)
        parent_matches_all.append(p)

    zoom_matches = np.vstack(zoom_matches_all)
    parent_matches = np.vstack(parent_matches_all)

    distances = np.linalg.norm(zoom_matches - parent_matches, axis=1)
    mask = distances < (np.mean(distances) + 0.2 * np.std(distances))
    zoom_matches = zoom_matches[mask]
    parent_matches = parent_matches[mask]
    log(f"{len(zoom_matches)} matches retained")

    # Step 3: Compute transform in physical space
    log("Computing transform in µm")
    zoom_img_sitk = zoom_voi.get_sitk_image()
    overview_img_sitk = overview_voi.get_sitk_image()

    pre_transform = hoa_tools.registration.Inventory.get_registration(
        source_dataset=zoom_voi.dataset,
        target_dataset=overview_voi.dataset
    )

    overview_img_sitk = sitk.Resample(overview_img_sitk, zoom_img_sitk, pre_transform,
                                      sitk.sitkLinear, 0.0, overview_img_sitk.GetPixelID())

    fixed = convert_voxel_to_physical(zoom_matches, zoom_img_sitk)
    moving = convert_voxel_to_physical(parent_matches, zoom_img_sitk)

    transform = sitk.LandmarkBasedTransformInitializer(
        sitk.AffineTransform(3),
        fixedLandmarks=np.ravel(fixed),
        movingLandmarks=np.ravel(moving)
    )

    composite = sitk.CompositeTransform(3)
    composite.AddTransform(pre_transform)
    composite.AddTransform(transform)

    # === Save transform matrices to CSV ===
    log("Saving transform matrices to CSV")

    # Get 4x4 matrices
    mat_pre = sitk_transform_to_matrix_4x4(pre_transform)
    mat_landmark = sitk_transform_to_matrix_4x4(transform)

    mat_combined = mat_landmark @ mat_pre

    if use_xyz:
        mat_pre = mat_pre[[2,1,0], :][:, [2,1,0,3]]
        mat_landmark = mat_landmark[[2,1,0], :][:, [2,1,0,3]]
        mat_combined = mat_combined[[2,1,0], :][:, [2,1,0,3]]
        xyz_suffix = "_xyz"
    else:
        xyz_suffix = ""

    # Save all matrices to CSV files
    csv_pre_path = f"{overview_name}_to_{zoom_name}_pre_transform_physical_um{xyz_suffix}.csv"
    csv_landmark_path = f"{overview_name}_to_{zoom_name}_landmark_only_physical_um{xyz_suffix}.csv"
    csv_combined_path = f"{overview_name}_to_{zoom_name}_combined_transform_physical_um{xyz_suffix}.csv"

    pd.DataFrame(mat_pre).to_csv(csv_pre_path, index=False, header=False)
    pd.DataFrame(mat_landmark).to_csv(csv_landmark_path, index=False, header=False)
    pd.DataFrame(mat_combined).to_csv(csv_combined_path, index=False, header=False)

    log(f"Transform matrices saved to {csv_pre_path}, {csv_landmark_path}, and {csv_combined_path}")

    return composite

In [2]:
# -*- coding: utf-8 -*-
"""
Landmark-based co-registration of multi-scale HiP-CT data.
Using ORB features, evaluate a 4x4 spatial transformation between overview and zoom datasets
and export to Excel sheets all transformation matrices (initial, landmark, and final combined).
Transformation coefficients can be inputed directly into Neuroglancer.

Not for clinical use.
SPDX-FileCopyrightText: 2025 University College London, UK
SPDX-FileCopyrightText: 2025 Thierry L. Lefebvre
SPDX-License-Identifier: MIT
"""

def register_zoom_to_overview_orb_matrix_out(
    down_level: int,
    overview_name: str,
    zoom_name: str,
    num_slices: int = None,
    search_range: int = None,
    verbose: bool = True,
    apply_transform: bool = True,
    show_vedo: bool = True,    
    save_picture: bool = True,
    use_xyz = False,
    random_seed: int = 42,
):
    import time
    import numpy as np
    import pandas as pd
    import cv2 as cv
    import SimpleITK as sitk
    import hoa_tools.dataset
    import hoa_tools.voi
    import hoa_tools.registration
    import random
    #import matplotlib.pyplot as plt
    import vedo
    from vedo import Volume, Plane, Plotter
    vedo.settings.default_backend = "jupyter"  # For Jupyter Notebook plotting    

    start_time = time.time()
    np.random.seed(random_seed)
    random.seed(random_seed)
    
    base_slices = 8
    base_range = 20
    scale = max(1, 2 - down_level) if down_level < 2 else 1 / (down_level - 1 + 1e-5)
    if num_slices is None:
        num_slices = max(5, int(round(base_slices * scale)))
    if search_range is None:
        search_range = max(10, int(round(base_range * scale)))


    def log(msg):
        if verbose:
            print(f"[{(time.time() - start_time)/60:.2f}min] {msg}")
            
    
    def make_rgb_overlay_slice(v1s, v2s, gamma=1.0):
        v1s = (v1s - v1s.min()) / (np.ptp(v1s) + 1e-5)
        v2s = (v2s - v2s.min()) / (np.ptp(v2s) + 1e-5)
        v1s = v1s**gamma
        v2s = v2s**gamma
        rgb = np.zeros((*v1s.shape, 3), dtype=np.float32)
        rgb[..., 0] = v1s
        rgb[..., 1] = v2s
        rgb[..., 2] = v1s
        return (np.clip(rgb, 0, 1) * 255).astype(np.uint8)
            
    def vedo_overlay(vol1, vol2, spacing, zoom_name, suffix, interactive, save):
        


        vol1 = (vol1 - vol1.min()) / (np.ptp(vol1) + 1e-5)
        vol2 = (vol2 - vol2.min()) / (np.ptp(vol2) + 1e-5)
        vol1_cut = vol1.copy()
        vol2_cut = vol2.copy()
        z, y, x = vol2.shape
        
        cy, cx = y // 2, x // 2
        radius = min(cy, cx)

        Y, X = np.ogrid[:y, :x]
        circular_mask = (X - cx) ** 2 + (Y - cy) ** 2 <= radius ** 2
        cylindrical_mask = np.broadcast_to(circular_mask, (z, y, x))
        vol2_cut[~cylindrical_mask] = 0
        

        vol1_cut[:z//2, :y//2, x//2:] = 0
        vol2_cut[:z//2, :y//2, x//2:] = 0
                    


        v1 = Volume(vol1_cut.transpose(2,1,0)).spacing(spacing).cmap([(0, 0, 0), (1, 0, 1)]).alpha([0, 0.0,0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8 ,1]).shade(False)#[0, 0.25, 0.45, 0.6, 1]
        v2 = Volume(vol2_cut.transpose(2,1,0)).spacing(spacing).cmap([(0, 0, 0), (0, 1, 0)]).alpha([0, 0.0,0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1]).shade(False)

        z_idx = z // 2
        y_idx = y // 2
        x_idx = x // 2
        sx, sy, sz = spacing
        px = x_idx * sx
        py = y_idx * sy
        pz = z_idx * sz

        rgb_xy = make_rgb_overlay_slice(vol1[z_idx], vol2[z_idx])
        rgb_xz = make_rgb_overlay_slice(vol1[:, y_idx, :], vol2[:, y_idx, :]).transpose(1, 0, 2)
        rgb_yz = make_rgb_overlay_slice(vol1[:, :, x_idx], vol2[:, :, x_idx]).transpose(1, 0, 2)

        plane_xy = Plane(pos=(px, py, pz), normal=(0, 0, 1), s=(x*sx, y*sy)).texture(rgb_xy)
        plane_xz = Plane(pos=(px, py, z*sz/2), normal=(0, 1, 0), s=(x*sx, z*sz)).texture(rgb_xz)
        plane_yz = Plane(pos=(px, y*sy/2, z*sz/2), normal=(1, 0, 0), s=(y*sy, z*sz)).texture(rgb_yz)

        plt = Plotter(bg='black', size=(1600, 1200))
        plt.show(v1,v2, plane_xy, plane_xz, plane_yz, axes=0, interactive=interactive) #v2, plane_xy, plane_xz, plane_yz,

        cx = x * sx / 2
        cy = y * sy / 2
        cz = z * sz / 2
        cam_pos = (x * sx * 1.5, -y * sy * 0.5, -z * sz * 0.5)
        plt.camera.SetPosition(cam_pos)
        plt.camera.SetFocalPoint((cx, cy, cz))
        plt.camera.SetViewUp((0, 0, 1))

        if save:
            fn = f"{zoom_name}_{suffix}_registration.png"
            plt.screenshot(fn, scale=3)
            if not interactive:
                plt.close()            
        
    def get_slice_from_voi(voi, axis, index):

        dims = ['z', 'y', 'x']
        dim = dims[axis]

        offset = getattr(voi.lower_corner, dim)
        abs_index = offset + index

        slicers = {}
        for d in dims:
            if d == dim:
                slicers[d] = abs_index
            else:
                slicers[d] = slice(getattr(voi.lower_corner, d), getattr(voi.upper_corner, d))

        da = voi.dataset.data_array(downsample_level=voi.downsample_level)
        slice2d = da.isel(**slicers).values
        h, w = slice2d.shape
        cy, cx = h // 2, w // 2
        radius = min(cy, cx)
        y, x = np.ogrid[:h, :w]
        mask = (x - cx) ** 2 + (y - cy) ** 2 <= radius ** 2

        slice2d[~mask] = 0

        return slice2d


    def log(msg):
        if verbose:
            print(f"[{(time.time() - start_time)/60:.2f}min] {msg}")
 
    def normalize_to_uint8(xr_data, sample_frac=1e-1, bins=int(1e7), clip_z=3.0, out_range=(0, 255)):
        def sample_volume(array, frac):
            total_voxels = np.prod(array.shape)
            sample_size = int(total_voxels * frac)
            stride = int((total_voxels / sample_size) ** (1/3)) + 1
            return array[::stride, ::stride, ::stride]
        def percentile(p):
            return np.interp(p / 100.0, cdf, bin_edges[1:])        

        sampled = sample_volume(xr_data.data, sample_frac)
        flat_sample = sampled.ravel()
        if hasattr(flat_sample, "compute"):
            flat_sample = flat_sample.compute()


        hist, bin_edges = np.histogram(flat_sample, bins=bins)
        cdf = np.cumsum(hist) / np.sum(hist)

        p05 = percentile(0.05)
        p995 = percentile(99.95)
        clipped = xr_data.clip(p05, p995)

        mean = clipped.mean().compute()
        std = clipped.std().compute()
        zscore = (clipped - mean) / std
        zscore = zscore.clip(-clip_z, clip_z)

        norm = ((zscore + clip_z) / (2 * clip_z)) * (out_range[1] - out_range[0])
        return norm.clip(*out_range).astype(np.uint8)
    
    
    def normalize_slice_to_uint8(slice2d: np.ndarray, clip_z: float = 3.0, out_range=(0, 255)) -> np.ndarray:

        if slice2d.dtype != np.float32:
            slice2d = slice2d.astype(np.float32)

        valid_mask = slice2d > 0

        if not np.any(valid_mask):
            return np.zeros_like(slice2d, dtype=np.uint8)

        valid_pixels = slice2d[valid_mask]
        mean = np.mean(valid_pixels)
        std = np.std(valid_pixels)
        std = std if std > 1e-5 else 1e-5  # avoid division by zero

        zscore = (slice2d - mean) / std
        zscore = np.clip(zscore, -clip_z, clip_z)

        norm = ((zscore + clip_z) / (2 * clip_z)) * (out_range[1] - out_range[0])
        norm = np.clip(norm, *out_range).astype(np.uint8)

        # Reapply mask: set background to 0
        norm[~valid_mask] = 0

        return norm


    def match_slices_along_axis(img_zoom, img_parent, axis, slice_indices, search_range=10):
        orb = cv.ORB_create(nfeatures=2000)
        zoom_matches_3d = []
        parent_matches_3d = []
        slice_num = 0

        for slice_idx in slice_indices:
            slice_num+=1
            if axis == 0:
                img2 = img_parent.isel(z=slice_idx).values
            elif axis == 1:
                if slice_num<3: # Skip first two slices in XZ, YZ, with cylinder VOI mostly background
                    continue
                img2 = img_parent.isel(y=slice_idx).values

            elif axis == 2:
                if slice_num<3:
                    continue
                img2 = img_parent.isel(x=slice_idx).values

            kp2, des2 = orb.detectAndCompute(img2, None)
            if des2 is not None and len(kp2) > 1:
                kp2, des2 = zip(*sorted(zip(kp2, des2), key=lambda x: x[0].response, reverse=True))
                des2 = np.array(des2)
            else:
                continue

            max_good_matches = []
            best_kp1 = None

            for zoom_slice_idx in range(slice_idx - search_range, slice_idx + search_range + 1):
                if zoom_slice_idx < 0 or zoom_slice_idx >= [img_zoom.size.z, img_zoom.size.y, img_zoom.size.x][axis]:
                    continue
                    
                img1 = get_slice_from_voi(img_zoom, axis, slice_idx)
                img1 = normalize_slice_to_uint8(img1)

                kp1, des1 = orb.detectAndCompute(img1, None)
                if des1 is not None and len(kp1) > 1:
                    kp1, des1 = zip(*sorted(zip(kp1, des1), key=lambda x: x[0].response, reverse=True))
                    des1 = np.array(des1)
                else:
                    continue

                bf = cv.BFMatcher(cv.NORM_HAMMING, crossCheck=False)
                matches = bf.knnMatch(des1, des2, k=2)

                good_matches = []
                for m_n in matches:
                    if len(m_n) == 2:
                        m, n = m_n
                        if m.distance < 0.7 * n.distance:
                            good_matches.append(m)

                if len(good_matches) > len(max_good_matches):
                    max_good_matches = good_matches
                    best_kp1 = kp1
                    best_zoom_slice_idx = zoom_slice_idx

            if max_good_matches:
                matched_coords_kp1 = np.array([best_kp1[m.queryIdx].pt for m in max_good_matches], dtype=np.float32)
                matched_coords_kp2 = np.array([kp2[m.trainIdx].pt for m in max_good_matches], dtype=np.float32)

                for pt1, pt2 in zip(matched_coords_kp1, matched_coords_kp2):
                    if axis == 0:
                        zoom_matches_3d.append([pt1[0], pt1[1], best_zoom_slice_idx])
                        parent_matches_3d.append([pt2[0], pt2[1], slice_idx])
                    elif axis == 1:
                        zoom_matches_3d.append([pt1[0], best_zoom_slice_idx, pt1[1]])
                        parent_matches_3d.append([pt2[0], slice_idx, pt2[1]])
                    elif axis == 2:
                        zoom_matches_3d.append([best_zoom_slice_idx, pt1[0], pt1[1]])
                        parent_matches_3d.append([slice_idx, pt2[0], pt2[1]])

        return np.array(zoom_matches_3d), np.array(parent_matches_3d)

    def convert_voxel_to_physical(points_3d, sitk_image):
        return [sitk_image.TransformIndexToPhysicalPoint([int(round(pt[2])), int(round(pt[1])), int(round(pt[0]))])
                for pt in points_3d]

    def sitk_transform_to_matrix_4x4(tfm: sitk.Transform) -> np.ndarray:
        if not isinstance(tfm, (sitk.AffineTransform, sitk.Similarity3DTransform)):
            raise TypeError(f"Unsupported transform type: {type(tfm)}")
        matrix = np.array(tfm.GetMatrix()).reshape(3, 3)
        translation = np.array(tfm.GetTranslation())
        mat4x4 = np.eye(4)
        mat4x4[:3, :3] = matrix
        mat4x4[:3, 3] = translation
        return mat4x4

    # Step 1: Load VOIs and data arrays
    log("Loading datasets")
    overview_dataset = hoa_tools.dataset.get_dataset(overview_name)
    zoom_dataset = hoa_tools.dataset.get_dataset(zoom_name)

    zoom_voi = hoa_tools.voi.VOI(
        dataset=zoom_dataset,
        downsample_level=down_level,
        lower_corner={"x": 0, "y": 0, "z": 0},
        size={k: round(v / (2**down_level)) for k, v in zip(["x", "y", "z"], zoom_dataset.data.shape)}
    )
    overview_voi = zoom_voi.transform_to(overview_dataset)
    resampled_overview = overview_voi.get_data_array_on_voi(zoom_voi, interpolator=sitk.sitkNearestNeighbor)
    img_parent = normalize_to_uint8(resampled_overview)

    # Step 2: Match features
    log("Finding feature matches")
    zoom_matches_all, parent_matches_all = [], []
    for axis in range(3):
        size_xyz = [zoom_voi.size.z, zoom_voi.size.y, zoom_voi.size.x]
        slice_indices = np.linspace(search_range, size_xyz[axis] - search_range, num_slices, dtype=int)
        z, p = match_slices_along_axis(zoom_voi, img_parent, axis, slice_indices, search_range)
        zoom_matches_all.append(z)
        parent_matches_all.append(p)

    zoom_matches = np.vstack(zoom_matches_all)
    parent_matches = np.vstack(parent_matches_all)

    distances = np.linalg.norm(zoom_matches - parent_matches, axis=1)
    mask = distances < (np.mean(distances) + 0.2 * np.std(distances))
    zoom_matches = zoom_matches[mask]
    parent_matches = parent_matches[mask]
    log(f"{len(zoom_matches)} matches retained")

    # Step 3: Compute transform in physical space
    log("Computing transform in µm")
    
    zoom_size = [zoom_voi.size.z, zoom_voi.size.y, zoom_voi.size.x]
    zoom_spacing = [zoom_voi.voxel_size_um] * 3
    zoom_origin = [
        zoom_voi.lower_corner.z * zoom_voi.voxel_size_um,
        zoom_voi.lower_corner.y * zoom_voi.voxel_size_um,
        zoom_voi.lower_corner.x * zoom_voi.voxel_size_um,
    ]

    zoom_img_sitk = sitk.Image(zoom_size, sitk.sitkUInt8)
    zoom_img_sitk.SetSpacing(zoom_spacing)
    zoom_img_sitk.SetOrigin(zoom_origin)
        

    pre_transform = hoa_tools.registration.Inventory.get_registration(
        source_dataset=zoom_voi.dataset,
        target_dataset=overview_voi.dataset
    )

    fixed = convert_voxel_to_physical(zoom_matches, zoom_img_sitk)
    moving = convert_voxel_to_physical(parent_matches, zoom_img_sitk)

    transform = sitk.LandmarkBasedTransformInitializer(
        sitk.AffineTransform(3),
        fixedLandmarks=np.ravel(fixed),
        movingLandmarks=np.ravel(moving)
    )

    composite = sitk.CompositeTransform(3)
    composite.AddTransform(pre_transform)
    composite.AddTransform(transform)
    
    # STEP 6: Apply transform (optional)
    transformed_array = None
    if apply_transform:
        log("Applying transform to overview")
        transformed_array = overview_voi.get_data_array_on_voi(
            zoom_voi,
            interpolator=sitk.sitkLinear,
            transform=composite.GetInverse()
        )

    spacing = (1.0, 1.0, 1.0)  # isotropic default#spacing = tuple(zoom_image.GetSpacing())
    if show_vedo or save_picture:
        if apply_transform:
            vedo_overlay(transformed_array.values, normalize_to_uint8(zoom_voi.get_data_array()).values, spacing, zoom_name, "post", show_vedo, save_picture)
    

    # === Save transform matrices to CSV ===
    log("Saving transform matrices to CSV")

    # Get 4x4 matrices
    mat_pre = sitk_transform_to_matrix_4x4(pre_transform)
    mat_landmark = sitk_transform_to_matrix_4x4(transform)

    mat_combined = mat_landmark @ mat_pre

    if use_xyz:
        mat_pre = mat_pre[[2,1,0], :][:, [2,1,0,3]]
        mat_landmark = mat_landmark[[2,1,0], :][:, [2,1,0,3]]
        mat_combined = mat_combined[[2,1,0], :][:, [2,1,0,3]]
        xyz_suffix = "_xyz"
    else:
        xyz_suffix = ""

    # Save all matrices to CSV files
    csv_pre_path = f"{overview_name}_to_{zoom_name}_down{down_level}_pre_transform_physical_um{xyz_suffix}.csv"
    csv_landmark_path = f"{overview_name}_to_{zoom_name}_down{down_level}_landmark_only_physical_um{xyz_suffix}.csv"
    csv_combined_path = f"{overview_name}_to_{zoom_name}_down{down_level}_combined_transform_physical_um{xyz_suffix}.csv"

    pd.DataFrame(mat_pre).to_csv(csv_pre_path, index=False, header=False)
    pd.DataFrame(mat_landmark).to_csv(csv_landmark_path, index=False, header=False)
    pd.DataFrame(mat_combined).to_csv(csv_combined_path, index=False, header=False)

    log(f"Transform matrices saved to {csv_pre_path}, {csv_landmark_path}, and {csv_combined_path}")

    return composite

In [None]:

transform = register_zoom_to_overview_orb_matrix_out(
    overview_name="OVERVIEW SCAN ID",
    zoom_name="ZOOM SCAN ID",
    verbose=True,
    apply_transform=False,
    show_vedo=False,    
    save_picture = False,
)

