# Multiple Ring Neurons Experiment

## Imports

In [90]:
import numpy as np
import pandas as pd
from scipy.integrate import solve_ivp
from scipy.stats import norm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import imageio
from datetime import datetime as dt
import os

In [72]:
COLORS = mcolors.TABLEAU_COLORS
COLOR_NAMES = [k.replace('tab:', '') for k in COLORS.keys()]
COLORS = list(COLORS.values())

# Utilities

In [80]:
def create_frame(ax, xs, ys, pen_label, pen_color, t, detect_intersec, intersec_pts) -> None:
    '''
    :param xs list: a list of the different x-series of the doodles
    :param ys list: a list of the different y-series of the doodles
    :param t int: the current timestep in the doodle

    TODO file naming
    '''
    assert len(xs) == len(ys), "xs and ys shape doesn't match!"
    
    ax.plot(xs[:t+1], ys[:t+1], color=pen_color, alpha=0.5, label=pen_label)
    ax.scatter(xs[t], ys[t], color=pen_color, alpha=0.8, marker = 'o')

    if detect_intersec and t >= 3:
        # get point 1 (time t-1) and point 2 (time t) coordinates
        x1 = xs[t-1]
        x2 = xs[t]
        y1 = ys[t-1]
        y2 = ys[t]

        # get point 3s (times t=0,...,t-3) and point 4s (times t=1,...,t-2) coordinates
            # NOTE: we don't include the line segment from time t-2 to time t-1 because that's just going to intersect 
            # with point 1 because of time t-1
        x3s = xs[:t-2]
        y3s = ys[:t-2]
        x4s = np.roll(xs, -1)[:t-2]
        y4s = np.roll(ys, -1)[:t-2]

        # find where two line segments intersect (if they do)
        # https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection#Given_two_points_on_each_line_segment
        numerator_t = ((x1 - x3s) * (y3s - y4s)) - ((y1 - y3s) * (x3s - x4s))
        numerator_u = ((x1 - x3s) * (y1 - y2)) - ((y1 - y3s) * (x1 - x2))
        denom = ((x1 - x2) * (y3s - y4s)) - ((y1 - y2) * (x3s - x4s))

        # (d != 0) & (0 <= n/d <= 1)
        intersec_t_idxs = np.nonzero((0. != denom) & (0 <= numerator_t / denom) & (numerator_t / denom <= 1))[0]
        intersec_u_idxs = np.nonzero((0. != denom) & (0 <= numerator_u / denom) & (numerator_u / denom <= 1))[0]

        # get indexes where both t and u are between 0 and 1
        intersec_idxs = np.intersect1d(intersec_t_idxs, intersec_u_idxs)  

        intersec_t = numerator_t[intersec_idxs] / denom[intersec_idxs]
        
        intersec_x = x1 + (intersec_t * (x2 - x1))
        intersec_y = y1 + (intersec_t * (y2 - y1))

        new_intersec = np.array([intersec_x, intersec_y]).reshape(-1,2)
        intersec_pts = np.concatenate((intersec_pts, new_intersec), axis=0)
        ax.scatter(intersec_pts[:,0], intersec_pts[:,1], color='red', marker='o')

    ax.set_xlim([-2, 2])
    ax.set_xlabel('x', fontsize = 14)
    ax.set_ylim([-2, 2])
    ax.set_ylabel('y', fontsize = 14)
    ax.set_title(f'Step {t}', fontsize=14)
    ax.legend()

    return intersec_pts

# Model Definition

## Equations

### Equations

Activation: 

$\frac{dv_i}{dt}=\frac{1}{\tau}(-\lambda u_i v_i + I_i' (1 - \eta \sum\limits_{\substack{j \in R \\ j \ne i}}z_j));$

Input Depletion:

$\frac{dI_i'}{dt} = -\varphi I_i z_i$

Deactivation: 

$\frac{du_i}{dt}=-\rho u_i + \gamma I_i' \frac{z_i}{c_i + \epsilon};$

Output: 

$z_i=f(v_i)=\frac{1}{1+e^{-\beta (v_i - \mu)}};$

In [75]:
def sigmoid(v, beta, mu):
    return 1 / (1 + (np.e ** ((-1*beta) * (v - mu))))

## Neurons

In [76]:
NUM_UNITS = 36
directions_deg = [i * 360 / NUM_UNITS for i in range(NUM_UNITS)] # define directions_deg so we can use these for easy plotting
directions_rad = np.array([np.deg2rad(dir) for dir in directions_deg]) # convert degrees to radians for headings
headings = np.array([[np.cos(dir), np.sin(dir)] for dir in directions_rad]) # headings is used for later calculations

# Driver Code

In [81]:
def doodle(t, state, p):
    '''
    Because we can't provide a vectorized state (i.e. state can't be 2-d in solve_ivp()),
    we hide the two vectors in state, so state is a vector of [v, u], 
    where v and u are both vectors of length `num_units`.
    
    Then, we can handle the change in v and change in u separately, 
    and concat them back together to be returned as the new state.
    '''
    v = state[0:p['N']]
    u = state[p['N']:2*p['N']]
    I_prime = state[2*p['N']:]

    assert set([v.shape[0], u.shape[0], I_prime.shape[0]]) == set([p['N']]), f"State's shapes don't match! {v.shape, u.shape, I_prime.shape}"

    z = sigmoid(v, p['beta'], p['mu'])
   
    # calculate dv/dt, du/dt, DI'/dt
    inhibition_vec = 1 - (p['eta'] * np.dot(z, 1 - np.eye(p['N']))) # multiply by the sum of *other* neuron's outputs
    dv = (1 / p['tau']) * ((-1 * p['lambda'] * u * v) + (I_prime * inhibition_vec))
    du = (-1 * p['rho'] * u) + (p['gamma'] * I_prime * (z) / (p['c'] + p['epsilon']))
    dI_prime = -1 * p['phi'] * p['I'] * z
    
    # join v and u back together to be returned
    return np.array((dv, du, dI_prime)).reshape(3*p['N'])

### Plot all variables

In [84]:
dt_string = str(dt.now()).replace(':', '').replace('.','')
print(f'Datetime string: {dt_string}')
if not os.path.isdir('output'): os.mkdir('output')
folder_name = f'output\\{dt_string}'
os.mkdir(folder_name)

plot_v = False
plot_u = False
plot_I_prime = False
plot_z = True

t_max = 40
t_steps = 400
t = np.linspace(0, t_max, t_steps)

I = np.random.rand(NUM_UNITS)
c = np.random.rand(NUM_UNITS)

params = {
    'N': NUM_UNITS,
    'tau': 1.0,
    'lambda': 20,
    'eta': 1.2,
    'I': I,
    'rho': 0.1,
    'gamma': 0.1,
    'c': c,
    'epsilon': 0.00001,
    'beta': 50.0,
    'mu': 0.1,
    'phi': 0.5,
    'alpha': 0.9
}

params_df = pd.DataFrame(params, columns=list(params.keys()))

v = np.zeros(params['N'])
u = np.zeros(params['N'])

state = np.array((v, u, I)).reshape(3*params['N'])
result = solve_ivp(fun=lambda t, state: doodle(t, state, params), t_span=(min(t), max(t)), t_eval=t, y0=state)
v_series = result.y[:params['N'],]
z_series = sigmoid(v_series, params['beta'], params['mu'])
u_series = result.y[params['N']:2*params['N'],]
I_prime_series = result.y[2*params['N']:,]

fig, axs = plt.subplots()

for i in np.argsort(I):
    color = COLORS[i % len(COLORS)]
    if plot_v: plt.plot(t, v_series[i], label=f'v_{i}', c=color, linestyle='dashed')
    if plot_u: plt.plot(t, u_series[i], label=f'u_{i}', c=color, linestyle='dotted')
    if plot_I_prime: plt.plot(t, I_prime_series[i], label=f"I'_{i}", c=color, linestyle='dashdot')
    if plot_z: plt.plot(t, z_series[i], label=f'z_{i}', c=color, linestyle='solid')    
    plt.axhline(y=0.0, c="black", linewidth=0.05)

    rank = i
    dir = directions_deg[i]
    input_val = I[i]
    c_val = c[i]
    color_name = COLOR_NAMES[i % len(COLORS)]

plt.ylim([0, 1])
plt.xlabel('t')

fig.savefig(f'{folder_name}\\plot_{dt_string}')
plt.close()

# draw output
dir_series = z_series.T @ headings # does not include momentum
momentum_term = np.roll(dir_series, 1, axis=0) # roll time series forward one step
momentum_term[0, :] = np.array([0., 0.]) # set first momentum step to 0

alphas = np.cumprod([params['alpha']] * (t_steps - 1))
alphas = np.array([0] + list(alphas))

# recurrence relation boils down to the following momentum term
# convolution(N, M) gives a result of n + m - 1 elements. we only need the first t_steps
dir_series_with_momentum_x = (1 - params['alpha']) * (z_series.T @ headings).T[0,:] + (1 - params['alpha']) * np.convolve((z_series.T @ headings).T[0,:], alphas)[:t_steps]
dir_series_with_momentum_y = (1 - params['alpha']) * (z_series.T @ headings).T[1,:] + (1 - params['alpha']) * np.convolve((z_series.T @ headings).T[1,:], alphas)[:t_steps]

# scale x and y distances by 1/10 to keep drawings on the page
xs_with_momentum = (1 / 10) * dir_series_with_momentum_x 
ys_with_momentum = (1 / 10) * dir_series_with_momentum_y

x_series_with_momentum = np.cumsum(xs_with_momentum)
y_series_with_momentum = np.cumsum(ys_with_momentum)

# create GIF
if not os.path.isdir(f'{folder_name}\\img'):
    os.mkdir(f'{folder_name}\\img')

frames = []
intersections = np.ndarray((0, 2))

for t in range(t_steps):
    f, axs = plt.subplots() 

    intersections = create_frame(axs, x_series_with_momentum, y_series_with_momentum, pen_color='black', pen_label='testing', t=t, detect_intersec=True, intersec_pts=intersections)
    
    f.savefig(f'{folder_name}\\img\\img_{t}.png')
    plt.close()
    image = imageio.v2.imread(f'{folder_name}\\img\\img_{t}.png')
    frames.append(image)

imageio.mimsave(f"{folder_name}\\GIF_{dt_string}.gif", frames, duration=t_max)

Datetime string: 2023-12-01 122856174661
