In [1]:
import numpy as np
from skimage import data, io
from skimage.filters import threshold_otsu
import os
import imageio
from matplotlib import pyplot as plt
import pandas as pd
import napari
import tifffile
from scipy import spatial
import pickle

from sklearn.preprocessing import normalize, PolynomialFeatures
from sklearn import linear_model

In [2]:
from threadpoolctl import threadpool_limits
import ray
# num_cpus=48
# ray.init(num_cpus=num_cpus, ignore_reinit_error=True)

In [3]:
io_directory = '/mnt/ampa_data01/tmurakami/220305_SMA_nuc_middlehuman/vessel_analysis'

In [4]:
asma_downscale = np.load(os.path.join(io_directory,'vessel.npy'))
dim = asma_downscale.ndim
print(asma_downscale.shape)

(450, 819, 546)


### If meninge mask are required, make it using labkit

In [5]:
# mask_meninge = 1 - tifffile.imread(os.path.join(io_directory,'meninge.tif'))

In [5]:
# Prepare mask. This time, use mask prepared with ClearMap.
mask = np.load(os.path.join(io_directory,'binary_final.npy'))#np.load('/scratch2/Share/tmurakami/220121_human_sma_5mm_1st/vessel/binary_final.npy')
# mask = mask * mask_meninge
mask_1d = mask.flatten()

# Extract the vectors and positions using mask.
points_position_array = np.array(np.where(mask)).T
points_position = points_position_array.tolist()

In [6]:
# manually design your guide vector
# To Do: automation of the vector detection.
guide_coordinate1 = np.array([237.,145.,405.])
guide_coordinate2 = np.array([201.,181.,373.])
guide_vector = guide_coordinate2 - guide_coordinate1
guide_vector = guide_vector / np.linalg.norm(guide_vector)

In [7]:
# Prepare skeleton. Use the skeleton prepared with ClearMap.
skeleton = np.load(os.path.join(io_directory,'skeleton.npy'))
# skeleton = skeleton * mask_meninge
skeleton_1d = skeleton.flatten()

# Extract the position of the skeleton for the vector field analysis.
skeleton_position_array = np.array(np.where(skeleton)).T
skeleton_position = skeleton_position_array.tolist()

In [8]:
def align_vector_sign(vectors, guide_vector=None):
    '''
    Align the sign of the vector by refering a guide vector. If the dot product of the vector and the guide vector is negative, the sign of vector is flipped.
    Highly encourage to make a guide vector before align sign.
    '''
    if (guide_vector is None):
        guide_vector = normalize(np.median(vectors,axis=0)[:,np.newaxis],axis=0).ravel()
    aligned_vectors = np.where(
        np.repeat(np.expand_dims(np.matmul(vectors, guide_vector) >= 0, axis=1), guide_vector.size, axis=1),
        vectors,
        -vectors
    )
    return aligned_vectors

In [9]:
# Vector field analysis on skeleton using neighbors.
kdtree = spatial.KDTree(skeleton_position_array)
k = 27 # Number of neighbors.
mean_skeleton_vectors = []

for point, point_position in enumerate(skeleton_position):
    # Extract vectors from k-nearest neighbors.
    d, neighbors = kdtree.query(point_position,k)
    neighbors = neighbors[d!=0]
    vectors_from_neighbors = normalize(skeleton_position_array[neighbors,:]-point_position,axis=1) # Normalize to equalize the weights
    mean_vector = np.mean(align_vector_sign(vectors_from_neighbors,guide_vector),axis=0) # Use arithmetic mean.
    mean_vector = normalize(mean_vector[:,np.newaxis],axis=0).ravel()
    mean_skeleton_vectors.append(mean_vector)
mean_skeleton_vectors = np.array(mean_skeleton_vectors)

In [10]:
# Expansion of vector field to binarized image.
kdtree = spatial.KDTree(skeleton_position_array)
k = 1 # Number of neighbors in skeleton.
point_vectors = []

for point, point_position in enumerate(points_position):
    _, neighbors = kdtree.query(point_position,k)
    neighbor_vector = mean_skeleton_vectors[neighbors]
#     if k > 1:
#         mean_neighbor_vectors =np.mean(neighbor_vector,axis=0)
#         neighbor_vector = normalize(mean_neighbor_vectors[:,np.newaxis],axis=0).ravel()
    point_vectors.append(neighbor_vector)
point_vectors = np.array(point_vectors)

In [None]:
'''Start denoising'''

In [11]:
@ray.remote
def get_neighbor_vectors(point_position, kdtree, vectors, radius):
    if not isinstance(point_position, np.ndarray):
        point_position = np.array(point_position)
    neighbors = kdtree.query_ball_point(point_position,radius)
    neighbor_vectors = vectors[neighbors,:]
    return neighbor_vectors

@ray.remote
def get_median_vector(vectors):
    # Ideally, the medoid vector should be calculated, but it is resource demanding. Instead, calculate the median in each dimension and normalize to a unit vector.
    """
    vectors: ndarray
    """
    median_vector = normalize(np.median(vectors,axis=0)[:,np.newaxis],axis=0).ravel()
    return median_vector

@ray.remote
def get_point_vector(point, vectors):
    return vectors[point]

@ray.remote
def single_thread_align_vector_sign(vectors, guide_vector):
    '''
    Align the sign of the vector by refering a guide vector. If the dot product of the vector and the guide vector is negative, the sign of vector is flipped.
    '''
    with threadpool_limits(limits=1, user_api='blas'):
        aligned_vectors = np.where(
            np.repeat(np.expand_dims(np.matmul(vectors, guide_vector) >= 0, axis=1), guide_vector.size, axis=1),
            vectors,
            -vectors
        )
    return aligned_vectors

@ray.remote
def select_point_in_dot_product_space(point_vector, neighbor_vectors, median_vector, k=10):
    with threadpool_limits(limits=1, user_api='blas'):
        # Calculate the dot product
        dot_product = np.matmul(neighbor_vectors, median_vector) # This is done in parallel otherwise stated.
        dot_product_of_point = np.matmul(point_vector,median_vector)
    if k>dot_product.size:
        selection = False
    else:
        # Find k neighbors in dot product space
        dot_product_neighbors = dot_product[np.argsort(np.abs(dot_product-dot_product_of_point))][0:k]
        # Calculate the null density if the density is even distribution.
        null_density = dot_product.size*(dot_product_neighbors.max()-dot_product_neighbors.min())
        selection = (k>null_density)
    # Second selection using otsu thresholding
    if selection:
        thresh = threshold_otsu(dot_product)
        selection = (dot_product_of_point>thresh)
    return selection

@ray.remote
def dot_product_vectors(vector1, vector2):
    with threadpool_limits(limits=1, user_api='blas'):
        dot_product = np.matmul(vector1, vector2)
    return dot_product

In [12]:
%%time
kdtree = spatial.KDTree(points_position_array)
radius = 42 # pixel unit. 500 / voxel micrometer works well. diameter in real scale: 2 * radius * voxelsize. 
k = 10 # Number of neighbor in dot product space. Note this is not a number of neighbor in 3D image space.
keeping = []
# dot_p = []

kdtree_id = ray.put(kdtree)
vectors_id = ray.put(point_vectors)

for point, point_position in enumerate(points_position):
    point_vector = get_point_vector.remote(point, vectors_id)
    # Get vectors in neighbor points.
    neighbor_vectors = get_neighbor_vectors.remote(point_position, kdtree_id, vectors_id, radius)
    # Make representitive vector
    median_vector = get_median_vector.remote(neighbor_vectors)
    neighbor_vectors = single_thread_align_vector_sign.remote(neighbor_vectors, median_vector) # Fix the sign of vectors.

    keeping.append(select_point_in_dot_product_space.remote(point_vector, neighbor_vectors, median_vector, k))
    # dot_p.append(dot_product_vectors.remote(point_vector,median_vector))

keeping = ray.get(keeping)
extract_idx = np.where(mask_1d)[0][keeping]

[2m[36m(get_median_vector pid=605365)[0m 
[2m[36m(get_neighbor_vectors pid=605298)[0m 
[2m[36m(get_neighbor_vectors pid=605415)[0m 
[2m[36m(get_point_vector pid=605362)[0m 
[2m[36m(get_median_vector pid=605452)[0m 
[2m[36m(select_point_in_dot_product_space pid=605358)[0m 
[2m[36m(select_point_in_dot_product_space pid=605433)[0m 
[2m[36m(get_median_vector pid=605355)[0m 
[2m[36m(get_point_vector pid=605350)[0m 
[2m[36m(get_neighbor_vectors pid=605388)[0m 
[2m[36m(get_point_vector pid=605380)[0m 
[2m[36m(get_neighbor_vectors pid=605357)[0m 
[2m[36m(get_neighbor_vectors pid=605356)[0m 
[2m[36m(get_median_vector pid=605457)[0m 
[2m[36m(get_neighbor_vectors pid=605378)[0m 
[2m[36m(get_point_vector pid=605423)[0m 
[2m[36m(get_median_vector pid=605343)[0m 
[2m[36m(get_median_vector pid=605410)[0m 
[2m[36m(get_neighbor_vectors pid=605299)[0m 
[2m[36m(get_median_vector pid=605457)[0m 
[2m[36m(get_median_vector pid=605354)[0m 
[2m[

In [15]:
if False: # True to save images and variables for later use.
    vec_img = np.zeros(asma_downscale.shape+(dim,)).astype(np.float32)
    extracted = points_position_array[keeping]
    vec_img[tuple(extracted.T)] = point_vectors[keeping,:]
    
    # export extracted vetors as image
    tifffile.imwrite(os.path.join(io_directory,'local_vector.tif'),
                 np.moveaxis(vec_img,-1,1).astype(np.float32),
                 imagej=True,
                 metadata={'spacing': 12, 'unit': 'um', 'axes': 'ZCYX'})
    
    # save variables as .npy
    np.save(os.path.join(io_directory,'extract_idx.npy'), extract_idx)
    np.save(os.path.join(io_directory,'point_vectors.npy'), point_vectors)
    np.save(os.path.join(io_directory,'keeping.npy'), np.asarray(keeping))
    

In [13]:
ray.shutdown()

In [19]:
# Fit to the nth polynomial
degree = 5
idx = np.array(np.unravel_index(extract_idx,asma_downscale.shape)).T
vec = point_vectors[keeping,:]
poly = PolynomialFeatures(degree=degree) # Overfitting may happen at the edge?
idx_ = poly.fit_transform(idx)

clf = linear_model.LinearRegression(fit_intercept=False) # False
clf.fit(idx_,vec)# Fit the model
clf.degree = degree# save information for polynomial degree for later use

LinearRegression(fit_intercept=False)

In [18]:
filename = '/mnt/ampa_data01/tmurakami/220305_SMA_nuc_middlehuman/vessel_analysis/model.pkl'
pickle.dump(clf, open(filename, 'wb'))

In [None]:
# Export the vector field as image if it is required.
if False: # honestly, the interpretation is difficult and does not help much.
    all_coord = np.indices(asma_downscale.shape)
    all_coord = np.stack([all_coord[i,:,:,:].flatten() for i in range(dim)], axis=1).astype(int)
    all_coord_ = poly.fit_transform(all_coord)

    fit_img = clf.predict(all_coord_)
    fit_img = fit_img.reshape(asma_downscale.shape+(dim,))
    vector_field_tif = os.path.join(io_directory,'vector_field_interpolation.tif')

    # export extracted vetors as image
    tifffile.imwrite(vector_field_tif,
             np.moveaxis(fit_img,-1,1).astype(np.float32),
             imagej=True,
             metadata={'spacing': 10, 'unit': 'um', 'axes': 'ZCYX'})
    del(all_coord)
    del(all_coord_)
    del(fit_img)

In [23]:
np.asarray(keeping).sum()

218925