# UTOPIA example notebook

In [None]:
import enum
import math

import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import tqdm

import byotrack
import byotrack.icy.io
from byotrack.implementation.refiner.interpolater import ForwardBackwardInterpolater
import byotrack.visualize

import schya.detrending
import schya.extraction
import schya.filtering
import schya.spike
import schya.visualize

TEST = False  # Will reduce the size of the video to play with

## Loading videos

In [None]:
gcamp_path = "path/to/gcamp_video"
tdtomato_path = "path/to_tdtomato/video"

gcamp_video = byotrack.Video(gcamp_path)
tdtomato_video = byotrack.Video(tdtomato_path)

if TEST:
    gcamp_video = gcamp_video[:100]
    tdtomato_video = tdtomato_video[:100]

# Normalize videos into [0, 1]
gcamp_video.set_transform(byotrack.VideoTransformConfig(aggregate=True, normalize=True, q_min=0.02, q_max=0.995, smooth_clip=1.0))
tdtomato_video.set_transform(byotrack.VideoTransformConfig(aggregate=True, normalize=True, q_min=0.02, q_max=0.995, smooth_clip=1.0))

In [None]:
# Display the video
# Use w/x to move forward in time (or space to run/pause the video)
# Use v to display none, green, red or both channels

schya.visualize.TwoColorInteractiveVisualizer((tdtomato_video, gcamp_video)).run()

## Track TDTomato

### Loading saved tracks

Prevents to rerun the tracking each time

In [None]:
# Reload tracks saved in the byotrack format

tracks = byotrack.Track.load("tracks.pt")

In [None]:
# Or reload tracks from Icy xml format

tracks = byotrack.icy.io.load_tracks("tracks_tdTomato.xml")

### Tracking pipeline

In [None]:
from byotrack.implementation.detector.stardist import StarDistDetector
from byotrack.implementation.linker.icy_emht import IcyEMHTLinker
from byotrack.implementation.refiner.cleaner import Cleaner
from byotrack.implementation.refiner.stitching.emc2 import EMC2Stitcher

#### Detections

In [None]:
model_path = "path/to/model_folder"

detector = StarDistDetector(model_path, batch_size=1)

In [None]:
detections_sequence = detector.run(tdtomato_video)

In [None]:
# Display the detections with opencv
# Use w/x to move forward in time (or space to run/pause the video)
# Use v to switch on/off the display of the video
# Use d to switch detection display mode (None, mask, segmentation)

byotrack.visualize.InteractiveVisualizer(tdtomato_video, detections_sequence).run()

#### Build tracklets

In [None]:
# Run linking

icy_path = "path/to/icy.jar"

linker = IcyEMHTLinker(icy_path)
linker.motion = linker.Motion.BROWNIAN
tracklets = linker.run(tdtomato_video, detections_sequence)

In [None]:
# Visualize life span

byotrack.visualize.display_lifetime(tracklets)

#### Tracklet stitching

In [None]:
# Cleaning + EMC2

cleaner = Cleaner(min_length=5, max_dist=3.5)
tracks = cleaner.run(tdtomato_video, tracklets)

stitcher = EMC2Stitcher(eta=5.0)  # Don't link tracks if they are too far (EMC dist > 5 (pixels))
tracks = stitcher.run(tdtomato_video, tracks)

In [None]:
# Visualize new life span

byotrack.visualize.display_lifetime(tracks)

In [None]:
# Save tracks

byotrack.Track.save(tracks, "tracks.pt")  # Can be reload with byotrack.Track.load("tracks.pt")

### Tracks visualization

You can export track to icy format and visualize them with icy, or use our own tool of visualization (or build new ones in python)

In [None]:
# Export track to icy
# Needs to fill hole in tracks before saving with Forward backward interpolator

interpolater = ForwardBackwardInterpolater(method="constant", full=False)

byotrack.icy.io.save_tracks(interpolater.run(tdtomato_video, tracks), "test.xml")

In [None]:
byotrack.visualize.display_lifetime(tracks)

In [None]:
# Display the tracks with opencv
# Use w/x to move forward in time (or space to run/pause the video)
# Use v (resp. t) to switch on/off the display of video (resp. tracks)

# You can also give the detections_sequence object to the visualizer
# Then, use d to switch detection display mode (None, mask, segmentation)

byotrack.visualize.InteractiveVisualizer(tdtomato_video, tracks=tracks).run()

## Calcium signal extraction

### Select long tracks and complete them

In [None]:
from byotrack.implementation.refiner.interpolater import ForwardBackwardInterpolater

# keep only big enough tracks (Cover at least 80% of video from start to end)

valid_tracks = [len(t) > 0.80 * len(tdtomato_video) for t in tracks]

interpolater = ForwardBackwardInterpolater(method="tps", full = True, alpha=10.0)
final_tracks = interpolater.run(tdtomato_video, tracks)  # Interpolate using all tracks, and filter afterwards
final_tracks = [track for i, track in enumerate(final_tracks) if valid_tracks[i]]

print(f"Kept {len(final_tracks)} valid tracks from {len(tracks)} tracks")

### TdTomato control intensities extraction

In [None]:
# Extract control intensities from tdtomato sequence

ctrl_intensities = schya.extraction.extract_intensities_from_roi(tdtomato_video, final_tracks, 9)

### Extract GCsub roi tracking

In [None]:
raw_intensities, calcium_positions = schya.extraction.SubRoiExtractor(
    gcamp_video, 
    byotrack.Track.tensorize(final_tracks).numpy(),
    25,  # Roi size (25 x 25)
    4,  # Max relative motion between two consecutive frames
).compute()

In [None]:
# Build the calcium tracks for visualization

calcium_tracks = []

for k in range(0, len(final_tracks)):
    calcium_tracks.append(byotrack.Track(0, torch.tensor(calcium_positions[:, k]).to(torch.float32), final_tracks[k].identifier))

In [None]:
# Visualize the calcium tracks
# Use w/x to move forward in time (or space to run/pause the video)
# Use t to display none, green, red or both tracks
# Use v to display none, green, red or both channels

vis = schya.visualize.TwoColorInteractiveVisualizer((tdtomato_video, gcamp_video), tracks=final_tracks, calcium_tracks=calcium_tracks)

vis.scale = 1  # Increase/decrease the size of the display
vis._display_video = 3  # GCaMP

vis.run()

### Filtering noise signals

In [None]:
# We first detrend using only frequency filtering (See detrending) as ICA creates some artefacts

detrended = schya.detrending.high_pass_filter(raw_intensities, 1 / 200)

# Then test the Gaussian hypothesis. If rejected with less than thresh p_value, the signal is not noise
# Lower values of thresh => More noise

thresh = 1e-5

is_noise = schya.filtering.is_noise(detrended, thresh)
print(f"Found {is_noise.sum()} noise signals")

In [None]:
# Switch to non inline and interactive matplotlib (You may have to run several times)
%matplotlib

In [None]:
# Interactive display of filtering
# Enable the correction of the filtering steps

title  = """Batch no {batch_id}/{MAX_BATCH}

Please use w/x to increase/decrease the batch id of signals displayed
Click on a signal to correct the filtering (Red signals are dropped, blue ones are kept)
"""


WIDTH = 4
HEIGHT = 4
MAX_BATCH = math.ceil(len(raw_intensities) / (WIDTH * HEIGHT))

batch_id = 0

fig, axs = plt.subplots(HEIGHT, WIDTH, sharex = 'col', sharey='row', figsize=(20, 20))
colors = ("b", "r")


def plot():
    fig.suptitle(title.format(batch_id=batch_id, MAX_BATCH=MAX_BATCH))
    for i in range(HEIGHT):
        for j in range(WIDTH):
            k = batch_id * WIDTH * HEIGHT + i * WIDTH + j
            k = k % len(is_noise)
            axs[i, j].clear()
            axs[i, j].set_title("Rejected" if is_noise[k] else "Kept")
            axs[i, j].plot(raw_intensities[k] + 0.4, label="Raw intensity")
            axs[i, j].plot(detrended[k], color=colors[int(is_noise[k])], label="Detrended intensity")
            axs[i, j].legend()


def on_click(event):
    """Switch the noise status of signals on click"""
    for i in range(HEIGHT):
        for j in range(WIDTH):
            if axs[i, j] == event.inaxes:
                k = batch_id * WIDTH * HEIGHT + i * WIDTH + j
                k = k % len(is_noise)
                print(f"Manual switch of track {k}")
                is_noise[k] = not is_noise[k]

                # Replot the k
                axs[i, j].clear()
                axs[i, j].set_title("Rejected" if is_noise[k] else "Kept")
                axs[i, j].plot(raw_intensities[k] + 0.4, label="Raw intensity")
                axs[i, j].plot(detrended[k], color=colors[int(is_noise[k])], label="Detrended intensity")
                axs[i, j].legend()
                plt.draw()
                return


def on_press(event):
    """Change the batch id with w/x"""
    global batch_id

    if event.key in "wx":
        batch_id = (batch_id + (1 if event.key == "x" else -1)) % MAX_BATCH
        plot()
        fig.canvas.draw()

plot()

fig.canvas.mpl_connect('key_press_event', on_press)
fig.canvas.mpl_connect('button_press_event', on_click)

plt.show()

In [None]:
# Re-Switch to inline matplotlib
%matplotlib inline

### Detrending & Smoothing

In [None]:
corrected_intensities = schya.detrending.ica_decorr(raw_intensities[~is_noise], ctrl_intensities[~is_noise], 0.5, 10)

In [None]:
# Additional independent detrending to remove the remaining baseline
# Drop baseline (period larger than 100 frames)

detrended_intensities = schya.detrending.high_pass_filter(corrected_intensities, 1 / 100)

In [None]:
# Smoothing with rolling average (the window size controls the amount of smoothing)

window_size = 5

calcium_signals = schya.detrending.smooth(detrended_intensities, window_size)

In [None]:
# Example of a particular neuron + 16 others
# You can select another n_id or batch_id

n_id = 50
batch_id = 0

plt.plot(raw_intensities[~is_noise][n_id] * 10 + 3, label="Raw")
plt.plot(ctrl_intensities[~is_noise][n_id] * 10 + 3, label="Ctrl")
plt.plot(corrected_intensities[n_id], label="ICA corrected")
plt.plot(detrended_intensities[n_id] - 5, label="Detrended")
plt.plot(calcium_signals[n_id] - 5, label="Final smoothed & detrended signal")
plt.legend(loc="upper right", bbox_to_anchor=(1.65, 1))
plt.show()


fig, axs = plt.subplots(4, 4, sharex = 'col', sharey='row', figsize=(20, 20))

for i in range(4):
    for j in range(4):
        k = batch_id * 16 + i * 4 + j
        k = k % len(corrected_intensities)
        
        axs[i, j].plot(raw_intensities[~is_noise][k] * 10 + 3)
        axs[i, j].plot(ctrl_intensities[~is_noise][k] * 10 + 3)
        axs[i, j].plot(corrected_intensities[k])
        axs[i, j].plot(detrended_intensities[k] - 5)
        axs[i, j].plot(calcium_signals[k] - 5)
        

plt.show()

### Spike extraction

In [None]:
# Apply foopsi to extract spikes and calcium signal reconstruction

calcium_reconstruction, spikes = schya.spike.foopsi_all(calcium_signals)

In [None]:
# Clusterize spikes
# You can use different std (how far the kernel will look for neighboring spikes to aggregate with the current one)
# 5 yields pretty good results

true_spikes = schya.spike.clusterize_spikes(spikes, std=5.0)

In [None]:
# Example to see how well the clustering has worked

plt.plot(spikes[n_id], label="Spikes")
plt.plot(true_spikes[n_id] + 0.25, label="Clustered")
plt.legend()
plt.show()

In [None]:
# Visualize same neuron as before + 16 others
# You can select another t_id or batch_id

n_id = n_id
batch_id = 0

plt.plot(calcium_signals[n_id], label="Calcium signal")
plt.plot(calcium_reconstruction[n_id], label="Reconstruction")
plt.plot(spikes[n_id] * 7, label="All spikes")
plt.plot(true_spikes[n_id] * 7, label="Clustered spikes")
plt.legend()
plt.show()


fig, axs = plt.subplots(4, 4, sharex = 'col', sharey='row', figsize=(20, 20))

for i in range(4):
    for j in range(4):
        k = batch_id * 16 + i * 4 + j
        k = k % (len(calcium_signals))
        axs[i, j].plot(calcium_signals[k])
        axs[i, j].plot(calcium_reconstruction[k])
        axs[i, j].plot(spikes[k] * 7)
        axs[i, j].plot(true_spikes[k] * 7)

plt.show()

In [None]:
# Filter low spikes to keep only meaningful ones

# First plot all spikes
plt.title("All spikes")
plt.xlabel("Frames")
plt.ylabel("Neurons")
binarized_spikes = true_spikes > 0
pos = schya.spike.to_raster_pos(binarized_spikes)
plt.eventplot(pos)
plt.show()


# First plot all spikes
plt.title("Kept spikes")
plt.xlabel("Frames")
plt.ylabel("Neurons")
binarized_spikes = schya.spike.binarize_ratio(true_spikes, 0.1)

pos = schya.spike.to_raster_pos(binarized_spikes)
plt.eventplot(pos)
plt.show()

In [None]:
# Visualize same neuron as before + 16 others
# You can select another t_id or batch_id

n_id = n_id
batch_id = 0

plt.plot(raw_intensities[~is_noise][n_id] * 10 + 3, label="Raw intensities (*5)")
plt.plot(calcium_signals[n_id], label="Calcium signal")
plt.plot(binarized_spikes[n_id] - 1, label="Selected spikes")
plt.legend()
plt.show()


fig, axs = plt.subplots(4, 4, sharex = 'col', sharey='row', figsize=(20, 20))

for i in range(4):
    for j in range(4):
        k = batch_id * 16 + i * 4 + j
        k = k % len(calcium_signals)
        axs[i, j].plot(raw_intensities[~is_noise][k] * 10 + 3)
        axs[i, j].plot(calcium_signals[k])
        axs[i, j].plot(binarized_spikes[k] - 1)

plt.show()

In [None]:
# Visualize a particular neurons

n_id = 4


print("Spiking frame:", np.arange(1000)[binarized_spikes[n_id] > 0] + 2)  # Add a lag of 2 frames to be on the maximum of intensity


plt.plot(raw_intensities[~is_noise][n_id] * 10 + 3, label="Raw intensities (*5)")
plt.plot(calcium_signals[n_id], label="Calcium signal")
plt.plot(binarized_spikes[n_id] - 1, label="Selected spikes")

plt.plot(spikes[n_id] * 7, label="Selected spikes")
plt.legend()
plt.show()

# Interactive visu in the video of this neuron
# Use w/x to move forward in time (or space to run/pause the video)
# Use t to display none, green, red or both tracks
# Use v to display none, green, red or both channels

k = np.arange(len(calcium_tracks))[~is_noise][n_id]

vis = schya.visualize.TwoColorInteractiveVisualizer((tdtomato_video, gcamp_video), tracks=final_tracks[k:k+1], calcium_tracks=calcium_tracks[k:k+1])

vis.scale = 1  # Increase/decrease the size of the display
vis._display_video = 3  # GCaMP

vis.run()