In [175]:
from collections import defaultdict

import dill
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.style
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

mpl.rcParams["figure.facecolor"] = "w"
mpl.rcParams["figure.dpi"] = 150
mpl.rcParams["savefig.dpi"] = 600
mpl.rcParams["savefig.transparent"] = True
mpl.rcParams["font.size"] = 15
mpl.rcParams["font.family"] = "sans-serif"
mpl.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
mpl.rcParams["axes.titlesize"] = "xx-large"  # medium, large, x-large, xx-large

mpl.style.use("seaborn-deep")

from itertools import groupby


def encode_list(s_list):
    return [[len(list(group)), key[0]] for key, group in groupby(s_list)]


import sys

sys.path.append("../")
import analysis.utils.trx_utils as trx_utils

filename = "/Genomics/ayroleslab2/scott/long-timescale-behavior/data/organized_tracks/20220217-lts-cam1/cam1_20220217_0through190_cam1_20220217_0through190_1-tracked.analysis.h5"
import h5py
import numpy as np

with h5py.File(filename, "r") as f:
    dset_names = list(f.keys())
    node_names = [n.decode() for n in f["node_names"][:]]

In [176]:
z_vals_file = "/Genomics/ayroleslab2/scott/git/lts-manuscript/analysis/20230118-mmpy-lts-day1-headprobinterp-missingness-pchip5-fillnanmedian-tss64k-tsp32k-removegt6missing/TSNE/20230119_sigma2_7_minregions20_zVals_wShed_groups_finalsave.mat"
f = h5py.File(z_vals_file, "r")
z_val_names_dset = f["zValNames"]
references = [
    f[z_val_names_dset[dset_idx][0]] for dset_idx in range(z_val_names_dset.shape[0])
]
z_val_names = ["".join(chr(i) for i in obj[:]) for obj in references]
z_lens = [l[0] for l in f["zValLens"][:]]

d = {}
for z_val_name_idx in range(len(z_val_names)):
    d[z_val_names[z_val_name_idx]] = [
        np.sum(z_lens[:(z_val_name_idx)]),
        np.sum(z_lens[: (z_val_name_idx + 1)]),
    ]
d[z_val_names[0]][0] = 0

In [191]:
FOCAL_REGIONS_DICT = {
    "foreleg_grooming": 2,
    "locomotion": 7,
    "hindleg_grooming": 18,
}

f = h5py.File(z_vals_file, "r")
groups_dict = {}
dset = f["groups"]
for name, idx in FOCAL_REGIONS_DICT.items():
    ref_group = f[dset[0][idx]]
    groups_dict[name] = ref_group[:][
        :,
        np.where(
            ref_group[0, :]
            == z_val_names.index("20220217-lts-cam2_day1_24hourvars-0-pcaModes")
        )[0],
    ]

In [192]:
source_file = "/Genomics/ayroleslab2/scott/git/lts-manuscript/analysis/20220217-lts-cam1_day1_24hourvars.h5"
import importlib

importlib.reload(trx_utils)

# videofile = "/Genomics/ayroleslab2/scott/long-timescale-behavior/data/organized_videos/20220217-lts-cam1/20220217-lts-cam1-0000.mp4"
# min_length = 25
# f = h5py.File(
#     "/Genomics/ayroleslab2/scott/git/lts-manuscript/analysis/20230118-mmpy-lts-day1-headprobinterp-missingness-pchip5-fillnanmedian-tss64k-tsp32k-removegt6missing/TSNE/20230119_sigma2_7_minregions20_zVals_wShed_groups_finalsave.mat"
# )

with h5py.File(source_file, "r") as f:
    print(f.keys())
    vels = f["vels"][:]
    locations = f["tracks"][:].T

<KeysViewHDF5 ['tracks', 'vels']>


In [218]:
locations_ego = trx_utils.normalize_to_egocentric(
    locations[:, :, :, 0],
    ctr_ind=node_names.index("thorax"),
    fwd_ind=node_names.index("head"),
)

In [230]:
trimmed_dict = {}
for name, group in groups_dict.items():
    print(name, group.shape)
    lengths = group[2, :] - group[1, :]
    cutoff = np.quantile(lengths, 0.90)
    focal_groups = np.where(lengths > cutoff)[0]
    trimmed_dict[name] = group[:, focal_groups]
    np.random.shuffle(trimmed_dict[name].T)
    print(f"Number of focal groups: {len(focal_groups)}")

foreleg_grooming (3, 2225)
Number of focal groups: 216
locomotion (3, 580)
Number of focal groups: 58
hindleg_grooming (3, 4656)
Number of focal groups: 458


In [232]:
%%capture
import pathlib
import time

base_figure_path = pathlib.Path("figures/")
base_figure_path.mkdir(exist_ok=True, parents=True)

name_node_map = {
    "foreleg_grooming": ["forelegL"],
    "locomotion": ["midlegL"],
    "hindleg_grooming": ["hindlegL"],
}

color_map = {"None": "red", "linear": "blue", "pchip": "green"}

WIN_SIZE = 10


def plot_focal_coordinates_from_group(info, locations):
    for name, group in info.items():
        print(name, group.shape)
        for i in range(group.shape[1]):
            start = group[1, i]
            end = group[2, i]
            print(f"Plotting {name} {i} from {start} to {end} ({end-start} frames)")
            fig, ax = plt.subplots(2, 1, figsize=(12, 8))
            print("Created subplots...")
            for interp_method in ["None", "linear", "pchip"]:
                tmp_locations = locations[start:end, :, :].copy()
                if interp_method != "None":
                    mask = np.arange(tmp_locations.shape[0])
                    mask = mask[
                        ((len(mask) // 2) - (WIN_SIZE // 2)) : (
                            (len(mask) // 2) + (WIN_SIZE // 2)
                        )
                    ]
                    tmp_locations[mask, :, :] = np.nan
                    tmp_locations = trx_utils.fill_missing(
                        tmp_locations, kind=interp_method
                    )
                if interp_method == "None":
                    lt = "solid"
                else:
                    lt = "--"
                for node in name_node_map[name]:
                    ax[0].plot(
                        tmp_locations[:, node_names.index(node), 0],
                        label=interp_method,
                        color=color_map[interp_method],
                        alpha=0.5,
                        linewidth=2,
                        linestyle=lt,
                    )
                    ax[1].plot(
                        tmp_locations[:, node_names.index(node), 1],
                        label=interp_method,
                        color=color_map[interp_method],
                        alpha=0.5,
                        linewidth=2,
                        linestyle=lt,
                    )
            ax[0].set_ylim(-100, 100)
            ax[0].set_ylabel("x coordinate (px)")
            ax[1].set_ylim(-50, 50)
            ax[1].set_xlabel("frame")
            ax[1].set_ylabel("y coordinate (px)")
            ax[0].axvline(
                x=((tmp_locations.shape[0] // 2) - (WIN_SIZE // 2)),
                color="black",
                alpha=0.5,
                linewidth=2,
                linestyle="--",
            )
            ax[0].axvline(
                x=(tmp_locations.shape[0] // 2) + (WIN_SIZE // 2),
                color="black",
                alpha=0.5,
                linewidth=2,
                linestyle="--",
            )
            ax[1].axvline(
                x=((tmp_locations.shape[0] // 2) - (WIN_SIZE // 2)),
                color="black",
                alpha=0.5,
                linewidth=2,
                linestyle="--",
            )
            ax[1].axvline(
                x=(tmp_locations.shape[0] // 2) + (WIN_SIZE // 2),
                color="black",
                alpha=0.5,
                linewidth=2,
                linestyle="--",
            )
            plt.legend()
            fig.suptitle(
                f"{name} {i} from {start} to {end} ({end-start} frames)", fontsize=12
            )
            print("Saving figure...")
            output_path = base_figure_path / f"examples/{name}/"
            output_path.mkdir(exist_ok=True, parents=True)
            plt.savefig(
                output_path / f"example-fly0-{i:04d}.png", transparent=False, dpi=300
            )
            plt.close()


plot_focal_coordinates_from_group(trimmed_dict, locations_ego)

14:37:28 ERROR: Error interpolating: `x` must contain at least 2 elements.
14:37:47 ERROR: Error interpolating: `x` must contain at least 2 elements.


KeyboardInterrupt: 

In [None]:
running_list = defaultdict(lambda: defaultdict(dict))
for fly_idx in tqdm(range(4)):
    fly_id_mm = fly_idx
    fly_id_trx = fly_idx
    rle_list = encode_list(
        f["watershedRegions"][
            d[f"20220217-lts-cam1_day1_24hourvars-{fly_id_mm}-pcaModes"][0] : d[
                f"20220217-lts-cam1_day1_24hourvars-{fly_id_mm}-pcaModes"
            ][1]
        ]
    )
    dict_rle = {"number": [p[1] for p in rle_list], "length": [p[0] for p in rle_list]}
    df = pd.DataFrame(dict_rle)
    # Get the endasd

    df["end"] = np.cumsum(df.length)
    # Get the start
    df["start"] = df["end"] - df.length

    for region in range(0, np.unique(f["watershedRegions"][:]).shape[0]):
        running_list[fly_idx][region] = list()
        try:
            subset = df[(df.number == region)]
            for section in subset.iterrows():
                start = section[1]["start"]
                try:
                    running_list[fly_idx][region].extend(
                        vels[fly_idx, start:end].tolist()
                    )
                except:
                    print(f"Failed to append {start},{end}")
        except:
            print(f"Failed to find velocities in {region}")
output = pd.DataFrame(columns=["fly_idx", "region", "mean_velocity"])
for idx in tqdm(range(len(running_list))):
    for region in range(len(running_list[idx])):
        output.loc[len(output.index)] = [
            idx,
            region,
            np.mean(running_list[idx][region]),
        ]
        print(
            f"{idx},{region} completed with {len(running_list[idx][region])} examples"
        )

(4, 14, 8636544)


100%|██████████| 4/4 [08:05<00:00, 121.36s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

0,0 completed with 3 examples
0,1 completed with 0 examples
0,2 completed with 0 examples
0,3 completed with 0 examples
0,4 completed with 0 examples
0,5 completed with 0 examples
0,6 completed with 0 examples
0,7 completed with 0 examples
0,8 completed with 0 examples


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


0,9 completed with 3 examples
0,10 completed with 0 examples


 25%|██▌       | 1/4 [00:08<00:25,  8.62s/it]

0,11 completed with 4 examples
0,12 completed with 0 examples
0,13 completed with 4 examples
0,14 completed with 0 examples
0,15 completed with 0 examples
0,16 completed with 0 examples
0,17 completed with 0 examples
0,18 completed with 0 examples
0,19 completed with 0 examples
0,20 completed with 0 examples
1,0 completed with 4 examples
1,1 completed with 0 examples
1,2 completed with 0 examples
1,3 completed with 0 examples
1,4 completed with 0 examples
1,5 completed with 0 examples
1,6 completed with 0 examples
1,7 completed with 0 examples
1,8 completed with 0 examples
1,9 completed with 0 examples
1,10 completed with 0 examples
1,11 completed with 0 examples
1,12 completed with 6 examples


 50%|█████     | 2/4 [00:15<00:15,  7.87s/it]

1,13 completed with 4 examples
1,14 completed with 0 examples
1,15 completed with 0 examples
1,16 completed with 0 examples
1,17 completed with 0 examples
1,18 completed with 0 examples
1,19 completed with 0 examples
1,20 completed with 0 examples
2,0 completed with 3 examples
2,1 completed with 0 examples
2,2 completed with 0 examples
2,3 completed with 0 examples
2,4 completed with 0 examples
2,5 completed with 0 examples
2,6 completed with 0 examples
2,7 completed with 0 examples
2,8 completed with 0 examples
2,9 completed with 6 examples
2,10 completed with 0 examples
2,11 completed with 2 examples
2,12 completed with 0 examples


 75%|███████▌  | 3/4 [00:23<00:07,  7.63s/it]

2,13 completed with 3 examples
2,14 completed with 0 examples
2,15 completed with 0 examples
2,16 completed with 0 examples
2,17 completed with 0 examples
2,18 completed with 0 examples
2,19 completed with 0 examples
2,20 completed with 0 examples
3,0 completed with 5 examples
3,1 completed with 0 examples
3,2 completed with 0 examples
3,3 completed with 0 examples
3,4 completed with 0 examples
3,5 completed with 0 examples
3,6 completed with 0 examples
3,7 completed with 0 examples
3,8 completed with 0 examples


100%|██████████| 4/4 [00:30<00:00,  7.62s/it]

3,9 completed with 3 examples
3,10 completed with 0 examples
3,11 completed with 0 examples
3,12 completed with 0 examples
3,13 completed with 6 examples
3,14 completed with 0 examples
3,15 completed with 0 examples
3,16 completed with 0 examples
3,17 completed with 0 examples
3,18 completed with 0 examples
3,19 completed with 0 examples
3,20 completed with 0 examples





In [None]:
# output.to_csv("wtf.csv")

# import numpy as np
# import pandas as pd
# import statsmodels.api as sm
# import statsmodels.formula.api as smf
# from statsmodels.tools.sm_exceptions import ConvergenceWarning

# cleaned_output = output[~np.isnan(output["mean_velocity"])]
# # md = smf.mixedlm("mean_velocity ~ 1 + fly_idx ", cleaned_output, groups=cleaned_output['region'])
# # mdf = md.fit(method=["lbfgs"])
# model = smf.ols(formula="mean_velocity ~ C(region)", data=cleaned_output)
# model_fit = model.fit()
# print(model_fit.summary())
# # anova_table = sm.stats.anova_lm(model_fit, typ=2)
# # print(anova_table)

# plt.figure()
# cleaned_output.groupby("region").mean()["mean_velocity"].plot(kind="bar")
# plt.savefig("test.png")

In [None]:
f = h5py.File(
    "/Genomics/ayroleslab2/scott/git/lts-manuscript/analysis/20230118-mmpy-lts-day1-headprobinterp-missingness-pchip5-fillnanmedian-tss64k-tsp32k-removegt6missing/TSNE/20230119_sigma2_7_minregions20_zVals_wShed_groups_finalsave.mat"
)