In [None]:
import os
import math

import numpy as np
import SimpleITK as sitk
from PIL import Image

from ipycanvas import MultiCanvas
from ipywidgets import (
    VBox, HBox, Button, ColorPicker, IntSlider,
    ToggleButtons, Checkbox, Label, Text, Dropdown
)
from ipyfilechooser import FileChooser

#############################################################
# Global state
VIEW_SIZE = 256  # canvas resolution

volume3d = None      # float32 (Z, Y, X)
mask3d = None        # uint8  (Z, Y, X)

dim_z = dim_y = dim_x = 0

axial_index = 0
coronal_index = 0
sagittal_index = 0

current_view = "axial"

drawing = False
last_canvas_x = None
last_canvas_y = None

box_drawing = False

box_prompts = []
point_prompts = []

overlay_color_hex = "#ff0000"

# For preserving metadata and save defaults
image_sitk = None          # last loaded SimpleITK image
current_image_path = None  # path of last loaded image
last_save_dir = None       # last folder used to save mask


############################################################
# Canvas (3 layers: background, mask overlay, UI)

main_canv = MultiCanvas(3, width=VIEW_SIZE, height=VIEW_SIZE)
main_bg = main_canv[0]
main_ol = main_canv[1]
main_ui = main_canv[2]


############################################################
# Widgets

status_label = Label(
    value="Choose a file → Load from path → View appears. Default view: Axial mid-slice."
)

# --- File chooser for loading MRI ---
fc = FileChooser(
    os.getcwd(),
    select_default=True,
    show_only_dirs=False
)
fc.title = "<b>Choose MRI file</b>"
fc.filter_pattern = [
    "*.nii", "*.nii.gz",
    "*.mha", "*.mhd",
    "*.nrrd",
    "*.dcm",
    "*.png", "*.jpg", "*.jpeg",
    "*.gz"
]

load_path_button = Button(
    description='Load from path',
    button_style=''
)

tool_selector = ToggleButtons(
    options=['Brush', 'Box', 'Point'],
    value='Brush',
    description='Tool:'
)

mode_selector = ToggleButtons(
    options=['Paint', 'Erase'],
    value='Paint',
    description='Mode:'
)

brush_size_slider = IntSlider(
    value=1, min=1, max=40, step=1,
    description='Brush size:'
)

color_picker = ColorPicker(
    concise=False,
    description='Mask color:',
    value=overlay_color_hex
)

show_mask_checkbox = Checkbox(
    value=True,
    description='Show mask overlay'
)

axial_slider = IntSlider(
    value=0, min=0, max=0, step=1,
    description='Axial (Z):',
    disabled=True
)
coronal_slider = IntSlider(
    value=0, min=0, max=0, step=1,
    description='Coronal (Y):',
    disabled=True
)
sagittal_slider = IntSlider(
    value=0, min=0, max=0, step=1,
    description='Sagittal (X):',
    disabled=True
)

clear_prompt_button = Button(
    description='Clear prompts',
    button_style=''
)

clear_mask_button = Button(
    description='Clear mask',
    button_style='warning'
)

save_mask_button = Button(
    description='Save mask',
    button_style='success'
)

# View switching buttons (under canvas)
axial_view_button = Button(description="Axial")
coronal_view_button = Button(description="Coronal")
sagittal_view_button = Button(description="Sagittal")
mip_view_button = Button(description="3D MIP")

# --- Save mask controls ---
save_dir_chooser = FileChooser(
    os.getcwd(),
    select_default=True,
    show_only_dirs=True
)
save_dir_chooser.title = "<b>Choose folder to save mask</b>"

save_filename_text = Text(
    value='',
    description='Name:',
    placeholder='mask3d'
)

save_format_dropdown = Dropdown(
    options=['.nii.gz', '.mhd', '.mha', '.nrrd', '.dcm'],
    value='.nii.gz',
    description='Format:'
)


############################################################
# Helper functions
############################################################

def log_status(msg):
    status_label.value = msg

def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))


def prepare_volume(img_sitk_local):
    arr = sitk.GetArrayFromImage(img_sitk_local)

    if arr.ndim == 2:
        arr = arr[None, ...]
    elif arr.ndim == 3:
        pass
    elif arr.ndim == 4:
        arr = arr[0]
    else:
        raise ValueError(f"Unsupported dimensions: {arr.shape}")

    arr = arr.astype(np.float32)
    lo, hi = np.percentile(arr, [2, 98])
    hi = max(hi, lo + 1e-6)
    arr = (arr - lo) / (hi - lo)
    arr = np.clip(arr, 0, 1)
    return arr


def init_mask():
    global mask3d
    if volume3d is None:
        mask3d = None
    else:
        mask3d = np.zeros_like(volume3d, dtype=np.uint8)


def slice_to_rgba(slice2d_float):
    img8 = (slice2d_float * 255).astype(np.uint8)
    img_gray = Image.fromarray(img8).resize((VIEW_SIZE, VIEW_SIZE), Image.BILINEAR)
    rgb = np.array(img_gray.convert("RGB"), dtype=np.uint8)

    rgba = np.zeros((VIEW_SIZE, VIEW_SIZE, 4), dtype=np.uint8)
    rgba[..., :3] = rgb
    rgba[..., 3] = 255
    return rgba


def mask_slice_to_rgba(mask_slice):
    overlay = np.zeros((VIEW_SIZE, VIEW_SIZE, 4), dtype=np.uint8)
    if mask_slice is None:
        return overlay

    mask_resized = Image.fromarray((mask_slice * 255).astype(np.uint8)).resize(
        (VIEW_SIZE, VIEW_SIZE), Image.NEAREST
    )
    mask_np = np.array(mask_resized) > 0

    r, g, b = hex_to_rgb(overlay_color_hex)
    overlay[mask_np, 0] = r
    overlay[mask_np, 1] = g
    overlay[mask_np, 2] = b
    overlay[mask_np, 3] = 120
    return overlay


def get_current_slice():
    if current_view == "axial":
        return axial_index
    if current_view == "coronal":
        return coronal_index
    if current_view == "sagittal":
        return sagittal_index
    return axial_index


def split_name_and_ext(path):
    """
    Given a path, return (stem, ext), handling .nii.gz specially.
    e.g. /path/brain.nii.gz -> ('brain', '.nii.gz')
         /path/brain.mhd    -> ('brain', '.mhd')
    """
    base = os.path.basename(path)
    if base.lower().endswith('.nii.gz'):
        return base[:-7], '.nii.gz'
    for ext in ['.nii', '.mhd', '.mha', '.nrrd', '.dcm', '.gz']:
        if base.lower().endswith(ext):
            return base[:-len(ext)], ext
    # fallback: no known extension
    return base, ''


############################################################
# Drawing slices
############################################################

def _draw_ui(view_name, slice_idx):
    main_ui.clear()
    if volume3d is None:
        return

    main_ui.stroke_style = 'yellow'
    main_ui.line_width = 2

    for box in box_prompts:
        if box['view'] == view_name and box['slice'] == slice_idx:
            x0,y0,x1,y1 = box['x0'], box['y0'], box['x1'], box['y1']
            x=min(x0,x1); y=min(y0,y1)
            w=abs(x1-x0); h=abs(y1-y0)
            main_ui.stroke_rect(x,y,w,h)

    main_ui.fill_style = 'cyan'
    for p in point_prompts:
        if p['view'] == view_name and p['slice'] == slice_idx:
            main_ui.begin_path()
            main_ui.arc(p['x'], p['y'], 4, 0, 2*math.pi)
            main_ui.fill()


def draw_axial():
    main_bg.clear(); main_ol.clear(); main_ui.clear()
    if volume3d is None:
        return

    rgba = slice_to_rgba(volume3d[axial_index])
    main_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        m = mask3d[axial_index]
        main_ol.put_image_data(mask_slice_to_rgba(m), 0, 0)

    _draw_ui("axial", axial_index)


def draw_coronal():
    main_bg.clear(); main_ol.clear(); main_ui.clear()
    if volume3d is None:
        return

    rgba = slice_to_rgba(volume3d[:, coronal_index, :])
    main_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        m = mask3d[:, coronal_index, :]
        main_ol.put_image_data(mask_slice_to_rgba(m), 0, 0)

    _draw_ui("coronal", coronal_index)


def draw_sagittal():
    main_bg.clear(); main_ol.clear(); main_ui.clear()
    if volume3d is None:
        return

    rgba = slice_to_rgba(volume3d[:, :, sagittal_index])
    main_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        m = mask3d[:, :, sagittal_index]
        main_ol.put_image_data(mask_slice_to_rgba(m), 0, 0)

    _draw_ui("sagittal", sagittal_index)


def draw_mip():
    main_bg.clear(); main_ol.clear(); main_ui.clear()
    if volume3d is None:
        return

    mip = volume3d.max(axis=0)
    rgba = slice_to_rgba(mip)
    main_bg.put_image_data(rgba, 0, 0)

    if show_mask_checkbox.value and mask3d is not None:
        mipm = (mask3d > 0).max(axis=0)
        main_ol.put_image_data(mask_slice_to_rgba(mipm), 0, 0)


def redraw():
    if current_view == "axial":
        draw_axial()
    elif current_view == "coronal":
        draw_coronal()
    elif current_view == "sagittal":
        draw_sagittal()
    else:
        draw_mip()


############################################################
# Painting
############################################################

def canvas_to_voxel_axial(cx,cy):
    vx = int(cx * dim_x / VIEW_SIZE)
    vy = int(cy * dim_y / VIEW_SIZE)
    return np.clip(vx,0,dim_x-1), np.clip(vy,0,dim_y-1)

def canvas_to_voxel_coronal(cx,cy):
    vx = int(cx * dim_x / VIEW_SIZE)
    vz = int(cy * dim_z / VIEW_SIZE)
    return np.clip(vz,0,dim_z-1), np.clip(vx,0,dim_x-1)

def canvas_to_voxel_sagittal(cx,cy):
    vy = int(cx * dim_y / VIEW_SIZE)
    vz = int(cy * dim_z / VIEW_SIZE)
    return np.clip(vz,0,dim_z-1), np.clip(vy,0,dim_y-1)


def paint_circle(cx,cy):
    if mask3d is None:
        return

    if current_view == "axial":
        vx, vy = canvas_to_voxel_axial(cx,cy)
        z = axial_index
        scale = (dim_x/VIEW_SIZE + dim_y/VIEW_SIZE)/2
        r_vox = max(1, int(brush_size_slider.value * scale))

        y0=max(0,vy-r_vox); y1=min(dim_y, vy+r_vox+1)
        x0=max(0,vx-r_vox); x1=min(dim_x, vx+r_vox+1)

        yy,xx = np.ogrid[y0:y1, x0:x1]
        region = (yy-vy)**2 + (xx-vx)**2 <= r_vox*r_vox

        if mode_selector.value=="Paint":
            mask3d[z,y0:y1,x0:x1][region]=1
        else:
            mask3d[z,y0:y1,x0:x1][region]=0

    elif current_view == "coronal":
        vz, vx = canvas_to_voxel_coronal(cx,cy)
        y = coronal_index
        scale = (dim_x/VIEW_SIZE + dim_z/VIEW_SIZE)/2
        r_vox = max(1, int(brush_size_slider.value*scale))

        z0=max(0,vz-r_vox); z1=min(dim_z, vz+r_vox+1)
        x0=max(0,vx-r_vox); x1=min(dim_x, vx+r_vox+1)

        zz,xx = np.ogrid[z0:z1, x0:x1]
        region = (zz-vz)**2 + (xx-vx)**2 <= r_vox*r_vox

        if mode_selector.value=="Paint":
            mask3d[z0:z1,y,x0:x1][region]=1
        else:
            mask3d[z0:z1,y,x0:x1][region]=0

    elif current_view == "sagittal":
        vz, vy = canvas_to_voxel_sagittal(cx,cy)
        x = sagittal_index
        scale = (dim_y/VIEW_SIZE + dim_z/VIEW_SIZE)/2
        r_vox = max(1, int(brush_size_slider.value*scale))

        z0=max(0,vz-r_vox); z1=min(dim_z, vz+r_vox+1)
        y0=max(0,vy-r_vox); y1=min(dim_y, vy+r_vox+1)

        zz,yy=np.ogrid[z0:z1, y0:y1]
        region = (zz-vz)**2 + (yy-vy)**2 <= r_vox*r_vox

        if mode_selector.value=="Paint":
            mask3d[z0:z1,y0:y1,x][region]=1
        else:
            mask3d[z0:z1,y0:y1,x][region]=0

    redraw()


def paint_line(x0,y0,x1,y1):
    steps = int(max(abs(x1-x0),abs(y1-y0))/max(brush_size_slider.value/2,1))+1
    for t in np.linspace(0,1,steps):
        paint_circle(x0+t*(x1-x0), y0+t*(y1-y0))


############################################################
# Mouse events
############################################################

def on_mouse_down(x,y):
    global drawing, last_canvas_x, last_canvas_y, box_drawing

    if volume3d is None or current_view=="mip":
        return

    tool = tool_selector.value

    if tool=="Brush":
        drawing=True
        last_canvas_x, last_canvas_y = x,y
        paint_circle(x,y)

    elif tool=="Box":
        slice_idx = get_current_slice()
        box_prompts.append({
            'view': current_view,
            'slice': slice_idx,
            'x0':x, 'y0':y, 'x1':x, 'y1':y
        })
        box_drawing=True
        _draw_ui(current_view, slice_idx)

    elif tool=="Point":
        slice_idx=get_current_slice()
        point_prompts.append({
            'view':current_view,
            'slice':slice_idx,
            'x':x,'y':y
        })
        _draw_ui(current_view, slice_idx)


def on_mouse_move(x,y):
    global drawing, box_drawing, last_canvas_x, last_canvas_y
    if volume3d is None or current_view=="mip":
        return

    tool = tool_selector.value

    if tool=="Brush" and drawing:
        paint_line(last_canvas_x, last_canvas_y, x, y)
        last_canvas_x, last_canvas_y = x,y

    elif tool=="Box" and box_drawing:
        if box_prompts:
            box_prompts[-1]['x1']=x
            box_prompts[-1]['y1']=y
            _draw_ui(current_view, get_current_slice())


def on_mouse_up(x,y):
    global drawing, box_drawing
    if volume3d is None or current_view=="mip":
        drawing=False; box_drawing=False
        return

    tool=tool_selector.value

    if tool=="Brush":
        drawing=False

    elif tool=="Box" and box_drawing:
        box_prompts[-1]['x1']=x
        box_prompts[-1]['y1']=y
        _draw_ui(current_view, get_current_slice())
        box_drawing=False


main_canv.on_mouse_down(on_mouse_down)
main_canv.on_mouse_move(on_mouse_move)
main_canv.on_mouse_up(on_mouse_up)


############################################################
# Load button
############################################################

def load_from_path(_):
    global volume3d, mask3d
    global dim_z, dim_y, dim_x
    global axial_index, coronal_index, sagittal_index
    global box_prompts, point_prompts, drawing, box_drawing, current_view
    global image_sitk, current_image_path, last_save_dir

    path = fc.selected
    if not path:
        log_status("Please choose a file first.")
        return

    path = os.path.abspath(path)
    if not os.path.isfile(path):
        log_status(f"File not found: {path}")
        return

    try:
        img_local = sitk.ReadImage(path)
    except Exception as e:
        log_status(f"SimpleITK read failed: {e}")
        return

    try:
        vol = prepare_volume(img_local)
    except Exception as e:
        log_status(f"Normalization failed: {e}")
        return

    volume3d = vol
    image_sitk = img_local
    current_image_path = path

    dim_z, dim_y, dim_x = volume3d.shape
    init_mask()

    axial_index = dim_z // 2
    coronal_index = dim_y // 2
    sagittal_index = dim_x // 2

    # Enable sliders
    axial_slider.min=0; axial_slider.max=dim_z-1; axial_slider.value=axial_index; axial_slider.disabled=False
    coronal_slider.min=0; coronal_slider.max=dim_y-1; coronal_slider.value=coronal_index; coronal_slider.disabled=False
    sagittal_slider.min=0; sagittal_slider.max=dim_x-1; sagittal_slider.value=sagittal_index; sagittal_slider.disabled=False

    box_prompts=[]; point_prompts=[]
    drawing=False; box_drawing=False
    current_view="axial"

    # Default save directory & filename based on image
    stem, ext = split_name_and_ext(path)
    last_save_dir = os.path.dirname(path)
    save_filename_text.value = stem
    if ext in save_format_dropdown.options:
        save_format_dropdown.value = ext
    else:
        save_format_dropdown.value = '.nii.gz'
    save_dir_chooser.default_path = last_save_dir
    save_dir_chooser.reset()

    redraw()
    log_status(f"Loaded {path} (Z={dim_z},Y={dim_y},X={dim_x}). View: Axial.")


load_path_button.on_click(load_from_path)


############################################################
# Clear / Save mask / Clear prompts
############################################################

def clear_prompts(_):
    global box_prompts, point_prompts, box_drawing
    if volume3d is None: return
    box_prompts=[]; point_prompts=[]; box_drawing=False
    _draw_ui(current_view, get_current_slice())
    log_status("Cleared prompts.")

clear_prompt_button.on_click(clear_prompts)


def clear_mask(_):
    if volume3d is None: return
    init_mask()
    redraw()
    log_status("Mask cleared.")

clear_mask_button.on_click(clear_mask)


def save_mask(_):
    global last_save_dir
    if volume3d is None or mask3d is None:
        log_status("Nothing to save (no volume or mask).")
        return

    # Determine directory
    dir_path = save_dir_chooser.selected or last_save_dir or os.getcwd()
    dir_path = os.path.abspath(dir_path)
    if not os.path.isdir(dir_path):
        log_status(f"Invalid save directory: {dir_path}")
        return

    # Determine file name and extension
    stem = save_filename_text.value.strip()
    if not stem:
        stem = "mask3d"
    ext = save_format_dropdown.value
    filename = stem + ext
    save_path = os.path.join(dir_path, filename)

    # Build SimpleITK image, copying metadata if available
    mask_img = sitk.GetImageFromArray(mask3d.astype(np.uint8))
    if image_sitk is not None:
        mask_img.CopyInformation(image_sitk)

    try:
        sitk.WriteImage(mask_img, save_path)
    except Exception as e:
        log_status(f"Failed to save mask: {e}")
        return

    last_save_dir = dir_path
    save_dir_chooser.default_path = last_save_dir
    save_dir_chooser.reset()

    log_status(f"Mask saved to {save_path}")

save_mask_button.on_click(save_mask)


############################################################
# View switches
############################################################

def set_axial(_):
    global current_view, drawing, box_drawing
    current_view="axial"; drawing=False; box_drawing=False; redraw()

def set_coronal(_):
    global current_view, drawing, box_drawing
    current_view="coronal"; drawing=False; box_drawing=False; redraw()

def set_sagittal(_):
    global current_view, drawing, box_drawing
    current_view="sagittal"; drawing=False; box_drawing=False; redraw()

def set_mip(_):
    global current_view, drawing, box_drawing
    current_view="mip"; drawing=False; box_drawing=False; redraw()

axial_view_button.on_click(set_axial)
coronal_view_button.on_click(set_coronal)
sagittal_view_button.on_click(set_sagittal)
mip_view_button.on_click(set_mip)


############################################################
# Sliders & other observers
############################################################

def on_axial_change(change):
    global axial_index
    axial_index=change['new']
    if current_view=="axial":
        draw_axial()

def on_coronal_change(change):
    global coronal_index
    coronal_index=change['new']
    if current_view=="coronal":
        draw_coronal()

def on_sag_change(change):
    global sagittal_index
    sagittal_index=change['new']
    if current_view=="sagittal":
        draw_sagittal()

axial_slider.observe(on_axial_change, names='value')
coronal_slider.observe(on_coronal_change, names='value')
sagittal_slider.observe(on_sag_change, names='value')

def on_color_change(_):
    if volume3d is not None:
        redraw()

color_picker.observe(on_color_change, names='value')

def on_show_mask_change(_):
    if volume3d is not None:
        redraw()

show_mask_checkbox.observe(on_show_mask_change, names='value')


############################################################
# Layout
############################################################

controls_tools = HBox([
    tool_selector, mode_selector, brush_size_slider,
    color_picker, show_mask_checkbox
])
controls_prompts = HBox([clear_prompt_button, clear_mask_button, save_mask_button])
controls_slices   = HBox([axial_slider, coronal_slider, sagittal_slider])
controls_view     = HBox([axial_view_button, coronal_view_button, sagittal_view_button, mip_view_button])


save_controls = VBox([
    save_dir_chooser,
    HBox([save_filename_text, save_format_dropdown])
])

ui = VBox([
    status_label,
    fc,
    HBox([load_path_button]),
    controls_tools,
    controls_prompts,
    controls_slices,
    VBox([
        Label("Current view"),
        main_canv,
        controls_view
    ]),
    Label("Save mask:"),
    save_controls,
])

ui

VBox(children=(Label(value='Choose a file → Load from path → View appears. Default view: Axial mid-slice.'), F…