In [None]:
import zarr
import re
import numpy as np
import xml.etree.ElementTree as ET
from skimage.registration import phase_cross_correlation
from scipy.ndimage import shift
from scipy.spatial import KDTree
from collections import defaultdict, deque


In [None]:

# ----- File Paths -----
base_path = "/mnt/gdrive/ThNe/development_retina/raw_images_n5/"
xml_folder = "P15/F/L/"
file_name = "retina_Age_P15_Sex_F_Side_L_Animal_1"
file_path = base_path + xml_folder + file_name

n5_path = file_path + '.n5'
xml_path = file_path + '.xml'

In [None]:
# ----- Open N5 dataset with zarr -----
z = zarr.open(n5_path, mode='r')

# ----- Select only the setups you want (e.g., channel of interest: setups where id % 4 == 2) -----
selected_setups = sorted(
    [k for k in z.group_keys() if re.match(r'setup\d+$', k) and int(k[5:]) % 4 == 2],
    key=lambda s: int(s[5:])
)
print("Selected setups:", selected_setups)

In [None]:
# # ----- Precompute Z-max Projections (store in a dictionary keyed by setup id) -----
# zmax_tiles = {}
# for setup in selected_setups:
#     sid = int(setup[5:])  # extract numeric setup id
#     print("Processing", setup)
#     tile_data = z[f'{setup}/timepoint0/s0'][:]  # load tile (3D array: Z, Y, X)
#     z_max_projection = np.max(tile_data, axis=0)  # Z-max projection (2D image)
#     zmax_tiles[sid] = z_max_projection

# print("Number of Z-max projected tiles:", len(zmax_tiles))

In [None]:
def select_z_slices(tile, start=0, end=100, step=2):
    # tile.shape is assumed to be (150, H, W)
    z_indices = np.arange(start, end + 1, step)  # Create indices from start to end with step
    return tile[z_indices, :, :]

In [None]:
# ----- Precompute Z-max Projections (store in a dictionary keyed by setup id) -----
zsub_tiles = {}
for setup in selected_setups:
    sid = int(setup[5:])  # extract numeric setup id
    print("Processing", setup)
    tile_data = z[f'{setup}/timepoint0/s0'][:]  # load tile (3D array: Z, Y, X)
    sampled_tile = select_z_slices(tile_data, start=5, end=35, step=1)
    zsub_tiles[sid] = sampled_tile


In [None]:
import matplotlib.pyplot as plt
from skimage import exposure

tile = zsub_tiles[16*4+2][6]
if hasattr(tile, 'compute'):
    tile = tile.compute()
tile_clahe = exposure.equalize_adapthist(tile, clip_limit=0.1)
# Show the tile
plt.figure(figsize=(6, 6))
plt.imshow(tile_clahe, cmap='gray')
plt.title(f"Z-slice {z}")
plt.axis('off')
plt.show()

In [None]:
# ----- Parse XML to Extract Tile Positions -----
tree = ET.parse(xml_path)
root = tree.getroot()

# We'll store the affine transformation for each tile (as a flattened 3x4 matrix).
positions = {}
for reg in root.findall('.//ViewRegistration'):
    setup_id = int(reg.attrib['setup'])
    # Only consider setups for the channel of interest
    if setup_id % 4 == 2:
        affine_elem = reg.find('.//affine')
        affine = list(map(float, affine_elem.text.strip().split()))
        # affine is expected to be a flat list of 12 elements (3x4 matrix in row-major order)
        positions[setup_id] = affine

In [None]:
# ----- Build Neighbor Pairs using KDTree (based on translation vectors) -----
# Convert positions to a list of translation vectors.
setup_ids = sorted(positions.keys())
coords = np.array([np.array(positions[sid]).reshape(3, 4)[:, 3] for sid in setup_ids])
print("Tile Coordinates:\n", coords)

# --- Tile Specifications ---
tile_size_in_pixels = 2048        # Tile size in pixels
overlap_percent = 0.05            # 10% overlap between adjacent tiles

# Compute the effective tile center-to-center distance in pixels.
# With 10% overlap, the effective distance = 2048 - (2048 * 0.20)
effective_distance_pixels = tile_size_in_pixels * (1 - overlap_percent)
print("Effective center-to-center distance (pixels):", effective_distance_pixels)

# Define a small margin (in pixels) to account for slight misalignments.
margin = 50

# Set a threshold for neighbor search.
# In your case, using a threshold around 1800 works, which is near the effective distance.
threshold = effective_distance_pixels + margin  
print("Threshold for neighbor search (pixels):", threshold)

# Build the KDTree and find neighbor pairs.
tree_kdt = KDTree(coords)
neighbor_pairs = []
for i, pos in enumerate(coords):
    neighbors = tree_kdt.query_ball_point(pos, r=threshold)
    for j in neighbors:
        if i < j:  # avoid duplicate pairs
            neighbor_pairs.append((setup_ids[i], setup_ids[j]))
print("Neighbor pairs:", neighbor_pairs)

In [None]:
len(neighbor_pairs)

In [None]:
# ----- Compute Relative Shifts for Each Neighbor Pair using Phase Correlation -----
relative_shifts = {}  # key: (a, b) will store the shift from tile a to tile b (in [y, x] order)
for a, b in neighbor_pairs:
    img_a = zsub_tiles[a]
    img_b = zsub_tiles[b]
    shift_xy, error, _ = phase_cross_correlation(img_a, img_b, upsample_factor=1)
    print(f"Shift from setup{a} to setup{b}: {shift_xy}, error: {error:.2f}")
    relative_shifts[(a, b)] = shift_xy
    relative_shifts[(b, a)] = -shift_xy  # inverse shift

In [None]:
# ----- Build the Neighbor Graph -----
graph = defaultdict(list)
for a, b in neighbor_pairs:
    graph[a].append(b)
    graph[b].append(a)

In [None]:
# # ----- Propagate Global Shifts via BFS -----
# global_shifts = {}  # key: tile id, value: global shift vector [y, x]
# # Use the first tile in our sorted list as the reference (global shift [0, 0])
# ref_tile = setup_ids[0]
# global_shifts[ref_tile] = np.array([0.0, 0.0])
# visited = {ref_tile}
# queue = deque([ref_tile])

# while queue:
#     current = queue.popleft()
#     for neighbor in graph[current]:
#         if neighbor not in visited:
#             # Get the computed relative shift from current to neighbor.
#             rel = relative_shifts.get((current, neighbor), np.array([0.0, 0.0]))
#             global_shifts[neighbor] = global_shifts[current] + rel
#             visited.add(neighbor)
#             queue.append(neighbor)

# print("Global Shifts:")
# for tile in sorted(global_shifts.keys()):
#     print(f"Tile {tile}: Shift {global_shifts[tile]}")

# ----- Propagate Global 3D Shifts via BFS -----
global_shifts = {}  # key: tile id, value: global shift vector [z, y, x]

# Use the first tile in our sorted list as the reference (global shift [0, 0, 0])
ref_tile = setup_ids[0]
global_shifts[ref_tile] = np.array([0.0, 0.0, 0.0])
visited = {ref_tile}
queue = deque([ref_tile])

while queue:
    current = queue.popleft()
    for neighbor in graph[current]:
        if neighbor not in visited:
            # Get the computed relative shift from current to neighbor (z, y, x)
            rel = relative_shifts.get((current, neighbor), np.array([0.0, 0.0, 0.0]))
            global_shifts[neighbor] = global_shifts[current] + rel
            visited.add(neighbor)
            queue.append(neighbor)

print("Global 3D Shifts:")
for tile in sorted(global_shifts.keys()):
    print(f"Tile {tile}: Shift {global_shifts[tile]}")


In [None]:
import xml.etree.ElementTree as ET
import numpy as np

# ----- Paths -----
xml_path_input = "/mnt/gdrive/ThNe/development_retina/raw_images_n5/P15/F/L/retina_Age_P15_Sex_F_Side_L_Animal_1.xml"
xml_path_output = "/mnt/gdrive/ThNe/development_retina/raw_images_n5/P15/F/L/retina_Age_P15_Sex_F_Side_L_Animal_1_new.xml"

# ----- Load Original XML -----
tree = ET.parse(xml_path_input)
root = tree.getroot()

# ----- Update Affine Translations Based on 3D Global Shifts -----
for reg in root.findall('.//ViewRegistration'):
    setup_id = int(reg.attrib['setup'])
    tile_id = setup_id // 4  # 1 tile = 4 channels

    if tile_id in global_shifts:
        shift = global_shifts[tile_id]  # [z, y, x]

        affine_elem = reg.find('.//affine')
        affine = list(map(float, affine_elem.text.strip().split()))

        # Update translation part of affine (tx, ty, tz)
        affine[3]  += shift[2]  # X
        affine[7]  += shift[1]  # Y
        affine[11] += shift[0]  # Z

        affine_elem.text = ' '.join(map(str, affine))

# ----- Save Updated XML -----
tree.write(xml_path_output)
print(f"Affine matrices for all channels updated with 3D shifts and saved to {xml_path_output}")
