# Plots for the section on non-rectangular PSPs (fig. 7)

In [None]:
from pathlib import Path
from functools import partial
import copy
import glob
import pickle
import multiprocessing as mp

import numpy as np
import matplotlib as mpl
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from tqdm import tqdm
from tqdm.contrib.itertools import product

from stddc import STDDMaker, PPDMaker, alpha_PSP, valpha_PSP, rect_PSP, vrect_PSP, exp_window

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
# define the style etc.
mpl.style.use("../mystyle.mpl")

In [None]:
FIG_DIR = Path("../figs")
FIG_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# (re-)run the simulations or load the simulation files:
RUNSIMULATION = True

In [None]:
## general functions:

def reshape_list(lst, shape):
    n_rows, n_cols = shape
    assert n_cols * n_rows == len(lst)
    return [lst[i * n_cols : (i + 1) * n_cols] for i in range(n_rows)]

def cm_to_inch(x, y):
    return (x / 2.54, y / 2.54)

### PSPs

In [None]:
def plot_psp(fig, ax, f_psp, tau_syn):
    TREF = 10
    tmax = 2 * TREF

    ts = np.linspace(-tmax / 10, tmax, 100)
    ax.plot(ts, f_psp(ts, TREF, TREF * tau_syn))
    ax.set_xticks([0.0, TREF], labels=("0", r"$\tau_\mathrm{ref}$"))
    return fig, ax

### STDDS

In [None]:
def plot_stdd(fig, ax, stddmakers, colors):
    for sm, c in zip(stddmakers, colors):
        ax.plot(sm.times, sm.stdd, color=c)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.set_yticks([])
    ax.spines["top"].set_visible(False)
    ax.set_xticks(
        [-sm.t_ref, 0.0, sm.t_ref], 
        labels=[r"$-\tau_\mathrm{ref}$", "0", r"$\tau_\mathrm{ref}$"],
    )    

### PPDs

In [None]:
def make_ppd_data(sm, b_range, w_range):
    all_ppds = []
    for i, (b1, b2) in enumerate(product(b_range, b_range)):
        if i % len(b_range) == 0:
            ppds = []
        print(f"STEP: {i}, b = ({b1}, {b2})")
        sm_ = copy.copy(sm)
        sm.b_1, sm.b_2 = b1, b2
        ppd = PPDMaker(sm, w_range)
        ppd.calc_sal_grid(exp_window, -1., 1., sm.t_ref, sm.t_ref);
        ppds.append(ppd)
        if i % len(b_range) == len(b_range) - 1:
            all_ppds.append(ppds)
    return all_ppds

def save_ppd_data(lst, b_range, fname):
    with open(fname, 'wb') as f:
        data = {"ppds": lst, "b_range": b_range}
        pickle.dump(data, f)

def load_ppd_data(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
        return data["ppds"], data["b_range"]

def plot_boas(fig, ax, lst_ppd, b_range, cmap):
    norm = mcolors.Normalize(vmin=min(b_range), vmax=max(b_range))
    boas = [ppd.calc_boa(upscale=100) for ppd in lst_ppd]
    for boa, b in zip(boas, b_range):
        c = cmap(norm(b))
        ax.plot(boa[:, 0], boa[:, 1], color=c)
    w_range = lst_ppd[0].w_range
    ax.plot(w_range, w_range, 'k--')
    ax.set_aspect("equal")

def discrete_cmap(cmap, vals):
    norm = mcolors.Normalize(vmin=min(vals), vmax=max(vals))
    colors = cmap(norm(vals))
    cmap_d = mcolors.ListedColormap(colors)
    return cmap_d, norm

def calc_deviation(lst_ppd, biases):
    dev = np.zeros((len(biases), len(biases)))
    for i in range(len(biases)):
        for j in range(len(biases)):
            ppd = lst_ppd[i][j]
            boa = ppd.calc_boa(upscale=100)
            dev[i, j] = ppd.deviation_of_boa(boa)
    return dev

def plot_deviation(fig, ax, dev, biases, cmap, norm):
    ax.contourf(biases, biases, dev, cmap=cmap, norm=norm)
    ax.set_aspect("equal")
    

### Create the data

In [None]:
## ppd-data:
# Define some values:
TMAX = 20
TREF = 10
W_RANGE = np.linspace(-1.5, 1.5, 10)
B_RANGE = np.array([-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0])
# B_RANGE = np.array([-1.0, -0.5, 0.0, 0.5, 1.0])
B_IDX_LOW = 3
B_IDX_MID = 5
B_IDX_HI = 7

fnames = ["rect", "alpha_short", "alpha_long"]

In [None]:
# Create STDD-data

TMAX_ = 60
TREF_ = 30
W12_STDD = 1.0
W21_STDD = 1.0

def calc_stdd(psp, t_syn, b_1, b_2):
    sm = STDDMaker(psp, TMAX_, TREF_, TREF_*t_syn, W12_STDD, W21_STDD, b_1, b_2)
    _ = sm.calc_stdd(fill_middle="smooth")
    return sm

stdds = {
    "rect": [
        calc_stdd(rect_PSP, 1.0, B_RANGE[B_IDX_LOW], B_RANGE[B_IDX_HI]),
        calc_stdd(rect_PSP, 1.0, B_RANGE[B_IDX_MID], B_RANGE[B_IDX_MID]),
        calc_stdd(rect_PSP, 1.0, B_RANGE[B_IDX_HI], B_RANGE[B_IDX_LOW]),
    ],
    "alpha_short": [
        calc_stdd(alpha_PSP, 0.33, B_RANGE[B_IDX_LOW], B_RANGE[B_IDX_HI]),
        calc_stdd(alpha_PSP, 0.33, B_RANGE[B_IDX_MID], B_RANGE[B_IDX_MID]),
        calc_stdd(alpha_PSP, 0.33, B_RANGE[B_IDX_HI], B_RANGE[B_IDX_LOW]),
    ],
    "alpha_long": [
        calc_stdd(alpha_PSP, 0.5, B_RANGE[B_IDX_LOW], B_RANGE[B_IDX_HI]),
        calc_stdd(alpha_PSP, 0.5, B_RANGE[B_IDX_MID], B_RANGE[B_IDX_MID]),
        calc_stdd(alpha_PSP, 0.5, B_RANGE[B_IDX_HI], B_RANGE[B_IDX_LOW]),
    ],
}

In [None]:
W12 = 0.0
W21 = 0.0

stddmakers = [
    STDDMaker(rect_PSP, TMAX, TREF, TREF, W12, W21, b_1=0.0, b_2=0.0),
    STDDMaker(alpha_PSP, TMAX, TREF, TREF*1./3., W12, W21, b_1=0.0, b_2=0.0),
    STDDMaker(alpha_PSP, TMAX, TREF, TREF*0.5, W12, W21, b_1=0.0, b_2=0.0),
]

def process_ppd(key, stddmaker, b_range, w_range, fname):
    ppd = make_ppd_data(stddmaker, b_range, w_range)
    print(f"WRITE {fname} TO DISK.")
    save_ppd_data(ppd, b_range, fname=fname)
    print(f"{fname} DONE.")
    return key, ppd 

if RUNSIMULATION:
    with mp.Pool() as pool:
        tasks = [(fname, stddmakers[i], B_RANGE, W_RANGE, fname + ".pickle") for i, fname in enumerate(fnames)]
        results = pool.starmap(process_ppd, tasks)
    
    ppds = dict(results)
    del results
elif "ppds" not in dir():
    ppds = {name: load_ppd_data(name + ".pickle")[0] for name in fnames}

In [None]:
## Calc the deviations of the BOA

devs = {name: calc_deviation(ppds[name], B_RANGE) for name in fnames}

dev_max = max([np.max(dev) for dev in devs.values()])

### Make the plot

In [None]:
plt.rcParams.update({
    # Label font size
    'axes.labelsize': 10,
    
    # Axis title size
    'axes.titlesize': 10,
    
    # Tick label size
    'xtick.labelsize': 8,
    'ytick.labelsize': 8,
    
    # Legend properties
    'legend.fontsize': 8,
    'legend.handlelength': 1.,  # Length of the legend handles
    'legend.handleheight': 0.5,  # Height of the legend handles
    'legend.handletextpad': 0.3,  # Padding between handle and text
    'legend.borderpad': 0.2,  # Padding between legend edge and content
    'legend.borderaxespad': 0.5,  # Padding between axes and legend edge
})


In [None]:
## Create the figure

# settings etc.
cmap_boa = plt.get_cmap("viridis")
cmap_dev = plt.get_cmap("YlOrRd")

LEFT = 0.1
RIGHT = 0.8
W_SPACE = 0.2
W_RATIOS = [1, 1, 1, 0.3]

plt.ion()
fig = plt.figure(figsize=cm_to_inch(14, 12))

# Create 4 subfigures
subfigs = fig.subfigures(4, 1, height_ratios=[0.6, .4, 1, 1.2],)

ax = []
# row 1
gs = GridSpec(1, 4, width_ratios=W_RATIOS, left=LEFT, right=RIGHT, wspace=W_SPACE, bottom=0.25, top=0.5)
ax.append([subfigs[0].add_subplot(gs[0, 0])])
ax[0].append(subfigs[0].add_subplot(gs[0, 1], sharey=ax[0][0]))
ax[0].append(subfigs[0].add_subplot(gs[0, 2], sharey=ax[0][1]))
[ax[0][i].spines['top'].set_visible(False) for i in [0, 1, 2]];
[ax[0][i].spines['right'].set_visible(False) for i in [0, 1, 2]];
[ax[0][i].set_yticklabels([]) for i in [1, 2]];
ax[0][0].set_ylabel(r"$\kappa(t)$")
ax[0][0].set_title("rect. PSP\n" + r"$\tau_\mathrm{syn} = \tau_\mathrm{ref}$")
ax[0][1].set_title(r"$\alpha$-PSP (short)"+"\n" + r"$\tau_\mathrm{syn} = 1/3 \cdot \tau_\mathrm{ref}$")
ax[0][2].set_title(r"$\alpha$-PSP (long)"+"\n" + r"$\tau_\mathrm{syn} = 1/2 \cdot \tau_\mathrm{ref}$")


# row 2
gs = GridSpec(1, 4, width_ratios=W_RATIOS, left=LEFT, right=RIGHT, wspace=W_SPACE, bottom=0.25)
ax.append([subfigs[1].add_subplot(gs[0, i]) for i in range(4)])
ax[1][3].axis('off')
ax[1][0].set_ylabel(r"$p(\Delta t)$")
# ax[1][0].set_xlabel(r"$\Delta t$")

# row 3
gs = GridSpec(1, 4, width_ratios=W_RATIOS, left=LEFT, right=RIGHT, wspace=W_SPACE, bottom=0.15)
ax.append([subfigs[2].add_subplot(gs[0, i]) for i in range(4)])
ax[2][1].set_yticklabels([])
ax[2][2].set_yticklabels([])
ax[2][0].set_ylabel(r"$W_{12}$")
[ax[2][i].set_xlabel(r"$W_{21}$") for i in range(3)];

# row 4
gs = GridSpec(1, 4, width_ratios=W_RATIOS, left=LEFT, right=RIGHT, wspace=W_SPACE, bottom=0.2)
ax.append([subfigs[3].add_subplot(gs[0, i]) for i in range(4)])
ax[3][1].set_yticklabels([])
ax[3][2].set_yticklabels([])
ax[3][0].set_ylabel(r"$b_2$")
[ax[3][i].set_xlabel(r"$b_1$") for i in range(3)];

#########################################################################
## plot the psp-shapes
plot_psp(fig, ax[0][0], vrect_PSP, 1.)
plot_psp(fig, ax[0][1], valpha_PSP, 1./3.)
plot_psp(fig, ax[0][2], valpha_PSP, 1./2.)

## plot the stdds
print(B_RANGE)
print(B_RANGE * 2.)
cmap_boa_d, norm_boa = discrete_cmap(cmap_boa, B_RANGE * 2.)
colors = cmap_boa_d([B_IDX_LOW, B_IDX_MID, B_IDX_HI])
plot_stdd(fig, ax[1][0], stdds["rect"], colors)
plot_stdd(fig, ax[1][1], stdds["alpha_short"], colors)
plot_stdd(fig, ax[1][2], stdds["alpha_long"], colors)
# legend
label_maker = lambda b_1, b_2: r"$b_1$" + f"={b_1}, " + r"$b_2$" + f"={b_2}"
labels = [label_maker(B_RANGE[B_IDX_LOW], B_RANGE[B_IDX_HI]),
          label_maker(B_RANGE[B_IDX_MID], B_RANGE[B_IDX_MID]),
          label_maker(B_RANGE[B_IDX_HI], B_RANGE[B_IDX_LOW])]
legend_elements = [Line2D([0], [0], color=c, label=l) for c, l in zip(colors, labels)
]
ax[1][3].legend(handles=legend_elements, loc='center left')


## plot the PPDs
plot_boas(fig, ax[2][0], [ppds["rect"][i][-i-1] for i in range(len(B_RANGE))], B_RANGE, cmap_boa)
plot_boas(fig, ax[2][1], [ppds["alpha_short"][i][-i-1] for i in range(len(B_RANGE))], B_RANGE, cmap_boa)
plot_boas(fig, ax[2][2], [ppds["alpha_long"][i][-i-1] for i in range(len(B_RANGE))], B_RANGE, cmap_boa)
for i in range(3):
    ax[2][i].plot([W12_STDD], [W21_STDD], c="red", marker="o", fillstyle='none')
sm = plt.cm.ScalarMappable(cmap=cmap_boa_d, norm=norm_boa)
cbar = fig.colorbar(sm, cax=ax[2][3])
cbar.set_label("diff. of biases\n"+r"$b_2 - b_1$")
legend_elements = [Line2D([0], [0], linestyle='', c="red", marker="o", fillstyle='none', label="STDDs in b")]
ax[2][0].legend(handles=legend_elements, loc="upper left")

## plot the deviations
norm_dev = mcolors.Normalize(vmin=0.0, vmax=dev_max)
plot_deviation(fig, ax[3][0], devs["rect"], B_RANGE, cmap_dev, norm_dev)
plot_deviation(fig, ax[3][1], devs["alpha_short"], B_RANGE, cmap_dev, norm_dev)
plot_deviation(fig, ax[3][2], devs["alpha_long"], B_RANGE, cmap_dev, norm_dev)
square_marker = mpl.markers.MarkerStyle("s", fillstyle="none")
for i in range(3):
    ax[3][i].scatter(B_RANGE, B_RANGE[::-1], c=cmap_boa_d(norm_boa(B_RANGE)), marker=square_marker, linewidths=1.0)
sm = plt.cm.ScalarMappable(norm_dev, cmap_dev)
cbar = subfigs[3].colorbar(sm, cax=ax[3][3])
cbar.set_label("av. rel. deviation of\n attractors from diagonal")
legend_elements = [Line2D([0], [0], linestyle='', c=cmap_boa_d(0), ls='', marker=square_marker, label="attractors in c")]
ax[3][0].legend(handles=legend_elements, loc="upper right")

fig.savefig(FIG_DIR / "psp_shapes_2.pdf")
fig.savefig(FIG_DIR / "psp_shapes_2.png")
fig.savefig(FIG_DIR / "psp_shapes_2.svg")