In [None]:
%matplotlib widget

from pathlib import Path

import flammkuchen as fl
import numpy as np
import pandas as pd
import seaborn as sns
from ec_code.file_utils import get_dataset_location
from ec_code.phy_tools.utilities import bouts_from_twitches, nanzscore
from ec_code.phy_tools.utilities.plotting import exp_plot, get_xy
from ec_code.phy_tools.utilities.spikes_detection import raster_on_evts
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

sns.set(style="ticks", palette="deep")
from numba import jit

cols = sns.color_palette()

In [None]:
data_folder = get_dataset_location("ephys")
data_collection = fl.load(data_folder / "all_pooled_data.hdf5")

In [None]:
pre_int_sec = 10
post_int_sec = 10
raster_collection = []

for data in data_collection:
    raster = dict()
    for session in data.keys():
        raster[session] = None
        if data[session] is not None:
            spikes = data[session]["spk_idxs"]
            twitches = data[session]["twc_idxs"]
            fn = data[session]["fn"]
            try:
                ons, offs = bouts_from_twitches(twitches, sort=True)

                raster_spk = raster_on_evts(
                    spikes, ons, fn=fn, pre_int=pre_int_sec, post_int=post_int_sec
                )
                raster_twc = raster_on_evts(
                    twitches, ons, fn=fn, pre_int=pre_int_sec, post_int=post_int_sec
                )
                raster_spk_shuf = raster_on_evts(
                    spikes,
                    np.random.randint(0, spikes[-1], len(ons)),
                    fn=fn,
                    pre_int=pre_int_sec,
                    post_int=post_int_sec,
                )
                raster_spk_twc = raster_on_evts(
                    spikes,
                    twitches,
                    fn=8333.3,
                    pre_int=pre_int_sec,
                    post_int=post_int_sec,
                )
                raster_twc_twc = raster_on_evts(
                    twitches,
                    twitches,
                    fn=8333.3,
                    pre_int=pre_int_sec,
                    post_int=post_int_sec,
                )

                raster[session] = dict(
                    spk=raster_spk,
                    spk_shuf=raster_spk_shuf,
                    spk_twc=raster_spk_twc,
                    twc_twc=raster_twc_twc,
                    twc=raster_twc,
                    name=data[session]["name"],
                )
            except TypeError:
                pass

    raster_collection.append(raster)

In [None]:
from ec_code.phy_tools.utilities.spikes_stats import get_moments, single_plot

# Make the plot for one cell

In [None]:
def cell_response_panel(raster):
    fig, ax = plt.subplots(3, 3, figsize=(9, 9), gridspec_kw=dict(hspace=0.2))
    sessions = ["exp022", "blanks", "lag"]
    t = None
    y_lab_loc = 0
    for j, t_range in enumerate([(10, 10, 0.5), (2, 2, 0.1)]):
        for i, s in enumerate(sessions):
            if raster[s] is not None:
                # spikes, pre_int_sec, post_int_sec, step, ax=None, events=None
                single_plot(
                    raster[s]["spk"],
                    t_range[0],
                    t_range[1],
                    t_range[2],
                    ax=ax[i][j],
                    events=raster[s]["twc"],
                )
                t = raster[s]["name"]
                y_lab_loc = i
            else:
                ax[i][j].set_visible(False)

    for i, s in enumerate(["exp022", "blanks", "lag"]):
        if raster[s] is not None:
            single_plot(
                raster[s]["spk_twc"],
                0.15,
                0.15,
                0.01,
                ax=ax[i][2],
                events=raster[s]["twc_twc"],
            )
        else:
            ax[i][2].set_visible(False)

    xlabels = ["Time from bout (s)"] * 2 + ["Time from twitch (s)"]
    for ax_i in range(3):
        ax[ax_i][0].set_ylabel("{}\n(Spikes/s)".format(sessions[ax_i]), labelpad=30)
        for ax_j in range(3):
            if ax_i != y_lab_loc:
                ax[ax_i][ax_j].set_xticklabels([])
            else:
                ax[ax_i][ax_j].set_xlabel(xlabels[ax_j])
    plt.suptitle(t)
    plt.legend()
    sns.despine()
    # plt.tight_layout()

In [None]:
cell_response_panel(raster_collection[28])

# Dirtylegend

In [None]:
f = plt.figure(figsize=(3, 3))
plt.plot(1, color=sns.color_palette()[0], alpha=0.4, label="tail")
plt.plot(1, color=sns.color_palette()[2], label="spikes")
plt.axhspan(0, 0, color="k", alpha=0.1, zorder=-500, label="99% conf.")
plt.legend()
sns.despine()

In [None]:
f.savefig("/Users/luigipetrucco/Desktop/legend.pdf")

## Export pdf with all cells

In [None]:
from matplotlib.backends.backend_pdf import PdfPages

In [None]:
with PdfPages("/Users/luigipetrucco/Desktop/all_cells_motor_resp.pdf") as pdf:
    for raster in raster_collection:
        cell_response_panel(raster)
        pdf.savefig()  # saves the current figure into a pdf page
        plt.close()

# Colorplot with all cells:

In [None]:
pre_int_sec = 10
post_int_sec = 10
CONF_INT_COEF = 1.96

pooled_spk = [
    raster["exp022"]["spk"]
    for raster in raster_collection
    if raster["exp022"] is not None
]
pooled_bouts = [
    raster["exp022"]["twc"]
    for raster in raster_collection
    if raster["exp022"] is not None
]

In [None]:
pre_int_sec = 10
post_int_sec = 10
step = 0.5
hst_arr = np.arange(-pre_int_sec, post_int_sec, step)
histograms_10s = np.array(
    [nanzscore(np.histogram(p[:, 0].flatten() / fn, hst_arr)[0]) for p in pooled_spk]
)
histograms_10s_twc = np.array(
    [nanzscore(np.histogram(p[:, 0].flatten() / fn, hst_arr)[0]) for p in pooled_bouts]
)
# histograms_10s[np.abs(histograms_10s) < CONF_INT_COEF] = 0


pre_int_sec = 2
post_int_sec = 2
step = 0.1
hst_arr = np.arange(-pre_int_sec, post_int_sec, step)
histograms_2s = np.array(
    [nanzscore(np.histogram(p[:, 0].flatten() / fn, hst_arr)[0]) for p in pooled_spk]
)
histograms_2s_twc = np.array(
    [nanzscore(np.histogram(p[:, 0].flatten() / fn, hst_arr)[0]) for p in pooled_bouts]
)
# histograms_2s[np.abs(histograms_2s) < CONF_INT_COEF] = 0

In [None]:
f = plt.figure()
i_sort = 20

ax10_im = f.add_axes((0.1, 0.3, 0.35, 0.4))
# plt.subplot(1,2,1)
ax10_im.imshow(
    histograms_10s[np.argsort(histograms_10s[:, i_sort]), :],
    cmap="RdBu_r",
    vmin=-3,
    vmax=3,
    aspect="auto",
)
# ax10_im.set_yticks([])
ax10_im.set_xticks(np.arange(0, histograms_10s.shape[1], 10))
ax10_im.set_xticklabels(np.arange(0, histograms_10s.shape[1] + 5, 10) / 2 - 10)
ax10_im.set_ylabel("Cell (sorted on motor resp.)")
ax10_im.set_xlabel("Time from bout (s)")

ax10_bt = f.add_axes((0.1, 0.71, 0.35, 0.15))
ax10_bt.step(
    np.arange(histograms_10s_twc.shape[1]), np.mean(histograms_10s_twc, 0), alpha=0.4
)
ax10_bt.set_xlim(0, histograms_2s_twc.shape[1])
ax10_bt.axis("off")

ax2_im = f.add_axes((0.5, 0.3, 0.35, 0.4))
# plt.subplot(1,2,1)
ax2_im.imshow(
    histograms_2s[np.argsort(histograms_2s[:, i_sort]), :],
    cmap="RdBu_r",
    vmin=-3,
    vmax=3,
    aspect="auto",
)
ax2_im.set_yticks([])
ax2_im.set_xticks(np.arange(0, histograms_2s.shape[1], 10))
ax2_im.set_xticklabels(np.arange(0, histograms_2s.shape[1] + 5, 10) / 10 - 2)

ax2_bt = f.add_axes((0.5, 0.71, 0.35, 0.15))
ax2_bt.step(
    np.arange(histograms_2s_twc.shape[1]), np.mean(histograms_2s_twc, 0), alpha=0.4
)
ax2_bt.set_xlim(0, histograms_2s_twc.shape[1])
ax2_bt.axis("off")
ax2_im.set_xlabel("Time from bout (s)")
sns.despine(left=True, top=True, bottom=True)

In [None]:
f.savefig("/Users/luigipetrucco/Desktop/summary_ephys.pdf")

In [None]:
def cell_response_panel(raster):
    fig, ax = plt.subplots(3, 3, figsize=(9, 9), gridspec_kw=dict(hspace=0.2))
    sessions = ["exp022", "blanks", "lag"]
    t = None
    y_lab_loc = 0
    for j, t_range in enumerate([(10, 10, 0.5), (2, 2, 0.1)]):
        for i, s in enumerate(sessions):
            if raster[s] is not None:
                # spikes, pre_int_sec, post_int_sec, step, ax=None, events=None
                single_plot(
                    raster[s]["spk"],
                    t_range[0],
                    t_range[1],
                    t_range[2],
                    ax=ax[i][j],
                    events=raster[s]["twc"],
                )
                t = raster[s]["name"]
                y_lab_loc = i
            else:
                ax[i][j].set_visible(False)

    for i, s in enumerate(["exp022", "blanks", "lag"]):
        if raster[s] is not None:
            single_plot(
                raster[s]["spk_twc"],
                0.15,
                0.15,
                0.01,
                ax=ax[i][2],
                events=raster[s]["twc_twc"],
            )
        else:
            ax[i][2].set_visible(False)

    xlabels = ["Time from bout (s)"] * 2 + ["Time from twitch (s)"]
    for ax_i in range(3):
        ax[ax_i][0].set_ylabel("{}\n(Spikes/s)".format(sessions[ax_i]), labelpad=30)
        for ax_j in range(3):
            if ax_i != y_lab_loc:
                ax[ax_i][ax_j].set_xticklabels([])
            else:
                ax[ax_i][ax_j].set_xlabel(xlabels[ax_j])
    plt.suptitle(t)
    sns.despine()
    # plt.tight_layout()