In [1]:
import os
import sys
import glob
import time
import numpy
import pandas
import tadpose

# mostly ploting
import ipywidgets
import seaborn as sns
from tqdm.auto import tqdm
from matplotlib import cm, colors
from matplotlib import pyplot as plt

# umap and wavelets
import umap  # ImportError -> pip install umap-learn
import pywt  # ImportError -> pip install PyWavelets

from skimage import filters
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.ndimage import gaussian_filter1d

In [2]:
from tadpose import plot

In [3]:
%load_ext autoreload
%autoreload 2

### 1. Select some videos 

In [4]:
all_vids = glob.glob(
    "B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos/*st59.mp4"
)
all_vids

['B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\\Tad1_Take1_oursNOGFP_st59.mp4',
 'B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\\Tad1_Take2_oursNOGFP_st59.mp4',
 'B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\\Tad1_Take3_oursNOGFP_st59.mp4',
 'B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\\Tad1_Take4_oursNOGFP_st59.mp4',
 'B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\\Tad1_Take5_oursNOGFP_st59.mp4',
 'B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\\Tad1_Take6_oursNOGFP_st59.mp4',
 'B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\\Tad1_Take7_oursNOGFP_st59.mp4']

### 2. Helper functions

In [5]:
def get_tadpole(video_fn):
    """
    Helper function to create a tadpole object from video file
    with aligner
    """
    tadpole = tadpose.Tadpole.from_sleap(video_fn)

    # create aligner by giving to part names and their correpsonding alignment location
    aligner = tadpose.alignment.TadpoleAligner(
        {"tail_stem": numpy.array([0, 0.0]), "heart": numpy.array([0, 1.0])},
        scale=False,
    )
    tadpole.aligner = aligner
    return tadpole

In [None]:
tad = get_tadpole(all_vids[2])
tad.bodyparts

In [None]:
heart_speed = tadpose.analysis.speeds(tad, parts=("heart",))

sm = tadpose.utils.smooth(heart_speed.to_numpy(), win=61)
%matplotlib qt5
heart_speed.plot()
plt.plot(sm)

In [6]:
fps = 60


def get_spectrograms(tadpole, scales, wavlet="morl", skel=None):
    """
    Compute a wavelet based spectrogram from tadpole body parts
    """

    if skel is None:
        skel = tadpole.bodyparts.copy()
        skel.remove("tail_stem")
        skel.remove("heart")

    # get aligned locatons for body parts in skel. Note 'skel' needs to be tuple (not list)
    locs = tadpole.ego_locs(parts=tuple(skel))

    heart_speed = tadpose.analysis.speeds(tadpole, parts=("heart",))

    # n == number of frames
    n = len(locs)
    x = numpy.arange(n)

    # create spectrogram for each coordinate (x or y) of each body part
    spectrogram = []

    locs_smooth = tadpose.utils.smooth_gaussian(locs.reshape(n, -1), sigma=3, deriv=0)

    for sig in locs_smooth.reshape(n, -1).T:
        # compute complex spectrogram 'coef'. 'freqs' are the frequencies that correspond to the scales (when using fps)
        coef, freqs = pywt.cwt(sig, scales, wavlet, sampling_period=1 / fps)
        spectrogram.append(numpy.abs(coef).T)

    spectrogram = numpy.concatenate(spectrogram, axis=1)
    return spectrogram

### 3. Choose scales for wavelet transfrom

In [10]:
N = 25
wavelet = "morl"

fps = 60
# create N=25 dyadically spaced scales, 25 is what they used in motionmapper
Fc = pywt.central_frequency(wavelet)
fps = 60
sp = 1 / fps
# scales = Fc / (numpy.arange(1, 30) * sp)
if wavelet == "morl":
    scales = numpy.power(2, numpy.linspace(1, 8, N))  # <- dyadic
elif wavelet == "mexh":
    scales = numpy.power(2, numpy.linspace(-0.4, 4, N))  # <- dyadic

frequencies = pywt.scale2frequency(wavelet, scales) / sp

# plot which scale correspond to which freq.
%matplotlib widget
f, ax = plt.subplots()
ax.plot(scales, frequencies, "b.")
ax.set_xlabel("Input scales for wavelet transform")
ax.set_ylabel(f"Corresponding frequency at movie fps of {fps}")
print(f"Scales range from {frequencies.min()} to {frequencies.max()} Hz")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Scales range from 0.1904296875 to 24.375 Hz


### 4. Compute spectrograms for all tadpoles

In [11]:
tail_parts = [
    "tail_1",
    "tail_2",
    "tail_3",
    "tail_4",
    "tail_tip",
]


Tadpoles = []
Spectrograms = []
for video_fn in all_vids:
    print(video_fn)
    tadpole = get_tadpole(video_fn)
    Tadpoles.append(tadpole)
    spec = get_spectrograms(
        tadpole, scales, skel=tail_parts
    )  # use 'skel=["left_leg", ...]'' for reduced body part list
    Spectrograms.append(spec)

# Merge all spectrograms into single matrix
all_spectrograms = numpy.concatenate(Spectrograms)
print("all_spectrograms.shape", all_spectrograms.shape)

B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\Tad1_Take1_oursNOGFP_st59.mp4
B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\Tad1_Take2_oursNOGFP_st59.mp4
B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\Tad1_Take3_oursNOGFP_st59.mp4
B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\Tad1_Take4_oursNOGFP_st59.mp4
B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\Tad1_Take5_oursNOGFP_st59.mp4
B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\Tad1_Take6_oursNOGFP_st59.mp4
B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos\Tad1_Take7_oursNOGFP_st59.mp4
all_spectrograms.shape (75217, 250)


In [12]:
# show a single spectrogram
%matplotlib qt5
f, ax = plt.subplots()
ax.imshow(Spectrograms[0], aspect="auto")
ax.set_ylabel("Time (frames)")
ax.set_xlabel("Spectrogram values of skels")



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 0, 'Spectrogram values of skels')

### 5. Reduce dimensions of spectrograms with PCA

In [13]:
# The data to reduce
X = all_spectrograms.copy()

# get means and std
Xmeans = X.mean(0)
Xstds = X.std(0)

# z-score data for pca
Xzs = (X - Xmeans) / Xstds

# compute global PCA such that 95% of variance is explained
pca = PCA(n_components=0.95)
Xpca = pca.fit_transform(Xzs)

print(f"PCA reduced dimension from {X.shape} to {Xpca.shape}")

# transfrom each single spectrogram with the global PCA
Xpcas = [pca.transform((spec - Xmeans) / Xstds) for spec in Spectrograms]

PCA reduced dimension from (75217, 250) to (75217, 75)


In [24]:
%matplotlib qt5
plt.plot(Xpca[:, 1], Xpca[:, 0], ".", alpha=0.1)



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x2280f369850>]

#### 6 UMAP

In [25]:
# global umap
mapper = umap.UMAP(n_components=2)
Xumap = mapper.fit_transform(Xpca)

# single umap transforms
Xumaps = [mapper.transform(xpca) for xpca in Xpcas]

In [26]:
# Create map
hist, bxe, bye = numpy.histogram2d(
    Xumap[:, 0], Xumap[:, 1], bins=(100, 100), density=True
)

# Smooth the map by sigma (i. e. quick and dirty kde=kernel density estimation)
hist_kde = filters.gaussian(hist, sigma=1.0, preserve_range=True)

# show
%matplotlib widget
plt.imshow(hist_kde ** 1 / 30, cmap="magma")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x2285143f370>

### 7. Interactive skelton / map viewer

In [79]:
%matplotlib widget


def show_skleton_viewer(tadpole, Map, Xumap, video_shape=(800, 500), track_idx=0):
    """
    Interactive viewer to visualize MotionMapper result

    """
    x_view = [-video_shape[1] // 2, video_shape[1] // 2]
    y_view = [-video_shape[0] // 2, video_shape[0] // 2]

    slider = ipywidgets.IntSlider(
        description="Time (frame)",
        value=0,
        min=0,
        max=Xumap.shape[0] - 1,
        continuous_update=True,
        style={"min_width": 300, "max_width": 300},
    )
    plt.ioff()

    fig, axs = plt.subplots(1, 2, figsize=(10, 4))

    # needed for scaling the map to the right corresponing values
    y0, y1 = Xumap[:, 1].min(), Xumap[:, 1].max()
    x0, x1 = Xumap[:, 0].min(), Xumap[:, 0].max()

    # Show Map (on the right)
    axs[1].imshow(Map, extent=(x0, x1, y0, y1))
    axs[1].set_title("Click on map to jump to closest frame")

    # get aligned locations for later
    aligned_locations = tadpole.ego_locs(track_idx=track_idx)

    # imshow tadpole on the left
    gray = tadpole.ego_image(
        frame=0, dest_height=video_shape[0], dest_width=video_shape[1], rgb=True
    )
    im = axs[0].imshow(
        gray,
        "gray",
        extent=(
            -gray.shape[1] // 2,
            gray.shape[1] // 2,
            -gray.shape[0] // 2,
            gray.shape[0] // 2,
        ),
    )

    # define callback for clicking into the Map
    def map_click(event):
        if event.inaxes in [axs[1]]:
            closest_frame = numpy.argmin(
                numpy.square(Xumap - numpy.array([event.xdata, event.ydata])).sum(1)
            )
            slider.value = closest_frame

    cid = fig.canvas.mpl_connect("button_press_event", map_click)

    # show inital point in map (from frame==0)
    umap_points = axs[1].plot(Xumap[0, 0], Xumap[0, 1], ".", color="red")[0]

    # plot body part locations (on the left)
    x_points, y_points = aligned_locations[0].T
    points = axs[0].plot(-x_points, y_points, "b.", alpha=0.1)[0]

    # define callback
    def update_widgets(change):
        # the new frame (from slider)
        frame = change.new

        # get new tadpole image
        gray = tadpole.ego_image(
            frame, dest_height=video_shape[0], dest_width=video_shape[1], rgb=True
        )

        # show tadpole
        im.set_data(gray)

        # update body part overlay
        x_points, y_points = aligned_locations[frame].T
        points.set_xdata(x_points)
        points.set_ydata(y_points)

        # update Map point on the right
        umap_points.set_xdata(Xumap[frame, 0])
        umap_points.set_ydata(Xumap[frame, 1])

        # make sure everything is drawn
        fig.canvas.draw()
        fig.canvas.flush_events()

    # connect callback
    slider.observe(update_widgets, names="value")
    slider.value = 1

    return ipywidgets.VBox([fig.canvas, slider])


# get
show_skleton_viewer(Tadpoles[0], hist_kde, Xumaps[0])

VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Ba…

In [74]:
%matplotlib qt5
plt.imshow(
    hist_kde,
    extent=(Xumap[:, 1].min(), Xumap[:, 1].max(), Xumap[:, 0].min(), Xumap[:, 0].max()),
    aspect="auto",
    origin="lower",
)
plt.plot(*Xumap[:, ::-1].T, "r.", alpha=0.01)



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x1fa03b23ca0>]

In [57]:
Xumap[:, 0].min()

(75217, 2)