# Plot the results of the spiking sampling network experiemnts (fig 4 and 5)

In [None]:
import sys
from pathlib import Path
from typing import Union, Optional, Callable, Iterator, List, Tuple
from collections import defaultdict
import warnings

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.patches import Rectangle
import yaml

In [None]:
FIG_DIR = Path("../../figs/")
FIG_DIR.mkdir(parents=True, exist_ok=True)

path_init = Path("../../results/ssn/syn_noise/")
path_init_sal = Path("../results/ssn/syn_noise_sal/")

path_stdp = Path("../results/ssn/plast_noise/")
path_stdp_sal = Path("../results/ssn/plast_noise_sal/")

In [None]:
# define the style etc.
mpl.style.use("../../mystyle.mpl")

In [None]:
# choose whether to plot fig 5 or interrupt the notebook after fig 4
PRINT_SECOND = True

## Define functions

In [None]:
def data_loader(
    fname: Callable[[int], str], it: Iterator, keys: Optional[List[str]] = None
) -> dict[str, npt.NDArray]:
    data = defaultdict(list)
    data_shapes = {}
    # get data-shapes:
    for i in it:
        try:
            d = np.load(fname(i))
            for key in d.files:
                data_shapes[key] = d[key].shape
            break
        except:
            pass
    # raise Exception if not a single valid file was found!
    if not data_shapes:
        raise RuntimeError("No valid data files found in the given range!")

    keys = keys if keys is not None else data_shapes.keys()
    for i in it:
        try:
            d = np.load(fname(i))
            print(f"Load {fname(i)}")
            for key in keys:
                if d[key].shape != data_shapes[key]:
                    raise ValueError
                data[key].append(d[key])
        except:
            warnings.warn("No data found!")
            data[key].append(np.full(data_shapes[key], np.nan))

    for key in data.keys():
        data[key] = np.array(data[key])
    return data


def keep_validation(
    dat: dict[str, npt.NDArray], val: int, tmax: int = -1
) -> dict[str, npt.NDArray]:
    return {key: dat[key][:, :tmax:val] for key in dat.keys()}


def calc_stats(dat: dict[npt.NDArray]) -> dict[str, dict[float, npt.NDArray]]:
    res = {}
    for key in dat.keys():
        res[key] = {
            0.25: np.nanquantile(dat[key], 0.25, axis=0),
            0.5: np.nanquantile(dat[key], 0.5, axis=0),
            0.75: np.nanquantile(dat[key], 0.75, axis=0),
        }
    return res


def calc_diffs(
    dat1: dict[npt.NDArray], dat2: dict[npt.NDArray], key: str = "dkls"
) -> dict[str, dict[float, npt.NDArray]]:
    res = {}
    diffs = dat1[key] - dat2[key]
    res[key] = {
        0.25: np.nanquantile(diffs, 0.25, axis=0),
        0.5: np.nanquantile(diffs, 0.5, axis=0),
        0.75: np.nanquantile(diffs, 0.75, axis=0),
    }
    return res


def plot_epochs(
    ax,
    dat: dict[float, npt.NDArray],
    epochs: npt.NDArray,
    yscale: str = "log",
    label="",
    **kwargs,
):
    p = ax.plot(epochs, dat[0.5], label=label, **kwargs)
    ax.fill_between(
        epochs, y1=dat[0.25], y2=dat[0.75], color=p[0].get_color(), alpha=0.3
    )
    ax.set_yscale(yscale)
    return ax


def cm_to_inch(vals: Tuple[int | float] | List[int | float]) -> Tuple[int | float]:
    CM_PER_INCH = 2.54
    return tuple(map(lambda x: x / CM_PER_INCH, vals))


def ebar(
    ax,
    dat: dict[float, npt.NDArray],
    pos: int,
    color: str,
    label: Optional[str] = None,
    **kwargs,
):
    u_err = dat[0.5][-1] - dat[0.25][-1]
    l_err = dat[0.75][-1] - dat[0.5][-1]
    ax.errorbar(
        [pos],
        [dat[0.5][-1]],
        yerr=np.array([[u_err], [l_err]]),
        capsize=4,
        color=color,
        label=label,
        **kwargs,
    )
    return ax
    

def plot_baseline(ax, dat, color, kwargs_line={}, kwargs_fill={}):
    ax.axhline(y=dat[0.5][-1], color=color, zorder=-1, **kwargs_line)
    x = np.array([ax.get_xlim()[0], ax.get_xlim()[1]])
    ax.fill_between(
        x, y1=dat[0.25][-1], y2=dat[0.75][-1], color=color, alpha=0.3, **kwargs_fill
    )
    ax.set_xlim(x[0], x[1])


def plot_roundmarker(ax, x, y, char, color):
    circle = mpl.markers.MarkerStyle("o", fillstyle="none", ).scaled(1.)
    # string = mpl.markers.MarkerStyle(f"${char}$").scaled(0.6)

    ax.plot([x], [y], marker=circle, color=color, markersize=8)
    # ax.plot([pos[0]], [pos[1]], marker=string, color=color)
    ax.text(x, y, f"{char}", color=color, ha="center", va="center", fontsize="x-small")


def axis_breaker(ax, y_pos, breaker_width=10, breaker_dist=0.02, whitespace_size=5):
    d = .5
    kwargs = dict(marker=[(-1, -d), (1, d)], markersize=breaker_width,
                  linestyle="none", color='k', mec='k', mew=1, clip_on=False)
    ax.plot([0., 1], [y_pos-breaker_dist/2]*2, transform=ax.transAxes, **kwargs, zorder=120)
    ax.plot([0., 1], [y_pos+breaker_dist/2]*2, transform=ax.transAxes, **kwargs, zorder=110)
    ax.plot([0., 1], [y_pos]*2, transform=ax.transAxes, marker="o", color='white', clip_on=False, linestyle="none", ms=whitespace_size, zorder=100)

## Synaptic noise scenario (fig 4)

In [None]:
# load a parameter file:
with open(path_init / "exp.yaml", "r") as f:
    params = yaml.safe_load(f)

# noise levels on weight initialization --> see change_params.py
NOISE_LEVELS = [0.0, 0.2, 0.4, 0.6, 0.8]
# number of seeds per noise level:
NUM_SEEDS = 20

fname_nosal = lambda i: path_init / f"exp.{i:04d}.npz"
fname_sal = lambda i: path_init_sal / f"exp.{i:04d}.npz"

stats_sal = []
stats_nosal = []
dkl_diffs = []
asym_diffs = []


for i, _ in enumerate(NOISE_LEVELS):
    chunk = range(i * NUM_SEEDS, (i + 1) * NUM_SEEDS)
    dat_sal = data_loader(fname_sal, chunk, ["dkls", "all_asym"])
    # dat_sal = keep_validation(dat_sal, params["val_step"])
    stats_sal.append(calc_stats(dat_sal))

    dat_nosal = data_loader(fname_nosal, chunk, ["dkls", "all_asym"])
    # dat_nosal = keep_validation(dat_nosal, params["val_step"])
    stats_nosal.append(calc_stats(dat_nosal))

    dkl_diffs.append(calc_diffs(dat_sal, dat_nosal, key="dkls"))
    asym_diffs.append(calc_diffs(dat_nosal, dat_sal, key="all_asym"))

In [None]:
EPOCHS = np.arange(0, params["num_epochs"], params["val_step"])
NOISE_ID = 4

fig, ax = plt.subplots(
    2,
    2,
    sharex="col",
    figsize=cm_to_inch((16, 8)),
    squeeze=False,
    # layout="constrained"
    gridspec_kw={"top": 0.8,
                 "wspace": 0.4,
                 "right": 0.98,
                 "bottom": 0.0},
)

ax[0, 0] = plot_epochs(
    ax[0, 0], stats_nosal[0]["all_asym"], EPOCHS, label="baseline $(\sigma^\mathrm{noise}_\mathrm{init} = 0,\sigma^\mathrm{noise}_\mathrm{stdp} = 0)$", yscale="linear",
)
ax[0, 0] = plot_epochs(
    ax[0, 0], stats_nosal[NOISE_ID]["all_asym"], EPOCHS, label="noise w/o SAL", yscale="linear",
)
ax[0, 0] = plot_epochs(
    ax[0, 0], stats_sal[NOISE_ID]["all_asym"], EPOCHS, label="noise with SAL", yscale="linear",
)

ax[1, 0] = plot_epochs(
    ax[1, 0], stats_nosal[0]["dkls"], EPOCHS, label="baseline $(\sigma^\mathrm{noise}_\mathrm{init} = 0,\sigma^\mathrm{noise}_\mathrm{stdp} = 0)$"
)
ax[1, 0] = plot_epochs(
    ax[1, 0], stats_nosal[NOISE_ID]["dkls"], EPOCHS, label="noise w/o SAL"
)
ax[1, 0] = plot_epochs(
    ax[1, 0], stats_sal[NOISE_ID]["dkls"], EPOCHS, label="noise with SAL"
)

# make the broken axis indicator:
axis_breaker(ax[0, 0], 0.18, breaker_dist=0.04, breaker_width=6, whitespace_size=2.5)
axis_breaker(ax[0, 1], 0.18, breaker_dist=0.04, breaker_width=6, whitespace_size=2.5)

ax[1, 0].set_xlim(right=EPOCHS[-1] * 1.28)

LEFT = EPOCHS[-1] * 1.09
RIGHT = EPOCHS[-1] * 1.21
plot_roundmarker(ax[0, 0], LEFT, stats_nosal[0]["all_asym"][0.5][-1], char="1", color="C0")
plot_roundmarker(ax[0, 0], LEFT, stats_nosal[NOISE_ID]["all_asym"][0.5][-1], char="2", color="C1")
plot_roundmarker(ax[0, 0], LEFT, stats_sal[NOISE_ID]["all_asym"][0.5][-1], char="3", color="C2")

plot_roundmarker(ax[1, 0], LEFT, stats_nosal[0]["dkls"][0.5][-1], char="1", color="C0")
plot_roundmarker(ax[1, 0], LEFT, stats_nosal[NOISE_ID]["dkls"][0.5][-1], char="2", color="C1")
plot_roundmarker(ax[1, 0], RIGHT, stats_sal[NOISE_ID]["dkls"][0.5][-1], char="3", color="C2")


DELTA = 0.01
for i, _ in enumerate(NOISE_LEVELS):
    # VAR
    ebar(ax[0, 1], stats_sal[i]["all_asym"], NOISE_LEVELS[i] + DELTA, color="C2", marker="x")
    ebar(ax[0, 1], stats_nosal[i]["all_asym"], NOISE_LEVELS[i] - DELTA, color="C1", marker="x")
    # DKLs
    ebar(ax[1, 1], stats_sal[i]["dkls"], NOISE_LEVELS[i] + DELTA, color="C2", marker="x")
    ebar(ax[1, 1], stats_nosal[i]["dkls"], NOISE_LEVELS[i] - DELTA, color="C1", marker="x")

# plot baseline
plot_baseline(ax[0, 1], stats_nosal[0]["all_asym"], "C0", kwargs_line={"linestyle": 'solid'})
plot_baseline(ax[1, 1], stats_nosal[0]["dkls"], "C0", kwargs_line={"linestyle": 'solid'})

plot_roundmarker(ax[0, 1], 0.0, 5.e-5, char="1", color="C0")
plot_roundmarker(ax[0, 1], 0.75, 0.2, char="2", color="C1")
plot_roundmarker(ax[0, 1], 0.75, 1.e-3, char="3", color="C2")

plot_roundmarker(ax[1, 1], 0.0, 6.e-3, char="1", color="C0")
plot_roundmarker(ax[1, 1], 0.72, 1.5e-2, char="2", color="C1")
plot_roundmarker(ax[1, 1], 0.78, 6.e-3, char="3", color="C2")


VAR_YLIM = (-0.5e-4, 3)
ax[1, 0].set_xlabel("wake-sleep cycles")
ax[1, 0].ticklabel_format(style='sci', scilimits=(-3,4), axis='x')
ax[0, 0].set_ylabel(r"$\mathrm{Var}\, (W_{ij} - W_{ji})$")
ax[0, 0].set_yscale("symlog", linthresh=1e-4)
ax[0, 0].set_ylim(*VAR_YLIM)
ax[1, 0].set_ylabel(r"$D_\mathrm{KL}(p \| p^*)$")
ax[1, 0].set_ylim(0.1e-2, 2)
ax[0, 0].text(
    0.05,
    0.2,
    r"$\sigma^\mathrm{noise}_\mathrm{init}$ =" + f" {NOISE_LEVELS[NOISE_ID]}",
    transform=ax[0, 0].transAxes,
    ha="left",
    va="top",
    size=8,
    bbox=dict(facecolor=(1,1,1,0.7), 
              edgecolor='black',
              boxstyle='round'),
)
ax[1, 0].text(
    0.575,
    -0.275,
    r"$\sigma^\mathrm{noise}_\mathrm{init}$ =" + f" {NOISE_LEVELS[NOISE_ID]}",
    transform=ax[0, 0].transAxes,
    ha="left",
    va="top",
    size=8,
    bbox=dict(facecolor=(1,1,1,0.7), 
              edgecolor='black',
              boxstyle='round'),
)

ax[1, 1].set_yscale("log")
ax[1, 1].yaxis.set_minor_formatter(mpl.ticker.NullFormatter())
ax[1, 1].margins(y=0.2)
# ax[0, 1].grid()
ax[1, 1].set_xlabel(r"$\sigma^\mathrm{noise}_\mathrm{init}$")
ax[1, 1].set_ylabel(r"$D_\mathrm{KL}(p \| p^*)$")
ax[0, 1].set_ylabel(r"$\mathrm{Var}\, (W_{ij} - W_{ji})$")
ax[0, 1].set_xticks(NOISE_LEVELS)
ax[0, 1].set_yscale("symlog", linthresh=1e-4)
ax[0, 1].set_ylim(*VAR_YLIM)
ax[1, 1].set_ylim([4e-3, 2.1e-2])
# ax[1, 1].get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
# ax[1, 1].ticklabel_format(style='sci', axis='y', scilimits=(0,0))
# ax[1, 1].set_yticks([5e-3, 1e-2])

handles, labels = ax[1, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.98), ncol=3)
plt.tight_layout()

fig.savefig(FIG_DIR / "ssn_initnoise.png", bbox_inches='tight')
fig.savefig(FIG_DIR / "ssn_initnoise.pdf", bbox_inches='tight')
fig.savefig(FIG_DIR / "ssn_initnoise.svg", bbox_inches='tight')

In [None]:
assert PRINT_SECOND

## Plasticity noise scenario (fig 5)

In [None]:
# load a parameter file:
with open(path_stdp / "exp.yaml", "r") as f:
    params = yaml.safe_load(f)

# noise levels on weight initialization --> see change_params.py
NOISE_LEVELS = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
# number of seeds per noise level:
NUM_SEEDS = 20

fname_nosal = lambda i: path_stdp / f"exp.{i:04d}.npz"
fname_sal = lambda i: path_stdp_sal / f"exp.{i:04d}.npz"

stats_sal = []
stats_nosal = []

for i, _ in enumerate(NOISE_LEVELS):
    chunk = range(i * NUM_SEEDS, (i + 1) * NUM_SEEDS)
    dat_sal = data_loader(fname_sal, chunk, ["dkls", "all_asym"])
    # dat_sal = keep_validation(dat_sal, params["val_step"], tmax=-1)
    stats_sal.append(calc_stats(dat_sal))

    dat_nosal = data_loader(fname_nosal, chunk, ["dkls", "all_asym"])
    # dat_nosal = keep_validation(dat_nosal, params["val_step"], tmax=-1)
    stats_nosal.append(calc_stats(dat_nosal))

In [None]:
EPOCHS = np.arange(0, params["num_epochs"], params["val_step"])
NOISE_ID = 5

fig, ax = plt.subplots(
    2,
    2,
    sharex="col",
    figsize=cm_to_inch((16, 8)),
    squeeze=False,
    # layout="constrained"
    gridspec_kw={"top": 0.8,
                 "wspace": 0.4,
                 "right": 0.98, "bottom": 0.0},
)

ax[0, 0] = plot_epochs(
    ax[0, 0], stats_nosal[0]["all_asym"], EPOCHS, label="baseline", yscale="linear", ls="--", c='purple'
)
ax[0, 0] = plot_epochs(
    ax[0, 0], stats_nosal[NOISE_ID]["all_asym"], EPOCHS, label="noise w/o SAL", yscale="linear", ls="--", c='C1'
)
ax[0, 0] = plot_epochs(
    ax[0, 0], stats_sal[NOISE_ID]["all_asym"], EPOCHS, label="noise with SAL", yscale="linear", ls="--", c='C2'
)

ax[1, 0] = plot_epochs(
    ax[1, 0], stats_nosal[0]["dkls"], EPOCHS, label="baseline $(\sigma^\mathrm{noise}_\mathrm{init} = 0.2, \sigma^\mathrm{noise}_\mathrm{stdp} = 0)$", ls="--", c='purple'
)
ax[1, 0] = plot_epochs(
    ax[1, 0], stats_nosal[NOISE_ID]["dkls"], EPOCHS, label="noise w/o SAL", ls="--", c='C1'
)
ax[1, 0] = plot_epochs(
    ax[1, 0], stats_sal[NOISE_ID]["dkls"], EPOCHS, label="noise with SAL", ls="--", c='C2'
)

ax[1, 0].set_xlim(right=EPOCHS[-1] * 1.2)
LEFT = EPOCHS[-1] * 1.1
plot_roundmarker(ax[0, 0], LEFT, stats_nosal[0]["all_asym"][0.5][-1], char="1", color="purple")
plot_roundmarker(ax[0, 0], LEFT, stats_nosal[NOISE_ID]["all_asym"][0.5][-1], char="2", color="C1")
plot_roundmarker(ax[0, 0], LEFT, stats_sal[NOISE_ID]["all_asym"][0.5][-1], char="3", color="C2")

plot_roundmarker(ax[1, 0], LEFT, stats_nosal[0]["dkls"][0.5][-1], char="1", color="purple")
plot_roundmarker(ax[1, 0], LEFT, stats_nosal[NOISE_ID]["dkls"][0.5][-1], char="2", color="C1")
plot_roundmarker(ax[1, 0], LEFT, stats_sal[NOISE_ID]["dkls"][0.5][-1], char="3", color="C2")


DELTA = 0.001
for i, _ in enumerate(NOISE_LEVELS):
    # VAR
    ebar(ax[0, 1], stats_sal[i]["all_asym"], NOISE_LEVELS[i] + DELTA, color="C2", marker="x")
    ebar(ax[0, 1], stats_nosal[i]["all_asym"], NOISE_LEVELS[i] - DELTA, color="C1", marker="x")
    # DKLs
    ebar(ax[1, 1], stats_sal[i]["dkls"], NOISE_LEVELS[i] + DELTA, color="C2", marker="x")
    ebar(ax[1, 1], stats_nosal[i]["dkls"], NOISE_LEVELS[i] - DELTA, color="C1", marker="x")

# plot baseline
plot_baseline(ax[0, 1], stats_nosal[0]["all_asym"], "purple", kwargs_line={"linestyle": '--'})
plot_baseline(ax[1, 1], stats_nosal[0]["dkls"], "purple", kwargs_line={"linestyle": '--'})

plot_roundmarker(ax[0, 1], 0.0, 0.25, char="1", color="purple")
plot_roundmarker(ax[0, 1], 0.05, 0.5, char="2", color="C1")
plot_roundmarker(ax[0, 1], 0.05, 5.e-4, char="3", color="C2")

plot_roundmarker(ax[1, 1], 0.0, 8.e-3, char="1", color="purple")
plot_roundmarker(ax[1, 1], 0.046, 7e-2, char="2", color="C1")
plot_roundmarker(ax[1, 1], 0.048, 1e-2, char="3", color="C2")

ax[1, 0].set_xlabel("wake-sleep cycles")
ax[1, 0].ticklabel_format(style='sci', scilimits=(-3,4), axis='x')
ax[0, 0].set_ylabel(r"$\mathrm{Var}\, (W_{ij} - W_{ji})$")
ax[1, 0].set_ylabel(r"$D_\mathrm{KL}(p \| p^*)$")
# ax[0, 0].set_yscale("symlog", linthresh=0.001)
ax[0, 0].set_yscale("log")
# ax[0, 0].set_ylim(-0.0002, 8)
ax[0, 0].text(
    0.05,
    0.92,
    r"$\sigma^\mathrm{noise}_\mathrm{STDP}$ =" + f" {NOISE_LEVELS[NOISE_ID]}",
    transform=ax[0, 0].transAxes,
    ha="left",
    va="top",
    size=8,
    bbox=dict(facecolor=(1,1,1,0.7), 
              edgecolor='black',
              boxstyle='round'),
)
ax[1, 0].text(
    0.5,
    -0.275,
    r"$\sigma^\mathrm{noise}_\mathrm{STDP}$ =" + f" {NOISE_LEVELS[NOISE_ID]}",
    transform=ax[0, 0].transAxes,
    ha="left",
    va="top",
    size=8,
    bbox=dict(facecolor=(1,1,1,0.7), 
              edgecolor='black',
              boxstyle='round'),
)

ax[1, 1].set_yscale("log")
# ax[0, 1].grid()
ax[1, 1].set_xlabel(r"$\sigma^\mathrm{noise}_\mathrm{STDP}$")
ax[1, 1].set_ylabel(r"$D_\mathrm{KL}(p \| p^*)$")
ax[0, 1].set_ylabel(r"$\mathrm{Var}\, (W_{ij} - W_{ji})$")

# ax[1, 1].xaxis.set_major_locator(mpl.ticker.MultipleLocator(1))
ax[1, 1].set_xticks(NOISE_LEVELS)


# ax[0, 1].set_ylim(1e-3, 2e-2)
# Create a custom formatter
def format_func(value, tick_number):
    if tick_number % 2 == 0:  # Change 2 to n for every nth label
        return f"{value:.2f}"
    else:
        return ""


ax[0, 1].xaxis.set_major_formatter(mpl.ticker.FuncFormatter(format_func))
# ax[0, 1].set_yscale("symlog", linthresh=0.001)
ax[0, 1].set_yscale("log")
# ax[0, 1].set_ylim(-0.0005, 4)

handles, labels = ax[1, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.98), ncol=3)
plt.tight_layout()

fig.savefig(FIG_DIR / "ssn_stdpnoise.png", bbox_inches='tight')
fig.savefig(FIG_DIR / "ssn_stdpnoise.pdf", bbox_inches='tight')
fig.savefig(FIG_DIR / "ssn_stdpnoise.svg", bbox_inches='tight')