In [None]:
"Om Namah Shivaya!! 🙏🙏"

## Updates
- Version 2: 

In [None]:
!pip install -q git+https://github.com/copick/copick-utils.git
!pip install -q copick zarr

In [None]:
# Make a copick project
import os
import shutil

config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
            "name": "beta-amylase",
            "is_particle": true,
            "pdb_id": "1FA2",
            "label": 2,
            "color": [153,  63,   0, 128],
            "radius": 65,
            "map_threshold": 0.035
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        },
        {
            "name": "membrane",
            "is_particle": false,
            "label": 8,
            "color": [100, 100, 100, 128]
        },
        {
            "name": "background",
            "is_particle": false,
            "label": 9,
            "color": [10, 150, 200, 128]
        }
    ],

    "overlay_root": "/kaggle/working/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "/kaggle/input/czii-cryo-et-object-identification/train/static"
}"""

In [None]:
copick_config_path = "/kaggle/working/copick.config"
output_png = "/kaggle/working/train_png_normalized"

with open(copick_config_path, "w") as f:
    f.write(config_blob)
    
# Update the overlay
# Define source and destination directories
source_dir = '/kaggle/input/czii-cryo-et-object-identification/train/overlay'
destination_dir = '/kaggle/working/overlay'

# Walk through the source directory
for root, dirs, files in os.walk(source_dir):
    
    # Create corresponding subdirectories in the destination
    relative_path = os.path.relpath(root, source_dir)
    target_dir = os.path.join(destination_dir, relative_path)
    os.makedirs(target_dir, exist_ok=True)
    
    # Copy and rename each file
    for file in files:
        if file.startswith("curation_0_"):
            new_filename = file
        else:
            new_filename = f"curation_0_{file}"
        
        # Define full paths for the source and destination files
        source_file = os.path.join(root, file)
        destination_file = os.path.join(target_dir, new_filename)

        print(source_file, destination_file)
        
        # Copy the file with the new name
        shutil.copy2(source_file, destination_file)
        print(f"Copied {source_file} to {destination_file}")

## Prepare Dataset

In [None]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torchinfo
import zarr, copick
from tqdm import tqdm
import matplotlib.pyplot as plt

### 1. Get copick root

In [None]:
root = copick.from_file(copick_config_path)

copick_user_name = "copickUtils"
copick_segmentation_name = "paintedPicks"
voxel_size = 10
tomo_type = "denoised"
precision = "8bit"

### 2. Generate multi-class segmentation masks from picks, and saved them to the copick overlay directory (one-time)

In [None]:
%%writefile boundingboxes_from_picks.py
import numpy as np
import pandas as pd
import zarr
import copick

def from_picks(pick, 
               run, 
               seg_volume,
               radius: float = 10.0, 
               label_value: int = 1,
               voxel_spacing: float = 10):
    """
    Generates bounding box information frame by frame for each pick.

    Parameters:
    -----------
    pick : copick.models.CopickPicks
        Copick object containing `points`, where each point has a `location` attribute with `x`, `y`, `z` coordinates.
    run : copick.impl.filesystem.CopickRunFSSpec
        Copick run contains specific information like `name`, voxel spacings and total number of picks.
    seg_volume : numpy.ndarray
        3D segmentation volume (numpy array) where the spheres are painted. Shape should be (Z, Y, X).
    radius : float, optional
        The radius of the spheres to be inserted in physical units (not voxel units). Default is 10.0.
    label_value : int, optional
        The integer value used to label the sphere regions in the segmentation volume. Default is 1.
    voxel_spacing : float, optional
        The spacing of voxels in the segmentation volume, used to scale the radius of the spheres. Default is 10.

    Returns:
    --------
    pandas.DataFrame
        A DataFrame containing bounding box information for each sphere, in the yolo format (x_center, y_center, h, w)
    """

    # Adjust radius for voxel spacing
    radius_voxel = radius / voxel_spacing
    delta = int(np.ceil(radius_voxel))
    
    # Get volume dimensions, (184(z), 630(x), 630(y))
    vol_shape_z, vol_shape_x, vol_shape_y = seg_volume.shape

    # Adjust the pick's location for voxel spacing
    cx_voxel = np.array([pnt.location.x / voxel_spacing for pnt in pick.points])
    cy_voxel = np.array([pnt.location.y / voxel_spacing for pnt in pick.points])
    cz_voxel = np.array([pnt.location.z / voxel_spacing for pnt in pick.points])

    # Loop over frame by frame along z-axis
    bboxes = []
    for zframe in range(vol_shape_z):
        # Get effective radius in each frame, for all points in a pick
        rframe = np.sqrt(np.maximum(radius_voxel**2 - (cz_voxel-zframe)**2, 0.))

        # Get cx, cy, r with raidus > 0
        assert vol_shape_x == vol_shape_y
        cx = np.clip(cx_voxel[rframe > 0], delta, vol_shape_x - delta) / vol_shape_x
        cy = np.clip(cy_voxel[rframe > 0], delta, vol_shape_y - delta) / vol_shape_y
        rframe = np.clip(rframe[rframe > 0], 0, delta) / vol_shape_x

        # Convert into a dataframe
        bboxes_frame = pd.DataFrame({"exp_name": run.name,
                                     "frame": zframe, 
                                     "label": label_value,
                                     "x_center": cx,
                                     "y_center": cy,
                                     "height": rframe * 2,
                                     "width": rframe * 2})
        bboxes.append(bboxes_frame)

    # Concat all bboxes
    bboxes = pd.concat(bboxes, axis=0).reset_index(drop=True)
    return bboxes

In [None]:
import boundingboxes_from_picks
from copick_utils.segmentation import segmentation_from_picks
import copick_utils.writers.write as write
from collections import defaultdict

# Just do this once
generate_masks = True
generate_bboxs = True

target_objects = defaultdict(dict)
for object in root.pickable_objects:
    if object.is_particle:
        target_objects[object.name]['label'] = object.label
        target_objects[object.name]['radius'] = object.radius

# Generate masks for each pick
if generate_masks:
    for run in tqdm(root.runs):
        tomo = run.get_voxel_spacing(voxel_size)
        tomo = tomo.get_tomogram(tomo_type).numpy()
        target = np.zeros(tomo.shape, dtype=np.uint8)
        for pickable_object in root.pickable_objects:
            pick = run.get_picks(object_name=pickable_object.name, user_id="curation")
            if pickable_object.is_particle and len(pick):
                target = segmentation_from_picks.from_picks(pick[0], 
                                                            target, 
                                                            target_objects[pickable_object.name]['radius'] * 0.8,
                                                            target_objects[pickable_object.name]['label'],
                                                            voxel_spacing=voxel_size
                                                            )

        write.segmentation(run, target, copick_user_name, name=copick_segmentation_name)

# Generate bbox for each pick
bboxes_all = pd.DataFrame()
if generate_bboxs:
    for run in tqdm(root.runs):
        target = np.zeros(tomo.shape, dtype=np.uint8)
        for pickable_object in root.pickable_objects:
            pick = run.get_picks(object_name=pickable_object.name, user_id="curation")
            if pickable_object.is_particle and len(pick):
                target_bboxs_df = boundingboxes_from_picks.from_picks(pick[0], 
                                                                      run,
                                                                      target, 
                                                                      target_objects[pickable_object.name]['radius'],
                                                                      target_objects[pickable_object.name]['label'],
                                                                      voxel_spacing=voxel_size
                                                                      )
                
                bboxes_all = bboxes_all._append(target_bboxs_df, ignore_index=True)

bboxes_all.to_csv('train_bounding_boxes.csv', index=False)

### 3. Get tomograms and their segmentaion masks (from picks) arrays

In [None]:
from PIL import Image

def normalise_by_percentile(data, min=1, max=99):
    min = np.percentile(data,min)
    max = np.percentile(data,max)
    data = np.clip(data,min,max)
    data = (data-min)/(max-min)
    return data

def write_tomogram(data, fpath):
    D, H, W = data.shape
    os.makedirs(fpath, exist_ok=True)
    for i in range(D):
        im = Image.fromarray(data[i])
        im.save(fpath + f"{i:03d}.png")

for run in tqdm(root.runs):
    tomogram = run.get_voxel_spacing(voxel_size).get_tomogram(tomo_type).numpy()
    segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=voxel_size, is_multilabel=True)[0].numpy()
    
    # Normalize tomogram persentile based 
    tomogram = normalise_by_percentile(tomogram)
    if precision == '8bit':
        tomogram = (tomogram * 255).astype(np.uint8)
    else:
        tomogram = (tomogram * 65535).astype(np.uint16)

    write_tomogram(tomogram, f"{output_png}_{precision}/{run.name}/images/")
    write_tomogram(segmentation, f"{output_png}_{precision}/{run.name}/labels/")

In [None]:
precision = '16bit'

for run in tqdm(root.runs):
    tomogram = run.get_voxel_spacing(voxel_size).get_tomogram(tomo_type).numpy()
    segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=voxel_size, is_multilabel=True)[0].numpy()
    
    # Normalize tomogram persentile based 
    tomogram = normalise_by_percentile(tomogram)
    if precision == '8bit':
        tomogram = (tomogram * 255).astype(np.uint8)
    else:
        tomogram = (tomogram * 65535).astype(np.uint16)

    write_tomogram(tomogram, f"{output_png}_{precision}/{run.name}/images/")
    write_tomogram(segmentation, f"{output_png}_{precision}/{run.name}/labels/")

In [None]:
!rm -r /kaggle/working/overlay

### 4. Visualize the tomogram and painted segmentation from ground-truth picks

In [None]:
import matplotlib.pyplot as plt

# Plot the images
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.title('Tomogram')
plt.imshow(tomogram[100],cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title('Painted Segmentation from Picks')
plt.imshow(segmentation[100], cmap='viridis')
plt.axis('off')

plt.tight_layout()
plt.show()