# Plots for spike-timing distributions in figure 2

Produces the spike timing difference distributions in fig 2c in https://arxiv.org/abs/2503.02642

In [None]:
from typing import Union, Optional, Callable, Iterator, List, Tuple
from pathlib import Path

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from stddc.distr import calc_dt_distr, rect_PSP
from neuralsampling.network import rect_kernel, NeuralSampler, logistic
from neuralsampling.stdp_functions import get_first_order_stds_2nrn

In [None]:
# define the style etc.
mpl.style.use("../mystyle.mpl")

In [None]:
FIG_DIR = Path("../figs")
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# (re-)run the simulations or load the simulation files:
RUNSIMULATION = True

# Plot STDD

In [None]:
def cm_to_inch(vals: Tuple[int | float] | List[int | float]) -> Tuple[int | float]:
    CM_PER_INCH = 2.54
    return tuple(map(lambda x: x / CM_PER_INCH, vals))

## excitatory

In [None]:
w_12 = 0.4
w_21 = 0.8
b_1 = -0.5
b_2 = -0.5
t_ref = 50

In [None]:
t_max = 2 * t_ref + 2

if RUNSIMULATION:
    stdd_exc_1 = calc_dt_distr(
        rect_PSP,
        tau_ref=t_ref,
        t_max=t_max,
        w_12=w_12,
        w_21=w_21,
        b_1=b_1,
        b_2=b_2,
        tau_syn=t_ref,
    )
    stdd_exc_2 = calc_dt_distr(
        rect_PSP,
        tau_ref=t_ref,
        t_max=t_max,
        w_12=(w_12 + w_21) / 2,
        w_21=(w_12 + w_21) / 2,
        b_1=b_1,
        b_2=b_2,
        tau_syn=t_ref,
    )
    np.savez("stdd_exc", stdd_1=stdd_exc_1, stdd_2=stdd_exc_2)
else:
    data = np.load("stdd_exc.npz")
    stdd_exc_1 = data["stdd_1"]
    stdd_exc_2 = data["stdd_2"]

## inhibitory

In [None]:
w_12 = -0.4
w_21 = -0.8
b_1 = -0.5
b_2 = -0.5
t_ref = 50

In [None]:
t_max = 2 * t_ref + 2

if RUNSIMULATION:
    stdd_inh_1 = calc_dt_distr(
        rect_PSP,
        tau_ref=t_ref,
        t_max=t_max,
        w_12=w_12,
        w_21=w_21,
        b_1=b_1,
        b_2=b_2,
        tau_syn=t_ref,
    )
    stdd_inh_2 = calc_dt_distr(
        rect_PSP,
        tau_ref=t_ref,
        t_max=t_max,
        w_12=(w_12 + w_21) / 2,
        w_21=(w_12 + w_21) / 2,
        b_1=b_1,
        b_2=b_2,
        tau_syn=t_ref,
    )
    np.savez("stdd_inh", stdd_1=stdd_inh_1, stdd_2=stdd_inh_2)
else:
    data = np.load("stdd_inh.npz")
    stdd_inh_1 = data["stdd_1"]
    stdd_inh_2 = data["stdd_2"]

In [None]:
FIGWIDTH = 2.3

fig, ax = plt.subplots(
    3,
    1,
    figsize=(FIGWIDTH, 3.5),
    gridspec_kw={
        "top": 0.96,
        "left": 0.12,
        "right": 0.90,
        "bottom": 0.07,
        "height_ratios": [2, 1.6, 1.2],
        "hspace": 0.3,
    },
)
# fig = plt.figure()

# ax1 = fig.add_subplot(111, axes_class=AA.Axes)

cright = "#88409c"
cright_light = "#8c6ab1"
cleft = "#3787c0"
cleft_light = "#abd0e6"
LINEWIDTH = 1.2

#######################
### PLOT EXCITATORY ###
#######################

# plot right side
x = np.arange(0, t_ref * 2)
y = stdd_exc_1[t_max:-1]
ax[0].plot(x, y, color=cright, lw=LINEWIDTH)
ax[0].fill_between(x, y, color=cright, alpha=0.5)

# plot left side
x = np.arange(-t_ref * 2 + 1, 1)
y = stdd_exc_1[1 : t_max - 1]
ax[0].plot(x, y, color=cleft, lw=LINEWIDTH)
ax[0].fill_between(x, y, color=cleft_light)

# plot right side
x = np.arange(0, t_ref * 2)
y = stdd_exc_2[t_max:-1]
ax[0].plot(x, y, "--k", lw=LINEWIDTH)

# plot left side
x = np.arange(-t_ref * 2 + 1, 1)
y = stdd_exc_2[1 : t_max - 1]
ax[0].plot(x, y, "--k", lw=LINEWIDTH)

ax[0].set_xticks([-t_ref, 0.0, t_ref], labels=[""] * 3)
ax[0].set_ylabel("$p(\Delta t_{ji})$        ")

ax[0].spines["right"].set_visible(False)
ax[0].spines["left"].set_visible(False)
ax[0].set_yticks([])
ax[0].spines["top"].set_visible(False)

LEGEND_MARGIN = 0.7
ax[0].text(0.55, LEGEND_MARGIN, "causal", transform=ax[0].transAxes, ha="left")
ax[0].text(0.45, LEGEND_MARGIN, "anti-causal", transform=ax[0].transAxes, ha="right")
_, upper = ax[0].get_ylim()
upper /= LEGEND_MARGIN
ax[0].set_ylim(0, upper)

ax[0].text(
    1.1,
    -0.05,
    r"$W_{ji} > W_{ij} > 0$",
    transform=ax[0].transAxes,
    ha="right",
    va="bottom",
    rotation="vertical",
)

# add a legend:
legend_elems = [
    mpl.lines.Line2D(
        [0], [0], lw=LINEWIDTH, color="gray", label=r"$w_{ji} \neq w_{ij}$"
    ),
    mpl.lines.Line2D(
        [0],
        [0],
        lw=LINEWIDTH,
        color="black",
        linestyle="--",
        label=r"$w_{ji} = w_{ij}$",
    ),
]
fig.legend(
    handles=legend_elems,
    loc="upper center",
    ncols=2,
)

#######################
### PLOT INHIBITORY ###
#######################

# plot right side
x = np.arange(0, t_ref * 2)
y = stdd_inh_1[t_max:-1]
ax[1].plot(x, y, color=cright, lw=LINEWIDTH)
ax[1].fill_between(x, y, color=cright, alpha=0.5)

# plot left side
x = np.arange(-t_ref * 2 + 1, 1)
y = stdd_inh_1[1 : t_max - 1]
ax[1].plot(x, y, color=cleft, lw=LINEWIDTH)
ax[1].fill_between(x, y, color=cleft_light)

# plot right side
x = np.arange(0, t_ref * 2)
y = stdd_inh_2[t_max:-1]
ax[1].plot(x, y, "--k", lw=LINEWIDTH)

# plot left side
x = np.arange(-t_ref * 2 + 1, 1)
y = stdd_inh_2[1 : t_max - 1]
ax[1].plot(x, y, "--k", lw=LINEWIDTH)

ax[0].set_ylim(0, upper)
ax[1].set_xticks([-t_ref, 0.0, t_ref], labels=[""] * 3)
ax[1].set_ylabel("$p(\Delta t_{ji})$        ")

ax[1].spines["right"].set_visible(False)
ax[1].spines["left"].set_visible(False)
ax[1].set_yticks([])
ax[1].spines["top"].set_visible(False)

ax[1].text(
    1.1,
    -0.05,
    r"$W_{ji} < W_{ij} < 0$",
    transform=ax[1].transAxes,
    ha="right",
    va="bottom",
    rotation="vertical",
)

ax[1].set_xticks(
    [-t_ref, 0.0, t_ref], labels=[r"$-\tau_\mathrm{ref}$", "0", r"$\tau_\mathrm{ref}$"]
)

########################
### PLOT STDP WINDOW ###
########################

dtl = np.linspace(-2 * t_ref, 0)
dtr = np.linspace(0, 2 * t_ref)
stdpl = np.exp(dtl / t_ref)
stdpr = -np.exp(-dtr / t_ref)
ax[2].plot(dtl, stdpl, c="gray")
ax[2].plot(dtr, stdpr, c="gray")
ax[2].spines["left"].set_position("center")
ax[2].spines["right"].set_visible(False)
ax[2].spines["bottom"].set_position("center")
ax[2].spines["top"].set_visible(False)
ax[2].set_yticks([])
ax[2].set_xticks([-t_ref, 0.0, t_ref], labels=[""] * 3)
ax[2].set_xlabel(r"$\Delta t_{ji} = t_j - t_i$", labelpad=20)

ax[2].text(
    -0.05,
    0.5,
    "$\Delta W_{ji}(\Delta t)$",
    transform=ax[2].transAxes,
    ha="center",
    rotation="vertical",
    rotation_mode="anchor",
    va="baseline",
)

fig.savefig(FIG_DIR / "stdd.svg")
fig.savefig(FIG_DIR / "stdd.pdf")
fig.savefig(FIG_DIR / "stdd.png")