# Filter Zarr file

Only keep segmentation IDs showing up in a least K layers

Parallelized for multi-core utilization

In [19]:
import zarr
import os
import numpy as np
from multiprocessing import Pool, cpu_count, Manager
from tqdm import tqdm
import time
import logging

# Initialize logging
logging.basicConfig(filename='parallel_processing.log', level=logging.INFO)

def count_segmentation_for_layers(args):
  zarr_array, start_layer, end_layer, progress_bar = args
  segmentation_counts = {}
  for layer_index in range(start_layer, end_layer):
    logging.info(f"Processing layer {layer_index} in range {start_layer}-{end_layer}")
    layer = zarr_array[layer_index, :, :]
    unique_segmentation_ids = np.unique(layer)
    for seg_id in unique_segmentation_ids:
      if seg_id != 0:
        segmentation_counts[seg_id] = segmentation_counts.get(seg_id, 0) + 1
    progress_bar.value += 1  # Update the shared counter directly
  return segmentation_counts

def get_segmentation_counts_parallel(zarr_array, num_cores):
  total_layers = zarr_array.shape[0]
  layers_per_core = total_layers // num_cores
  remainder = total_layers % num_cores

  with Manager() as manager:
    progress_bar = manager.Value('i', 0)  # shared counter

    # Distribute layers among cores
    args = []
    start_layer = 0
    for i in range(num_cores):
      end_layer = start_layer + layers_per_core
      if i < remainder:
        end_layer += 1  # give one extra layer to this core
      args.append((zarr_array, start_layer, end_layer, progress_bar))
      start_layer = end_layer

    logging.info(f'Layers per core: {layers_per_core}, with {remainder} cores processing an extra layer.')

    # Process in parallel
    with Pool(num_cores) as pool:
      async_result = pool.map_async(count_segmentation_for_layers, args)

      # Update tqdm while the processes are running
      with tqdm(total=zarr_array.shape[0], position=0, leave=True) as pbar:
        last_count = 0
        while not async_result.ready():
          current_count = progress_bar.value
          pbar.update(current_count - last_count)
          last_count = current_count
          time.sleep(0.5)
        results = async_result.get()

    # Combine results
    logging.info('Combining results')
    combined_counts = {}
    for segmentation_counts in results:
      for seg_id, count in segmentation_counts.items():
        combined_counts[seg_id] = combined_counts.get(seg_id, 0) + count

  logging.info(f'Total count: {len(combined_counts)}')
  return combined_counts

data = zarr.open('3M-APP-SCN.zarr', mode='r')
zarr_array = data['segmentation_0.1']
logging.info(f"Zarr array shape: {zarr_array.shape}")
num_cores = min(cpu_count(), 15)  # Use up to 15 cores
logging.info(f"Number of cores: {num_cores}")
segmentation_counts = get_segmentation_counts_parallel(zarr_array, num_cores)


 28%|██▊       | 366/1286 [55:11<2:18:43,  9.05s/it]


KeyboardInterrupt: 

In [22]:

def find_max_value(zarr_array):
  max_value = 0  # Initialize max_value to 0 (assuming all values are non-negative)
  total_layers = zarr_array.shape[0]

  for layer_index in range(total_layers):
    layer = zarr_array[layer_index, :, :]
    layer_max = layer.max()
    if layer_max > max_value:
      max_value = layer_max
    print(f"Layer index: {layer_index}, Current max value: {max_value}")

  return max_value

data = zarr.open('3M-APP-SCN.zarr', mode='r')
zarr_array = data['segmentation_0.1']

max_value = find_max_value(zarr_array)
print(f"The maximum value in the entire array is: {max_value}")

Layer index: 0, Current max value: 4227121


KeyboardInterrupt: 

In [None]:
from multiprocessing import Pool, cpu_count
import numpy as np
import zarr

def filter_layers(args):
    start_layer, end_layer, zarr_array, ids_to_remove = args
    modified_layers = []
    for layer_idx in range(start_layer, end_layer):
        layer = zarr_array[layer_idx, :, :]
        modified_layer = np.where(np.isin(layer, list(ids_to_remove)), 0, layer)
        modified_layers.append(modified_layer)
    return modified_layers

def parallel_filter(zarr_array, ids_to_remove, num_cores):
    total_layers = zarr_array.shape[0]
    layers_per_core = total_layers // num_cores
    remainder = total_layers % num_cores

    args = []
    start_layer = 0
    for i in range(num_cores):
        end_layer = start_layer + layers_per_core
        if i < remainder:
            end_layer += 1  # give one extra layer to this core
        args.append((start_layer, end_layer, zarr_array, ids_to_remove))
        start_layer = end_layer

    with Pool(num_cores) as pool:
        results = pool.map(filter_layers, args)

    # Combine results
    modified_segmentation = np.vstack(results)
    return modified_segmentation

# Your main code here
threshold = 5
ids_to_remove = set(seg_id for seg_id, count in segmentation_counts.items() if count < threshold)

num_cores = min(cpu_count(), 16)  # Use up to 16 cores
modified_segmentation = parallel_filter(zarr_array, ids_to_remove, num_cores)

# Create a new Zarr file
filtered_zarr = zarr.open('8M-APP-retina-100M_filtered.zarr', mode='w')

# Save the raw dataset
filtered_zarr['raw'] = np.transpose(data['raw'], (2, 1, 0))
filtered_zarr['segmentation_0.1'] = np.transpose(modified_segmentation, (2, 1, 0))
