In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import skimage.io as iio
from PIL import Image
import torch

import byotrack
from byotrack.implementation.detector.stardist import StarDistDetector
from byotrack.implementation.linker.icy_emht import IcyEMHTLinker
from byotrack.implementation.refiner.cleaner import Cleaner
from byotrack.implementation.refiner.stitching.emc2 import EMC2Stitcher
from byotrack.implementation.refiner.interpolater import ForwardBackwardInterpolater
from byotrack import Track

icy_path = '/home/noah/Documents/icy-2.4.2.0-all/icy.jar'
tifpath = './ExampleData/shortStack_adjusted/' #path to sequence of tiff files


In [None]:
def Read_Data_TIFseq(vid_path):
    # positions = (pd.read_csv(csv_path,usecols=['TrackID','t','x','y'])).values
    vid = iio.ImageCollection(vid_path + '/*.tif').concatenate() #concatonate to numpyarray
    # red_vid = iio.ImageCollection(red_vid_path + '/*.tif')
    vid = vid.reshape(vid.shape[0], vid.shape[1], vid.shape[2], 1)
    #vid = np.asarray([csbdeepNormaliser(frame) for frame in vid])
    return vid

video = Read_Data_TIFseq(tifpath)


# normalize video
mini = np.quantile(video, 0.005)
maxi = np.quantile(video, 0.999)

np.clip(video, mini, maxi, video)
video = (video - mini) / (maxi - mini)

In [None]:
#load in a detection sequence
detections_sequence = np.load('/home/noah/Documents/NoahT2022/CodeRepos/Utopia/ExampleData/shortStack_adjusted/detections.npy', allow_pickle=True)

#somehow this is plotting yx instead of xy!!!!!!!
#this is probably the same issue I was having before with open cv, one of my reshaper functions is probably wrong :)

# #quick fix : doesn't work in place for some reason
# for detection_frame in detections_sequence:
#     detection_frame.data = detection_frame.position.flip(1)

# detections_sequence = [detection_frame.position.flip(1) for detection_frame in detections_sequence]


In [None]:
# print(detections_sequence[0].position)
flip = detections_sequence[0].position.flip(1)

print(flip)
# flip = np.flip(detections_sequence[0].position, 1)

In [None]:
arr = np.ones((2))
arr.flip()

In [None]:
# Run linking

linker = IcyEMHTLinker(icy_path)
linker.motion = linker.Motion.BROWNIAN
tracklets = linker.run(video, detections_sequence) #why does the linker need the video?

In [None]:
def display_lifetime(tracks):
    # Transform into tensor
    tracks_tensor = byotrack.Track.tensorize(tracks)
    print(tracks_tensor.shape)  # N_frame x N_track x D

    mask =  ~ torch.isnan(tracks_tensor).any(dim=2)

    plt.figure(figsize=(24, 16), dpi=100)
    plt.xlabel("Track id")
    plt.ylabel("Frame")
    plt.imshow(mask)
    plt.show()
    
display_lifetime(tracklets)

In [None]:
cleaner = Cleaner(min_length=5, max_dist=3.5)
tracks_clean = cleaner.run(video, tracklets)

stitcher = EMC2Stitcher(eta=5.0)  # Don't link tracks if they are too far (EMC dist > 5 (pixels))
tracks_stitched = stitcher.run(video, tracks_clean)

display_lifetime(tracks_stitched)

In [None]:
# Filter tracks and interpolate

# keep only big enough tracks (Cover at least 80% of video from start to end)
valid_tracks = [len(t) > 0.80 * len(video) for t in tracks_stitched]

interpolater = ForwardBackwardInterpolater(method="tps", full = True, alpha=10.0)
final_tracks = interpolater.run(video, tracks_stitched)  # Interpolate using all tracks, and filter afterwards
final_tracks = [track for i, track in enumerate(final_tracks) if valid_tracks[i]]

In [None]:
#Set model parameters for your dataset

vidCopy = video
scale = 1

global frameID
global frame
global frame_cv
global contours

frameID = 0
frame = vidCopy[frameID].copy()
frame_cv = cv2.normalize(src=frame, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
frame_cv = cv2.cvtColor(frame_cv, cv2.COLOR_GRAY2BGR)
h, w = frame.shape[0:2]

display_detections = True

# window_name = 'Frame', f'Frame {frameID} / {len(detections_sequence_test)} - Number of detections: {len(detections_sequence_test[i])}'
window_name = 'Display Tracks   (Press Q to Quit)'

try:

    #create and rescale window
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(window_name, h*scale, w*scale)

    #Frame Trackbar
    def update_frame(x): #callback function for trackbar - default argument is the position of the track bar
        pass
    cv2.createTrackbar('Frame',window_name,0,len(vidCopy)-1,update_frame)


except Exception as e:
    print(e)


while True:
    try:

        frameID = cv2.getTrackbarPos('Frame',window_name)
        frame = (vidCopy[frameID] * 255).astype(np.uint8)

        if display_detections and frameID < len(final_tracks):
            mask = (detections_sequence[frameID].segmentation.numpy() != 0).astype(np.uint8) * 255
            frame = np.concatenate((mask[..., None], frame, np.zeros_like(frame)), axis=2)
        else:
            frame = np.concatenate((np.zeros_like(frame), frame, np.zeros_like(frame)), axis=2)
        

        # Add tracklets
        for track in final_tracks:
            point = track[frameID]
            if torch.isnan(point).any():
                continue

            i, j = point.round().to(torch.int).tolist()

            color = (0, 0, 255)  # Red

            cv2.circle(frame, (j, i), 5, color)
            cv2.putText(frame, str(track.identifier % 100), (j + 4, i - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color)

        cv2.imshow(window_name, frame)



        #exit on q
        if cv2.waitKey(5) == ord('q'):
            # press q to terminate the loop
            cv2.destroyAllWindows()
            break

    except Exception as e:
        print(e)
        cv2.destroyAllWindows()
        break
    
cv2.destroyAllWindows()



In [None]:
#edit tracks (cv mouse click functionality)
#visualise track lifetimes interactively

In [None]:
#save tracks
save_path_numpy = '/home/noah/Documents/NoahT2022/CodeRepos/Utopia/ExampleData/shortStack_adjusted/tracks'
tensorpoints = Track.tensorize(final_tracks)
detection_array = np.asarray(tensorpoints)
np.save(save_path_numpy, detection_array, allow_pickle=True)