In [None]:
%matplotlib widget
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import gridspec
# import matplotlib.gridspec as gridspec
# from scipy.optimize import curve_fit
from bouter.utilities import predictive_tail_fill, crop

from lotr import LotrExperiment, DATASET_LOCATION

# from lotr.plotting import despine, add_scalebar, get_circle_xy, color_stack, add_cbar, dark_col
from lotr.pca import pca_and_phase, fit_phase_neurons
from lotr.rpca_calculation import get_normalized_coords, reorient_pcs, match_rpc_and_neuron_phases
from lotr.plotting import dark_col
from lotr.utils import zscore, get_vect_angle, reduce_to_pi
from lotr.behavior import get_fictive_heading

from itertools import product
def match_rpc_and_neuron_phases(rpc_phases, neuron_phases):
    """Function to match phase fit from neuron's best activation
    over network trajectory to neuron phase in rPC.

    Parameters
    ----------
    rpc_phases
    neuron_phases

    Returns
    -------

    """
    shifts = np.arange(-np.pi, np.pi, 0.1)
    coefs = [-1, 1]# np.arange(-2, 2, 0.2)
    params_list = list(product(coefs, shifts))
    residuals = np.zeros(len(params_list))
    for i, (coef, shift) in enumerate(params_list):
        new_phases = reduce_to_pi(neuron_phases * coef + shift)
        residuals[i] = np.sum(np.abs(new_phases - rpc_phases))

    return params_list[np.argmin(residuals)]

In [None]:
# List all experiments
master_path = Path(DATASET_LOCATION)
file_list = sorted([f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")])

In [None]:
pooled_bouts = []
for path in file_list:
    exp = LotrExperiment(path)
    pooled_bouts.append(exp.bouts_df)

In [None]:
pooled_bouts = []
for path in file_list:
    exp = LotrExperiment(path)
    pooled_bouts.append(exp.bouts_df)
    
pooled_bouts = pd.concat(pooled_bouts, axis=0)

In [None]:
plt.figure(figsize=(4, 4))

plt.scatter(pooled_bouts["bias"], pooled_bouts["bias_total"], 
            c=(pooled_bouts["n_pos_peaks"]+pooled_bouts["n_neg_peaks"])>4, 
            s=5, lw=0)
# plt.scatter(pooled_bouts["med_vig"], pooled_bouts["peak_vig"],
#             vmax=2, vmin=0.1, s=5, lw=0)
    
plt.axvline(0, c=(0.4,)*3)
plt.axhline(0, c=(0.4,)*3)

In [None]:
plt.figure(figsize=(4, 3))
plt.hist(pooled_bouts.loc[:, "bias"], 
         np.arange(-2, 2, 0.05), alpha=0.2, density=True)
plt.yscale("log")

In [None]:
def load_and_phase(path):    
    exp = LotrExperiment(path)
    traces = exp.traces[:, exp.hdn_indexes]

    # Compute PCA in population dim
    pcaed, phase, pca, circle_params = pca_and_phase(traces)

    # Compute preferred phase of each neuron:
    neuron_phases, _ = fit_phase_neurons(traces, phase, disable_bar=True)

    # Compute PCA in time, fit circle and center projections:
    pcaed_t, phase_t, pca_t, circle_params_t = pca_and_phase(traces.T)
    cpc_scores = pcaed_t[:, :2] - circle_params_t[:2]

    coords = exp.coords[exp.hdn_indexes, :]
    w_coords = get_normalized_coords(coords)

    # rotate pcs:
    rpc_scores = reorient_pcs(cpc_scores, w_coords)

    # We can now calculate a phase for each neuron from their position in this rotated space:
    rpc_phases = np.angle(rpc_scores[:, 0] + 1j * rpc_scores[:, 1])

    min_params = match_rpc_and_neuron_phases(rpc_phases, neuron_phases)
    
    norm_activity = get_normalized_coords(traces.T).T
    avg_vects = np.einsum("ij,ik->jk", norm_activity.T, rpc_scores[:, :2])

    angles = get_vect_angle(avg_vects.T) #[:, 0], pcaed_t[:, 1]
    return phase, angles

In [None]:
pooled_phases = []
pooled_angles = []
for path in file_list:
    exp = LotrExperiment(path)
    phase, angle = load_and_phase(path)
    pooled_phases.append(phase)
    pooled_angles.append(angle)

In [None]:
pooled_headings = []
for path in file_list:
    exp = LotrExperiment(path)
    pooled_headings.append(get_fictive_heading(
        exp.n_pts, exp.bouts_df, min_bias=0
    ))
#    dicts_for_df.append(exp_id=path.name,
#                        )

In [None]:
from numba import njit, prange

@njit(parallel=True)
def quantify_corr_with_heading(phase, fictive_heading, wnd_pts=500):
    n_pts = len(fictive_heading)

    correlations = np.zeros(len(phase) - 2*wnd_pts)
    for i in prange(len(correlations)):
        t_slice = slice(i, i+wnd_pts*2)
        correlations[i] = np.corrcoef(phase[t_slice], 
                                      fictive_heading[t_slice])[0, 1]

    return correlations

In [None]:
vals_list = []
bad = [28, 26, 14, 2, 6, 19, 23, 25]
for i, (phase, angle, head, exp) in tqdm(enumerate(
    zip(pooled_phases, pooled_angles, pooled_headings, file_list))):
    f, ax = plt.subplots(2, 1, figsize=(8, 4), sharex=True)
    ax[1].plot(zscore(head), label="heading dir.")
    for wnd in (range(50, 2000, 200)):
        
        for data, k, c in zip([phase, angle], ["phase", "angle"], [(0.8, 0.2,0.3), (0.2, 0.8, 0.3)]):
            corrs = quantify_corr_with_heading(-np.unwrap(data), head, wnd_pts=wnd)

            vals_list.append({"wnd": wnd,
                                  "bad": i in bad,
                                  "exp_id": exp.name,
                                  k + "_mn": np.nanmean(corrs),
                                  k + "_md": np.nanmedian(corrs),
                                  k + "_q1": np.nanpercentile(corrs, 25),
                                  k + "_q3": np.nanpercentile(corrs, 75)})
        
            if wnd == 850:
                sgn = np.sign(np.mean(corrs))
                ax[1].plot(zscore(-np.unwrap(data))*sgn, label=k, c=c)
                ax[0].plot(np.arange(len(corrs)) + wnd, corrs*sgn, c=c, label=k + " corr")
                
                ax[0].set_ylabel("Correlation in 200s wnd")
                ax[1].set_xlabel("Time (frame n.)")
                ax[0].axhline(np.corrcoef(np.unwrap(data), head)[0, 1], c=dark_col(c), 
                              label="tot. corr.")
                
                ax[1].legend(frameon=False, fontsize=8)
                ax[0].legend(frameon=False, fontsize=8)

                for l in [-1, 0, 1]:
                    ax[0].axhline(l, lw=0.5)
                ax[0].axhline(np.mean(corrs)*sgn, linestyle="dashed", c=c, lw=1)
                plt.suptitle(f"{i}, {exp.name}")
    
vals_df = pd.DataFrame(vals_list)

In [None]:
np.corrcoef(data, head)[0, 1]

In [None]:
plt.figure()
all_lines = []
for fid in vals_df.exp_id.unique():
    sel = vals_df[vals_df.exp_id==fid]
    c = "r" if sel.iloc[0, 1] else "k"
    plt.plot(sel.wnd, sel.mn * np.sign(sel.mn[-1:].mean()), c=c)
    all_lines.append(sel.mn * np.sign(sel.mn[-1:].mean()))
    
all_lines = np.array(all_lines)

plt.plot(sel.wnd, np.median(all_lines, 0), lw=5)

In [None]:
plt.figure()
plt.plot(zscore(np.unwrap(phase)))
plt.plot(zscore(head))
plt.plot(np.arange(len(corrs)) + wnd, corrs)

In [None]:
plt.figure()
plt.plot(correlate(zscore(angle), zscore(head), mode="full"))
plt.plot(correlate(zscore(-phase), zscore(head), mode="full"))
plt.plot(correlate(zscore(head), zscore(head), mode="full"))

In [None]:
plt.figure()
plt.plot(np.unwrap(phase))
plt.plot(head)

In [None]:
plt.figure()
for f in fictive_trajectories:
    plt.hist(np.mod(f, 2*np.pi) - np.pi, np.arange(-np.pi, np.pi, 0.3), alpha=0.2)

In [None]:
plt.figure()
# [plt.plot(f) for f in fictive_trajectories]

In [None]:
plt.figure()
plt.plot(counts.mean(0))

In [None]:
def load_and_plot(path):
    bad = [3, 17, 33, 36, 38, 19]
    
    exp = LotrExperiment(path)
    bouts_df = exp.bouts_df

    traces = exp.traces[:, exp.hdn_indexes]

    # Compute PCA in population dim
    pcaed, phase, pca, circle_params = pca_and_phase(traces)

    # Compute preferred phase of each neuron:
    neuron_phases, _ = fit_phase_neurons(traces, phase, disable_bar=True)

    # Compute PCA in time, fit circle and center projections:
    pcaed_t, phase_t, pca_t, circle_params_t = pca_and_phase(traces.T)
    cpc_scores = pcaed_t[:, :2] - circle_params_t[:2]

    coords = exp.coords[exp.hdn_indexes, :]
    w_coords = get_normalized_coords(coords)

    # rotate pcs:
    rpc_scores = reorient_pcs(cpc_scores, w_coords)

    # We can now calculate a phase for each neuron from their position in this rotated space:
    rpc_phases = np.angle(rpc_scores[:, 0] + 1j * rpc_scores[:, 1])

    min_params = match_rpc_and_neuron_phases(rpc_phases, neuron_phases)
    
    fict_traj, params = fictive_trajectory_and_fit(np.unwrap(phase), bouts_df, fn=5, min_bias=0.05)

    norm_activity = get_normalized_coords(traces.T).T
    avg_vects = np.einsum("ij,ik->jk", norm_activity.T, rpc_scores[:, :2])

    angles = get_vect_angle(avg_vects.T) #[:, 0], pcaed_t[:, 1]

    #plt.figure()
    #plt.scatter(reduce_to_pi(min_params[0]*rpc_phases + min_params[1]), neuron_phases)
    print(np.sum(np.abs(reduce_to_pi(min_params[0]*rpc_phases + min_params[1])-neuron_phases)),
         min_params)
    
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(4, 2, figure=fig)
    
    beh_plot = fig.add_subplot(gs[0, :])
    beh_plot.plot(exp.behavior_log.t.values[::3], exp.behavior_log.tail_sum.values[::3])
    beh_plot.set_xlim(0, exp.behavior_log.t.values[-1])
    beh_traces = fig.add_subplot(gs[1, :], sharex=beh_plot)
    beh_traces.imshow(traces[::4, np.argsort(rpc_phases)].T, extent=(0, exp.behavior_log.t.values[-1],
                                                                  0, len(rpc_phases)), 
                      aspect="auto", cmap="gray_r")
    
    ax_plot = fig.add_subplot(gs[2, :], sharex=beh_plot)
    axs_phase = fig.add_subplot(gs[3, 0])
    axs_circle = fig.add_subplot(gs[3, 1])
    
    axs_circle.scatter(rpc_scores[:, 0], rpc_scores[:, 1], c=neuron_phases, cmap="twilight")
    axs_circle.axis("equal")
    axs_phase.scatter(pcaed[:, 0], pcaed[:, 1], c=angles, cmap="twilight", s=8)
    axs_phase.axis("equal")
    
    x = np.arange(len(phase)) / exp.fn
    ax_plot.plot(x, -zscore(np.unwrap(phase)))
    ax_plot.plot(x, zscore(fict_traj))
    ax_plot.plot(x, -zscore(np.unwrap(angles)))
    title = f"{i}. {path.name}, {min_params[0]}"
    if i in bad:
        title = title + "bad"
    plt.suptitle(title)

In [None]:


# path = master_path / "210314_f1" / "210314_f1_natmov"# "210926_f0" / "210926_f0_gainmod" 
path = master_path / "210924_f1" / "210924_f1b_gainmod"
for i, path in enumerate(tqdm(file_list)):
    load_and_plot(path)

In [None]:
plt.figure(figsize=(3, 3))
plt.scatter(pcaed_t[:, 0], pcaed_t[:, 1])
.scatter(pcaed[:, 0], pcaed[:, 1], c=phase, cmap="twilight")

In [None]:
plt.figure(figsize=(3, 3))
plt.scatter(pcaed[:, 0], pcaed[:, 1], c=phase, cmap="twilight")

In [None]:
plt.figure()
plt.plot((phase))
plt.plot(-zscore(fict_traj))

In [None]:

plt.figure(figsize=(7, 2.5))
for n, i in enumerate([500, 2000, 7000]):
    plt.subplot(1,3, 1+n)
    plt.scatter(rpc_scores[:, 0], rpc_scores[:, 1], c=traces[i, :])
    plt.plot([0, np.cos(angles[i])*100], [0, np.sin(angles[i])*100])
    plt.axis("equal")

In [None]:
beh_df = exp.behavior_log

In [None]:
plt.figure()
#plt.plot(beh_df.t, beh_df.tail_sum)
plt.plot(beh_df.t, beh_df.tail_sum)

In [None]:
from bouter.utilities import predictive_tail_fill

beh_df = exp.behavior_log
theta_mat = beh_df.loc[:, [f"theta_0{i}" for i in range(9)]].values
beh_df.loc[:, [f"theta_0{i}" for i in range(9)]] = predictive_tail_fill(
    theta_mat
)

beh_df["tail_sum"] = (beh_df["theta_07"] + beh_df["theta_08"]) - (
    beh_df["theta_00"] + beh_df["theta_01"]
)