In [None]:
from cellpose.contrib.distributed_segmentation import distributed_eval, numpy_array_to_zarr
from aicspylibczi import CziFile
import numpy as np
import napari
import zarr
from dask.array import from_zarr
import zarr

In [None]:
import scipy.ndimage
from cellpose.contrib import distributed_segmentation

# Save original function
original_block_face_adjacency_graph = distributed_segmentation.block_face_adjacency_graph

def fixed_block_face_adjacency_graph(faces, nlabels):
    """Fixed version that determines structure dimensionality from data"""
    nlabels = int(nlabels)
    all_mappings = []
    
    # Determine dimensionality from the first face
    ndim = faces[0].ndim if len(faces) > 0 else 3
    structure = scipy.ndimage.generate_binary_structure(ndim, 1)
    
    for face in faces:
        sl0 = tuple(slice(0, 1) if d==2 else slice(None) for d in face.shape)
        sl1 = tuple(slice(1, 2) if d==2 else slice(None) for d in face.shape)
        a = distributed_segmentation.shrink_labels(face[sl0], 1.0)
        b = distributed_segmentation.shrink_labels(face[sl1], 1.0)
        face = np.concatenate((a, b), axis=np.argmin(a.shape))
        
        from dask_image.ndmeasure._utils._label import _across_block_label_grouping
        mapped = _across_block_label_grouping(face, structure)
        all_mappings.append(mapped)
    
    i, j = np.concatenate(all_mappings, axis=1)
    v = np.ones_like(i)
    return scipy.sparse.coo_matrix((v, (i, j)), shape=(nlabels+1, nlabels+1)).tocsr()

# Apply the patch
distributed_segmentation.block_face_adjacency_graph = fixed_block_face_adjacency_graph


In [None]:
distributed_eval?

In [None]:
test_image = CziFile(r"E:\Steensma_Lab\OIC-222_YAP_pERK_Ki67\TMA_10272025\TMA_555only.czi")

In [None]:
test_image.get_dims_shape()

In [None]:
test_image.is_mosaic()

In [None]:
full_res = np.squeeze(test_image.read_mosaic(scale_factor=1,C=0))

In [None]:
added_dim = full_res[np.newaxis, :, :]

In [None]:
data_zarr = numpy_array_to_zarr('test_ch1_zarr_3D.zarr',added_dim,chunks=(1,2048,2048))

In [None]:
del full_res
del data_zarr
del added_dim

In [None]:
model_kwargs = {
    'gpu':True,
    'pretrained_model':'cpsam',
    }
# eval_kwargs = {
#     'z_axis':0,
#     'do_3D': True,
#                }
cluster_kwargs = {
    'n_workers':1,
    'ncpus':12,
    'memory_limit':'400GB',
    'threads_per_worker':1,
                  }

In [None]:
array = zarr.open('test_ch1_zarr.zarr',mode='r')

In [None]:
segments, boxes = distributed_eval(
    input_zarr=array,
    blocksize=(2048,2048),
    write_path='test_output.zarr',
    model_kwargs=model_kwargs,
    # eval_kwargs=eval_kwargs,
    cluster_kwargs=cluster_kwargs,
)

In [None]:
import psutil
import os

total_ram = psutil.virtual_memory().total / (1024**3)  # GB
cpu_count = os.cpu_count()
print(f"Total RAM: {total_ram:.1f} GB")
print(f"CPU cores: {cpu_count}")

In [None]:
labels = zarr.open('test_output.zarr',mode='r')

In [None]:
viewer = napari.view_image(array,name='image')
viewer.add_labels(labels,name='labels')