In [5]:
# This script should take the em_centroids h5_file and query the agglo_id

# Standard library imports
import os
import time
from tqdm import tqdm

# Third-party imports
import numpy as np
import h5py
import tifffile
import napari
import matplotlib.pyplot as plt
from scipy.interpolate import RBFInterpolator
import dill
from google.cloud import bigquery
from google.auth import exceptions

# Custom imports
from brainmaps_api_fcn.equivalence_requests import EquivalenceRequests
from brainmaps_api_fcn.subvolume_requests import SubvolumeRequest
from scripts.sample_db import SampleDB
from scripts.utils.image_utils import load_tiff_as_hyperstack

In [6]:
def get_agglo_group_with_retry(sa, volume_id, stack_change, centroid_xyz, max_retries=5):
    """Get agglomeration group with retry mechanism."""
    for attempt in range(max_retries):
        try:
            sr = SubvolumeRequest(sa, volume_id)
            vol = sr.get_subvolume(centroid_xyz, size=[1,1,1], change_stack_id=stack_change)
            agglo_id = int(np.unique(vol[vol>0])[0])
            
            er = EquivalenceRequests(sa, volume_id, stack_change)
            return er.get_groups(agglo_id)
        except exceptions.RefreshError:
            if attempt == max_retries - 1:
                raise
            print(f"Try {attempt+1}/{max_retries}")
            time.sleep(attempt)


def get_neuron_segments(lut_path, neuron_id):
    """Get all segment IDs for a given neuron"""
    with h5py.File(lut_path, 'r') as f:
        neuron_path = f'neurons/neuron_{neuron_id}/segments'
        if neuron_path not in f:
            return []

        # Get all segment IDs from the segments group
        segments = list(f[neuron_path].keys())
        # Convert segment IDs from strings to integers
        segments = [int(seg) for seg in segments]

        return segments


def get_segments_by_agglo_id(lut_path, agglo_id):
    """Get all segment IDs for a given agglomeration ID"""
    segments = []
    with h5py.File(lut_path, 'r') as f:
        # Iterate through all neurons to find matching agglo_id
        for neuron_name, neuron_group in f['neurons'].items():
            if 'agglo_id' in neuron_group.attrs:
                if neuron_group.attrs['agglo_id'] == agglo_id:
                    # Get segments from this neuron
                    if 'segments' in neuron_group:
                        segments = list(neuron_group['segments'].keys())
                        # Convert segment IDs from strings to integers
                        segments = [int(seg) for seg in segments]
                        break
    return segments


def find_neurons_in_mask(lut_path, mask_path):
    """Find neurons with centroids inside a 3d mask"""
    # Open the Paintera zarr mask
    mask = tifffile.imread(mask_path, mode='r')
    neurons_inside = []

    with h5py.File(lut_path, 'r') as f:
        for neuron_name, neuron_group in f['neurons'].items():
            # Get the neuroglancer coordinates of the centroid
            # We use ng coordinates since Paintera mask is in the same space
            centroid = neuron_group['em_centroid_ng'][:] // 16

            print(centroid)

            # Round coordinates to integers for indexing
            x, y, z = np.round(centroid).astype(int)

            # Check if the centroid is within mask bounds
            if (0 <= x < mask.shape[2] and
                    0 <= y < mask.shape[1] and
                    0 <= z < mask.shape[0]):

                # Check if the point is inside the mask (non-zero value)
                if mask[x, y, z] > 0:
                    neuron_id = int(neuron_name.split('_')[1])
                    agglo_id = neuron_group.attrs.get('agglo_id', None)
                    neurons_inside.append({
                        'neuron_id': neuron_id,
                        'agglo_id': agglo_id,
                        'centroid': centroid
                    })

    return neurons_inside


def get_neurons_with_attribute(lut_path, attribute_name, attribute_value, operator="=="):
    """
    Get neurons where attribute matches the comparison with threshold
    
    Args:
        lut_path: Path to HDF5 file
        attribute_name: Name of attribute to check
        operator: String specifying comparison ('>', '<', '>=', '<=', '==')
        threshold: Value to compare against
    """
    neuron_with_attribute = []
    operators = {
        '>': lambda x, y: x > y,
        '<': lambda x, y: x < y,
        '>=': lambda x, y: x >= y,
        '<=': lambda x, y: x <= y,
        '==': lambda x, y: x == y
    }

    if operator not in operators:
        raise ValueError(f"Operator must be one of {list(operators.keys())}")

    with h5py.File(lut_path, 'r') as f:
        for neuron, neuron_group in f['neurons'].items():
            if attribute_name in neuron_group.attrs:
                stored_value = neuron_group.attrs[attribute_name]
                if operators[operator](stored_value, attribute_value) and "agglo_id" in neuron_group.attrs:
                    print(stored_value)
                    neuron_with_attribute.append(neuron_group.attrs["agglo_id"])
                    #print(f"Neuron {neuron}: {attribute_name} = {stored_value}")

    return neuron_with_attribute

In [7]:
client = bigquery.Client(project="aggloproofreading")

# Step 1: Load the sample database
db_path = r'\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\sample_db.csv'
sample_db = SampleDB()
sample_db.load(db_path)
# Step 2: Load experiment configuration
sample_id = '20220426_RM0008_130hpf_fP1_f3'
exp = sample_db.get_sample(sample_id)

# TODO: add paths to sample
# Input: volume data
sa = r"\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\Reconstruction\fmi-friedrich-ebb67584e50d.json"
volume_id = r"280984173682:montano_rm2_ngff:raw_230701_seg_240316fb"
stack_change = "240705d_rsg9_spl"

In [8]:
em_centroids_path = os.path.join(exp.paths.em_path, f"{exp.sample.id}_em_centroids.h5")


In [9]:
# Read centroids from HDF5 file and do transformation to neuroglancer
with h5py.File(em_centroids_path, 'r') as f:
    # Get centroids coordinates
    data = f['centroids/coordinates'][:]
    centroids = data[:, 1:]

    # Get transformation parameters
    transformation_matrix = f['metadata/transformations/bigwarp2neuroglancer/transformation_matrix'][:]
    rotated_cropped_stack_center_shift = f[
                                             'metadata/transformations/bigwarp2neuroglancer/rotated_cropped_stack_center_shift'][
                                         :]
    rotated_stack_center_shift = f['metadata/transformations/bigwarp2neuroglancer/rotated_stack_center_shift'][:]
    cropped_shift = f['metadata/transformations/bigwarp2neuroglancer/cropped_shift'][:]
    ng_shift = f['metadata/transformations/bigwarp2neuroglancer/ng_shift'][:]
    downsampled_factor = f['metadata/transformations/bigwarp2neuroglancer/downsampled_factor'][:]
    zyx2xyz = f['metadata/transformations/bigwarp2neuroglancer'].attrs['zyx2xyz']
    # Apply transformations to get Neuroglancer coordinates
    # 1. Center points around origin
    centered_points = centroids - rotated_cropped_stack_center_shift

    # 2. Apply transformation matrix
    transformed_centered = np.dot(centered_points, transformation_matrix.T)

    # 3. Move to target space and apply shifts
    transformed_points = (transformed_centered +
                          rotated_stack_center_shift +
                          cropped_shift)
    if f['metadata/transformations/bigwarp2neuroglancer/downsampled_factor']:
        downsampled_factor = f['metadata/transformations/bigwarp2neuroglancer/downsampled_factor'][:]
        transformed_points = transformed_points * downsampled_factor

    # 4. Convert from ZYX to XYZ if needed
    if f['metadata/transformations/bigwarp2neuroglancer'].attrs['zyx2xyz']:
        transformed_points = transformed_points[:, ::-1]

    # 5. Correct for shift in neuroglancer
    transformed_points = transformed_points + ng_shift


In [10]:
lut_path = os.path.join(exp.paths.clem_path, f"{exp.sample.id}_lut.h5")

In [None]:
error_neurons = []
with h5py.File(lut_path, 'w') as hdf5_file:
    # Create metadata group with source information
    metadata_group = hdf5_file.create_group('metadata')
    metadata_group.attrs['source_stack'] = os.path.join(exp.paths.em_path, '20220426_RM0008_130hpf_fP1_f3_fine_aligned_downsampled_16_em_stack_cropped_woResin_rough_rotated_to_LM.tif')
    metadata_group.attrs['source_mask'] = os.path.join(exp.paths.em_path, '20220426_RM0008_130hpf_fP1_f3_fine_aligned_downsampled_16_em_stack_cropped_woResin_rough_rotated_to_LM_mask_filtered.tif')
    metadata_group.attrs['sample_id'] = exp.sample.id
    metadata_group.attrs['volume_id'] = volume_id
    metadata_group.attrs['stack_change'] = stack_change
    
    # Store transformation parameters
    transformations_group = metadata_group.create_group('transformations')
    ng2bw_group = transformations_group.create_group('neuroglancer2bigwarp')
    ng2bw_group.create_dataset('transformation_matrix', data=transformation_matrix)
    ng2bw_group.create_dataset('rotated_cropped_stack_center_shift', data=rotated_cropped_stack_center_shift)
    ng2bw_group.create_dataset('rotated_stack_center_shift', data=rotated_stack_center_shift)
    ng2bw_group.create_dataset('cropped_shift', data=cropped_shift)
    ng2bw_group.create_dataset('ng_shift', data=ng_shift)
    ng2bw_group.attrs['zyx2xyz'] = zyx2xyz
    ng2bw_group.attrs['downsampled_factor'] = downsampled_factor
    
    # Create neurons group
    neurons_group = hdf5_file.create_group('neurons')
    
    # Add data for each neuron
    for neuron_id in tqdm(range(len(centroids)), desc="Finding agglomeration ids and their segments"):
        try:
            neuron_group = neurons_group.create_group(f'neuron_{neuron_id}')
            neuron_group.create_dataset('em_centroid_bw', data=centroids[neuron_id])
            neuron_group.create_dataset('em_centroid_ng', data=transformed_points[neuron_id])
            
            ng_centroid = transformed_points[neuron_id]
            #print(f"Processing neuron {neuron_id} at position {ng_centroid}")
            
            segments_group = neuron_group.create_group('segments')
            
            # Get agglomeration ID and segments
            sr = SubvolumeRequest(sa, volume_id)
            centroid_xyz = np.round(ng_centroid).astype(int)
            vol = sr.get_subvolume(centroid_xyz, size=[5,5,5], change_stack_id=stack_change)
            
            # Check if any positive values exist in the volume
            if np.any(vol > 0):
                agglo_id = int(np.unique(vol[vol>0])[0])
                er = EquivalenceRequests(sa, volume_id, stack_change)
                agglo_group = er.get_groups(agglo_id)
                
                for agglo_id, segment_ids in agglo_group.items():
                    neuron_group.attrs['agglo_id'] = agglo_id
                    neuron_group.create_dataset('agglo_segments', data=segment_ids)
            else:
                error_neurons.append((neuron_id, "No agglomeration ID found at position"))
                #print(f"No agglomeration ID found for neuron {neuron_id}")
                
        except Exception as e:
            error_neurons.append((neuron_id, str(e)))
            print(f"Error processing neuron {neuron_id}: {e}")

print(f"Neurons with errors: {error_neurons}")


Finding agglomeration ids and their segments:   7%|â–‹         | 623/9462 [32:26<7:53:42,  3.22s/it] 

In [None]:
# Get segments for neuron 0
segments = get_neuron_segments(lut_path, 0)
print(f"Segments for neuron 0: {segments}")

# Get segments for agglomeration ID 12345
segments = get_segments_by_agglo_id(lut_path, 55233441)
print(f"Segments for agglo_id {55233441}: {segments}")

# Finding neurons within OB mask 

In [None]:
ob_mask_path = os.path.join(exp.paths.em_path, 'masks', 'ob_mask.tif')
neurons_inside = find_neurons_in_mask(lut_path, ob_mask_path)
ob_mask = tifffile.imread(ob_mask_path)

viewer = napari.Viewer()

viewer.add_image(ob_mask)
viewer.add_labels(ob_mask)


In [None]:

agglo_id = 61815715
with h5py.File(lut_path, 'r') as f:
    # Iterate through all neurons to find matching agglo_id
    for neuron_name, neuron_group in f['neurons'].items():
        if 'agglo_id' in neuron_group.attrs:
            if neuron_group.attrs['agglo_id'] == agglo_id:
                # Get segments from this neuron
                centroid = neuron_group['em_centroid_ng'][:]
                print(centroid)
                
 

In [None]:
               
with h5py.File(lut_path, 'r') as f:
    transformation = {key: val[()] for key, val in f['metadata/transformations/neuroglancer2bigwarp'].items()}

transformation

In [None]:


x, y, z = tuple(centroid)
print(centroid)
print(x, y, z)
centroid_ds = np.array((z, y, x)) // 16 - transformation['cropped_shift']
print(centroid_ds)
viewer.add_points(centroid_ds, size=10)


In [None]:

em_stack = tifffile.imread(
    r"\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\CLEM_Analyses\CLEM_20220426_RM0008_130hpf_fP1_f3\fine_aligned_downsampled_4_em_stack.tif")

In [None]:
         
# For all neurons
with h5py.File(lut_path, 'r+') as f:
    # Iterate through all neurons to find matching agglo_id
    transformation = {key: val[()] for key, val in f['metadata/transformations/neuroglancer2bigwarp'].items()}
    for neuron_name, neuron_group in tqdm(f['neurons'].items()):
        # Get centroid from this neuron

        centroid = neuron_group['em_centroid_ng'][:]
        centroid_transformed = centroid[::-1] // 16 - transformation['cropped_shift']

        mask_value = ob_mask[tuple(centroid_transformed.astype(int))]

        if mask_value == 1:
            neuron_group.attrs['in_OB'] = True
        elif mask_value == 0:
            neuron_group.attrs['in_OB'] = False
            
     

In [None]:
       # Read all neurons that are in OB 
neurons_in_ob = []
centroids_in_ob = []
centroids_outside_ob = []
with h5py.File(lut_path, 'r') as f:
    for neuron_name, neuron_group in tqdm(f['neurons'].items()):
        # First check if in_OB exists and is True
        if neuron_group.attrs.get("agglo_id", False):
            if neuron_group.attrs["in_OB"] == True:
                neurons_in_ob.append(neuron_group.attrs['agglo_id'])
                centroids_in_ob.append(neuron_group['em_centroid_ng'][:])
            else:
                centroids_outside_ob.append(neuron_group['em_centroid_ng'][:])

print(len(centroids_in_ob))
print(len(centroids_outside_ob))
print(np.stack(centroids_in_ob[:10]))
centroids_in_ob_transformed = [centroid[::-1] for centroid in centroids_in_ob]
print(centroids_in_ob_transformed[:10])

In [None]:

viewer.add_points(np.array([centroid[::-1] for centroid in centroids_in_ob]) // 16 - transformation['cropped_shift'],
                  size=10, face_color="green")
viewer.add_points(
    np.array([centroid[::-1] for centroid in centroids_outside_ob]) // 16 - transformation['cropped_shift'], size=10,
    face_color="red")

In [None]:
# Adding LM stack centroids
exp.paths.clem_path
os.listdir(exp.paths.clem_path)

In [None]:
# Load an existing interpolator 
with open(exp.paths.clem_path + '/' + f'{exp.sample.id}_em2lm_interpolator.dill', 'rb') as f:
    rbf_interpolator = dill.load(f)

with h5py.File(em_centroids_path, 'r') as f:
    # Get centroids coordinates
    data = f['centroids/coordinates'][:]
    em_centroids = data[:, 1:]

# Compare predictions
lm_centroids_proxy = rbf_interpolator(em_centroids)

viewer = napari.Viewer()
viewer.add_points(em_centroids, face_color="green")
viewer.add_points(lm_centroids_proxy, face_color="red")

In [None]:

os.listdir(os.path.join(exp.paths.anatomy_path, 'processed'))
lm_stack = load_tiff_as_hyperstack(os.path.join(exp.paths.anatomy_path, 'processed',
                                                'flipped_upsampled_clahe_20220426_RM0008_130hpf_fP1_f3_anatomyGFRF_001_.tif'))
lm_stack.shape

In [None]:
viewer.add_image(lm_stack[:, 0], name="lm_stack")
with h5py.File(lut_path, 'r+') as f:
    # Create group and dataset

    f['metadata']['transformations'].create_group('em2lmstack_bigwarp')
    interpolator_path = os.path.join(exp.paths.clem_path, f'{exp.sample.id}_em2lm_interpolator.dill')
    f['metadata']['transformations']['em2lmstack_bigwarp'].create_dataset('rfb_interpolator_path',
                                                                          data=interpolator_path)

In [None]:

test_point = em_centroids[10:20]
print(test_point)
test_lm_point = rbf_interpolator(test_point)

viewer.add_points(test_point, face_color="blue")
viewer.add_points(test_lm_point, face_color="cyan")
with h5py.File(lut_minimal_path, 'r+') as f:
    for ii, (neuron, neuron_group) in enumerate(f['neurons'].items()):
        f['neurons'][neuron].create_dataset('lm_centroid_bw', data=lm_centroids_proxy[ii])

# LM bw to LM raw transformation

# Assign IN information

mask_colored_stack_c0_path = r"\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\2P_RawData\2022-04-26\f3\anatomy\masks\mask_manderscoeff_c0_20220426_RM0008_130hpf_fP1_f3_anatomyGFRF_001_.tif"
mask_colored_stack_c1_path = r"\\tungsten-nas.fmi.ch\tungsten\scratch\gfriedri\montruth\2P_RawData\2022-04-26\f3\anatomy\masks\mask_manderscoeff_c1_20220426_RM0008_130hpf_fP1_f3_anatomyGFRF_001_.tif"

mask_colored_stack_c0 = tifffile.imread(mask_colored_stack_c0_path)
mask_colored_stack_c1 = tifffile.imread(mask_colored_stack_c1_path)

viewer.add_image(mask_colored_stack_c0, name="mask_colored_stack_c0")
viewer.add_image(mask_colored_stack_c1, name="mask_colored_stack_c1")
r_in_centroids = []
r_out_centroids = []

r_pos_centroids = []
r_neg_centroids = []
c1_threshold = 0.8

# Get mask dimensions for bounds checking
z_max, y_max, x_max = mask_colored_stack_c1.shape

for centroid in lm_centroids_proxy:
    z, y, x = centroid
    int_centroid = centroid.astype(int)
    z_int, y_int, x_int = int_centroid

    # Check if centroid is within mask bounds
    if (z_int < 0 or y_int < 0 or x_int < 0 or
            z_int >= z_max or y_int >= y_max or x_int >= x_max):
        r_out_centroids.append(centroid)
        continue

    # Get coefficient and classify centroid
    try:
        c1_coeff = mask_colored_stack_c1[z_int, y_int, x_int]
        r_in_centroids.append(centroid)

        if c1_coeff > c1_threshold:
            r_pos_centroids.append(centroid)
        else:
            r_neg_centroids.append(centroid)

    except IndexError:
        r_out_centroids.append(centroid)

In [None]:

viewer = napari.Viewer()
viewer.add_image(lm_stack[:, 1], blending="additive", name="lm_stack")
viewer.add_image(mask_colored_stack_c1, blending="additive", name="c1_mask")
viewer.add_points(r_in_centroids, face_color="green", name="r_in_centroids")
viewer.add_points(r_out_centroids, face_color="red", name="r_out_centroids")
viewer.add_points(r_pos_centroids, face_color="cyan", name="r_pos_centroids")
viewer.add_points(r_neg_centroids, face_color="magenta", name="r_neg_centroids")


lm_centroids_stored = []
em_centroids_stored = []
ng_centroids_stored = []
outside_lm_points = []
with h5py.File(lut_minimal_path, 'r+') as f:
    for neuron, neuron_group in f['neurons'].items():
        if 'lm_centroid_bw' in neuron_group:
            lm_centroids_stored.append(list(neuron_group['lm_centroid_bw'][:]))
            em_centroids_stored.append(list(neuron_group['em_centroid_bw'][:]))
            ng_centroids_stored.append(list(neuron_group['em_centroid_ng'][:]))

viewer.add_points(lm_centroids_stored, name="lm_centroids_stored")
viewer.add_points(em_centroids_stored, name="em_centroids_stored")

In [None]:

# Generate a color cycle using hsv colormap which gives better color separation
num_colors = 10000  # Increased number of colors
colors = plt.cm.coolwarm(np.linspace(0, 1, num_colors))

# Add points with the larger color cycle
properties = {
    'point_index': np.arange(len(lm_centroids_stored))
}

viewer.add_points(
    lm_centroids_stored,
    name="lm_centroids_stored",
    properties=properties,
    face_color='point_index',
    face_color_cycle=colors,
    size=10
)

viewer.add_points(
    em_centroids_stored,
    name="em_centroids_stored",
    properties=properties,
    face_color='point_index',
    face_color_cycle=colors,
    size=10
)

viewer.add_points(
    ng_centroids_stored,
    name="ng_centroids_stored",
    properties=properties,
    face_color='point_index',
    face_color_cycle=colors,
    size=100
)


In [None]:

outside_lm_points = []

# Get mask dimensions for bounds checking
z_max, y_max, x_max = mask_colored_stack_c0.shape

with h5py.File(lut_minimal_path, 'r+') as f:
    for neuron, neuron_group in f['neurons'].items():
        if 'lm_centroid_bw' in neuron_group:
            # Get centroid coordinates
            lm_centroid = neuron_group['lm_centroid_bw'][:].round().astype(int)
            z, y, x = lm_centroid

            # Check if centroid is within mask bounds
            if (z < 0 or y < 0 or x < 0 or
                    z >= z_max or y >= y_max or x >= x_max):
                outside_lm_points.append(lm_centroid)
                neuron_group.attrs['g_coeff'] = np.nan
                neuron_group.attrs['r_coeff'] = np.nan
                continue

            # Get coefficients from masks at centroid position
            try:
                c0_coeff = mask_colored_stack_c0[z, y, x]
                neuron_group.attrs['g_coeff'] = c0_coeff
            except IndexError:
                outside_lm_points.append(lm_centroid)
                neuron_group.attrs['g_coeff'] = np.nan

            try:
                c1_coeff = mask_colored_stack_c1[z, y, x]
                neuron_group.attrs['r_coeff'] = c1_coeff
            except IndexError:
                outside_lm_points.append(lm_centroid)
                neuron_group.attrs['r_coeff'] = np.nan

INs = get_neurons_with_attribute(lut_minimal_path, "r_coeff", 0.8, ">")
print(INs)

viewer.add_points(outside_lm_points, face_color="magenta")
outside_lm_points.shape
