In [None]:
import matplotlib.pyplot as plt
import cv2

from byotrack import Video, VideoTransformConfig

import skimage.io as iio
from byotrack.video.transforms import ChannelSelect, ChannelAvg, ScaleAndNormalize 
from PIL import Image
import numpy as np #can switch this out for pytorch at somepoint - notation is identical


icy_path = "/home/noah/Documents/icy-2.4.2.0-all"
tifpath = '/home/noah/Desktop/cellsegtest/segTestNew/shortStack_adjusted' #path to sequence of tiff files

#for video playback only
fps = 7

In [None]:
#reader funtion for tif sequences - handles reshaping and normalising (using the stardist recommended normaliser) (doesn't seem to perform that well)
from csbdeep.utils import normalize as csbdeepNormaliser
    
def Read_Data_TIFseq(vid_path):
    # positions = (pd.read_csv(csv_path,usecols=['TrackID','t','x','y'])).values
    vid = iio.ImageCollection(vid_path + '/*.tif').concatenate() #concatonate to numpyarray
    # red_vid = iio.ImageCollection(red_vid_path + '/*.tif')
    vid = vid.reshape(vid.shape[0], vid.shape[1], vid.shape[2], 1)
    normalisedVid = np.asarray([csbdeepNormaliser(frame) for frame in vid])
    return normalisedVid

video = Read_Data_TIFseq(tifpath)

In [None]:
# Visualisation of the whole video with opencv

for i, frame in enumerate(video):
    try:
        # Display the resulting frame
        cv2.imshow('Frame', frame)
        cv2.setWindowTitle('Frame', f'Frame {i} / {len(video)}')

        # Press Q on keyboard to  exit
        key = cv2.waitKey(1000 // fps) & 0xFF

        if key == ord('q'):
            break

        if cv2.getWindowProperty("Frame", cv2.WND_PROP_VISIBLE) <1:
            break
    except Exception as e:
        print(e)
        cv2.destroyAllWindows()

# Closes all the frames
cv2.destroyAllWindows()
#warning with opencv QOBject::moveToThread can be fixed: https://stackoverflow.com/questions/52337870/python-opencv-error-current-thread-is-not-the-objects-thread

In [None]:
import numpy as np
import torch

from byotrack.implementation.detector.stardist import StarDistDetector
from byotrack.implementation.detector.wavelet import WaveletDetector

In [None]:
# Create the detector object with its hyper parameters
model_path = "/home/noah/Desktop/STARDIST_CONFOCAL/NEWEST TdT MODELS/10X_IMAGES_ONLY"
detector = StarDistDetector(model_path, batch_size=5)
#TODO: figure out batch size

In [None]:

#TODO: add roi funtionality to this so plotting shows mask over image (better evaluation) Will need to edit both stardist detector and the detector class .detect() function to include polygon data and add imageJ ROI code
vidCopy = video[0:50] #test batch
scale = 1

global frameID
frameID = 0
frame = video[frameID]
h, w = frame.shape[0:2]
global mask_glob
detection_zero = detector.run([vidCopy[frameID]])
mask_glob = (detection_zero[frameID].segmentation.numpy() != 0).astype(np.uint8) * 255

# window_name = 'Frame', f'Frame {frameID} / {len(detections_sequence_test)} - Number of detections: {len(detections_sequence_test[i])}'
window_name = 'Paramater Test - Segmentation   (Press Q to Quit)'

try:

    #create and rescale window
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(window_name, h*scale, w*scale)

    #Frame Trackbar
    def update_frame(x): #callback function for trackbar - default argument is the position of the track bar
        detections = detector.detect(vidCopy[x][None, ...])
        global mask_glob
        mask_glob = (detections[0].segmentation.numpy() != 0).astype(np.uint8) * 255
    cv2.createTrackbar('Frame',window_name,0,len(vidCopy)-1,update_frame)

    #Probability Trackbar
    def update_probability_threshold(x):
        detector.prob_threshold = (x+1)/100 
        update_frame(frameID)
    cv2.createTrackbar('Probability Threshold', window_name, 0, 99, update_probability_threshold)

    #Overlap Trackbar
    def update_overlap_threshold(x):
        detector.nms_threshold = (x+1)/100
        update_frame(frameID)
    cv2.createTrackbar('Overlap Threshold', window_name, 0, 99, update_overlap_threshold)

except Exception as e:
    print(e)


while True:
    try:

        frameID = cv2.getTrackbarPos('Frame',window_name)
        # cv2.imshow(window_name, video[frameID])
        cv2.imshow(window_name, mask_glob)

        probabilityThreshold = cv2.getTrackbarPos('Probability Threshold', window_name)/100
        nmsThreshold = cv2.getTrackbarPos('Overlap Threshold', window_name)/100

        #exit on q
        if cv2.waitKey(5) == ord('q'):
            # press q to terminate the loop
            cv2.destroyAllWindows()
            break

    except Exception as e:
        print(e)
        cv2.destroyAllWindows()
        break
    
cv2.destroyAllWindows()
print('Prob: ', probabilityThreshold)
print('nms: ', nmsThreshold)

In [None]:
#Run Detections
detections_sequence = detector.run(video)

In [None]:
from byotrack import Track
from byotrack.implementation.linker.icy_emht import IcyEMHTLinker

In [None]:
linker = IcyEMHTLinker(icy_path)
linker.motion = linker.Motion.BROWNIAN  # Already by default

In [None]:
tracks = linker.run(video, detections_sequence)

In [None]:
# Visualize tracks existence in time

# Transform into tensor
tracks_tensor = Track.tensorize(tracks)
print(tracks_tensor.shape)  # N_frame x N_track x D

mask = ~ torch.isnan(tracks_tensor).any(dim=2)

plt.figure(figsize=(24, 16), dpi=100)
plt.xlabel("Track id")
plt.ylabel("Frame")
plt.imshow(mask)
plt.show()

In [None]:
fps = 20
running = False
display_detections = False

frame_id = 0

while True:
    frame_id += running
    frame = (video[frame_id] * 255).astype(np.uint8)
    if display_detections and frame_id < len(detections_sequence):
        mask = (detections_sequence[frame_id].segmentation.numpy() != 0).astype(np.uint8) * 255
        frame = np.concatenate((mask[..., None], frame, np.zeros_like(frame)), axis=2)
    else:
        frame = np.concatenate((np.zeros_like(frame), frame, np.zeros_like(frame)), axis=2)

    # Add tracklets
    for track in tracks:
        point = track[frame_id]
        if torch.isnan(point).any():
            continue

        x, y = point.round().to(torch.int).tolist()

        color = (0, 0, 255)  # Red

        cv2.circle(frame, (x, y), 5, color)
        cv2.putText(frame, str(track.identifier % 10), (x + 4, y - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color)

    # Display the resulting frame
    cv2.imshow('Frame', frame)
    cv2.setWindowTitle('Frame', f'Frame {frame_id} / {len(video)}')

    # Press Q on keyboard to  exit
    key = cv2.waitKey(1000 // fps) & 0xFF

    if key == ord('q'):
        break

    if cv2.getWindowProperty("Frame", cv2.WND_PROP_VISIBLE) <1:
        break

    if key == ord(" "):
        running = not running

    if not running and key == ord("w"):  # Prev
        frame_id = (frame_id - 1) % len(video)

    if not running and key == ord("x"):  # Next
        frame_id = (frame_id + 1) % len(video)
        
    if key == ord("c"):
        display_detections = 1 - display_detections

# Closes all the frames
cv2.destroyAllWindows()

In [None]:
from byotrack.implementation.refiner.cleaner import Cleaner
from byotrack.implementation.refiner.stitching import EMC2Stitcher

In [None]:
cleaner = Cleaner(min_length=5, max_dist=3.5)
tracks_clean = cleaner.run(video, tracks)

In [None]:
tracks_tensor = Track.tensorize(tracks_clean)
print(tracks_tensor.shape)  # N_frame x N_track x D

mask = ~ torch.isnan(tracks_tensor).any(dim=2)

plt.figure(figsize=(24, 16), dpi=100)
plt.xlabel("Track id")
plt.ylabel("Frame")
plt.imshow(mask)
plt.show()

In [None]:
stitcher = EMC2Stitcher(eta=10.0)  # Don't link tracks if they are too far (EMC dist > 5 (pixels))
tracks_stitched = stitcher.run(video, tracks_clean)

In [None]:
#  Visualize tracks existence in time

# Transform into tensor
tracks_tensor = Track.tensorize(tracks_stitched)
print(tracks_tensor.shape)  # N_frame x N_track x D

mask = ~ torch.isnan(tracks_tensor).any(dim=2)

plt.figure(figsize=(24, 16), dpi=100)
plt.xlabel("Track id")
plt.ylabel("Frame")
plt.imshow(mask)
plt.show()

In [None]:
fps = 20
running = False
display_detections = False

frame_id = 0

while True:
    frame_id += running
    frame = (video[frame_id] * 255).astype(np.uint8)
    if display_detections and frame_id < len(detections_sequence):
        mask = (detections_sequence[frame_id].segmentation.numpy() != 0).astype(np.uint8) * 255
        frame = np.concatenate((mask[..., None], frame, np.zeros_like(frame)), axis=2)
    else:
        frame = np.concatenate((np.zeros_like(frame), frame, np.zeros_like(frame)), axis=2)

    # Add tracklets
    for track in tracks_stitched:
        point = track[frame_id]
        if torch.isnan(point).any():
            continue

        x, y = point.round().to(torch.int).tolist()

        color = (0, 0, 255)  # Red

        cv2.circle(frame, (x, y), 5, color)
        cv2.putText(frame, str(track.identifier % 10), (x + 4, y - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color)

    # Display the resulting frame
    cv2.imshow('Frame', frame)
    cv2.setWindowTitle('Frame', f'Frame {frame_id} / {len(video)}')

    # Press Q on keyboard to  exit
    key = cv2.waitKey(1000 // fps) & 0xFF

    if key == ord('q'):
        break

    if cv2.getWindowProperty("Frame", cv2.WND_PROP_VISIBLE) <1:
        break

    if key == ord(" "):
        running = not running

    if not running and key == ord("w"):  # Prev
        frame_id = (frame_id - 1) % len(video)

    if not running and key == ord("x"):  # Next
        frame_id = (frame_id + 1) % len(video)
        
    if key == ord("c"):
        display_detections = 1 - display_detections

# Closes all the frames
cv2.destroyAllWindows()

In [None]:
#convert tracks to dataframe and save as csv

import pandas as pd
tensorpoints = Track.tensorize(tracks_stitched)
df = pd.DataFrame(data = tensorpoints.tolist()) #trackIndex x videoFrame

In [None]:
df.head()

In [None]:
#save dataframe for futute use
savepath = 'testOutputs/dataframe.csv'
df.to_csv(savepath) #or to parquet - issues with loading data lists in pandas from a csv

In [None]:
# save tensor as numpy array
savepath = 'testOutputs/tracks.npy'
array_points = tensorpoints.numpy()
np.save(savepath, array_points)