In [None]:
from pathlib import Path
from collections import namedtuple

import numpy as np
from pyometiff import OMETIFFWriter
import webknossos as wk

import dask.array as da

import pandas as pd
from skimage.measure import label, regionprops_table
from skimage.color import label2rgb

from webknossos import BoundingBox
import matplotlib.pyplot as plt

from webknossos_utils import Pixel_size, Annotation
import napari


In [None]:
# this comes from the WEBKNOSSOS website, under the user settings
#AUTH_TOKEN = "1mng65J7d-5IVFmfJoF4rw" # from gisela

AUTH_TOKEN = "S-QRDIegZYX0IM1lXmyiJg" #2024-07-04
WK_TIMEOUT="3600" # in seconds
ORG_ID = "83d574429f8bc523" # gisela's webknossos

#sample1   = Annotation("6644c04d0100004a01fa11af", "deprecated","60_2_5R")
#sample2    = Annotation("664316880100008a049e890e", "deprecated","60_2_3L")
id_1 = "6644c04d0100004a01fa11af"
id_2 = "664316880100008a049e890e"
id_3 = "664606440100005102550210"

#666c4ae70100002b015e2344
# the dataset url comes from the WEBKNOSSOS website, open the image of interest from the dashboard and check
# I removed the view information

ANNOTATION_ID = id_2

with wk.webknossos_context(token=AUTH_TOKEN):
    annotations = wk.Annotation.open_as_remote_dataset(annotation_id_or_url=ANNOTATION_ID)
    lbl_layers = annotations.get_segmentation_layers()

    DATASET_NAME = annotations._properties.id['name']
    ds = wk.Dataset.open_remote(dataset_name_or_url=DATASET_NAME, organization_id=ORG_ID)
    img_layer = ds.get_color_layers()
    assert len(img_layer) == 1, "more than an image, this is unexpected for this project"
    img_layer = img_layer[0]

    #dataset = wk.Dataset.open_remote(DATASET_URL)
    #img_data = dataset.get_color_layer().get_finest_mag().read()
    

In [None]:
for i, lbl in enumerate(lbl_layers):
    if lbl.name == "Myelin":
        myelin_idx = i
    elif lbl.name == "Axon":
        axon_idx = i
    elif lbl.name == "Mitochondria":
        mito_idx = i
    elif lbl.name == "Dystrophic_myelin":
        dystrophic_idx = i

print(myelin_idx)
print(axon_idx)

In [None]:
import matplotlib.pyplot as plt

voxel_size = ds.voxel_size
mag_list = list(img_layer.mags.keys())
print(mag_list)
MAG = mag_list[3]
pSize = Pixel_size(voxel_size[0] * MAG.x, voxel_size[1] * MAG.y, voxel_size[2] * MAG.z, MAG=MAG, unit="nm")
print(pSize)
with wk.webknossos_context(token=AUTH_TOKEN, timeout=WK_TIMEOUT):
    img_data = img_layer.get_mag(pSize.MAG).read()
    #img_data_large = img_layer.get_finest_mag().read()


    
print(img_data.shape)
img_dask = da.from_array(np.swapaxes(img_data,-1,-3), chunks=(1,1,512,512))
img_dask

plt.imshow(img_dask[0][0])
#plt.imshow(img_data)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import webknossos as wk

with wk.webknossos_context(token=AUTH_TOKEN):
    lbl_data = lbl_layers[myelin_idx].get_mag(pSize.MAG).read()
unique_lbls = np.unique(lbl_data)
print(unique_lbls)


if np.max(unique_lbls) < 512:
    lbl_data = lbl_data.astype(np.uint8)

lbl_dask = da.from_array(np.swapaxes(lbl_data,-1,-3), chunks=(1,5,512,512))
lbl_dask

plt.imshow(lbl_dask[0][0])
plt.show()

# a string describing the dimension ordering
dimension_order = "CZYX"

segmentation = np.nonzero(lbl_dask[0][0])

bbox = 0, 0, 0, 0
#if len(segmentation) != 0 and len(segmentation[1]) != 0 and len(segmentation[0]) != 0:
y_min = int(np.min(segmentation[1]))
y_max = int(np.max(segmentation[1]))
x_min = int(np.min(segmentation[0]))
x_max = int(np.max(segmentation[0]))


block_size = 160 #size of resulting training images in pixels (both x and y)

from utils import make_block_bb

#make the size of the BBOX, in both dimensions, a multiple of block_size
# x_size = x_max-x_min
# x_mod = x_size % block_size
# x_div = x_size // block_size
# x_add = (block_size - x_mod)
# x_add_half =  x_add // 2

# y_size = y_max-y_min
# y_mod = y_size % block_size
# y_div = y_size // block_size
# y_add = (block_size - y_mod)
# y_add_half = y_add // 2

# x_min -= x_add_half
# x_max += x_add - x_add_half
# y_min -= y_add_half
# y_max += y_add - y_add_half
#the additions are weird in order to handle odd numbers of x/y_add
large_bbox = make_block_bb(x_min,y_min,x_max, y_max,block_size)

#construct WK bbox rom large_bbox
wk_bbox = wk.BoundingBox(topleft=(x_min,y_min,0), size=(x_max-x_min,y_max-y_min,1))

#make wk bbox as in mag 1

In [None]:

#wk_bbox = wk_bbox.align_with_mag(pSize.MAG)
#img_data_ha = img_layer.get_mag(pSize.MAG).read(absolute_offset=wk_bbox.topleft, size=wk_bbox.size).squeeze()
#fig = plt.figure(figsize=(12, 12)) 
#plt.imshow(label_img)

alphas = numpy.array(lbl_dask[0][0][x_min:x_max,y_min:y_max],dtype=float)
alphas[np.nonzero(lbl_dask[0][0][x_min:x_max,y_min:y_max])] = 0.5
fig = plt.figure(figsize=(20, 20)) 
#cmapr = matplotlib.colors.Colormap("myred", [(0.0,0.0,0.0,1.0),(1.0,0.0,0.0,1.0)])
#cmapr.set_over(color='r')
plt.imshow(img_dask[0][0][x_min:x_max,y_min:y_max])

plt.imshow(lbl_dask[0][0][x_min:x_max,y_min:y_max], alpha=alphas,cmap=plt.cm.Reds)
#plt.imshow(lbl_dask[0][0][x_min:x_max,y_min:y_max], cmap=plt.cm.Reds)

# plt.show()
# plt.imshow(img_data[0])
plt.show()

In [None]:
#get all the individual annotations/labels and fill them
import myelin_morphometrics as morpho
from skimage.morphology import convex_hull_image
from skimage.segmentation import active_contour
from skimage.measure import find_contours
from webknossos_utils import skibbox2wkbbox
from scipy import ndimage
import numpy
from IPython.display import clear_output

properties = ['label', 'bbox', 'centroid']
label_img = lbl_dask[0,0,:,:].compute()
reg_table = regionprops_table(label_image=label_img,
                          properties=properties)
reg_table = pd.DataFrame(reg_table)
fig = plt.figure(figsize=(10, 10)) 
real_label = numpy.zeros_like(label_img)
#aoi = label_img[x_min:x_max,y_min:y_max].compute()

Mag1 = wk.Mag("1")
# with wk.webknossos_context(token=AUTH_TOKEN, timeout=WK_TIMEOUT):
#     img_mag = img_layer.get_mag(pSize.MAG)
#     img_large = img_mag.read()

with wk.webknossos_context(token=AUTH_TOKEN, timeout=WK_TIMEOUT):
    img_data_ha = img_layer.get_finest_mag().read(absolute_offset=(x_min,y_max,0), size=(x_max-x_min,y_max-y_min,1))

print(img_data_ha.shape)
img_dask_ha = da.from_array(np.swapaxes(img_data_ha,-1,-3), chunks=(1,1,512,512))
plt.imshow(img_dask_ha[0][0])
#plt.imshow(myelin_bw_fill)
plt.show()

for index, row in reg_table.iterrows():
    obj_idx = row['label']

    bbox = skibbox2wkbbox(row.to_dict(), pSize)
    print(f"bbox size: {bbox.size}")
    # img_data = img_layer.get_finest_mag().read(absolute_offset=wk_bbox.in_mag(1).topleft, size=wk_bbox.in_mag(1).size)
    #img_data = img_layer.get_finest_mag().read(absolute_offset=(x_min,y_max,0), size=(x_max-x_min,y_max-y_min,1))
    #img_data = img_layer.get_finest_mag().read()

    myelin_lbl = lbl_layers[myelin_idx].get_finest_mag().read(absolute_offset=bbox.topleft, size=bbox.size).squeeze()

    myelin_bw = morpho.get_BW_from_lbl(myelin_lbl, obj_idx)
    # clean myelin label map, in case other neurons are close by
    myelin_lbl[np.logical_not(myelin_bw)] = 0
    # Create the padded myelin_bw image to avoid edge effects in the contours
    myelin_bw_fill = ndimage.binary_fill_holes(myelin_bw)

    com = ndimage.center_of_mass(myelin_bw_fill)
    #check if "center of mass" index is true. If it is it is considered filled,
    #otherwise we fill it below
    expected_filled = myelin_bw_fill[int(com[0]), int(com[1])]

    #from utils import fill_with_convex_hull
    if not expected_filled:
        myelin_bw_fill = fill_with_convex_hull(myelin_bw_fill)
        # myelin_bw_ch = convex_hull_image(myelin_bw_fill)
        # myelin_bw_invert = np.invert(myelin_bw_fill)
        # myelin_bw_solid = np.logical_and(myelin_bw_invert,myelin_bw_ch)
        # myelin_bw_fill = np.logical_or(myelin_bw_fill,myelin_bw_solid)

    myelin_lbl[np.nonzero(myelin_bw_fill)] = np.rand
    print(f"label size: {myelin_lbl.shape}")  

    #lbl_dask[0][0][bbox.bottomright[0]:bbox.topleft[0],bbox.bottomright[1]:bbox.topleft[1]] = 4#myelin_bw_fill[::]
#    aoi[myelin_bw]
    row_inds, col_inds = numpy.nonzero(myelin_bw_fill)
    #clear_output(wait=True)
    label_img[row_inds, col_inds] = 124
    fig = plt.figure(figsize=(12, 12)) 
    #plt.imshow(label_img)
    plt.imshow(img_dask_ha[0][0])
    #plt.imshow(myelin_bw_fill)
    plt.show()
    print(index)
    if index > 20:
        break

    print(f"label image size: {label_img.shape}")  


    
#plt.show()

In [None]:
import numpy
import matplotlib.colors
#aoi = label_img
#aoi = lbl_dask[0,0,x_min:x_max,y_min:y_max].compute()
#plt.imshow(lbl_data[0])
#aoi[numpy.nonzero(aoi)] = 220
alphas = numpy.array(aoi,dtype=float)
alphas[np.nonzero(aoi)] = 0.5
fig = plt.figure(figsize=(100, 100)) 
#cmapr = matplotlib.colors.Colormap("myred", [(0.0,0.0,0.0,1.0),(1.0,0.0,0.0,1.0)])
#cmapr.set_over(color='r')
plt.imshow(aoi, alpha=alphas,cmap=plt.cm.Reds)


In [None]:
import numpy
aoii = img_dask[0,0,x_min:x_max,y_min:y_max]
#plt.imshow(lbl_data[0])
fig = plt.figure(figsize=(100, 100)) 

#step = 160#256/2
ax = fig.gca()
ax.set_xticks(numpy.arange(0, x_max, block_size))
ax.set_yticks(numpy.arange(0, y_max, block_size))

plt.grid(color='r', lw=4)

plt.imshow(aoii)
plt.imshow(aoi,alpha=alphas,cmap=plt.cm.Reds)
plt.show()

In [None]:

#chunks = numpy.reshape(aoi,(block_size,block_size,(x_div+1)*(y_div+1)))
chunks_lbl = numpy.reshape(aoi,[x_div+1,block_size,y_div+1,block_size])
chunks_lbl = chunks_lbl.swapaxes(1,2)

chunks_img = numpy.reshape(aoii,[x_div+1,block_size,y_div+1,block_size])
chunks_img = chunks_img.swapaxes(1,2)

alphas = numpy.reshape(alphas,[x_div+1,block_size,y_div+1,block_size])
alphas = alphas.swapaxes(1,2)


In [None]:

fig = plt.figure(figsize=(5, 5))

for x in range(x_div+1):
    for y in range(y_div+1):
        if numpy.max(chunks_lbl[x][y]) > 0:
            plt.imshow(chunks_img[x][y])
            plt.imshow(chunks_lbl[x][y], interpolation='nearest', alpha=alphas[x][y], cmap=plt.cm.Reds)
            plt.show() 