In [None]:
# -*- coding: utf-8 -*-
"""
Landmark-based co-registration of multi-scale HiP-CT data.
Using ORB features, evaluate a 4x4 spatial transformation between overview and zoom datasets
and export to Excel sheets all transformation matrices (initial, landmark, and final combined).
Transformation coefficients can be inputed directly into Neuroglancer.

Not for clinical use.
SPDX-FileCopyrightText: 2025 University College London, UK
SPDX-FileCopyrightText: 2025 Thierry L. Lefebvre
SPDX-License-Identifier: MIT
"""

def register_zoom_to_overview_orb_matrix_out(
    down_level: int,
    overview_name: str,
    zoom_name: str,
    num_slices: int = 5,
    search_range: int = 10,
    verbose: bool = True,
):
    import time
    import numpy as np
    import pandas as pd
    import cv2 as cv
    import SimpleITK as sitk
    import hoa_tools.dataset
    import hoa_tools.voi
    import hoa_tools.registration

    start_time = time.time()

    def log(msg):
        if verbose:
            print(f"[{(time.time() - start_time)/60:.2f}min] {msg}")
 
    def normalize_slice_to_uint8(img_2d, clip_z=3.0):
        p_low, p_high = np.percentile(img_2d, [0.05, 99.995])
        img_clipped = np.clip(img_2d, p_low, p_high)
        mean = np.mean(img_clipped)
        std = np.std(img_clipped)
        zscore = (img_clipped - mean) / std
        zscore = np.clip(zscore, -clip_z, clip_z)
        norm = ((zscore + clip_z) / (2 * clip_z))
        return (norm * 255).astype(np.uint8)
    

    def match_slices_along_axis(img_zoom, img_parent, axis, slice_indices, search_range=10):
        orb = cv.ORB_create(nfeatures=2000)
        zoom_matches_3d = []
        parent_matches_3d = []

        for slice_idx in slice_indices:
            if axis == 0:
                img2 = img_parent.isel(z=slice_idx).values
            elif axis == 1:
                img2 = img_parent.isel(y=slice_idx).values
            elif axis == 2:
                img2 = img_parent.isel(x=slice_idx).values

            img2  =  normalize_slice_to_uint8(img2)

            kp2, des2 = orb.detectAndCompute(img2, None)
            if des2 is None or len(kp2) < 2:
                continue

            max_good_matches = []
            best_kp1 = None

            for zoom_slice_idx in range(slice_idx - search_range, slice_idx + search_range + 1):
                if zoom_slice_idx < 0 or zoom_slice_idx >= img_zoom.shape[axis]:
                    continue
                if axis == 0:
                    img1 = img_zoom.isel(z=zoom_slice_idx).values
                elif axis == 1:
                    img1 = img_zoom.isel(y=zoom_slice_idx).values
                elif axis == 2:
                    img1 = img_zoom.isel(x=zoom_slice_idx).values

                img1  =  normalize_slice_to_uint8(img1)


                kp1, des1 = orb.detectAndCompute(img1, None)
                if des1 is None or len(kp1) < 2:
                    continue

                bf = cv.BFMatcher(cv.NORM_HAMMING, crossCheck=False)
                matches = bf.knnMatch(des1, des2, k=2)

                good_matches = []
                for m_n in matches:
                    if len(m_n) == 2:
                        m, n = m_n
                        if m.distance < 0.7 * n.distance:
                            good_matches.append(m)

                if len(good_matches) > len(max_good_matches):
                    max_good_matches = good_matches
                    best_kp1 = kp1
                    best_zoom_slice_idx = zoom_slice_idx

            if max_good_matches:
                matched_coords_kp1 = np.array([best_kp1[m.queryIdx].pt for m in max_good_matches], dtype=np.float32)
                matched_coords_kp2 = np.array([kp2[m.trainIdx].pt for m in max_good_matches], dtype=np.float32)

                for pt1, pt2 in zip(matched_coords_kp1, matched_coords_kp2):
                    if axis == 0:
                        zoom_matches_3d.append([pt1[0], pt1[1], best_zoom_slice_idx])
                        parent_matches_3d.append([pt2[0], pt2[1], slice_idx])
                    elif axis == 1:
                        zoom_matches_3d.append([pt1[0], best_zoom_slice_idx, pt1[1]])
                        parent_matches_3d.append([pt2[0], slice_idx, pt2[1]])
                    elif axis == 2:
                        zoom_matches_3d.append([best_zoom_slice_idx, pt1[0], pt1[1]])
                        parent_matches_3d.append([slice_idx, pt2[0], pt2[1]])

        return np.array(zoom_matches_3d), np.array(parent_matches_3d)

    def convert_voxel_to_physical(points_3d, sitk_image):
        return [sitk_image.TransformIndexToPhysicalPoint([int(round(pt[2])), int(round(pt[1])), int(round(pt[0]))])
                for pt in points_3d]

    def sitk_transform_to_matrix_4x4(tfm: sitk.Transform) -> np.ndarray:
        if not isinstance(tfm, (sitk.AffineTransform, sitk.Similarity3DTransform)):
            raise TypeError(f"Unsupported transform type: {type(tfm)}")
        matrix = np.array(tfm.GetMatrix()).reshape(3, 3)
        translation = np.array(tfm.GetTranslation())
        mat4x4 = np.eye(4)
        mat4x4[:3, :3] = matrix
        mat4x4[:3, 3] = translation
        return mat4x4

    # Step 1: Load VOIs and data arrays
    log("Loading datasets")
    overview_dataset = hoa_tools.dataset.get_dataset(overview_name)
    zoom_dataset = hoa_tools.dataset.get_dataset(zoom_name)

    zoom_voi = hoa_tools.voi.VOI(
        dataset=zoom_dataset,
        downsample_level=down_level,
        lower_corner={"x": 0, "y": 0, "z": 0},
        size={k: round(v / (2**down_level)) for k, v in zip(["x", "y", "z"], zoom_dataset.data.shape)}
    )
    overview_voi = zoom_voi.transform_to(overview_dataset)
    img_zoom = zoom_voi.get_data_array()
    #overview_array = overview_voi.get_data_array()
    img_parent = overview_voi.get_data_array_on_voi(zoom_voi, interpolator=sitk.sitkNearestNeighbor)

    #log("Normalizing")
    #img_zoom = normalize_to_uint8(zoom_array)
    #img_parent = normalize_to_uint8(resampled_overview)
    

    # Step 2: Match features
    log("Finding feature matches")
    zoom_matches_all, parent_matches_all = [], []
    for axis in range(3):
        slice_indices = np.linspace(search_range, img_zoom.shape[axis] - search_range, num_slices, dtype=int)
        z, p = match_slices_along_axis(img_zoom, img_parent, axis, slice_indices, search_range)
        zoom_matches_all.append(z)
        parent_matches_all.append(p)

    zoom_matches = np.vstack(zoom_matches_all)
    parent_matches = np.vstack(parent_matches_all)

    distances = np.linalg.norm(zoom_matches - parent_matches, axis=1)
    mask = distances < (np.mean(distances) + 0.2 * np.std(distances))
    zoom_matches = zoom_matches[mask]
    parent_matches = parent_matches[mask]
    log(f"{len(zoom_matches)} matches retained")

    # Step 3: Compute transform in physical space
    log("Computing transform in µm")
    zoom_img_sitk = zoom_voi.get_sitk_image()
    overview_img_sitk = overview_voi.get_sitk_image()

    pre_transform = hoa_tools.registration.Inventory.get_registration(
        source_dataset=zoom_voi.dataset,
        target_dataset=overview_voi.dataset
    )

    overview_img_sitk = sitk.Resample(overview_img_sitk, zoom_img_sitk, pre_transform,
                                      sitk.sitkLinear, 0.0, overview_img_sitk.GetPixelID())

    fixed = convert_voxel_to_physical(zoom_matches, zoom_img_sitk)
    moving = convert_voxel_to_physical(parent_matches, zoom_img_sitk)

    transform = sitk.LandmarkBasedTransformInitializer(
        sitk.AffineTransform(3),
        fixedLandmarks=np.ravel(fixed),
        movingLandmarks=np.ravel(moving)
    )

    composite = sitk.CompositeTransform(3)
    composite.AddTransform(pre_transform)
    composite.AddTransform(transform)

    # === Save transform matrices to Excel ===
    log("Saving transform matrices to Excel")


    # Get 4x4 matrices
    mat_pre = sitk_transform_to_matrix_4x4(pre_transform)
    mat_landmark = sitk_transform_to_matrix_4x4(transform)

    mat_combined = mat_landmark @ mat_pre

    # Save all matrices to Excel
    excel_path = f"{overview_name}_to_{zoom_name}_transforms_physical_um.xlsx"
    with pd.ExcelWriter(excel_path) as writer:
        pd.DataFrame(mat_pre).to_excel(writer, sheet_name="Pre_Transform", index=False)
        pd.DataFrame(mat_landmark).to_excel(writer, sheet_name="Landmark_Only", index=False)
        pd.DataFrame(mat_combined).to_excel(writer, sheet_name="Combined_Transform", index=False)

    log(f"Transform matrices saved to {excel_path}")

    return composite


In [None]:
# Only run for private data
import pathlib
import hoa_tools.dataset
hoa_tools.dataset.change_metadata_directory(pathlib.Path('/hdd2/thierry/private-hoa-metadata/metadata/'))

# Adapt your file names you wish to co-register
transform = register_zoom_to_overview_orb_matrix_out(
    down_level=2,
    overview_name="SH1_skull_complete-organ_16.545um_bm18",
    zoom_name="SH1_skull_VOI-09_2.20um_bm18",   
)




[0.00min] Loading datasets
[0.52min] Normalizing
[0.88min] Finding feature matches
[1.66min] 1794 matches retained
[1.66min] Computing transform in µm
[2.18min] Saving transform matrices to Excel
[2.18min] Transform matrices saved to SH1_skull_complete-organ_16.545um_bm18_to_SH1_skull_VOI-09_2.20um_bm18_transforms_physical_um.xlsx
