Skip to content

Commit

Permalink
refactor: Use enum to represent plane
Browse files Browse the repository at this point in the history
  • Loading branch information
shernshiou committed Feb 1, 2024
1 parent ab4c10c commit 0ec1426
Showing 1 changed file with 86 additions and 62 deletions.
148 changes: 86 additions & 62 deletions darwin/exporter/formats/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json as native_json
import re
from dataclasses import dataclass
from enum import Enum
from numbers import Number
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -10,7 +12,7 @@
console = Console()
try:
import nibabel as nib
from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform
from nibabel.orientations import io_orientation, ornt_transform
except ImportError:
import_fail_string = """
You must install darwin-py with pip install darwin-py\[medical]
Expand All @@ -24,6 +26,12 @@
from darwin.utils import convert_polygons_to_mask


class Plane(Enum):
XY = 0
XZ = 1
YZ = 2


@dataclass
class Volume:
pixel_array: np.ndarray
Expand Down Expand Up @@ -169,7 +177,7 @@ def build_output_volumes(

def check_for_error_and_return_imageid(
video_annotation: dt.AnnotationFile, output_dir: Path
):
) -> str:
"""
Given the video_annotation file and the output directory, checks for a range of errors and
returns messages accordingly.
Expand Down Expand Up @@ -240,6 +248,39 @@ def check_for_error_and_return_imageid(
return image_id


def update_pixel_array(
volume: Dict,
annotation_class_name: str,
im_mask: np.ndarray,
plane: Plane,
frame_idx: int,
) -> Dict:
"""Updates the pixel array of the given volume with the given mask.
Args:
volume (Dict): Volume with pixel array to be updated
annotation_class_name (str): Name of the annotation class
im_mask (np.ndarray): Mask to be added to the pixel array
plane (Plane): Plane of the mask
frame_idx (int): Frame index of the mask
Returns:
Dict: Updated volume
"""
plane_to_slice = {
Plane.XY: volume[:, :, frame_idx],
Plane.XZ: volume[:, frame_idx, :],
Plane.YZ: volume[frame_idx, :, :],
}
if plane in plane_to_slice:
slice_ = plane_to_slice[plane]
volume[annotation_class_name].pixel_array[slice_] = np.logical_or(
im_mask,
volume[annotation_class_name].pixel_array[slice_],
)
return volume


def populate_output_volumes_from_polygons(
annotations: List[Union[dt.Annotation, dt.VideoAnnotation]],
slot_map: Dict,
Expand Down Expand Up @@ -273,30 +314,17 @@ def populate_output_volumes_from_polygons(
volume = output_volumes.get(series_instance_uid)
frames = annotation.frames

# define the different planes
XYPLANE = 0
XZPLANE = 1
YZPLANE = 2

for frame_idx in frames.keys():
view_idx = get_view_idx_from_slot_name(
plane = get_plane_from_slot_name(
slot_name, slot.metadata.get("orientation")
)
if view_idx == XYPLANE:
height, width = (
volume[annotation.annotation_class.name].dims[0],
volume[annotation.annotation_class.name].dims[1],
)
elif view_idx == XZPLANE:
height, width = (
volume[annotation.annotation_class.name].dims[0],
volume[annotation.annotation_class.name].dims[2],
)
elif view_idx == YZPLANE:
height, width = (
volume[annotation.annotation_class.name].dims[1],
volume[annotation.annotation_class.name].dims[2],
)
dims = volume[annotation.annotation_class.name].dims
if plane == Plane.XY:
height, width = dims[0], dims[1]
elif plane == Plane.XZ:
height, width = dims[0], dims[2]
elif plane == Plane.YZ:
height, width = dims[1], dims[2]
if "paths" in frames[frame_idx].data:
# Dealing with a complex polygon
polygons = [
Expand All @@ -313,36 +341,14 @@ def populate_output_volumes_from_polygons(
)
else:
continue
frames[frame_idx].annotation_class.name
im_mask = convert_polygons_to_mask(polygons, height=height, width=width)
volume = output_volumes[series_instance_uid]
if view_idx == 0:
volume[annotation.annotation_class.name].pixel_array[
:, :, frame_idx
] = np.logical_or(
im_mask,
volume[annotation.annotation_class.name].pixel_array[
:, :, frame_idx
],
)
elif view_idx == 1:
volume[annotation.annotation_class.name].pixel_array[
:, frame_idx, :
] = np.logical_or(
im_mask,
volume[annotation.annotation_class.name].pixel_array[
:, frame_idx, :
],
)
elif view_idx == 2:
volume[annotation.annotation_class.name].pixel_array[
frame_idx, :, :
] = np.logical_or(
im_mask,
volume[annotation.annotation_class.name].pixel_array[
frame_idx, :, :
],
)
volume = update_pixel_array(
output_volumes[series_instance_uid],
annotation.annotation_class.name,
im_mask,
plane,
frame_idx,
)
return volume


Expand Down Expand Up @@ -438,7 +444,7 @@ def unnest_dict_to_list(d: Dict) -> List:
nib.save(img=img, filename=output_path)


def shift_polygon_coords(polygon, pixdim):
def shift_polygon_coords(polygon: List[Dict], pixdim: List[Number]) -> List:
# Need to make it clear that we flip x/y because we need to take the transpose later.
if pixdim[1] > pixdim[0]:
return [{"x": p["y"], "y": p["x"] * pixdim[1] / pixdim[0]} for p in polygon]
Expand All @@ -456,13 +462,12 @@ def get_view_idx(frame_idx, groups):
return view_idx


def get_view_idx_from_slot_name(slot_name: str, orientation: Union[str, None]) -> int:
def get_plane_from_slot_name(slot_name: str, orientation: Union[str, None]) -> Plane:
if orientation is None:
orientation_dict = {"0.1": 0, "0.2": 1, "0.3": 2}
return orientation_dict.get(slot_name, 0)
else:
orientation_dict = {"AXIAL": 0, "SAGITTAL": 1, "CORONAL": 2}
return orientation_dict.get(orientation, 0)
return Plane(orientation_dict.get(slot_name, 0))
orientation_dict = {"AXIAL": 0, "SAGITTAL": 1, "CORONAL": 2}
return Plane(orientation_dict.get(orientation, 0))


def process_metadata(metadata: Dict) -> Tuple:
Expand All @@ -489,9 +494,19 @@ def process_metadata(metadata: Dict) -> Tuple:
return volume_dims, pixdim, affine, original_affine


def process_affine(affine):
def process_affine(affine: Union[str, List, np.ndarray]) -> Optional[np.ndarray]:
"""Converts affine to numpy array if it is not already.
Args:
affine (Union[str, List, np.ndarray]): affine object to be converted
Returns:
Optional[np.ndarray]: affine as numpy array
"""
if isinstance(affine, str):
affine = np.squeeze(np.array([ast.literal_eval(l) for l in affine.split("\n")]))
affine = np.squeeze(
np.array([ast.literal_eval(lst) for lst in affine.split("\n")])
)
elif isinstance(affine, list):
affine = np.array(affine).astype(float)
else:
Expand All @@ -512,8 +527,17 @@ def create_error_message_json(
return False


def decode_rle(rle_data, width, height):
"""Decodes run-length encoding (RLE) data into a mask array."""
def decode_rle(rle_data: List[int], width: int, height: int) -> np.ndarray:
"""Decodes run-length encoding (RLE) data into a mask array.
Args:
rle_data (List[int]): List of RLE data
width (int): Width of the data
height (int): Height of the data
Returns:
np.ndarray: RLE data
"""
total_pixels = width * height
mask = np.zeros(total_pixels, dtype=np.uint8)
pos = 0
Expand Down

0 comments on commit 0ec1426

Please sign in to comment.