In [None]:
%matplotlib widget
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import gridspec
# import matplotlib.gridspec as gridspec
# from scipy.optimize import curve_fit

from lotr import LotrExperiment, DATASET_LOCATION

# from lotr.plotting import despine, add_scalebar, get_circle_xy, color_stack, add_cbar, dark_col
from lotr.pca import pca_and_phase, fit_phase_neurons
from lotr.rpca_calculation import get_normalized_coords, reorient_pcs, match_rpc_and_neuron_phases
from lotr.utils import reduce_to_pi
from lotr.pca import fictive_heading_and_fit 
from lotr.utils import zscore, get_vect_angle

from itertools import product
def match_rpc_and_neuron_phases(rpc_phases, neuron_phases):
    """Function to match phase fit from neuron's best activation
    over network trajectory to neuron phase in rPC.

    Parameters
    ----------
    rpc_phases
    neuron_phases

    Returns
    -------

    """
    shifts = np.arange(-np.pi, np.pi, 0.1)
    coefs = [-1, 1]# np.arange(-2, 2, 0.2)
    params_list = list(product(coefs, shifts))
    residuals = np.zeros(len(params_list))
    for i, (coef, shift) in enumerate(params_list):
        new_phases = reduce_to_pi(neuron_phases * coef + shift)
        residuals[i] = np.sum(np.abs(new_phases - rpc_phases))

    return params_list[np.argmin(residuals)]

In [None]:
# List all experiments
master_path = Path(DATASET_LOCATION)
file_list = sorted([f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")])

In [None]:
def load_and_plot(path):
    # bad = [28, 26, 14, 2, 6, 19, 23, 25]
    
    exp = LotrExperiment(path)
    bouts_df = exp.bouts_df

    traces = exp.traces[:, exp.hdn_indexes]

    # Compute PCA in population dim
    pcaed, phase, pca, circle_params = pca_and_phase(traces)

    # Compute preferred phase of each neuron:
    neuron_phases, _ = fit_phase_neurons(traces, phase, disable_bar=True)

    # Compute PCA in time, fit circle and center projections:
    pcaed_t, phase_t, pca_t, circle_params_t = pca_and_phase(traces.T)
    cpc_scores = pcaed_t[:, :2] - circle_params_t[:2]

    coords = exp.coords[exp.hdn_indexes, :]
    w_coords = get_normalized_coords(coords)

    # rotate pcs:
    rpc_scores = reorient_pcs(cpc_scores, w_coords)

    # We can now calculate a phase for each neuron from their position in this rotated space:
    rpc_phases = np.angle(rpc_scores[:, 0] + 1j * rpc_scores[:, 1])

    min_params = match_rpc_and_neuron_phases(rpc_phases, neuron_phases)
    
    fict_traj, params = fictive_heading_and_fit(np.unwrap(phase), bouts_df, fn=5, min_bias=0.05)

    norm_activity = get_normalized_coords(traces.T).T
    avg_vects = np.einsum("ij,ik->jk", norm_activity.T, rpc_scores[:, :2])

    angles = get_vect_angle(avg_vects.T) #[:, 0], pcaed_t[:, 1]

    #plt.figure()
    #plt.scatter(reduce_to_pi(min_params[0]*rpc_phases + min_params[1]), neuron_phases)
    print(np.sum(np.abs(reduce_to_pi(min_params[0]*rpc_phases + min_params[1])-neuron_phases)),
         min_params)
    
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(5, 2, figure=fig)
    
    beh_plot = fig.add_subplot(gs[0, :])
    beh_plot.plot(exp.behavior_log.t.values[::3], exp.behavior_log.tail_sum.values[::3])
    beh_plot.set_xlim(0, exp.behavior_log.t.values[-1])
    
    beh_traces = fig.add_subplot(gs[2, :], sharex=beh_plot)
    beh_traces.imshow(traces[::4, np.argsort(rpc_phases)].T, extent=(0, exp.behavior_log.t.values[-1],
                                                                  0, len(rpc_phases)), 
                      aspect="auto", cmap="gray_r")
    
    ax_plot = fig.add_subplot(gs[3, :], sharex=beh_plot)
    axs_phase = fig.add_subplot(gs[4, 0])
    axs_circle = fig.add_subplot(gs[4, 1])
    
    axs_circle.scatter(rpc_scores[:, 0], rpc_scores[:, 1], c=neuron_phases, cmap="twilight")
    axs_circle.axis("equal")
    axs_phase.scatter(pcaed[:, 0], pcaed[:, 1], c=angles, cmap="twilight", s=8)
    axs_phase.axis("equal")
    
    x = np.arange(len(phase)) / exp.fn
    ax_plot.plot(x, -zscore(np.unwrap(phase)))
    ax_plot.plot(x, zscore(fict_traj))
    # corrs = quantify_corr_with_heading(-np.unwrap(data), head, wnd_pts=wnd)
    ax_plot.plot(x, -zscore(np.unwrap(angles))*min_params[0])
    title = f"{i}. {path.name}, {min_params[0]}"

    plt.suptitle(title)
    
    return fig

In [None]:
    
# path = master_path / "210314_f1" / "210314_f1_natmov"# "210926_f0" / "210926_f0_gainmod" 
path = master_path / "210924_f1" / "210924_f1b_gainmod"

#with PdfPages('multipage_pdf.pdf') as pdf:
#    for i, path in enumerate(tqdm(file_list)):
        fig = load_and_plot(path)
        pdf.savefig(fig)  # or you can pass a Figure object to pdf.savefig
#        plt.close()