# Plot the phase plane diagram and SAL weight evolution (fig 2d and e)

In [None]:
from pathlib import Path

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.contrib.itertools import product

from neuralsampling.network import sim_poisson_neurons, rect_kernel
from neuralsampling.stdp_functions import pairbased_stdp, exp_kernel
from stddc.distr import calc_dt_distr, rect_PSP

In [None]:
# define the style etc.
mpl.style.use("../mystyle.mpl")

In [None]:
FIG_DIR = Path("../figs")
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# (re-)run the simulations or load the cached simulation result files:
RUNSIMULATION = True

## Plot the weight evolution of weight symmetrization (fig. 2d)

The two neuron system is run twice:
1. SAL is applied to both weights such they converge to their common mean.
2. SAL is applied to one weight only which converges to the other weight.

In [None]:
FIGWIDTH = 2.5

In [None]:
np.random.seed(45)

def sal_epoch(w, b, params):

    spks = sim_poisson_neurons(
        params["t_max"], rect_kernel, b, w, params["t_ref"], params["t_ref"]
    )
    kernel_args = (-1.0, 1.0, params["t_ref"], params["t_ref"])
    stdp_sal = pairbased_stdp(exp_kernel, kernel_args, spks, 2, 10)
    return stdp_sal / params["t_max"] * params["t_ref"]


def run_sal(w, b, params, lr, n_epochs):
    """SAL is applied to both weights here."""
    all_ws = [w]
    for i in tqdm(range(n_epochs)):
        dw = sal_epoch(w, b, params)
        w = w + lr * dw
        all_ws.append(w)
    return np.array(all_ws)

def run_sal2(w, b, params, lr, n_epochs):
    """SAL is applied to one weight only."""
    all_ws = [w]
    for i in tqdm(range(n_epochs)):
        dw = sal_epoch(w, b, params)
        dw[0, 1] = 0.
        w = w + lr * dw
        all_ws.append(w)
    return np.array(all_ws)

In [None]:
## Run first simulation (SAL affects both weights)

w = np.array([[0.0, 1.5], [0.5, 0.0]])
b = np.array([-0.5, -0.2])
params = {
    "t_max": 1500,
    "t_ref": 25,
}
    
lr = 0.03
n_epochs = 1500

if RUNSIMULATION:
    all_w = run_sal(w, b, params, lr, n_epochs)
    np.save("weight_evo", all_w)
else:
    all_w = np.load("weight_evo.npy")

In [None]:
## Run second simulation (SAL affects only one weight)

w = np.array([[0.0, 1.5], [0.5, 0.0]])
b = np.array([-0.5, -0.2])
params = {
    "t_max": 1500,
    "t_ref": 25,
}

lr = 0.03
n_epochs = 1500

if RUNSIMULATION:
    all_w2 = run_sal2(w, b, params, lr, n_epochs)
    np.save("weight_evo2", all_w2)
else:
    all_w2 = np.load("weight_evo2.npy")

In [None]:
## Plot the time evolultion of the weights in both cases.

times = np.linspace(0, 1, all_w.shape[0]) * params["t_max"] / params["t_ref"] * n_epochs * 0.01 
fig, ax = plt.subplots(1, 1, figsize=(FIGWIDTH, FIGWIDTH / 1.8), layout="constrained")
steps = np.s_[::10]
ax.plot(times[steps], all_w[steps, 1, 0], color="darkred")
ax.plot(times[steps], all_w[steps, 0, 1], color="red", linestyle="--")
ax.plot(times[steps], all_w2[steps, 1, 0], color="dimgray", label=r"$W_{ji}$")
ax.plot(times[steps], all_w2[steps, 0, 1], color="darkgray", linestyle="--", label=r"$W_{ij}$")
ax.set_xlabel(r"$t$ [s]")
ax.set_ylabel(r"$W$ [s]")
# ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.)
fig.legend(loc="outside upper center", ncols=2)
fig.savefig(FIG_DIR / "sal_evol.png", bbox_inches='tight')
fig.savefig(FIG_DIR / "sal_evol.pdf", bbox_inches='tight')
fig.savefig(FIG_DIR / "sal_evol.svg", bbox_inches='tight')

## Plot the phase plane diagram in fig. 2e

In [None]:
def exp_kernel_2(dt, a_plus, a_minus, tau_plus, tau_minus):
    """Exponential STDP window for SAL."""
    if dt > 0.0:
        return a_plus * np.exp(-dt / tau_plus)
    elif dt < 0.0:
        return a_minus * np.exp(dt / tau_minus)
    else:
        return 0.0


vexp_kernel = np.vectorize(exp_kernel_2)


def plot_ppd(fig, ax, range_weights, dw_12, dw_21, color):
    """Plotting routine for phase plane diagrams
    Takes the Delta weights (2-dim array) and the weight range (1-dim array)
    and plots the arrow field and the update magnitude as color map. Marks also
    all attractors and repellors.
    """
    ax.grid()
    abs_stdp = np.sqrt(dw_12**2 + dw_21**2)

    ax.set_aspect("equal")
    cm = plt.get_cmap("GnBu")

    im = ax.contourf(
        range_weights,
        range_weights,
        abs_stdp / np.max(abs_stdp),
        15,
        alpha=0.9,
        cmap=cm,
    )
    fig.colorbar(
        im,
        label=r"$\sqrt{\Delta W_{ij}^2 + \Delta W_{ji}^2}$",
        ax=ax,
        ticks=[0.0, 0.5, 1.0],
    )

    quiv = ax.quiver(
        range_weights[::2],
        range_weights[::2],
        dw_12[::2, ::2],
        dw_21[::2, ::2],
        pivot="mid",
    )

    # ax.contour(range_weights, range_weights, dw_12, [0.], colors='red')
    ax.contour(range_weights, range_weights, dw_21, [0.0], colors=color)

    ax.set_xlabel(r"$W_{ij}$")
    ax.set_ylabel(r"$W_{ji}$")
    # ax.plot([range_weights[0], range_weights[-1]],
    #         [range_weights[0], range_weights[-1]], 'k--')

    ax.set_xticks(ax.get_yticks())

    return fig, ax

In [None]:
# create the stdds on a grid:

TREF = 25
TMAX = 2 * TREF

w_range = np.arange(-2, 2.1, 0.25)
n_range = len(w_range)


if RUNSIMULATION:
    stdd_grid = np.empty((n_range, n_range, 2 * TMAX - 1))
    for i, j in product(range(n_range), range(n_range)):
        print("W_12 = ", w_range[i], ", W_21 = ", w_range[j])
        stdd = calc_dt_distr(
            rect_PSP, TREF, TMAX, w_range[i], w_range[j], 0.0, 0.0, TREF
        )
        stdd_grid[i, j] = stdd
    np.save("stdd_grid", stdd_grid)
else:
    stdd_grid = np.load("stdd_grid.npy")

In [None]:
# apply the sal-rule
ts = np.arange(-TMAX + 1, TMAX, dtype=float)
stdp_kernel = vexp_kernel(ts, -1.0, 1.0, TREF, TREF)

sal_12 = np.sum(stdd_grid * stdp_kernel[None, None, :], axis=2)
sal_21 = np.sum(stdd_grid * stdp_kernel[None, None, ::-1], axis=2)

In [None]:
## plot the figure:

fig, ax = plt.subplots(1, 1, figsize=(FIGWIDTH, FIGWIDTH * 0.7), layout="constrained")

# plot the PPD itself.
fig, ax = plot_ppd(fig, ax, w_range, sal_12, sal_21, "blue")

# add the traces of the simulation above:
ax.plot(all_w[:, 0, 1], all_w[:, 1, 0], color="darkred", linestyle="--", label="example trajectory in d)")
ax.plot(all_w[:1, 0, 1], all_w[:1, 1, 0], marker="o", color="darkred")
ax.plot(all_w2[:, 0, 1], all_w2[:, 1, 0], color="dimgray", linestyle="--", label="example trajectory in d)")
ax.plot(all_w2[:1, 0, 1], all_w2[:1, 1, 0], marker="s", ms=3, color="dimgray")

fig.savefig(FIG_DIR / "sal_ppd.pdf", bbox_inches='tight')
fig.savefig(FIG_DIR / "sal_ppd.svg", bbox_inches='tight')
fig.savefig(FIG_DIR / "sal_ppd.png", bbox_inches='tight')