In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.patches import Rectangle
from pathlib import Path
from scipy.ndimage import gaussian_filter1d

from utils import set_params
from utils import load_pickle, extract_used_data
from utils import mergeAB
from utils.config import Params

import ipywidgets as widgets
from IPython.display import display, clear_output

In [2]:
def align_track(data: dict, params: Params):
    
    
    aligned_firing = np.empty_like(data["simple_firing"],dtype=object)
    aligned_firing_std = np.empty_like(data["simple_firing"],dtype=object)
    
    for index in params.total_index_grid:
        
        fr = data["simple_firing"][index]
        if fr is None:
            continue

        fr = gaussian_filter1d(fr, sigma=params.gaussian_sigma, axis=2,
                                    mode="nearest", truncate=3.0)
        
        fr_mean = np.mean(fr, axis=1)
        fr_std = np.std(fr, axis=1) / np.sqrt(fr.shape[1])

        aligned_firing[index] = fr_mean
        aligned_firing_std[index] = fr_std
    
    aligned_zones_id = np.empty_like(data["zones"],dtype=int)
    zones = data["zones"]
    for i in range(len(zones)):
        offset = 0
        start = int(zones[i][0] / params.space_unit - offset) - 1
        end = int(zones[i][1] / params.space_unit - offset) + 1
        start = max(start, 0)
        end = min(end, params.len_track)
        aligned_zones_id[i] = [start, end]
        
    data["aligned_firing"] = aligned_firing
    data["aligned_zones_id"] = aligned_zones_id
    data["aligned_firing_std"] = aligned_firing_std

In [None]:
def plot_neuron_id(
    data: dict,
    params: Params,
    rule1: str,
    rule2: str,
    id: int,
    alpha=0.9,
    lw=1.5,
):
    color_map = {
        "pattern_1": "#1f77b4",
        "pattern_2": "#ff7f0e",
        "pattern_3": "#d62728",
        "position_1": "#1f77b4",
        "position_2": "#ff7f0e",
        "position_3": "#d62728",
    }

    zones_color = {
        0: "#1f77b4",
        2: "#ff7f0e",
        4: "#d62728",
    }

    zones = data.get("aligned_zones_id", [])

    def plot_one_rule(ax, rule: str):

        indices = params.ana_index_grid(ana_tt=[f"{rule}*"], ana_bt=["correct"])

        for idx in indices:
            fr = data["aligned_firing"][idx]
            fr_std = data["aligned_firing_std"][idx]
            if fr is None:
                continue

            neuron_fr = fr[id, :]
            neuron_fr_std = fr_std[id, :]
            X = np.arange(neuron_fr.shape[0])
            tt = params.tt[idx[0]]

            c = color_map.get(tt, "#000000")

            ax.plot(neuron_fr, color=c, alpha=alpha, linewidth=lw)
            ax.fill_between(X, neuron_fr - neuron_fr_std, neuron_fr + neuron_fr_std, color=c, alpha=0.2)

        h_frac = 0.05

        for zi, row in enumerate(zones):
            if zi not in zones_color:
                continue
            if len(row) < 2:
                continue

            start_x, end_x = float(row[0]), float(row[1])


            ax.axvline(start_x, linestyle="--", color="k", linewidth=0.8)
            ax.axvline(end_x,   linestyle="--", color="k", linewidth=0.8)

            rect = Rectangle(
                (start_x, 1.0 - h_frac),
                end_x - start_x,
                h_frac,
                transform=ax.get_xaxis_transform(),
                facecolor=zones_color[zi],
                edgecolor="none",
                clip_on=False,
                zorder=6,
            )
            ax.add_patch(rect)

        ax.set_xlabel("Position (cm)", fontsize=12)
        ax.set_ylabel("Firing rate (Hz)", fontsize=12)
        ax.tick_params(axis='both', labelsize=12)
        ax.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, pos: f'{x*params.space_unit*params.len_pos_average:g}'))
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

    fig, axes = plt.subplots(1, 2, figsize=(10, 3), sharey=True)
    plot_one_rule(axes[0], rule1)
    plot_one_rule(axes[1], rule2)

    fig.suptitle(f"Example neuron #{id}", fontsize=16, y=1.05)
    plt.tight_layout()
    plt.show()


def browse_neurons(
    data: dict,
    params: Params,
    rule1: str,
    rule2: str,
    id_min=0,
    id_max=None,
):
    n_neurons = None
    for index in params.ana_index_grid(ana_tt=[f"{rule1}*"], ana_bt=["correct"]):
        fr = data["aligned_firing"][index]
        if fr is not None:
            n_neurons = fr.shape[0]
            break
    if n_neurons is None:
        raise ValueError("No valid fr found for given rule")

    if id_max is None:
        id_max = n_neurons - 1

    slider = widgets.IntSlider(
        value=id_min, min=id_min, max=id_max, step=1,
        description='Neuron id', continuous_update=False
    )
    btn_prev = widgets.Button(description='◀ Prev', layout=widgets.Layout(width='80px'))
    btn_next = widgets.Button(description='Next ▶', layout=widgets.Layout(width='80px'))
    out = widgets.Output()

    def render(_=None):
        with out:
            clear_output(wait=True)
            plot_neuron_id(data, params, rule1, rule2, id=slider.value)

    def on_prev(_):
        slider.value = max(slider.min, slider.value - 1)

    def on_next(_):
        slider.value = min(slider.max, slider.value + 1)

    btn_prev.on_click(on_prev)
    btn_next.on_click(on_next)
    slider.observe(render, names="value")

    display(widgets.HBox([btn_prev, slider, btn_next]), out)
    render()


In [4]:
params = set_params(tt_preset='merge',
                    bt_preset='basic',
                    len_pos_average=1,
                    gaussian_sigma=50)

In [5]:
entry_dir = "../../data/flexible_shift/"

In [6]:
path = Path(entry_dir)
data_list = []
for file in path.glob("*.pkl"):
    if "RC02" in file.name:
        continue
    data = load_pickle(file)
    data = extract_used_data(data)
    mergeAB(data)
    align_track(data, params)
    data_list.append(data)

data = {}
for key in data_list[0].keys():
    if key == "aligned_firing" or key == "aligned_firing_std":
        data[key] = np.empty_like(data_list[0][key], dtype=object)
        for i in range(data[key].shape[0]):
            j = 0
            if data_list[0][key][i, j] is None:
                continue
            data[key][i, j] = np.concatenate([data[key][i, j] for data in data_list], axis=0)
    elif key == "zones":
        data[key] = data_list[0][key]
    elif key == "aligned_zones_id":
        data[key] = data_list[0][key]
    else:
        pass

In [None]:
browse_neurons(data, params, "pattern", "position")