In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ortho_group
from scipy.fft import fft, fftfreq
from scipy.signal import find_peaks
from scipy.signal.windows import blackman
from scipy.interpolate import interp1d

from src.dt_helpers.network_dynamics import run_dynamics_step
from src.dt_helpers.plotters import visualize_network_activity
from src.dt_helpers.network_dynamics import compute_individual_neuron_frequency

#### Define network parameters

In [None]:
network_size = 3
network_input = np.array([1, 0, 0])

weights_12 = 1
weights_23 = 1 / 4
weights_31 = 1
weights_32 = 1 / 4
# network weights
#   Exci.  Exci.   Inhi.
#       0      0    -w31
#     w12      0    -w32
#       0    w23       0
network_weights = np.array(
    [
        [0, 0, -weights_31],
        [weights_12, 0, -weights_32],
        [0, weights_23, 0],
    ]
)


taus_s = np.array([1, 1, 1])
dt_s = 0.0001

#### Run linear network

In [None]:
simulation_time_s = 10
num_time_steps = int(simulation_time_s / dt_s)
current_rates = np.zeros(network_size)
history_rates = []
for i in range(num_time_steps):
    current_rates = run_dynamics_step(
        current_rates=current_rates,
        network_input=network_input,
        weights=network_weights,
        taus=taus_s,
        dt=dt_s,
        noise=0,
        activations=lambda x: x,
    )
    history_rates.append(current_rates)
history_rates = np.array(history_rates)

visualize_network_activity(
    network_activity=history_rates,
    duration=simulation_time_s,
    dt=dt_s,
    num_steps=num_time_steps,
    title="",
    legend_labels=["Exc-input", "Exc-hidden", "Inh-hidden"],
)

#### Run ReLU network

In [None]:
simulation_time_s = 10
num_time_steps = int(simulation_time_s / dt_s)
current_rates = np.zeros(network_size)
history_rates = []
for i in range(num_time_steps):
    current_rates = run_dynamics_step(
        current_rates=current_rates,
        network_input=network_input,
        weights=network_weights,
        taus=taus_s,
        dt=dt_s,
        noise=0,
        activations=lambda x: np.maximum(x, 0, x),
    )
    history_rates.append(current_rates)
history_rates = np.array(history_rates)

visualize_network_activity(
    network_activity=history_rates,
    duration=simulation_time_s,
    dt=dt_s,
    num_steps=num_time_steps,
    title="",
    legend_labels=["Exc-input", "Exc-hidden", "Inh-hidden"],
)

#### Design oscillatory network

In [None]:
a = 0.01
b = 10
network_weights_designed = np.array(
    [
        [-0.01, 0, 0],
        [0, -a, b],
        [0, -b, -a],
    ]
)

Linear network

In [None]:
simulation_time_s = 2 * np.pi
num_time_steps = int(simulation_time_s / dt_s)
current_rates = np.zeros(network_size)
history_rates = []
for i in range(num_time_steps):
    current_rates = run_dynamics_step(
        current_rates=current_rates,
        network_input=np.array([0, 1, 0]),  # network_input * 10,
        weights=network_weights_designed + np.eye(network_size),
        taus=taus_s,
        dt=dt_s,
        noise=0,
        activations=lambda x: x,
    )
    history_rates.append(current_rates)
history_rates = np.array(history_rates)

fig, ax = visualize_network_activity(
    network_activity=history_rates,
    duration=simulation_time_s,
    dt=dt_s,
    num_steps=num_time_steps,
    title="",
    legend_labels=["1", "2", "3"],
    show_plot=False,
)
time_points = (np.arange(num_time_steps) + 1) * dt_s
ax.plot(
    time_points,
    (a * np.cos(b * time_points) + b * np.sin(b * time_points)) / (a**2 + b**2) + a / (a**2 + b**2),
    "k--",
    lw=5,
    alpha=0.3,
    zorder=1,
    label="oscillatory & asymptotic lines",
)
ax.legend()
ax.plot(time_points, np.ones(num_time_steps) * a / (a**2 + b**2), "k--", lw=5, alpha=0.3, zorder=1)
plt.show()

Rotate network weights, still linear

In [None]:
a = 0.01
b = 10
network_weights_designed = np.array(
    [
        [-0.01, 0, 0],
        [0, -a, b],
        [0, -b, -a],
    ]
)

rot_matrix = ortho_group(dim=network_size, seed=31).rvs()
network_weights_designed = rot_matrix @ network_weights_designed @ rot_matrix.T

In [None]:
input_intensity = 1

simulation_time_s = 2 * np.pi
num_time_steps = int(simulation_time_s / dt_s)
current_rates = np.zeros(network_size)
history_rates = []
for i in range(num_time_steps):
    current_rates = run_dynamics_step(
        current_rates=current_rates,
        network_input=network_input * input_intensity,
        weights=network_weights_designed + np.eye(network_size),
        taus=taus_s,
        dt=dt_s,
        noise=0,
        activations=lambda x: x,
    )
    history_rates.append(current_rates)
history_rates = np.array(history_rates)

visualize_network_activity(
    network_activity=history_rates,
    duration=simulation_time_s,
    dt=dt_s,
    num_steps=num_time_steps,
    title=f"Input intensity: {input_intensity}",
    legend_labels=["1", "2", "3"],
    show_plot=True,
)

In [None]:
input_intensities = np.linspace(1, 12, 30)

simulation_time_s = 2 * np.pi

num_time_steps = int(simulation_time_s / dt_s)
individual_neuron_frequencies = pd.DataFrame()
for input_intensity in input_intensities:
    current_rates = np.zeros(network_size)
    history_rates = []
    for i in range(num_time_steps):
        current_rates = run_dynamics_step(
            current_rates=current_rates,
            network_input=network_input * input_intensity,
            weights=network_weights_designed + np.eye(network_size),
            taus=taus_s,
            dt=dt_s,
            noise=0,
            activations=lambda x: x,
        )
        history_rates.append(current_rates)
    history_rates = np.array(history_rates)
    temp_indi_freq = compute_individual_neuron_frequency(
        history_rates,
        dt=dt_s,
        num_steps=num_time_steps,
    )
    temp_indi_freq = pd.Series(temp_indi_freq)
    temp_indi_freq.name = "frequency"
    temp_indi_freq = temp_indi_freq.to_frame()
    temp_indi_freq["input_intensity"] = input_intensity
    individual_neuron_frequencies = pd.concat(
        [individual_neuron_frequencies, temp_indi_freq], axis=0
    )
individual_neuron_frequencies = individual_neuron_frequencies.reset_index()

In [None]:
sns.lineplot(
    data=individual_neuron_frequencies,
    x="input_intensity",
    y="frequency",
    hue="index",
    palette="tab10",
    size="index",
    sizes=(2, 10),
)
plt.ylim(0)
sns.despine(offset=1, trim=True)
plt.show()

Non-linear network

In [None]:
input_intensity = 1

simulation_time_s = 2 * np.pi
num_time_steps = int(simulation_time_s / dt_s)
current_rates = np.zeros(network_size)
history_rates = []
for i in range(num_time_steps):
    current_rates = run_dynamics_step(
        current_rates=current_rates,
        network_input=network_input * input_intensity,
        weights=network_weights_designed + np.eye(network_size),
        taus=taus_s,
        dt=dt_s,
        noise=0,
        activations=lambda x: np.maximum(np.tanh(x / 10) * 10, 0),
    )
    history_rates.append(current_rates)
history_rates = np.array(history_rates)

visualize_network_activity(
    network_activity=history_rates,
    duration=simulation_time_s,
    dt=dt_s,
    num_steps=num_time_steps,
    title=f"Input intensity: {input_intensity}",
    legend_labels=["1", "2", "3"],
    show_plot=True,
)

Frequency modulation with input strength?

In [None]:
input_intensities = np.linspace(1, 12, 30)

simulation_time_s = 2 * np.pi

num_time_steps = int(simulation_time_s / dt_s)
individual_neuron_frequencies = pd.DataFrame()
for input_intensity in input_intensities:
    current_rates = np.zeros(network_size)
    history_rates = []
    for i in range(num_time_steps):
        current_rates = run_dynamics_step(
            current_rates=current_rates,
            network_input=network_input * input_intensity,
            weights=network_weights_designed + np.eye(network_size),
            taus=taus_s,
            dt=dt_s,
            noise=0,
            activations=lambda x: np.maximum(np.tanh(x / 10) * 10, 0),
        )
        history_rates.append(current_rates)
    history_rates = np.array(history_rates)
    temp_indi_freq = compute_individual_neuron_frequency(
        history_rates,
        dt=dt_s,
        num_steps=num_time_steps,
        use_interpolation=True,
    )
    temp_indi_freq = pd.Series(temp_indi_freq)
    temp_indi_freq.name = "frequency"
    temp_indi_freq = temp_indi_freq.to_frame()
    temp_indi_freq["input_intensity"] = input_intensity
    individual_neuron_frequencies = pd.concat(
        [individual_neuron_frequencies, temp_indi_freq], axis=0
    )
individual_neuron_frequencies = individual_neuron_frequencies.reset_index()

In [None]:
sns.lineplot(
    data=individual_neuron_frequencies,
    x="input_intensity",
    y="frequency",
    hue="index",
    palette="tab10",
    size="index",
    sizes=(2, 10),
)
plt.ylim(0)
sns.despine(offset=1, trim=True)
plt.show()

### Multi-period networks

In [None]:
network_size = 5
network_input = np.array([1, 0, 0, 0, 0])

d1 = 0.0001
d2 = 0.001
f1 = 1
f2 = 10
network_weights_designed = np.array(
    [
        [-0.1, 0, 0, 0, 0],
        [0, -d1, f1, 0, 0],
        [0, -f1, -d1, 0, 0],
        [0, 0, 0, -d2, f2],
        [0, 0, 0, -f2, -d2],
    ]
)

rot_matrix = ortho_group(dim=network_size, seed=4).rvs()
network_weights_designed = rot_matrix @ network_weights_designed @ rot_matrix.T

taus_s = np.array([1, 1, 1, 1, 1])
dt_s = 0.0001

#### Linear network

In [None]:
input_intensity = 1

simulation_time_s = 4 * np.pi
num_time_steps = int(simulation_time_s / dt_s)
current_rates = np.zeros(network_size)
history_rates = []
for i in range(num_time_steps):
    current_rates = run_dynamics_step(
        current_rates=current_rates,
        # network_input=np.array([0, f2, 0, f1, 0]),
        network_input=network_input * input_intensity,
        weights=network_weights_designed + np.eye(network_size),
        taus=taus_s,
        dt=dt_s,
        noise=0,
        activations=lambda x: x,
    )
    history_rates.append(current_rates)
history_rates = np.array(history_rates)

visualize_network_activity(
    network_activity=history_rates,
    duration=simulation_time_s,
    dt=dt_s,
    num_steps=num_time_steps,
    title=f"Input intensity: {input_intensity}",
    legend_labels=np.arange(network_size),
    show_plot=True,
)

#### Non-linear
np.maximum(np.tanh(x / 10) * 10, 0)

In [None]:
input_intensity = 1.5

simulation_time_s = 6 * np.pi
num_time_steps = int(simulation_time_s / dt_s)
current_rates = np.zeros(network_size)
history_rates = []
for i in range(num_time_steps):
    current_rates = run_dynamics_step(
        current_rates=current_rates,
        # network_input=np.array([0, f1, 0, f2, 0]),
        network_input=network_input * input_intensity,
        weights=network_weights_designed + np.eye(network_size),
        taus=taus_s,
        dt=dt_s,
        noise=0,
        activations=lambda x: np.maximum(np.tanh(x / 10) * 10, 0),
    )
    history_rates.append(current_rates)
history_rates = np.array(history_rates)

visualize_network_activity(
    network_activity=history_rates,
    duration=simulation_time_s,
    dt=dt_s,
    num_steps=num_time_steps,
    title=f"Input intensity: {input_intensity}",
    legend_labels=np.arange(network_size),
    show_plot=True,
)

In [None]:
xs_f = fftfreq(num_time_steps, dt_s)[: num_time_steps // 2]
xs_f_temp = xs_f
weights = blackman(num_time_steps)

ys = history_rates[:, 4]
ys_fft = fft(ys * weights)
ys_fft = np.abs(ys_fft[0 : num_time_steps // 2])

interp_fun = interp1d(xs_f, ys_fft, kind="cubic")
xs_f_interp = np.linspace(0, xs_f[-1], len(xs_f) * 20)
xs_f_temp2 = xs_f_interp
ys_fft2 = interp_fun(xs_f_interp)

peaks, _ = find_peaks(ys_fft)
peaks2, _ = find_peaks(ys_fft2)

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(xs_f_temp, ys_fft, lw=3, zorder=0)
plt.scatter(xs_f_temp[peaks[ys_fft[peaks].argmax()]], ys_fft[peaks].max(), color="green", s=100)
plt.scatter(xs_f_temp[peaks], ys_fft[peaks], color="red", s=40)
plt.plot(xs_f_temp2, ys_fft2, zorder=0)
plt.scatter(xs_f_temp2[peaks2], ys_fft2[peaks2], color="black", s=20)
plt.scatter(xs_f_temp2[peaks2[ys_fft2[peaks2].argmax()]], ys_fft2[peaks2].max(), color="blue")
plt.scatter(xs_f_temp, ys_fft, marker="x", s=100, color="y")

temp_freq = xs_f_temp[peaks[ys_fft[peaks].argmax()]]
temp_freq = xs_f_temp2[peaks2[np.abs(xs_f_interp[peaks2] - temp_freq).argmin()]]
plt.scatter(temp_freq, ys_fft2[peaks2[np.abs(xs_f_interp[peaks2] - temp_freq).argmin()]], marker=3, s=100, c='orange')
# plt.xscale("log")
plt.yscale("log")
plt.xlim(0.3,1.3)
plt.show()

#### Frequency analysis as function of input strength

In [None]:
input_intensities = np.linspace(1, 12, 30)

simulation_time_s = 6 * np.pi

num_time_steps = int(simulation_time_s / dt_s)
individual_neuron_frequencies = pd.DataFrame()
for input_intensity in input_intensities:
    current_rates = np.zeros(network_size)
    history_rates = []
    for i in range(num_time_steps):
        current_rates = run_dynamics_step(
            current_rates=current_rates,
            network_input=network_input * input_intensity,
            weights=network_weights_designed + np.eye(network_size),
            taus=taus_s,
            dt=dt_s,
            noise=0,
            activations=lambda x: np.maximum(np.tanh(x / 10) * 10, 0),
        )
        history_rates.append(current_rates)
    history_rates = np.array(history_rates)
    temp_indi_freq = compute_individual_neuron_frequency(
        history_rates,
        dt=dt_s,
        num_steps=num_time_steps,
        use_interpolation=True,
    )
    temp_indi_freq = pd.Series(temp_indi_freq)
    temp_indi_freq.name = "frequency"
    temp_indi_freq = temp_indi_freq.to_frame()
    temp_indi_freq["input_intensity"] = input_intensity
    individual_neuron_frequencies = pd.concat(
        [individual_neuron_frequencies, temp_indi_freq], axis=0
    )
individual_neuron_frequencies = individual_neuron_frequencies[individual_neuron_frequencies["frequency"].apply(lambda x: type(x) != dict)]
individual_neuron_frequencies = individual_neuron_frequencies.reset_index()

In [None]:
sns.lineplot(
    data=individual_neuron_frequencies,
    x="input_intensity",
    y="frequency",
    hue="index",
    palette="tab10",
    size="index",
    sizes=(2, 10),
)
sns.despine(offset=1, trim=True)
plt.show()

## Simulate solutions equations to gauge periodicity

In [None]:
lam1 = np.sqrt(3) / 2 + 1j / 2
lam2 = 0.5 + 1j / 2
ks = np.arange(0, 500, 0.1)
powers1 = np.power(lam1, ks)
powers2 = np.power(lam2, ks)

print(np.abs(lam1))
print(np.abs(lam2))

fig, axes = plt.subplots(2, 1, figsize=(20, 4), sharex=True)
axes[0].plot(ks, powers1.real)
axes[0].plot(ks, powers2.real / np.power(np.abs(lam2), ks))
sns.despine(offset=5, trim=True, ax=axes[0])
axes[1].plot(ks, powers1.imag)
axes[1].plot(ks, powers2.imag / np.power(np.abs(lam2), ks))
sns.despine(offset=5, trim=True, ax=axes[1])
plt.show()