In [1]:
import collections.abc
collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import matplotlib as mpl
import matplotlib.animation
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib ipympl

plt.rcParams['figure.figsize'] = [10, 4]
plt.rcParams['font.size'] = 8
mpl.rc('image', cmap='gray')
import trackpy as tp
tp.quiet()

import numpy as np
import pandas as pd
import csv, json
import pims
from PIL import Image, ImageDraw
import cv2

from scipy.optimize import dual_annealing, linear_sum_assignment
from scipy.spatial import distance_matrix
from tqdm import tqdm
import random

import skimage
from csbdeep.utils import normalize
from stardist.models import StarDist2D
from stardist.data import test_image_nuclei_2d
from stardist.plot import render_label
from csbdeep.utils import normalize
from stardist import random_label_cmap, _draw_polygons, export_imagej_rois
np.random.seed(6)
lbl_cmap = random_label_cmap()
# initialize model with versatile fluorescence pretrained weights
model = StarDist2D.from_pretrained('2D_versatile_fluo')
print(model)

2023-05-20 15:15:22.321441: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Found model '2D_versatile_fluo' for 'StarDist2D'.


2023-05-20 15:15:39.829063: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.
StarDist2D(2D_versatile_fluo): YXC → YXC
├─ Directory: None
└─ Config2D(n_dim=2, axes='YXC', n_channel_in=1, n_channel_out=33, train_checkpoint='weights_best.h5', train_checkpoint_last='weights_last.h5', train_checkpoint_epoch='weights_now.h5', n_rays=32, grid=(2, 2), backbone='unet', n_classes=None, unet_n_depth=3, unet_kernel_size=[3, 3], unet_n_filter_base=32, unet_n_conv_per_depth=2, unet_pool=[2, 2], unet_activation='relu', unet_last_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_prefix='', net_conv_after_unet=128, net_input_shape=[None, None, 1], net_mask_shape=[None, None, 1], train_shape_completion=False, train_completion_crop=32, train_patch_size=[256, 256], train_background_reg=0.0001, train_foreground_only=0.9, train_sample_cache=True, train_dist_loss='mae', train_loss_weights=[1, 0.2], train_class_weights=(1

In [None]:
@pims.pipeline
def preprocessing(image, x1, y1, x2, y2):
    """
    Preprocessing function for the data.

    Parameters
    ----------
    image : pims.Frame
        Frame of the video.
    x1 : int
        x coordinate of the top left corner of the ROI. (region of interest)
    y1 : int
        y coordinate of the top left corner of the ROI.
    x2 : int    
        x coordinate of the bottom right corner of the ROI.
    y2 : int    
        y coordinate of the bottom right corner of the ROI.

    Returns
    -------
    npImage : np.array
        Preprocessed image.
    """
    npImage = np.array(image)
    alpha = Image.new('L', (920, 960), 0)
    draw = ImageDraw.Draw(alpha)
    draw.pieslice(((x1, y1), (x2, y2)), 0, 360, fill=255)
    npAlpha = np.array(alpha)
    npImage = cv2.cvtColor(npImage, cv2.COLOR_BGR2GRAY)*npAlpha
    ind = np.where(npImage == 0)
    npImage[ind] = npImage[200, 200]
    kernel = np.array([[0, -1, 0],
                   [-1, 5,-1],
                   [0, -1, 0]])
    # sharpen image https://en.wikipedia.org/wiki/Kernel_(image_processing)
    image_sharp = cv2.filter2D(src=npImage, ddepth=-1, kernel=kernel)
    #npImage = cv2.medianBlur(npImage, 5)
    #npImage = normalize(npImage)
    return npImage

In [None]:
data = preprocessing(pims.open('./data/movie.mp4'), 40, 55, 895, 910)
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.imshow(data[0], cmap='gray')
plt.show()

In [None]:
path = './stardist_res/sharp_post/'

In [None]:
if 0:
    preprocessed_data = np.zeros((30000, data[0].shape[0], data[0].shape[1]), dtype=data[0].dtype)
    for i in tqdm(range(30000)):
        preprocessed_data[i] = data[i]
    #np.savez_compressed(path + 'preprocessed_data.npz', data=preprocessed_data) # --> 15 min
else:
    preprocessed_data = np.load(path + 'preprocessed_post_merge.npz')['data'] # --> 3 min

correct_n = 49

In [None]:
## TEST
frame = -1
img = preprocessed_data[frame]
labels_test, dict_test = model.predict_instances(normalize(img), predict_kwargs = {'verbose':False}) 

fig, ax = plt.subplots(1, 1, figsize = (10, 5))
ax.imshow(labels_test)
ax.scatter(dict_test['points'][:,1], dict_test['points'][:,0], c='r', s=5)
plt.show()

plt.figure(figsize = (10, 5))
coord, points, prob = dict_test['coord'], dict_test['points'], dict_test['prob']
ax = plt.subplot(121)
ax.imshow(img, cmap='gray'); #plt.axis('off')
ax.set(title = 'Preprocessed Image', xlabel='x', ylabel='y')
ax1 = plt.subplot(122, sharex=ax, sharey=ax)
ax1.imshow(img, cmap='gray'); #plt.axis('off')
_draw_polygons(coord, points, prob, show_dist=True)
ax1.set(title = 'Stardist result', xlabel='x', ylabel='y')
#ax.set(xlim=(200, 600), ylim=(200, 600))
plt.tight_layout()
plt.savefig(path + 'stardist_test.png', bbox_inches='tight')
plt.show()

In [None]:
run = False
if run:
    ## SEGMENT ALL FRAMES AND SAVE THEM IN A NPZ FILE 
    ## COMPUTE THE FEATURES AND SAVE THEM IN A DATAFRAME
    nFrames = 10000
    segm_preload = np.zeros((nFrames, 960, 920), dtype=np.int8)
    area, x, y, prob = [], [], [], []

    for frame in tqdm(range(nFrames)):
        segm_preload[frame], dict_test = model.predict_instances(normalize(data[frame]), predict_kwargs = {'verbose':False})
        test = skimage.measure.regionprops_table(segm_preload[frame], properties=('centroid', 'area'))
        area += list(test['area'])
        x += list(test['centroid-0'])
        y += list(test['centroid-1'])
        prob += list(dict_test['prob'])
        frames += list(np.ones(len(test))*frame)

    df = pd.DataFrame({'x':x, 'y':y, 'area':area, 'prob':prob})
    df.to_parquet('./data/df.parquet')
    print(df)

    # Save the labeled elements using numpy.savez_compressed
    np.savez_compressed(path, data=segm_preload)
else:
    labeled_elements = np.load(path + 'segm.npz')['labeled_elements'] #'data'
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(labeled_elements[-1], cmap='gray')
    plt.show()

In [None]:
df = pd.read_parquet(path + 'df.parquet')
df = df.loc[df.r.between(15, 30)]
print("frames:", len(df.frame.unique()))
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.plot(df.frame.unique(), df.groupby('frame').count().x.values)
ax.axhline(correct_n, color='r')
plt.show()
#df = df.loc[df.r.between(15, 30)]

df = df.groupby('frame').apply(lambda x: x.nlargest(49, 'prob'))
df = df.reset_index(drop=True)
display(df)

In [None]:
err_frames = np.where(df.groupby('frame').count().x != correct_n)[0] + df.frame.min()
print(len(err_frames))

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(10, 4))
ax[0, 0].plot(df.frame.unique(), df.groupby('frame').count().x.values)
ax[0, 1].plot(df.r, '.')
ax[1, 0].hist(df.area, bins=100, density=True)
ax[1, 1].scatter(df.r, df.prob, s=0.1)
plt.show()

In [None]:
mmmeh_frames = df.loc[df.r > 24].frame.unique()
print(mmmeh_frames)
frame = mmmeh_frames[0]
df_plot = df.loc[df.frame == frame]

fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.imshow(preprocessed_data[frame - df.frame.min()])
for i in range(len(df_plot)):
    ax.add_artist(plt.Circle((df_plot.x.values[i], df_plot.y.values[i]), df_plot.r.values[i], color='r', fill=False))
ax.set(title='Labeled elements', xlabel='x', ylabel='y')
ax1.set(title='Preprocessed image', xlabel='x', ylabel='y')
plt.show()

In [None]:
#############################################################################################################
#                                         LINK FEATURES WITH TRACKPY                                        #
#############################################################################################################
if 1:
    t = tp.link_df(df, 150, memory = 3, link_strategy = 'hybrid', neighbor_strategy = 'KDTree', adaptive_stop = 1)
    #print(t)
    t = t.sort_values(['frame', 'particle'])

    # CREATE COLOR COLUMN AND SAVE DF
    n = max(t.particle)
    print(n)
    random.seed(5)
    colors = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(n)]
    for i in range(max(t.particle)+1-n):
        colors.append("#00FFFF")
    c = []
    for p in t.particle:
        c.append(colors[p])
    t["color"] = c
    trajectory = t.copy()
    print(trajectory)
    trajectory.to_parquet(path + 'df_linked.parquet')
else:
    trajectory = pd.read_parquet(path + 'df_linked.parquet')
    print(trajectory)

In [None]:
fig = plt.figure(figsize = (5, 5))
anim_running = True

def onClick(event):
    global anim_running
    if anim_running:
        ani.event_source.stop()
        anim_running = False
    else:
        ani.event_source.start()
        anim_running = True

def update_graph(frame):
    df = trajectory.loc[(trajectory.frame == frame) , ["x", "y", "color", "r"]]
    for i in range(len(df)):
        graph[i].center = (df.x.values[i], df.y.values[i])
        graph[i].radius = df.r.values[i]
    graph2.set_data(preprocessed_data[frame-trajectory.frame.min()])
    title.set_text('Tracking raw - frame = {}'.format(frame))
    return graph

ax = fig.add_subplot(111)
title = ax.set_title('Tracking stardist + trackpy - frame = 0')
ax.set(xlabel = 'X [px]', ylabel = 'Y [px]')
df = trajectory.loc[(trajectory.frame == trajectory.frame.min()), ["x", "y", "color", "r"]]

graph = []
for i in range(len(df)):
    graph.append(ax.add_artist(plt.Circle((df.x.values[i], df.y.values[i]), df.r.values[i], color = df.color.values[i],\
                                           fill = False, linewidth=1)))
graph2 = ax.imshow(preprocessed_data[0])

fig.canvas.mpl_connect('button_press_event', onClick)
ani = matplotlib.animation.FuncAnimation(fig, update_graph, range(trajectory.frame.min(), trajectory.frame.max(), 1), interval = 5, blit=False)
if 1: 
    writer = matplotlib.animation.FFMpegWriter(fps = 30, metadata = dict(artist='Matteo Scandola'), extra_args=['-vcodec', 'libx264'])
    ani.save(path + 'tracking.mp4', writer=writer, dpi = 300)
plt.close()