In [None]:
# The autoreload extension allows you to tweak the code in the imported modules
# and rerun cells to reflect the changes.
%reload_ext autoreload
%autoreload 2
# this allows us to output the dictionary structure
%reload_ext ipy_dict_hierarchy
# matplotlib settings to look decent in the notebook
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# the helpers are in the `/ana/` folder.
import sys
sys.path.append("./../")
from ana import ana_helper as ah
from ana import plot_helper as ph
from ana import paper_plots as pp
from ana import ndim_helper as nh

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# setup logging for notebook
import logging
logging.basicConfig(
    format="%(asctime)s | %(levelname)-8s | %(name)-12s | %(message)s",
    datefmt="%y-%m-%d %H:%M",
)
log = logging.getLogger("notebook")
log.setLevel(logging.DEBUG)

In [None]:
# plot a connectivity overview
h5f = ah.prepare_file(f"{pp.p_sim}/lif/raw_directed/stim=0_k=1_kin=30_rate=80.0_stimrate=30.0_rep=000.hdf5")

_, ax = plt.subplots()
ph.plot_connectivity_layout(h5f, ax=ax)
ax.get_figure().savefig(f"{pp.p_fo}/sm_rev_directed_connectivity.png", transparent=True, dpi=900)


In [None]:
ndim = nh.load_ndim_h5f("./../dat/simulations/lif/processed/ndim_directed.hdf5")

In [None]:
ndim["rag_burst_core_delays"]

In [None]:
ndim["rag_burst_seqs"]

In [None]:
log.setLevel(logging.INFO)
pp._error_bar_cap_style = "round"

fig, axes = plt.subplots(
    nrows=2, figsize=(3, 4), sharex=True, gridspec_kw={"height_ratios": [1, 2]}
)


def propagation_status(seq_as_int):
    # 0 -> 1 is a successful propagation
    # the int for that is 12
    # an unseccesufl propagation is a 0
    # the int for that is 1
    success = ah.seq_to_int([0, 1])
    failure = ah.seq_to_int([0])

    if seq_as_int == success:
        return 1
    elif seq_as_int == failure:
        return 0
    else:
        return -1


for tdx, target in enumerate(ndim["rag_burst_seqs"]["stim_mods"].values):
    nd_seqs = ndim["rag_burst_seqs"].sel(stim_mods=target)

    y_l = dict()
    y_m = dict()
    y_h = dict()

    y_l_count = dict()
    y_m_count = dict()
    y_h_count = dict()

    for s in nd_seqs["stim_rate"].values:
        # each rep has an array of different length
        reps = nd_seqs.sel(stim_rate=s).squeeze()

        rep_level_prob = []
        rep_level_counts = []
        aggregated_success = 0
        aggregated_failure = 0
        aggregated_irrelevant = 0

        # for each rep, get all sequences, one for each burst
        for rdx, sequences in enumerate(reps):

            # to get the propagation probability, we need to filter
            # sequences that propagated (1), those that did not (0) and those that
            # started in accidentally in the wrong module (-1)

            prepped = np.array([propagation_status(s) for s in sequences.values[()]])
            irrelevant = prepped[prepped == -1]
            success = prepped[prepped == 1]
            failure = prepped[prepped == 0]

            aggregated_success += len(success)
            aggregated_failure += len(failure)
            aggregated_irrelevant += len(irrelevant)

            log.debug(
                f"rep {rdx} | irrelevant: {len(irrelevant)} | success: {len(success)} |"
                f" failure: {len(failure)}"
            )

            total = len(success) + len(failure)
            rep_level_prob.append(len(success) / total)
            rep_level_counts.append(total)

        # y_m[s] = np.median(rep_level_prob)
        # y_l[s] = np.quantile(rep_level_prob, 0.975)
        # y_h[s] = np.quantile(rep_level_prob, 0.025)

        y_m[s] = np.mean(rep_level_prob)
        y_l[s] = y_m[s] + np.std(rep_level_prob) / np.sqrt(len(rep_level_prob))
        y_h[s] = y_m[s] - np.std(rep_level_prob) / np.sqrt(len(rep_level_prob))

        y_m_count[s] = np.mean(rep_level_counts)
        y_l_count[s] = y_m_count[s] + np.std(rep_level_counts) / np.sqrt(
            len(rep_level_counts)
        )
        y_h_count[s] = y_m_count[s] - np.std(rep_level_counts) / np.sqrt(
            len(rep_level_counts)
        )

    pp.errorsticks(
        center=np.array(list(y_m.keys())) + 0.5 * tdx,
        mid=y_m.values(),
        thin=np.array([list(y_l.values()), list(y_h.values())]).T,
        ax=axes[1],
        color=f"C{tdx}",
        label=f"targeting module: {target.decode()}",
    )

    # number of bursts, we recorded for 5 minutes
    pp.errorsticks(
        center=np.array(list(y_m_count.keys())) + 0.5 * tdx,
        mid=y_m_count.values(),
        thin=np.array([list(y_l_count.values()), list(y_h_count.values())]).T,
        ax=axes[0],
        color=f"C{tdx}",
        label=f"_targeting module: {target.decode()}",
    )

axes[1].set_ylim(0, 1)
axes[1].set_xlabel("Stimulus rate")
axes[1].set_ylabel("Probability to propagate")
axes[1].legend()

sns.despine(ax=axes[0], bottom=True)
axes[0].tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
axes[0].set_ylabel("Number of events\nin module 0 (5 min)")

axes[0].grid(axis="y", color="0.9", linewidth=0.5)
axes[1].grid(axis="y", color="0.9", linewidth=0.5)

In [None]:
log.setLevel(logging.INFO)
pp._error_bar_cap_style = "round"
import pandas as pd

# fig, ax = plt.subplots(figsize=(4, 3))


data = pd.concat(
    [
        ndim["rag_burst_seqs"].to_dataframe(name="sequences"),
        ndim["rag_resources_mod_0"].to_dataframe(name="source_resources"),
        ndim["rag_resources_mod_1"].to_dataframe(name="target_resources"),
    ],
    axis=1,
)
data.reset_index(inplace=True)
data["stim_mods"] = data["stim_mods"].apply(lambda x: x.decode())

# create a long list of bursts, with status and resources in target, source
rows_to_add = []

for index, row in data.iterrows():

    seqs = row["sequences"][:]

    # to get the propagation probability, we need to filter
    # sequences that propagated (1), those that did not (0) and those that
    # started in accidentally in the wrong module (-1)

    prepped = np.array([propagation_status(s) for s in seqs])

    # build new rows by expanding
    num_bursts = len(seqs)
    rows_to_add.extend(
        [
            pd.DataFrame(
                {
                    "stim_mods": [row["stim_mods"]] * num_bursts,
                    "stim_rate": [row["stim_rate"]] * num_bursts,
                    "target_resources": row["target_resources"],
                    "source_resources": row["source_resources"],
                    "status": prepped,
                }
            )
        ]
    )

df = pd.concat(rows_to_add)


In [None]:
df.query("status == 0 and stim_mods == '0'")

In [None]:
this_df = df.query("status != -1")
fig, ax = plt.subplots(figsize=(5/2.54, 3 / 2.54))
hb = ax.hexbin(
    x=this_df["source_resources"],
    y=this_df["target_resources"],
    C=this_df["status"],
    gridsize=15,
    mincnt=4,
    marginals=False,
    linewidths=0.1,
    cmap="Blues",
    reduce_C_function = lambda arr: float(np.sum(arr) / len(arr)),
    vmin=0,
    vmax=1,
)
cb = fig.colorbar(hb, ax=ax)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel("Resources (left module)")
ax.set_ylabel("Resources\n(right module)")
cb.set_label("Probability\nto propagate")
ax.get_figure().savefig(f"{pp.p_fo}/sm_rev_propagation_probability.pdf", dpi=300, bbox_inches="tight")

In [None]:
ax = sns.scatterplot(
    # data=df.query("status == 1 and stim_mods == '1'"),
    data=df.query("status == -1"),
    x="source_resources",
    y="target_resources",
    hue="stim_mods",
    # alpha=0.1,
    marker="+",
    s=5,
    lw=0,
    alpha=0.2,
)
# ax.set_xlim(0, 1)
# ax.set_ylim(0, 1)
