In [None]:
import os
import sys
import glob
import time
import numpy
import pandas
import tadpose

# mostly ploting
import ipywidgets
import seaborn as sns
from tqdm.auto import tqdm
from matplotlib import cm, colors
from matplotlib import pyplot as plt

# umap and wavelets
import umap # ImportError -> pip install umap-learn
import pywt # ImportError -> pip install PyWavelets

from skimage import filters
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.ndimage import gaussian_filter1d

### 1. Select some videos 

In [None]:
all_vids = glob.glob("B:/fs3-bifexchgrp/BIF_StaffSci/Christoph/sweengrp/Mara SLEAP/WT videos/WT-Frog-Videos/*st59.mp4")
all_vids

### 2. Helper functions

In [None]:
def get_tadpole(video_fn):
    """
    Helper function to create a tadpole object from video file
    with aligner
    """
    tadpole = tadpose.Tadpole.from_sleap(video_fn)

    # create aligner by giving to part names and their correpsonding alignment location
    aligner = tadpose.alignment.TadpoleAligner({'tail_stem' : numpy.array([0, 0.]), 
                                                'heart' : numpy.array([0, 1.])}, scale=False)   
    tadpole.aligner = aligner
    return tadpole

In [None]:
fps = 60

def get_spectrograms(tadpole, scales, skel=None):
    """
    Compute a wavelet based spectrogram from tadpole body parts
    """

    if skel is None:
        skel = tadpole.bodyparts.copy()
        skel.remove("tail_stem")
        skel.remove("heart")

    # get aligned locatons for body parts in skel. Note 'skel' needs to be tuple (not list)
    locs = tadpole.ego_locs(parts=tuple(skel))

    # n == number of frames
    n = len(locs)
    x = numpy.arange(n)

    # create spectrogram for each coordinate (x or y) of each body part
    spectrogram = []
    for sig in locs.reshape(n, -1).T:
        # z-score of signal (dunno if we actually need that)
        sig_zscore = (sig - sig.mean()) / sig.std()
        
        # compute complex spectrogram 'coef'. 'freqs' are the frequencies that correspond to the scales (when using fps)
        coef, freqs = pywt.cwt(sig_zscore, scales, 'cmorl1.5-1.0', sampling_period=1/fps)
        spectrogram.append(numpy.abs(coef).T)

    spectrogram = numpy.concatenate(spectrogram, axis=1)
    return spectrogram

### 3. Choose scales for wavelet transfrom

In [None]:
%matplotlib widget

N = 25

# create N=25 dyadically spaced scales, 25 is what they used in motionmapper
scales = numpy.power(2, numpy.linspace(1, 6, N)) # <- dyadic
#scales = numpy.linspace(1,60, N) # <- naiv (probably not as good as dyadic, but not sure)

# map the chosen scales to frequency
frequencies = pywt.scale2frequency('cmor1.5-1.0', scales) / (1/fps)

# plot which scale correspond to which freq.
f, ax = plt.subplots()
ax.plot(scales, frequencies, "b.")
ax.set_xlabel("Input scales for wavelet transform")
ax.set_ylabel(f"Corresponding frequency at movie fps of {fps}")
print(f"Scales range from {frequencies.min()} to {frequencies.max()} Hz")


### 4. Compute spectrograms for all tadpoles

In [None]:
Tadpoles     = []
Spectrograms = []
for video_fn in all_vids:
    print(video_fn)
    tadpole = get_tadpole(video_fn)
    Tadpoles.append(tadpole)
    spec = get_spectrograms(tadpole, scales, skel=None) # use 'skel=["left_leg", ...]'' for reduced body part list
    Spectrograms.append(spec)
    
# Merge all spectrograms into single matrix
all_spectrograms = numpy.concatenate(Spectrograms)
print("all_spectrograms.shape", all_spectrograms.shape)

In [None]:
# show a single spectrogram
%matplotlib widget
f, ax = plt.subplots()
ax.imshow(Spectrograms[3]) 
ax.set_aspect(0.05) # aspect ratio of plot, otherwise to narrow
ax.set_ylabel("Time (frames)")
ax.set_xlabel("Spectrogram values of skels")

### 5. Reduce dimensions of spectrograms with PCA

In [None]:
# The data to reduce
X = all_spectrograms.copy()

# get means and std
Xmeans = X.mean(0)
Xstds  = X.std(0)

# z-score data for pca
Xzs  = (X - Xmeans ) / Xstds

# compute global PCA such that 95% of variance is explained 
pca  = PCA(n_components=0.95)
Xpca = pca.fit_transform(Xzs)

print(f"PCA reduced dimension from {X.shape} to {Xpca.shape}")

# transfrom each single spectrogram with the global PCA
Xpcas = [pca.transform((spec-Xmeans) / Xstds) for spec in Spectrograms]

#### 6. Apply embedding (TSNE or UMAP) and create map

In [None]:
# TSNE (takes very loong, using umap)
Xtsne = TSNE(n_components=2, n_jobs=6).fit_transform(Xpca)

# Create map
hist, bxe, bye =  numpy.histogram2d(Xtsne[:,0], Xtsne[:,1], bins=(100,100), density=True)

# Smooth the map by sigma
hist_kde       = filters.gaussian(hist, sigma=0.5, preserve_range=True)

# show
%matplotlib widget
plt.imshow(hist_kde, cmap="magma")

#### UMAP

In [None]:
# global umap
mapper = umap.UMAP(n_components=2)
Xumap = mapper.fit_transform(Xpca)

# single umap transforms
Xumaps = [mapper.transform(xpca) for xpca in Xpcas]

In [None]:
# Create map
hist, bxe, bye =  numpy.histogram2d(Xumap[:,0], Xumap[:,1], bins=(100,100), density=True)

# Smooth the map by sigma (i. e. quick and dirty kde=kernel density estimation)
hist_kde = filters.gaussian(hist, sigma=1., preserve_range=True)

# show
%matplotlib widget
plt.imshow(hist_kde, cmap="magma")

### 7. Interactive skelton / map viewer

In [None]:
%matplotlib widget

def show_skleton_viewer(tadpole, Map, Xumap, video_shape=(800, 500), track_idx=0):
    """
    Interactive viewer to visualize MotionMapper result
    
    """
    x_view = [-video_shape[1]//2, video_shape[1]//2]
    y_view = [-video_shape[0]//2, video_shape[0]//2]

    slider = ipywidgets.IntSlider(
        description="Time (frame)",
        value=0,
        min=0,
        max=Xumap.shape[0]-1, continuous_update=True,
        style={'min_width': 300, 'max_width': 300}
    )
    plt.ioff()
    
    fig, axs = plt.subplots(1,2, figsize=(10,4))
    
    # needed for scaling the map to the right corresponing values
    y0, y1 = Xumap[:,1].min(), Xumap[:,1].max()
    x0, x1 = Xumap[:,0].min(), Xumap[:,0].max()
    
    # Show Map (on the right)
    axs[1].imshow(Map, extent=(x0,x1, y0,y1))
    axs[1].set_title("Click on map to jump to closest frame")
    
    # get aligned locations for later
    aligned_locations = tadpole.ego_locs(track_idx=track_idx)
    
    # imshow tadpole on the left 
    gray = tadpole.ego_image(frame=0, dest_height=video_shape[0], dest_width=video_shape[1], rgb=True)
    im = axs[0].imshow(gray, "gray", extent=(-gray.shape[1]//2, gray.shape[1]//2, -gray.shape[0]//2, gray.shape[0]//2), )
    
    # define callback for clicking into the Map
    def map_click(event):
        if event.inaxes in [axs[1]]:
            closest_frame = numpy.argmin(numpy.square(Xumap - numpy.array([event.xdata, event.ydata])).sum(1))
            slider.value = closest_frame
    cid = fig.canvas.mpl_connect('button_press_event', map_click)
    
    # show inital point in map (from frame==0)
    umap_points = axs[1].plot(Xumap[0, 0], Xumap[0, 1], ".", color="red")[0]
    
    # plot body part locations (on the left)
    x_points, y_points = aligned_locations[0].T
    points = axs[0].plot(-x_points, y_points, "b.")[0]

    # define callback 
    def update_widgets(change):
        # the new frame (from slider)
        frame = change.new
        
        # get new tadpole image
        gray = tadpole.ego_image(frame, dest_height=video_shape[0], dest_width=video_shape[1], rgb=True)
        
        # show tadpole
        im.set_data(gray)

        # update body part overlay
        x_points, y_points = aligned_locations[frame].T
        points.set_xdata(-x_points)
        points.set_ydata(y_points)
        
        # update Map point on the right
        umap_points.set_xdata(Xumap[frame, 0])
        umap_points.set_ydata(Xumap[frame, 1])

        # make sure everything is drawn
        fig.canvas.draw()
        fig.canvas.flush_events()
        
    # connect callback
    slider.observe(update_widgets, names='value')
    slider.value=1
    
    return ipywidgets.VBox([fig.canvas, slider])

# get 
show_skleton_viewer(Tadpoles[0], hist_kde, Xumaps[0])
