In [None]:
from pathlib import Path
from collections import namedtuple

import numpy as np
from pyometiff import OMETIFFWriter
import webknossos as wk

import dask.array as da

import pandas as pd
from skimage.measure import label, regionprops_table
from skimage.color import label2rgb

from webknossos import BoundingBox
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.transforms as mtransforms
#matplotlib.use('Agg')

from webknossos_utils import Pixel_size, Annotation
import napari

AUTH_TOKEN = "S-QRDIegZYX0IM1lXmyiJg" #2024-07-04
WK_TIMEOUT="3600" # in seconds
ORG_ID = "83d574429f8bc523" # gisela's webknossos

#sample1   = Annotation("6644c04d0100004a01fa11af", "deprecated","60_2_5R")
#sample2    = Annotation("664316880100008a049e890e", "deprecated","60_2_3L")
id_1 = "6644c04d0100004a01fa11af"
id_2 = "664316880100008a049e890e"
id_3 = "664606440100005102550210"

#666c4ae70100002b015e2344
# the dataset url comes from the WEBKNOSSOS website, open the image of interest from the dashboard and check
# I removed the view information

ANNOTATION_ID = id_2

with wk.webknossos_context(token=AUTH_TOKEN):
    annotations = wk.Annotation.open_as_remote_dataset(annotation_id_or_url=ANNOTATION_ID)
    lbl_layers = annotations.get_segmentation_layers()

    label_indices = {i.name : l for l,i in enumerate(lbl_layers)}

    DATASET_NAME = annotations._properties.id['name']
    ds = wk.Dataset.open_remote(dataset_name_or_url=DATASET_NAME, organization_id=ORG_ID)
    img_layer = ds.get_color_layers()
    assert len(img_layer) == 1, "more than an image, this is unexpected for this project"
    img_layer = img_layer[0]    

    voxel_size = ds.voxel_size
    mag_list = list(img_layer.mags.keys())
    
    #print(mag_list)
    
    MAG = mag_list[3]
    pSize = Pixel_size(voxel_size[0] * MAG.x, voxel_size[1] * MAG.y, voxel_size[2] * MAG.z, MAG=MAG, unit="nm")

    img_data = img_layer.get_mag(pSize.MAG).read()
    lbl_data = lbl_layers[label_indices["Myelin"]].get_mag(pSize.MAG).read()

#print(img_data.shape)
img_dask = da.from_array(np.swapaxes(img_data,-1,-3), chunks=(1,1,512,512))
#print(img_dask.shape)
# img_dask


unique_lbls = np.unique(lbl_data)
#print(unique_lbls)


if np.max(unique_lbls) < 512:
    lbl_data = lbl_data.astype(np.uint8)

#lbl_dask = da.from_array(np.swapaxes(lbl_data,-1,-3), chunks=(1,5,512,512))
lbl_dask = da.from_array(np.swapaxes(lbl_data,-1,-3), chunks=(1,5,512,512))
# lbl_dask



plt.imshow(lbl_dask[0,0,:,:])
plt.show()

# plt.imshow(img_dask[0][0])
# #plt.imshow(img_data)
# plt.show()

segmentation = np.nonzero(lbl_dask[0,0])

bbox = 0, 0, 0, 0
#if len(segmentation) != 0 and len(segmentation[1]) != 0 and len(segmentation[0]) != 0:
x_min = int(np.min(segmentation[1]))
x_max = int(np.max(segmentation[1]))
y_min = int(np.min(segmentation[0]))
y_max = int(np.max(segmentation[0]))

block_size = 160 #size of resulting training images in pixels (both x and y)

from utils import make_block_bb

large_bbox = make_block_bb(x_min,y_min,x_max, y_max,block_size)

#construct WK bbox from large_bbox
wk_bbox = wk.BoundingBox(topleft=(x_min,y_min,0), size=(x_max-x_min,y_max-y_min,1))
wk_bbox = wk_bbox.align_with_mag(pSize.MAG)
wk_bbox = wk_bbox.from_mag_to_mag1(pSize.MAG)

with wk.webknossos_context(token=AUTH_TOKEN, timeout=WK_TIMEOUT):
    img_data_ha = img_layer.get_mag(pSize.MAG).read()
    img_data_small = img_layer.get_finest_mag().read(absolute_bounding_box=wk_bbox)
    #img_data_ha = img_layer.get_finest_mag().read(absolute_offset=(x_min,y_min,0), size=(x_max-x_min,y_max-y_min,1))
    #img_data_ha = img_layer.get_finest_mag().read()

img_dask_ha = da.from_array(np.swapaxes(img_data_ha,-1,-3), chunks=(1,1,512,512))
plt.imshow(img_dask_ha[0,0,:,:])


img_dask_small = da.from_array(np.swapaxes(img_data_small,-1,-3), chunks=(1,2,512,512))


from matplotlib.patches import Rectangle
from PIL import Image
ax = plt.gca()

# Create a Rectangle patch
bx = wk_bbox.topleft[0]
by = wk_bbox.topleft[1]
bw = wk_bbox.size.x
bh = wk_bbox.size.y
rect = Rectangle((bx,by),bw,bh,linewidth=1,edgecolor='r',facecolor='none')

# Add the patch to the Axes
ax.add_patch(rect)
plt.show()




#get all annotations as bboxes

import myelin_morphometrics as morpho
#from skimage.morphology import convex_hull_image
#from skimage.segmentation import active_contour
#from skimage.measure import find_contours
from webknossos_utils import skibbox2wkbbox
from scipy import ndimage
#import numpy
#from IPython.display import clear_output

properties = ['label', 'bbox', 'centroid']
label_img = lbl_dask[0,0,:,:].compute()
reg_table = regionprops_table(label_image=label_img,
                          properties=properties)
reg_table = pd.DataFrame(reg_table)
#real_label = np.zeros_like(label_img)


#aoi = label_img[x_min:x_max,y_min:y_max].compute()

Mag1 = wk.Mag("1")
# with wk.webknossos_context(token=AUTH_TOKEN, timeout=WK_TIMEOUT):
#     img_mag = img_layer.get_mag(pSize.MAG)
#     img_large = img_mag.read()

#with wk.webknossos_context(token=AUTH_TOKEN, timeout=WK_TIMEOUT):
#    img_data_ha = img_layer.get_finest_mag().read(absolute_offset=(x_min,y_max,0), size=(x_max-x_min,y_max-y_min,1))

#print(img_data_ha.shape)
#img_dask_ha = da.from_array(np.swapaxes(img_data_ha,-1,-3), chunks=(1,1,512,512))
#plt.imshow(img_dask_ha[0][0])
#plt.imshow(myelin_bw_fill)
#plt.show()


plt.imshow(img_dask_small[0,0,:,:])
#plt.show()

for index, row in reg_table.iterrows():
    obj_idx = row['label']

    bbox = skibbox2wkbbox(row.to_dict(), pSize)
    print(f"bbox size: {bbox.size}")
    # img_data = img_layer.get_finest_mag().read(absolute_offset=wk_bbox.in_mag(1).topleft, size=wk_bbox.in_mag(1).size)
    #img_data = img_layer.get_finest_mag().read(absolute_offset=(x_min,y_max,0), size=(x_max-x_min,y_max-y_min,1))
    #img_data = img_layer.get_finest_mag().read()

    myelin_lbl = lbl_layers[label_indices["Myelin"]].get_finest_mag().read(absolute_offset=bbox.topleft, size=bbox.size).squeeze()

    myelin_bw = morpho.get_BW_from_lbl(myelin_lbl, obj_idx)
    # clean myelin label map, in case other neurons are close by
    myelin_lbl[np.logical_not(myelin_bw)] = 0
    # Create the padded myelin_bw image to avoid edge effects in the contours
    myelin_bw_fill = ndimage.binary_fill_holes(myelin_bw)

    com = ndimage.center_of_mass(myelin_bw_fill)
    #check if "center of mass" index is true. If it is it is considered filled,
    #otherwise we fill it below
    expected_filled = myelin_bw_fill[int(com[0]), int(com[1])]

    #from utils import fill_with_convex_hull
    if not expected_filled:
        myelin_bw_fill = fill_with_convex_hull(myelin_bw_fill)
        # myelin_bw_ch = convex_hull_image(myelin_bw_fill)
        # myelin_bw_invert = np.invert(myelin_bw_fill)
        # myelin_bw_solid = np.logical_and(myelin_bw_invert,myelin_bw_ch)
        # myelin_bw_fill = np.logical_or(myelin_bw_fill,myelin_bw_solid)

    myelin_lbl[np.nonzero(myelin_bw_fill)] = 25  
    print(f"label size: {myelin_lbl.shape}")  

    tx = bbox.topleft[0]
    ty = bbox.topleft[1]
    wx = bbox.size.x
    wy = bbox.size.y


    im = plt.imshow(myelin_lbl)

    # ax = plt.gca()
    # ax.set_xlim(0, 300)
    # ax.set_ylim(0, 300)

    # imgplot = ax.imshow(myelin_bw_fill)
    transform = mtransforms.Affine2D().translate(tx, ty)
    im.set_transform(transform + ax.transData)

    plt.show()

    if index > 1:
        break

In [None]:
plt.imshow(img_dask_ha[0,0,:,:])

In [None]:
ax = plt.gca()

# Create a Rectangle patch
bx = wk_bbox.topleft[0]
by = wk_bbox.topleft[1]
bw = wk_bbox.size.x
bh = wk_bbox.size.y
rect = Rectangle((bx,by),bw,bh,linewidth=1,edgecolor='r',facecolor='none')

# Add the patch to the Axes
ax.add_patch(rect)
plt.show()

In [None]:
plt.imshow(img_dask_ha[0,0,:,:])
ax = plt.gca()

# Create a Rectangle patch
bx = wk_bbox.topleft[0]
by = wk_bbox.topleft[1]
bw = wk_bbox.size.x
bh = wk_bbox.size.y
rect = Rectangle((bx,by),bw,bh,linewidth=1,edgecolor='r',facecolor='none')

# Add the patch to the Axes
ax.add_patch(rect)
plt.show()

In [None]:
plt.imshow(img_dask_small[0,0,:,:])


In [None]:
im = plt.imshow(myelin_lbl)

In [None]:
im = plt.imshow(myelin_lbl)

In [None]:
transform = mtransforms.Affine2D().translate(tx, ty)
    im.set_transform(transform + ax.transData)

In [None]:
im = plt.imshow(myelin_lbl)
transform = mtransforms.Affine2D().translate(tx, ty)
im.set_transform(transform + ax.transData)

In [None]:
plt.imshow(img_dask_small[0,0,:,:])
im = plt.imshow(myelin_lbl)

    # ax = plt.gca()
    # ax.set_xlim(0, 300)
    # ax.set_ylim(0, 300)

    # imgplot = ax.imshow(myelin_bw_fill)
transform = mtransforms.Affine2D().translate(tx, ty)
im.set_transform(transform + ax.transData)


In [None]:
plt.imshow(img_dask_small[0,0,:,:])
im = plt.imshow(myelin_lbl)

ax = plt.gca()
    # ax.set_xlim(0, 300)
    # ax.set_ylim(0, 300)

    # imgplot = ax.imshow(myelin_bw_fill)
transform = mtransforms.Affine2D().translate(tx, ty)
im.set_transform(transform + ax.transData)


In [None]:
plt.imshow(img_dask_small[0,0,:,:])
im = plt.imshow(myelin_lbl)

ax = plt.gca()
    # ax.set_xlim(0, 300)
    # ax.set_ylim(0, 300)

imgplot = ax.imshow(myelin_bw_fill)
transform = mtransforms.Affine2D().translate(tx, ty)
imgplot.set_transform(transform + ax.transData)
