# Plot the results of the spiking microcircuits experiments (fig 6)

In [None]:
from pathlib import Path
import pickle
import yaml
from collections import defaultdict

from tqdm import tqdm
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

In [None]:
# plt.rcParams.update({"text.usetex": True, "font.family": "serif"})
# define the style etc.
mpl.style.use("../../mystyle.mpl")

In [None]:
FIG_DIR = Path("../../figs/")
FIG_DIR.mkdir(parents=True, exist_ok=True)

STDPAL_PATH = Path("../../results/microcircuits/sal")
FA_PATH = Path("../../results/microcircuits/fa")
BACKPROP_PATH = Path("../../results/microcircuits/bp")

## Load data


In [None]:
# choose the runs:
num_seeds = 20

start_ids = [0, 0, 0]
KEYS = ["stdpal", "backprop", "fa"]

val_loss = defaultdict(list)
w10 = defaultdict(list)
w21 = defaultdict(list)
b12 = defaultdict(list)

fname_stdpal = lambda i: STDPAL_PATH / f"student.{i:04d}.pickle"
fname_backprop = lambda i: BACKPROP_PATH / f"student.{i:04d}.pickle"
fname_fa = lambda i: FA_PATH / f"student.{i:04d}.pickle"

with open(fname_stdpal(0), "rb") as f:
    data = pickle.load(f)

    val_mask = data["recordings"][2]["validation"].flatten().astype(np.bool_)
    symm_mask = data["recordings"][2]["symmetrization"].flatten().astype(np.bool_)
    val_times = data["validation"]["times"]

fname_yaml = "sal.0000.yaml"
with open(STDPAL_PATH / fname_yaml) as f:
    yaml_data = yaml.safe_load(f)
    epoch_len = (
        yaml_data["student_simulation_settings"]["len_epoch"]
        + yaml_data["student_simulation_settings"]["len_validation"]
    )


def dataloader(key, fname):
    with open(fname, "rb") as f:
        data = pickle.load(f)

    val_loss[key].append(data["validation"]["loss"])
    w10[key].append(data["recordings"][1]["w_up"][::epoch_len, 0, 0])
    w21[key].append(data["recordings"][2]["w_up"][::epoch_len, 0, 0])
    b12[key].append(data["recordings"][1]["w_down"][::epoch_len, 0, 0])


for i in tqdm(range(num_seeds)):
    dataloader("stdpal", fname_stdpal(i + start_ids[0]))
    dataloader("backprop", fname_backprop(i + start_ids[0]))
    dataloader("fa", fname_fa(i + start_ids[0]))

for key in KEYS:
    val_loss[key] = np.array(val_loss[key])
    w10[key] = np.array(w10[key])
    w21[key] = np.array(w21[key])
    b12[key] = np.array(b12[key])

## Compute means etc

In [None]:
def smoother(arr, window_len):
    window_view = sliding_window_view(arr, (1, window_len))
    return np.mean(window_view, axis=3).squeeze()


for arr in [val_loss, w10, w21, b12]:
    for key in KEYS:
        arr[key + "_mean"] = np.mean(arr[key], axis=0)
        arr[key + "_std"] = np.std(arr[key], axis=0)
        arr[key + "_median"] = np.median(arr[key], axis=0)
        arr[key + "_q1"] = np.percentile(arr[key], 25, axis=0)
        arr[key + "_q3"] = np.percentile(arr[key], 75, axis=0)
        arr[key + "_min"] = np.min(arr[key], axis=0)
        arr[key + "_max"] = np.max(arr[key], axis=0)
        smoothed_arr = smoother(arr[key], 5)
        arr[key + "_smoothed_median"] = np.median(smoothed_arr, axis=0)
        arr[key + "_smoothed_q1"] = np.percentile(smoothed_arr, 25, axis=0)
        arr[key + "_smoothed_q3"] = np.percentile(smoothed_arr, 75, axis=0)

## Define plotting functions

In [None]:
def plot_mean(ax, ts, mean, lower_lim, upper_lim, param_dict, sl=...):
    line = ax.plot(
        ts[sl], mean[sl], color=param_dict["color"], label=param_dict["label"]
    )
    ax.fill_between(
        ts[sl],
        mean[sl] - lower_lim[sl],
        mean[sl] + upper_lim[sl],
        color=param_dict["color"],
        alpha=param_dict["alpha"],
    )
    # ax.set_yscale("log")
    return line


def plot_median(ax, ts, median, q1, q3, param_dict, sl=...):
    line = ax.plot(
        ts[sl], median[sl], color=param_dict["color"], label=param_dict["label"]
    )
    ax.fill_between(
        ts[sl], q1[sl], q3[sl], color=param_dict["color"], alpha=param_dict["alpha"]
    )
    if "fb_label" in param_dict:
        ax.fill_between(
            [],
            [],
            color=param_dict["fb_color"],
            alpha=param_dict["alpha"],
            label=param_dict["fb_label"],
        )
    # ax.set_yscale("log")
    return line


def plot_val(ax, key_tails, sl):
    epochs = np.arange(len(val_loss["stdpal_" + key_tails[0]]))
    plot_median(
        ax,
        epochs,
        val_loss["stdpal_" + key_tails[0]],
        val_loss["stdpal_" + key_tails[1]],
        val_loss["stdpal_" + key_tails[2]],
        {"color": "C0", "label": "SAL", "alpha": 0.3},
        sl=sl,
    )
    plot_median(
        ax,
        epochs,
        val_loss["backprop_" + key_tails[0]],
        val_loss["backprop_" + key_tails[1]],
        val_loss["backprop_" + key_tails[2]],
        {"color": "C1", "label": "BP", "alpha": 0.3},
        sl=sl,
    )
    plot_median(
        ax,
        epochs,
        val_loss["fa_" + key_tails[0]],
        val_loss["fa_" + key_tails[1]],
        val_loss["fa_" + key_tails[2]],
        {"color": "C2", "label": "FA", "fb_color": "gray", "alpha": 0.3},
        sl=sl,
    )


boxwidth = 0.5


def plot_val_distr(ax, data_stdpal, data_bp, data_fa):
    sns.swarmplot(data=[data_stdpal, data_bp, data_fa], ax=ax, size=4)
    bplots = ax.boxplot(
        [data_stdpal, data_bp, data_fa],
        positions=[0, 1, 2],
        showfliers=False,
        patch_artist=True,
        widths=boxwidth,
    )
    colors = ["C0", "C1", "C2"]
    for patch, color in zip(bplots["boxes"], colors):
        print(type(patch))
        patch.set_facecolor(color)
        patch.set_alpha(0.3)
    ax.boxplot(
        [data_stdpal, data_bp, data_fa],
        positions=[0, 1, 2],
        showfliers=False,
        zorder=1,
        widths=boxwidth,
    )


def plot_w_distr(ax, data_stdpal, data_bp, data_fa):
    sns.stripplot(
        data=[data_stdpal, data_bp, data_fa],
        ax=ax,
        size=4,
        jitter=0.15,
    )
    bplots = ax.boxplot(
        [data_stdpal, data_bp, data_fa],
        positions=[0, 1, 2],
        showfliers=False,
        patch_artist=True,
        widths=boxwidth,
    )
    colors = ["C0", "C1", "C2"]
    for patch, color in zip(bplots["boxes"], colors):
        print(type(patch))
        patch.set_facecolor(color)
        patch.set_alpha(0.3)
    ax.boxplot(
        [data_stdpal, data_bp, data_fa],
        positions=[0, 1, 2],
        showfliers=False,
        zorder=1,
        widths=boxwidth,
    )

## Plot left panel (fig. 6b)

In [None]:
## 1. Plot the traces

gridspecs = {
    "left": 0.17,
    "right": 0.67,
    "top": 0.97,
    "bottom": 0.15,
    "hspace": 0.03,
}
legendparams = {
    "fontsize": 10,
    "labelspacing": 0.2,
    "borderpad": 0.4,
    "handlelength": 1.5,
    "framealpha": 0.5,
}

nrn_id = 11

fig, ax = plt.subplots(
    2, 1, sharex=False, squeeze=False, figsize=(3, 3.5), layout="constrained"
)

ax[0, 0].set_yscale("log")
plot_val(ax[0, 0], ["median", "q1", "q3"], sl=np.s_[::4])
ax[0, 0].legend(**legendparams)
ax[0, 0].set_ylabel("validation loss")

sl = np.s_[:]
epochs = np.arange(len(w21["stdpal_mean"]))
# plot_b(ax[1, 0], ["mean", "min", "max"], sl=sl)
ax[1, 0].plot(
    epochs[sl],
    w10["stdpal"][nrn_id, sl],
    color="tab:cyan",
    linestyle="-",
    label=r"$W_{10}$",
    lw=2,
)
ax[1, 0].plot(
    epochs[sl],
    w21["stdpal"][nrn_id, sl],
    color="tab:olive",
    linestyle=":",
    label=r"$W_{21}$",
    lw=2,
)
ax[1, 0].plot(
    epochs[sl],
    b12["stdpal"][nrn_id, sl],
    color="tab:red",
    linestyle="--",
    label=r"$B_{12}$",
    lw=2,
)
ax[1, 0].axhline(
    2.0, color="darkgray", linestyle="--", label="teacher\n$W_{10}$, $W_{21}$"
)
ax[1, 0].set_ylabel(r"weight [a.U.]")
left_lim, _ = ax[1, 0].get_xlim()
ax[1, 0].set_xlim(left_lim, 50)
ax[1, 0].legend(**legendparams)

ax[1, 0].set_xlabel("epochs")

fig.savefig(FIG_DIR / "mc_epochs.png", bbox_inches='tight')
fig.savefig(FIG_DIR / "mc_epochs.pdf", bbox_inches='tight')
fig.savefig(FIG_DIR / "mc_epochs.svg", bbox_inches='tight')

## Plot right panel (fig 6c)    

In [None]:
## plot the final distrutions:


def mytitle(ax, text):
    return
    ax.text(0.05, 0.9, text, transform=ax.transAxes, fontsize=9.0)


fig, ax = plt.subplots(
    2,
    2,
    sharey=False,
    sharex=True,
    squeeze=False,
    figsize=(3.5, 3.5),
    layout="constrained",
)

ax[0, 0].set_yscale("log")
plot_val_distr(
    ax[0, 0],
    val_loss["stdpal"][:, -1],
    val_loss["backprop"][:, -1],
    val_loss["fa"][:, -1],
)
ax[0, 0].set_ylabel("validation loss")
mytitle(ax[0, 0], "val. loss")

mytitle(ax[0, 1], "hidden weight")
plot_w_distr(
    ax[0, 1],
    w10["stdpal"][:, -1],
    w10["backprop"][:, -1],
    w10["fa"][:, -1],
)
ax[0, 1].set_ylabel(r"$W_{10}$")

ax[0, 1].axhline(2.0, color="darkgray", linestyle="--")
plot_w_distr(
    ax[1, 0],
    w21["stdpal"][:, -1],
    w21["backprop"][:, -1],
    w21["fa"][:, -1],
)

mytitle(ax[1, 0], "output weight")
ax[1, 0].set_ylabel(r"$W_{21}$")
ax[1, 0].axhline(2.0, color="darkgray", linestyle="--")

mytitle(ax[1, 1], "topdown weight")
plot_w_distr(
    ax[1, 1],
    b12["stdpal"][:, -1],
    b12["backprop"][:, -1],
    b12["fa"][:, -1],
)
ax[1, 1].axhline(2.0, color="darkgray", linestyle="--", label="teacher weight")
ax[1, 1].set_ylabel(r"$B_{12}$")

_ = ax[1, 0].set_xticklabels("")

# make a legend
labels = [
    (mpatches.Patch(color="C0"), "SAL"),
    (mpatches.Patch(color="C1"), "BP"),
    (mpatches.Patch(color="C2"), "FA"),
    (mlines.Line2D([], [], color="darkgray", linestyle="--"), "teacher weight"),
]

fig.legend(
    *zip(*labels),
    ncols=4,
    columnspacing=1.5,
    handlelength=1.5,
    handletextpad=0.3,
    loc="outside upper right",
)

fig.savefig(FIG_DIR / "mc_distr.png", bbox_inches='tight')
fig.savefig(FIG_DIR / "mc_distr.pdf", bbox_inches='tight')
fig.savefig(FIG_DIR / "mc_distr.svg", bbox_inches='tight')