In [None]:
%matplotlib widget

In [None]:
import lotr.plotting as pltltr
import matplotlib.pyplot as plt
import numpy as np
from numba import njit

In [None]:
def get_cut_sinewave(n_pts, pos, width):
    """Generate a cosine positive bump of given position and width.
    """
    assert width % 2 == 0, "width must be even!"
    midpoint = n_pts // 2  # midpoint of output array

    output_array = np.zeros(n_pts)
    output_array[midpoint - width // 2 : midpoint + width // 2] = np.sin(
        np.linspace(0, np.pi, width)
    )

    return np.roll(output_array, pos - midpoint)  # move bump to position

## Initialize weigths

In [None]:
def get_weigths_matrix(n_units, width):
    """Create connectivity matrix shifting a profile.
    """
    weigths = np.zeros((n_units, n_units))
    for i in range(n_units):
        weigths[i, :] = get_cut_sinewave(n_units, i - n_units // 2, width * 2)

    # Ensure symmetry (avoid being off for small numerical errors)
    weigths = (weigths + weigths.T) / 2

    # Normalize to have sum 1 and invert:
    weigths /= np.sum(weigths)
    weigths = -weigths

    return weigths

In [None]:
N_UNITS = 100
IN_WIDTH_RAD = 5 * 2 * np.pi / 100  # width of the contralateral inhibition, in radiants

# alphas = np.linspace(0, np.pi * 2, N_NEURONS)  # angles of neurons
width = 20

weigths = get_weigths_matrix(N_UNITS, width)  # [::-1, :]

plot_clims = np.abs(weigths.min())
f, w_ax = plt.subplots(figsize=(2, 2), gridspec_kw=dict(left=0.2, bottom=0.2, top=0.7, right=0.7))
wplot = w_ax.imshow(
        weigths, aspect="auto", cmap="Blues_r", vmin=weigths.min(), vmax=0
    )  # , #aspect="auto", vmax=10)

cbar = pltltr.add_cbar(
    wplot,
    w_ax,
    (0.55, 1.2, 0.5, 0.1),
    orientation="horizontal",
    ticks=[-0.0007, 0],
    ticklabels=["min", 0],
    labelsize=6,
    title="Weight",
    titlesize=8,
)
w_ax.set(ylabel="Unit n.", xlabel="Unit n.", xticks=[0, 50], yticks=[0, 50])
cbar.ax.xaxis.set_ticks_position("top")
plt.show()
plt.savefig("/Users/vigji/Desktop/connmat.pdf")

## Initialization of the network

In [None]:
def initialize_network(
    n_steps, n_neurons, bump_width, bump_idx, bump_amp, base_activation
):
    traces = np.zeros((N_UNITS, SIMULATION_STEPS))
    traces[:, 0] = BASELINE_ACTIVATION

    # Put a bump somewhere:
    traces[:, 0] += get_cut_sinewave(N_UNITS, bump_idx, bump_width)

SIMULATION_STEPS = 2000
BASELINE_ACTIVATION = 0.1  # small amount of activation for all neurons
BUMP_WIDTH_RAD = 10 * 2 * np.pi / 100  # width of the bump, in radiants
BUMP_POSITION_RAD = 25 * 2 * np.pi / 100  # initial position of the bump, in radiants
# BUMP_PEAK_AMP = 100  # amplitude of the bump

bump_width = int((BUMP_WIDTH_RAD / (2 * np.pi)) * N_UNITS)
bump_idx = int((BUMP_POSITION_RAD / (2 * np.pi)) * N_UNITS)


plt.figure()
plt.plot(traces[:, 0])
plt.ylabel("Activity at time=0")
plt.xlabel("Neuron number")
plt.show()

## Simulate the network
Model simulation. We model activity with the following assumptions:
 - there is a constant drive E, that reflects either intrinsic depolarizaion or some fixed global excitation (`FIXED_EXCITATION`)
 - network activity is normalized to some maximum (`TOTAL_ACTIVATION`) at every timestep not to diverge 


 First, we observe that the the width of the bump is always $\pi$, regardless of the weight matrix width. 

In [None]:
@njit
def evolve_traces(traces, weigths, fixed_excitation, weigths_coef, total_activation):
    for t in range(1, traces.shape[1]):
        for i in range(N_UNITS):
            traces[i, t] = (
                traces[i, t - 1]
                + weigths_coef * np.sum(traces[:, t - 1] * weigths[i, :])
                + fixed_excitation
            )

            if traces[i, t] < 0:
                traces[i, t] = 0

        traces[:, t] = TOTAL_ACTIVATION * traces[:, t] / np.sum(traces[:, t])

In [None]:
FIXED_EXCITATION = 1
WEIGTHS_COEF = 20
TOTAL_ACTIVATION = 500
NOISE_SIGMA = 0.1
NOISE_TIMESCALE = 10

SIMULATION_STEPS = 70

In [None]:
plt.close("all")

traces = np.zeros((N_UNITS, SIMULATION_STEPS))


# traces[50:60, 0] = 10
np.random.seed(
    561
)  # results are consistent with anu number, but ensure reproducivility of figs
traces[:, 0] = np.abs(np.random.randn(N_UNITS))  # BASELINE_ACTIVATION

widths = [20]
f, axs = plt.subplots(1,2,
    figsize=(6, 2),
    gridspec_kw=dict(bottom=0.25, left=0.2, top=0.7, width_ratios=[0.3, 1], hspace=0.6),
    sharey=True,  # sharex=True
)


for i, width in enumerate(widths):
    w_ax, traces_ax = axs
    weigths = get_weigths_matrix(N_UNITS, width)  # [::-1, :]  # * 10
    evolve_traces(traces, weigths, FIXED_EXCITATION, WEIGTHS_COEF, TOTAL_ACTIVATION)
    noise_arr = np.random.randn(N_UNITS)
    # traces += noise_arr
    
    l = np.abs(weigths).max()
    wplot = w_ax.imshow(
        weigths, aspect="auto", cmap="Blues_r", vmin=weigths.min(), vmax=0
    )  # , #aspect="auto", vmax=10)
    tplot = traces_ax.imshow(
        traces, aspect="auto", cmap="gray"
    )  # , #aspect="auto", vmax=10)
    if i < len(widths) - 1:
        traces_ax.set(xticklabels=[])
        w_ax.set(xticklabels=[])
    else:
        traces_ax.set(xlabel="Pseudotime")
        w_ax.set(xlabel="Unit n.")
    w_ax.set(ylabel="Unit n.", xticks=[0, 50], yticks=[0, 50])
    #traces_ax.set(title=f"Width: {width} units")

cbar = pltltr.add_cbar(
    wplot,
    w_ax,
    (0.55, 1.2, 0.5, 0.1),
    orientation="horizontal",
    ticks=[-0.00015, 0],
    ticklabels=["min", 0],
    labelsize=6,
    title="Weight",
    titlesize=8,
)
cbar.ax.xaxis.set_ticks_position("top")

cbar = pltltr.add_cbar(
    tplot,
    traces_ax,
    (0.88, 1.2, 0.15, 0.1),
    orientation="horizontal",
    ticks=[0, 10],
    ticklabels=[0, "max"],
    labelsize=6,
    title="$\Delta F/F$",
    titlesize=8,
)
cbar.ax.xaxis.set_ticks_position("top")

# pltltr.savefig("convergence")
plt.savefig("/Users/vigji/Desktop/convergence.pdf")

# Different connectivity matrices

In [None]:
FIXED_EXCITATION = 1
WEIGTHS_COEF = 20
TOTAL_ACTIVATION = 500
NOISE_SIGMA = 0.0
NOISE_TIMESCALE = 10

SIMULATION_STEPS = 500

In [None]:
plt.close("all")

traces = np.zeros((N_UNITS, SIMULATION_STEPS))


# traces[50:60, 0] = 10
np.random.seed(
    561
)  # results are consistent with anu number, but ensure reproducivility of figs
traces[:, 0] = np.abs(np.random.randn(N_UNITS))  # BASELINE_ACTIVATION

widths = [5, 10, 20, 40]
f, axs = plt.subplots(
    len(widths),
    2,
    figsize=(4, 5),
    gridspec_kw=dict(width_ratios=[0.3, 1], hspace=0.6),
    sharey=True,  # sharex=True
)

for i, width in enumerate(widths):
    w_ax, traces_ax = axs[i, :]
    weigths = get_weigths_matrix(N_UNITS, width)  # [::-1, :]  # * 10
    evolve_traces(traces, weigths, FIXED_EXCITATION, WEIGTHS_COEF, TOTAL_ACTIVATION)
    # noise_arr = np.random.randn(N_UNITS)
    # l = np.abs(weigths).max()
    wplot = w_ax.imshow(
        weigths, aspect="auto", cmap="Blues_r", vmin=weigths.min(), vmax=0
    )  # , #aspect="auto", vmax=10)
    tplot = traces_ax.imshow(
        traces, aspect="auto", cmap="gray"
    )  # , #aspect="auto", vmax=10)
    if i < len(widths) - 1:
        traces_ax.set(xticklabels=[])
        w_ax.set(xticklabels=[])
    else:
        traces_ax.set(xlabel="Pseudotime")
        w_ax.set(xlabel="Unit n.")
    w_ax.set(ylabel="Unit n.")
    traces_ax.set(title=f"Width: {width} units")

cbar = pltltr.add_cbar(
    wplot,
    axs[0, 0],
    (0.55, 1.2, 0.5, 0.1),
    orientation="horizontal",
    ticks=[-0.00015, 0],
    ticklabels=["min", 0],
    labelsize=6,
    title="Weight",
    titlesize=8,
)
cbar.ax.xaxis.set_ticks_position("top")

cbar = pltltr.add_cbar(
    tplot,
    axs[0, 1],
    (0.88, 1.2, 0.15, 0.1),
    orientation="horizontal",
    ticks=[0, 10],
    ticklabels=[0, "max"],
    labelsize=6,
    title="$\Delta F/F$",
    titlesize=8,
)
cbar.ax.xaxis.set_ticks_position("top")

pltltr.savefig("weight_effect")

In [None]:
plt.close("all")

traces = np.zeros((N_UNITS, SIMULATION_STEPS))


# traces[50:60, 0] = 10
np.random.seed(
    561
)  # results are consistent with anu number, but ensure reproducivility of figs
traces[:, 0] = np.abs(np.random.randn(N_UNITS))  # BASELINE_ACTIVATION

widths = [5, 10, 20, 40]
f, axs = plt.subplots(
    len(widths),
    2,
    figsize=(4, 5),
    gridspec_kw=dict(width_ratios=[0.3, 1], hspace=0.6),
    sharey=True,  # sharex=True
)

for i, width in enumerate(widths):
    w_ax, traces_ax = axs[i, :]
    weigths = get_weigths_matrix(N_UNITS, width)  # [::-1, :]  # * 10
    evolve_traces(traces, weigths, FIXED_EXCITATION, WEIGTHS_COEF, TOTAL_ACTIVATION)
    # noise_arr = np.random.randn(N_UNITS)
    # l = np.abs(weigths).max()
    wplot = w_ax.imshow(
        weigths, aspect="auto", cmap="Blues_r", vmin=weigths.min(), vmax=0
    )  # , #aspect="auto", vmax=10)
    tplot = traces_ax.imshow(
        traces, aspect="auto", cmap="gray"
    )  # , #aspect="auto", vmax=10)
    if i < len(widths) - 1:
        traces_ax.set(xticklabels=[])
        w_ax.set(xticklabels=[])
    else:
        traces_ax.set(xlabel="Pseudotime")
        w_ax.set(xlabel="Unit n.")
    w_ax.set(ylabel="Unit n.")
    traces_ax.set(title=f"Width: {width} units")

cbar = pltltr.add_cbar(
    wplot,
    axs[0, 0],
    (0.55, 1.2, 0.5, 0.1),
    orientation="horizontal",
    ticks=[-0.00015, 0],
    ticklabels=["min", 0],
    labelsize=6,
    title="Weight",
    titlesize=8,
)
cbar.ax.xaxis.set_ticks_position("top")

cbar = pltltr.add_cbar(
    tplot,
    axs[0, 1],
    (0.88, 1.2, 0.15, 0.1),
    orientation="horizontal",
    ticks=[0, 10],
    ticklabels=[0, "max"],
    labelsize=6,
    title="$\Delta F/F$",
    titlesize=8,
)
cbar.ax.xaxis.set_ticks_position("top")

pltltr.savefig("weight_effect")

## Evolve with perturbation

In [None]:
@njit
def evolve_traces_perturb(
    traces, weigths, perturb, fixed_excitation, weigths_coef, total_activation
):
    for t in range(1, traces.shape[1]):
        for i in range(N_UNITS):
            traces[i, t] = (
                traces[i, t - 1]
                + weigths_coef * np.sum(traces[:, t - 1] * weigths[i, :])
                + fixed_excitation
                + perturb[i, t]
            )

            if traces[i, t] < 0:
                traces[i, t] = 0

        traces[:, t] = TOTAL_ACTIVATION * traces[:, t] / np.sum(traces[:, t])

In [None]:
T_START = 50
SIMULATION_STEPS = 120
W_PULSE_UNITS = 30
PULSE_AMP = 3
positions = 35, 50, 65
durations = 1, 4, 8

f, axs = plt.subplots(
    len(positions),
    len(durations),
    figsize=(5, 4),
    gridspec_kw=dict(bottom=0.2, right=0.85),  # width_ratios=[0.3, 1], hspace=0.6),
    sharey=True,
    sharex=True,
)
for i in range(3):
    for j in range(3):
        perturb = np.zeros((N_UNITS, SIMULATION_STEPS))
        perturb[
            positions[i] : positions[i] + W_PULSE_UNITS,
            T_START : T_START + durations[j],
        ] = PULSE_AMP

        np.random.seed(561)  # so we know where the bump is
        traces = np.zeros((N_UNITS, SIMULATION_STEPS))
        traces[:, 0] = np.abs(np.random.randn(N_UNITS))  # BASELINE_ACTIVATION

        weigths = get_weigths_matrix(N_UNITS, 20)  # [::-1, :]  # * 10
        evolve_traces_perturb(
            traces, weigths, perturb, FIXED_EXCITATION, WEIGTHS_COEF, TOTAL_ACTIVATION
        )

        # traces[50:60, 0] = 10
        # results are consistent with anu number, but ensure reproducivility of figs
        traces[:, 0] = np.random.randn(N_UNITS)  # BASELINE_ACTIVATION

        # plt.figure(figsize=(3, 1))
        # plt.subplot(121)
        # plt.imshow(perturb)
        # plt.subplot(122)
        im = axs[i, j].imshow(traces, cmap="gray", aspect="auto")
        cf = axs[i, j].contour(perturb, colors=["C2"], levels=[5], linewidths=1)

        if j == 0:
            axs[i, j].set(ylabel=f"Dist: {positions[i]} units")
        if i == len(positions) - 1:
            axs[i, j].set(xlabel="Pseudotime")

        if i == 0:
            axs[i, j].set_title(f"Duration: {durations[j]} s")

axs[0, 0].text(T_START - 20, positions[0] + W_PULSE_UNITS + 10, "perturbation", c="C2")

cbar = pltltr.add_cbar(
    im,
    axs[0, -1],
    (1.25, 0.6, 0.06, 0.5),
    orientation="vertical",
    ticks=[-1, 13],
    ticklabels=[0, "max"],
    labelsize=6,
    title="$\Delta F/F$",
    titlesize=8,
)

pltltr.savefig("dur_dist_effect")