In [None]:
import os
import sys
import glob
import pywt
import numpy
import pandas
import tadpose
from tadpose import utils
from scipy.signal import stft
from tqdm.auto import tqdm, trange

# mostly ploting
import ipywidgets
import seaborn as sns
from tqdm.auto import tqdm
from matplotlib import cm, colors
from matplotlib import pyplot as plt

# clustering and pca
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN, KMeans
from scipy.ndimage import gaussian_filter1d

### Basic definitions
* create a tadpole object
* configure alignment

In [None]:
# main input required. SLEAP naysis file is expected to be in same folder with ending ".predictions.analysis.h5"
video_fn = "tail_oscillation/Tad2_Take3_oursNOGFP_St57.mp4"
video_fn = "tail_oscillation/Tad2_Take4_oursNOGFP_St57.mp4"

### Create tadpole and aligner

In [None]:
# create Tadpole object
tadpole = tadpose.Tadpole.from_sleap(video_fn)

# create aligner by giving to part names and their correpsonding alignment location
aligner = tadpose.alignment.TadpoleAligner(
    {"tail_1": numpy.array([0, 0.0]), "heart": numpy.array([0, 1.0])}, scale=False
)
tadpole.aligner = aligner

### Define skeleton of interest

In [None]:
cluster_columns = (
    "tail_1",
    "tail_2",
    "tail_3",
    "tail_4",
)

### Extract ego-centric lcoations

In [None]:
X = tadpole.ego_locs(parts=cluster_columns)
X = X.reshape(X.shape[0], -1)
X.shape

### Do PCA

In [None]:
Xc = (X - X.mean(0)) / X.std(0)

# PCA with N components
N = 3
Xp = numpy.zeros((len(X), N))

# Xp will contain the PCA components
pca = PCA(n_components=N)
pca.fit(Xc)
Xpca = pca.transform(Xc)

### Plot randomly selected skeltons with color-coded PC

In [None]:
%matplotlib widget

f, ax = plt.subplots()
ax.set_aspect(1.0)

# select PC component to plot
pc = 0

norm = colors.Normalize(vmin=-5, vmax=5, clip=True)
for rand_ind in numpy.random.randint(X.shape[0], size=500):
    points = X[rand_ind, :]
    load = Xpca[rand_ind, pc]
    color = cm.seismic(norm(load))
    p = plt.plot(-points[::2].T, points[1::2].T, ".-", alpha=0.2, color=color)

plt.axis("off")
plt.title("Tracked tail points")

sm = plt.cm.ScalarMappable(cmap="seismic", norm=norm)
cbar = plt.colorbar(
    sm,
    fraction=0.033,
    pad=0.04,
)
cbar.ax.set_ylabel("PC 0")
plt.savefig("01_tracked_tail_pc.png", bbox_inches="tight")

### Smooth a little and compute gradient 

In [None]:
pc0 = Xpca[:, 0:1]

%matplotlib widget
pc_s0 = utils.smooth(pc0, win=5, poly=3, deriv=0)
pc_s1 = utils.smooth(pc0, win=5, poly=3, deriv=1)

### Compute global speed = hearspeed

In [None]:
heart_speed = tadpose.analysis.speeds(tadpole, parts=("heart",))

### Interactive skelton viewer

In [None]:
%matplotlib widget


def show_skleton_viewer(tadpole, X, Xp, video_shape=(800, 400)):
    """
    Interactive viewer to visualize a skeleton from X together with scalar value "Xp".
    Use left/right keys to go through time
    """

    plt.ioff()
    x_view = [-200, 200]
    y_view = [-400, 100]

    slider = ipywidgets.IntSlider(
        description="Time (frame)",
        value=0,
        min=0,
        max=X.shape[0] - 1,
        continuous_update=True,
        style={"min_width": 5000},
    )

    fig, axs = plt.subplots(1, 2, figsize=(9, 5))
    ax = axs[0]

    # which pc-component to vizualize
    pc = 0

    # normalize colors for that component
    norm = colors.Normalize(vmin=-5, vmax=5, clip=True)

    def update_lines(change):
        frame = change.new

        points = X[frame, :]
        pc_load = Xp[frame, pc]
        color = cm.seismic(norm(pc_load))
        ax.clear()

        gray = tadpole.ego_image(
            frame, dest_height=video_shape[0], dest_width=video_shape[1], rgb=False
        )
        ax.imshow(
            gray,
            "gray",
            extent=(
                -gray.shape[1] // 2,
                gray.shape[1] // 2,
                -gray.shape[0] // 2,
                gray.shape[0] // 2,
            ),
        )

        ax.set_xticks([])
        ax.set_yticks([])

        ax.plot(points[::2].T, points[1::2].T, ".-", alpha=1, color=color)
        ax.set_xlim(x_view[0], x_view[1])
        ax.set_ylim(y_view[0], y_view[1])

        grad = Xp[:, 0]

        axs[1].clear()
        axs[1].plot(grad, "-", color="gray")
        axs[1].plot(frame, grad[frame], ".", color="red")
        axs[1].set_xlim(frame - 35, frame + 35)
        axs[1].set_ylabel("PC 0")
        axs[1].set_xlabel("Time (frames)")
        axs[1].set_xticks([frame - 30, frame, frame + 30])
        axs[1].set_xticklabels(list(map(str, [frame - 30, frame, frame + 30])))

        fig.canvas.draw()
        fig.canvas.flush_events()
        
        # plt.savefig(
        #     f"tad_tail_oscillation_anim/{tadpole.vid_fn}_{frame}.png",
        #     bbox_inches="tight",
        # )

    slider.observe(update_lines, names="value")
    slider.value = 1

    # for k in trange(6600, 7400):
    #     slider.value = k

    return ipywidgets.VBox([fig.canvas, slider])

show_skleton_viewer(tadpole, X, pc_s0)

In [None]:
N = 32
wavelet = "morl"

fps = 60
# create N=25 dyadically spaced scales, 25 is what they used in motionmapper
Fc = pywt.central_frequency(wavelet)
fps = 60
sp = 1 / fps
# scales = Fc / (numpy.arange(1, 30) * sp)
if wavelet == "morl":
    scales = numpy.power(2, numpy.linspace(1, 7, N))  # <- dyadic
elif wavelet == "mexh":
    scales = numpy.power(2, numpy.linspace(-0.4, 4, N))  # <- dyadic

frequencies = pywt.scale2frequency(wavelet, scales) / sp

# plot which scale correspond to which freq.
%matplotlib widget
f, ax = plt.subplots()
ax.plot(scales, frequencies, "b.")
ax.set_xlabel("Input scales for wavelet transform")
ax.set_ylabel(f"Corresponding frequency at movie fps of {fps}")
print(f"Scales range from {frequencies.min()} to {frequencies.max()} Hz")

In [None]:
sig = pc_s1[:, 0]
# sig = (sig - sig.mean()) / sig.std()
coef_cwt, freqs_cwt = pywt.cwt(sig, scales, wavelet, sampling_period=1 / fps)
coef_cwt = numpy.abs(coef_cwt)

In [None]:
%matplotlib widget

y, x = numpy.mgrid[0 : coef_cwt.shape[0], 0 : coef_cwt.shape[1]]
y = numpy.ones_like(y)
y = (y.T * freqs_cwt).T
plt.pcolor(x, y, (numpy.abs(coef_cwt)))
plt.gca().set_aspect(400.0)
# plt.gca().set_ylim(0, 10)

In [None]:
%matplotlib widget
spectrum, F, t, _ = plt.specgram(
    pc_s0[:, 0],
    Fs=60,
    NFFT=256,
    noverlap=128,
    detrend="linear",
    scale="linear",
    # interpolation="nearest",
)
spectrum.shape
plt.gca().set_ylabel("Tail-beat frequency (Hz)")
plt.gca().set_xlabel("Time (secs)")

In [None]:
%matplotlib widget
f, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].bar(F, height=spectrum.mean(1))
axs[0].set_xlim(0, 25)
axs[0].set_ylabel("Mean squared power spectrum")
axs[0].set_xlabel("Tail-beat frequency (Hz)")
axs[0].set_title("Short-term Fourier transfrom from PC")

axs[1].bar(
    freqs_cwt, height=coef_cwt[:, heart_speed.to_numpy().squeeze() > 0.5].mean(1)
)
axs[1].set_xlim(0, 25)
axs[1].set_ylabel("Morlet wavelet coefficient")
axs[1].set_xlabel("Tail-beat frequency swimming (Hz) ")
axs[1].set_title("Coninous wavelet transfrom from PC gradient")
for ax in axs:
    sns.despine(ax=ax)
plt.savefig("03_tail_beat_freqencies.png", bbox_inches="tight")