In [12]:
import os
import math
import json
import time
import tempfile

import numpy as np
import SimpleITK as sitk
from PIL import Image

from ipycanvas import MultiCanvas
from ipywidgets import (
    VBox, HBox, Button, ColorPicker, IntSlider,
    ToggleButtons, FileUpload, Checkbox, Label
)

# =========================
# Global state
# =========================
VIEW_SIZE = 256  # canvas size for each 2D view (square)

volume3d = None      # float32 (Z, Y, X) normalized to [0,1]
mask3d = None        # uint8 (Z, Y, X), 0/1

dim_z = dim_y = dim_x = 0

axial_index = 0
coronal_index = 0
sagittal_index = 0

# For axial painting
drawing = False
last_canvas_x = None
last_canvas_y = None

# Box/point prompts (on axial view)
box_prompts = []     # list of {x0,y0,x1,y1,z}
point_prompts = []   # list of {x,y,z}

overlay_color_hex = "#ff0000"  # default mask overlay color

# =========================
# Canvas widgets
# =========================
# Each view: 3 layers -> image, mask overlay, UI
axial_canv = MultiCanvas(3, width=VIEW_SIZE, height=VIEW_SIZE)
axial_bg = axial_canv[0]
axial_ol = axial_canv[1]
axial_ui = axial_canv[2]

coronal_canv = MultiCanvas(3, width=VIEW_SIZE, height=VIEW_SIZE)
coronal_bg = coronal_canv[0]
coronal_ol = coronal_canv[1]
coronal_ui = coronal_canv[2]

sagittal_canv = MultiCanvas(3, width=VIEW_SIZE, height=VIEW_SIZE)
sagittal_bg = sagittal_canv[0]
sagittal_ol = sagittal_canv[1]
sagittal_ui = sagittal_canv[2]

# "3D" panel = MIP view (image + mask overlay)
mip_canv = MultiCanvas(2, width=VIEW_SIZE, height=VIEW_SIZE)
mip_bg = mip_canv[0]
mip_ol = mip_canv[1]

# =========================
# Controls
# =========================
uploader = FileUpload(
    accept='.nii,.nii.gz,.mha,.mhd,.nrrd,.dcm,.png,.jpg,.jpeg',
    multiple=True,
    description='Upload MRI'
)

tool_selector = ToggleButtons(
    options=['Brush', 'Box', 'Point'],
    value='Brush',
    description='Tool:'
)

mode_selector = ToggleButtons(
    options=['Paint', 'Erase'],
    value='Paint',
    description='Mode:'
)

brush_size_slider = IntSlider(
    value=8,
    min=1,
    max=40,
    step=1,
    description='Brush size:'
)

color_picker = ColorPicker(
    concise=False,
    description='Mask color:',
    value=overlay_color_hex
)

show_mask_checkbox = Checkbox(
    value=True,
    description='Show mask overlay'
)

axial_slider = IntSlider(
    value=0,
    min=0,
    max=0,
    step=1,
    description='Axial (Z):',
    disabled=True
)
coronal_slider = IntSlider(
    value=0,
    min=0,
    max=0,
    step=1,
    description='Coronal (Y):',
    disabled=True
)
sagittal_slider = IntSlider(
    value=0,
    min=0,
    max=0,
    step=1,
    description='Sagittal (X):',
    disabled=True
)

clear_button = Button(
    description='Clear mask+prompts',
    button_style='warning'
)
save_button = Button(
    description='Save mask+prompts',
    button_style='success'
)

status_label = Label(value="Upload a 2D/3D MRI (NIfTI, DICOM, etc.) to start.")


# =========================
# Helpers
# =========================

def log_status(msg: str):
    status_label.value = msg

def hex_to_rgb(hex_color):
    """#rrggbb -> (r,g,b)"""
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def _get_file_name(info):
    """Handle both old/new FileUpload formats."""
    if isinstance(info, dict):
        if 'metadata' in info and isinstance(info['metadata'], dict) and 'name' in info['metadata']:
            return info['metadata']['name']
        if 'name' in info:
            return info['name']
    return 'uploaded_file'

def _get_file_content(info):
    """Return raw bytes from FileUpload entry."""
    if isinstance(info, dict) and 'content' in info:
        return info['content']
    # Some versions might use 'data'
    if isinstance(info, dict) and 'data' in info:
        return info['data']
    raise ValueError("Unknown FileUpload entry format.")

def sitk_from_upload(widget):
    """
    Load a SimpleITK image from FileUpload widget.
    - If 1 file: read directly (NIfTI, PNG, single-slice DICOM, etc.)
    - If >1 file: assume DICOM series (3D).
    """
    value = widget.value
    if not value:
        return None, "No files uploaded."

    # ipywidgets 7 vs 8 difference: value can be dict or tuple
    if isinstance(value, dict):
        file_list = list(value.values())
    else:
        # assume iterable (tuple/list) of dict entries
        file_list = list(value)

    # Single file -> direct read with SimpleITK
    if len(file_list) == 1:
        info = file_list[0]
        name = _get_file_name(info)
        ext = os.path.splitext(name)[1].lower()

        tmpdir = tempfile.mkdtemp()
        tmp_path = os.path.join(tmpdir, name)
        with open(tmp_path, 'wb') as f:
            f.write(_get_file_content(info))

        try:
            img = sitk.ReadImage(tmp_path)
            return img, f"Loaded {name} ({ext or 'unknown'})"
        except Exception as e:
            return None, f"Failed to read {name}: {e}"

    # Multiple files -> try DICOM series
    tmpdir = tempfile.mkdtemp()
    for info in file_list:
        name = _get_file_name(info)
        path = os.path.join(tmpdir, name)
        with open(path, 'wb') as f:
            f.write(_get_file_content(info))

    reader = sitk.ImageSeriesReader()
    series_ids = reader.GetGDCMSeriesIDs(tmpdir)

    if not series_ids:
        return None, "No DICOM series found among uploaded files."

    series_file_names = reader.GetGDCMSeriesFileNames(tmpdir, series_ids[0])
    reader.SetFileNames(series_file_names)
    try:
        img = reader.Execute()
        return img, f"Loaded DICOM series with {len(series_file_names)} slices."
    except Exception as e:
        return None, f"Failed to read DICOM series: {e}"


def prepare_volume(img_sitk):
    """
    Convert SimpleITK image to (Z, Y, X) float32 volume normalized to [0,1].
    Handles 2D and 3D (and simple 4D like (t,z,y,x)->use t=0).
    """
    arr = sitk.GetArrayFromImage(img_sitk)  # typically (Z,Y,X) or (Y,X)

    if arr.ndim == 2:
        # 2D -> treat as single-slice volume
        arr = arr[None, ...]  # (1, Y, X)
    elif arr.ndim == 3:
        # already (Z,Y,X) usually
        pass
    elif arr.ndim == 4:
        # e.g., (t,Z,Y,X) -> take t=0
        arr = arr[0]
    else:
        raise ValueError(f"Unsupported dimension: {arr.ndim}")

    arr = arr.astype(np.float32)
    lo, hi = np.percentile(arr, [2, 98])
    if hi <= lo:
        hi = lo + 1.0
    norm = (arr - lo) / (hi - lo)
    norm = np.clip(norm, 0.0, 1.0)

    return norm  # (Z, Y, X) float32 in [0,1]


def init_mask():
    global mask3d
    if volume3d is None:
        mask3d = None
    else:
        mask3d = np.zeros_like(volume3d, dtype=np.uint8)


def slice_to_rgba(slice2d_float):
    """
    Convert a 2D float [0,1] to RGBA (VIEW_SIZE,VIEW_SIZE,4) uint8.
    """
    img8 = (slice2d_float * 255).astype(np.uint8)
    img_gray = Image.fromarray(img8)
    img_gray = img_gray.resize((VIEW_SIZE, VIEW_SIZE), resample=Image.BILINEAR)
    img_rgb = img_gray.convert('RGB')
    rgb = np.array(img_rgb, dtype=np.uint8)

    rgba = np.zeros((VIEW_SIZE, VIEW_SIZE, 4), dtype=np.uint8)
    rgba[..., :3] = rgb
    rgba[..., 3] = 255
    return rgba


def mask_slice_to_rgba(mask_slice):
    """
    Convert 2D mask slice (0/1) to RGBA overlay.
    """
    overlay = np.zeros((VIEW_SIZE, VIEW_SIZE, 4), dtype=np.uint8)
    if mask_slice is None:
        return overlay

    # Resize mask slice to VIEW_SIZE
    mask_img = Image.fromarray((mask_slice * 255).astype(np.uint8))
    mask_img = mask_img.resize((VIEW_SIZE, VIEW_SIZE), resample=Image.NEAREST)
    mask_np = np.array(mask_img) > 0

    r, g, b = hex_to_rgb(overlay_color_hex)
    alpha = 120

    overlay[mask_np, 0] = r
    overlay[mask_np, 1] = g
    overlay[mask_np, 2] = b
    overlay[mask_np, 3] = alpha
    return overlay


# =========================
# Drawing functions
# =========================

def draw_axial():
    axial_bg.clear()
    axial_ol.clear()
    axial_ui.clear()
    if volume3d is None:
        return
    slice_img = volume3d[axial_index]      # (Y, X)
    rgba = slice_to_rgba(slice_img)
    axial_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        m_slice = mask3d[axial_index]
        m_rgba = mask_slice_to_rgba(m_slice)
        axial_ol.put_image_data(m_rgba, 0, 0)

    draw_axial_ui()


def draw_axial_ui():
    axial_ui.clear()
    axial_ui.stroke_style = 'yellow'
    axial_ui.line_width = 2

    # Boxes for this Z
    for box in box_prompts:
        if box['z'] != axial_index:
            continue
        x0, y0, x1, y1 = box['x0'], box['y0'], box['x1'], box['y1']
        x = min(x0, x1)
        y = min(y0, y1)
        w = abs(x1 - x0)
        h = abs(y1 - y0)
        axial_ui.stroke_rect(x, y, w, h)

    # Points for this Z
    axial_ui.fill_style = 'cyan'
    for p in point_prompts:
        if p['z'] != axial_index:
            continue
        axial_ui.begin_path()
        axial_ui.arc(p['x'], p['y'], 4, 0, 2 * math.pi)
        axial_ui.fill()


def draw_coronal():
    coronal_bg.clear()
    coronal_ol.clear()
    coronal_ui.clear()
    if volume3d is None:
        return
    # slice: (Z, X), treat Z as vertical
    slice_img = volume3d[:, coronal_index, :]
    rgba = slice_to_rgba(slice_img)
    coronal_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        m_slice = mask3d[:, coronal_index, :]
        m_rgba = mask_slice_to_rgba(m_slice)
        coronal_ol.put_image_data(m_rgba, 0, 0)


def draw_sagittal():
    sagittal_bg.clear()
    sagittal_ol.clear()
    sagittal_ui.clear()
    if volume3d is None:
        return
    # slice: (Z, Y), treat Z as vertical
    slice_img = volume3d[:, :, sagittal_index]
    rgba = slice_to_rgba(slice_img)
    sagittal_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        m_slice = mask3d[:, :, sagittal_index]
        m_rgba = mask_slice_to_rgba(m_slice)
        sagittal_ol.put_image_data(m_rgba, 0, 0)


def draw_mip():
    mip_bg.clear()
    mip_ol.clear()
    if volume3d is None:
        return
    # Z-max MIP -> (Y,X)
    mip_img = volume3d.max(axis=0)
    rgba = slice_to_rgba(mip_img)
    mip_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        mip_mask = (mask3d > 0).max(axis=0).astype(np.uint8)
        m_rgba = mask_slice_to_rgba(mip_mask)
        mip_ol.put_image_data(m_rgba, 0, 0)


def redraw_all_views():
    draw_axial()
    draw_coronal()
    draw_sagittal()
    draw_mip()


# =========================
# Painting on axial view
# =========================

def canvas_to_voxel_axial(cx, cy):
    """
    Map canvas coords (0..VIEW_SIZE-1) to voxel (x,y) for axial slice.
    """
    if dim_x == 0 or dim_y == 0:
        return 0, 0
    vx = int(cx * dim_x / VIEW_SIZE)
    vy = int(cy * dim_y / VIEW_SIZE)
    vx = np.clip(vx, 0, dim_x - 1)
    vy = np.clip(vy, 0, dim_y - 1)
    return vx, vy


def paint_axial_circle(cx, cy):
    """
    Paint/erase a circular region on mask3d at current axial_index,
    centred at canvas coords (cx,cy).
    """
    global mask3d
    if mask3d is None:
        return

    vx_c, vy_c = canvas_to_voxel_axial(cx, cy)
    z = axial_index

    # approximate radius in voxels
    scale_x = dim_x / VIEW_SIZE
    scale_y = dim_y / VIEW_SIZE
    scale = 0.5 * (scale_x + scale_y)
    r_vox = max(1, int(brush_size_slider.value * scale))

    y0 = max(0, vy_c - r_vox)
    y1 = min(dim_y, vy_c + r_vox + 1)
    x0 = max(0, vx_c - r_vox)
    x1 = min(dim_x, vx_c + r_vox + 1)

    yy, xx = np.ogrid[y0:y1, x0:x1]
    dy = yy - vy_c
    dx = xx - vx_c
    dist2 = dx * dx + dy * dy
    region = dist2 <= r_vox * r_vox

    if mode_selector.value == 'Paint':
        mask3d[z, y0:y1, x0:x1][region] = 1
    else:
        mask3d[z, y0:y1, x0:x1][region] = 0

    redraw_all_views()


def paint_axial_line(cx0, cy0, cx1, cy1):
    """
    Interpolate between two canvas points and paint circles along the line.
    """
    steps = int(max(abs(cx1 - cx0), abs(cy1 - cy0)) / max(brush_size_slider.value / 2, 1)) + 1
    for t in np.linspace(0, 1, steps):
        cx = cx0 + t * (cx1 - cx0)
        cy = cy0 + t * (cy1 - cy0)
        paint_axial_circle(cx, cy)


# =========================
# Event handlers
# =========================

def on_upload_change(change):
    global volume3d, mask3d, dim_z, dim_y, dim_x
    global axial_index, coronal_index, sagittal_index
    global box_prompts, point_prompts

    if not uploader.value:
        return

    img_sitk, msg = sitk_from_upload(uploader)
    if img_sitk is None:
        log_status("Error: " + msg)
        return

    try:
        volume = prepare_volume(img_sitk)
    except Exception as e:
        log_status(f"Error preparing volume: {e}")
        return

    volume3d = volume
    dim_z, dim_y, dim_x = volume3d.shape

    init_mask()

    # Reset sliders to mid-slice
    axial_index = dim_z // 2
    coronal_index = dim_y // 2
    sagittal_index = dim_x // 2

    axial_slider.min = 0
    axial_slider.max = max(dim_z - 1, 0)
    axial_slider.value = axial_index
    axial_slider.disabled = False

    coronal_slider.min = 0
    coronal_slider.max = max(dim_y - 1, 0)
    coronal_slider.value = coronal_index
    coronal_slider.disabled = False

    sagittal_slider.min = 0
    sagittal_slider.max = max(dim_x - 1, 0)
    sagittal_slider.value = sagittal_index
    sagittal_slider.disabled = False

    # Reset prompts
    box_prompts = []
    point_prompts = []

    redraw_all_views()
    log_status(msg + f" Volume shape: (Z={dim_z}, Y={dim_y}, X={dim_x}).")


uploader.observe(on_upload_change, names='value')


def on_color_change(change):
    global overlay_color_hex
    overlay_color_hex = change['new']
    if volume3d is not None:
        redraw_all_views()

color_picker.observe(on_color_change, names='value')


def on_show_mask_change(change):
    if volume3d is not None:
        redraw_all_views()

show_mask_checkbox.observe(on_show_mask_change, names='value')


def on_axial_slider_change(change):
    global axial_index
    axial_index = change['new']
    draw_axial()

axial_slider.observe(on_axial_slider_change, names='value')


def on_coronal_slider_change(change):
    global coronal_index
    coronal_index = change['new']
    draw_coronal()

coronal_slider.observe(on_coronal_slider_change, names='value')


def on_sagittal_slider_change(change):
    global sagittal_index
    sagittal_index = change['new']
    draw_sagittal()

sagittal_slider.observe(on_sagittal_slider_change, names='value')


def on_axial_mouse_down(x, y):
    global drawing, last_canvas_x, last_canvas_y
    if volume3d is None:
        return

    tool = tool_selector.value

    if tool == 'Brush':
        drawing = True
        last_canvas_x, last_canvas_y = x, y
        paint_axial_circle(x, y)

    elif tool == 'Box':
        # start a new box (canvas coords, linked to current slice)
        box_prompts.append({'x0': x, 'y0': y, 'x1': x, 'y1': y, 'z': axial_index})
        draw_axial_ui()
        log_status(f"Box start at ({int(x)}, {int(y)}) on Z={axial_index}")

    elif tool == 'Point':
        point_prompts.append({'x': x, 'y': y, 'z': axial_index})
        draw_axial_ui()
        log_status(f"Added point at ({int(x)}, {int(y)}) on Z={axial_index}")


def on_axial_mouse_move(x, y):
    global drawing, last_canvas_x, last_canvas_y
    if volume3d is None:
        return

    tool = tool_selector.value

    if tool == 'Brush' and drawing:
        paint_axial_line(last_canvas_x, last_canvas_y, x, y)
        last_canvas_x, last_canvas_y = x, y

    elif tool == 'Box':
        if box_prompts:
            # Update last box x1,y1
            box_prompts[-1]['x1'] = x
            box_prompts[-1]['y1'] = y
            draw_axial_ui()


def on_axial_mouse_up(x, y):
    global drawing
    if volume3d is None:
        return

    tool = tool_selector.value

    if tool == 'Brush':
        drawing = False
    elif tool == 'Box':
        if box_prompts:
            box_prompts[-1]['x1'] = x
            box_prompts[-1]['y1'] = y
            draw_axial_ui()
            log_status(
                f"Box finished: ({int(box_prompts[-1]['x0'])}, {int(box_prompts[-1]['y0'])}) -> "
                f"({int(x)}, {int(y)}) on Z={axial_index}"
            )

axial_canv.on_mouse_down(on_axial_mouse_down)
axial_canv.on_mouse_move(on_axial_mouse_move)
axial_canv.on_mouse_up(on_axial_mouse_up)


def on_clear_clicked(b):
    global mask3d, box_prompts, point_prompts
    if volume3d is None:
        return
    init_mask()
    box_prompts = []
    point_prompts = []
    redraw_all_views()
    log_status("Cleared mask and prompts.")


clear_button.on_click(on_clear_clicked)


def on_save_clicked(b):
    if volume3d is None or mask3d is None:
        log_status("Nothing to save (no volume/mask).")
        return

    timestamp = time.strftime("%Y%m%d_%H%M%S")
    mask_filename = f"mask3d_{timestamp}.nii.gz"
    prompts_filename = f"prompts_{timestamp}.json"

    # Save 3D mask as NIfTI (in index space; you can add metadata later)
    mask_sitk = sitk.GetImageFromArray(mask3d.astype(np.uint8))
    sitk.WriteImage(mask_sitk, mask_filename)

    prompts = {
        "box_prompts": box_prompts,
        "point_prompts": point_prompts,
        "shape": {"Z": dim_z, "Y": dim_y, "X": dim_x},
        "note": "Box/point coords are in axial canvas space (0..VIEW_SIZE-1), with slice index z."
    }
    with open(prompts_filename, "w") as f:
        json.dump(prompts, f, indent=2)

    log_status(f"Saved mask to {mask_filename}, prompts to {prompts_filename}.")

save_button.on_click(on_save_clicked)


# =========================
# Layout
# =========================
controls_top = HBox([uploader])
controls_tools = HBox([tool_selector, mode_selector, brush_size_slider, color_picker, show_mask_checkbox])
controls_slices = HBox([axial_slider, coronal_slider, sagittal_slider])
controls_actions = HBox([clear_button, save_button])

views_row1 = HBox([
    VBox([Label("Axial"), axial_canv]),
    VBox([Label("Coronal"), coronal_canv])
])
views_row2 = HBox([
    VBox([Label("Sagittal"), sagittal_canv]),
    VBox([Label("3D MIP (Z-max)"), mip_canv])
])

ui = VBox([
    status_label,
    controls_top,
    controls_tools,
    controls_slices,
    controls_actions,
    views_row1,
    views_row2,
])

ui


VBox(children=(Label(value='Upload a 2D/3D MRI (NIfTI, DICOM, etc.) to start.'), HBox(children=(FileUpload(valâ€¦

In [13]:
from ipycanvas import Canvas
Canvas(width=200, height=200)

Canvas(height=200, width=200)