# About

Implement gradient descent algorithms to process the data directly. The results should align with MATLAB given the parameters are the same. For publishing, the MATLAB script and implementation should be primarily cited.

# Library

In [1]:
from __future__ import division, print_function

%matplotlib inline
# Toggle on/off
# %matplotlib notebook

import os
import numpy as np
import pandas as pd
import scipy.io as sio
from scipy import optimize
import scipy.integrate as integrate
from scipy import stats
from scipy import special
from scipy.spatial import distance
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.transforms as tsfm
import matplotlib.colors as clr
from tqdm.notebook import tqdm
import math
from math import pi


from lib import *

from IPython.display import clear_output

# Directories

In [2]:
foldername = os.path.join(os.getcwd(), 'data', 'arrays', 'results1')
filename_params = os.path.join(foldername, 'parameters.mat') # All parameters and raw data
filename_arrays = os.path.join(foldername, 'arrays.mat') # All plottable arrays

# Import data

### Parameters

In [3]:
dictPar = sio.loadmat(filename_params)

# Raw data
W_raw = dictPar['W_raw']
dist = dictPar['dist']
W = dictPar['W']

# Parameters
N = dictPar['N'][0,0]
vel0 = dictPar['vel0'][0,0]
r0 = np.reshape(dictPar['r0'], -1)
kappa = dictPar['kappa'][0,0]
gamma0 = dictPar['gamma0'][0,0]
eta = dictPar['eta'][0,0]
numIters = dictPar['numIters'][0,0]

# Injury parameters
vel_range = dictPar['vel_range']
velInj = dictPar['velInj']
beta = dictPar['beta'][0,0]
injIndex = dictPar['injIndex'][0,0]
injTime = dictPar['injTime'][0,0]

# Other
tau0 = dist / vel0
iters = np.arange(numIters)

### Processed arrays

In [None]:
dictArr = sio.loadmat(filename_arrays)

# Stability
stab = np.reshape(dictArr['stab'], -1) # Stability (lowerst real eig part) over iters
# ODE solved rates arrays?

# Objective
objective = np.reshape(dictArr['objective'], -1) # Objective over iters

# Heatmaps of gammas

# Histograms of delays, velocities

# Figures

## Objective function

In [None]:
fig, ax = plt.subplots(1, figsize=(12,4), dpi=80)
iters = np.arange(1,numIters+1)
ax.plot(iters, objective)
ax.axvline(injTime * numIters, color='red', zorder=0)
ax.set_xlim(left=0, right=numIters)

## Connectivity and initial matrices

Left (blue) heatmap of connectivity weights (processed), and right (red) heatmap of connection delays

In [None]:
# Positions [left, bottom, width, height] 
bbox0 = [0.05, 0.1, 0.35, 0.80]
bbox1 = [0.43, 0.1, 0.03, 0.80]
bbox2 = [0.55, 0.1, 0.35, 0.80]
bbox3 = [0.93, 0.1, 0.03, 0.80]

fig, ax = plt.subplots(4, figsize=(14,6), dpi=80)

ax[0].set_position(bbox0)
ax[1].set_position(bbox1)
ax[2].set_position(bbox2)
ax[3].set_position(bbox3)

cs1 = ax[0].imshow(W, cmap='Blues')
cs2 = ax[2].imshow(tau0, cmap='Reds')

# Colour bars
numticks1 = 10
numticks2 = 10
ticks1 = np.arange(0, np.max(W), np.round(np.max(W) / numticks1, decimals=1)) # Connectivity strength
ticks2 = np.arange(0, np.max(tau0), np.round(np.max(tau0) / numticks2, decimals=0)) # Delays

fig.colorbar(cs1, cax=ax[1], ticks=ticks1)
fig.colorbar(cs2, cax=ax[3], ticks=ticks2)

# Coincidences over time

Heatmap of coincidence factors at specified timestamps, initially at time = 0s (left), before injury (middle), and post-injury (right).

# Statistics

## Display parameters

### Table 1: Main parameters

In [6]:
var_name1 = ['N', r'Initial velocity $v_0$', 
            r'Scaling factor $\kappa$', 
            r'Myelination rate $\eta$', 
            r'Baseline firing rate $r_i^0$',
            r'Coincidence normalizer $\gamma$']

var_value1 = [N, vel0, kappa, eta, r0[0], gamma0]

var_name1 = np.array(var_name1)
var_value1 = np.array(var_value1)

table1 = pd.DataFrame({'Variable' : var_name1, 'Value': var_value1})

# Display
table1.style

Unnamed: 0,Variable,Value
0,N,68.0
1,Initial velocity $v_0$,1.0
2,Scaling factor $\kappa$,900.0
3,Myelination rate $\eta$,100.0
4,Baseline firing rate $r_i^0$,0.1
5,Coincidence normalizer $\gamma$,1.0


### Table 2: Injury parameters

In [7]:
var_name2 = [r'Total iterations',
             r'$v_0$ uniform sample range',
             r'Rate of injury $\beta$',
             r'Injury index',
             r'Injury time']
var_value2 = [numIters, vel_range, beta, injIndex, injTime*numIters]
table2 = pd.DataFrame({'Variable' : var_name2, 'Value': var_value2})

# Display
table2.style

Unnamed: 0,Variable,Value
0,Total iterations,300
1,$v_0$ uniform sample range,"[0.5, 2.0]"
2,Rate of injury $\beta$,0.05
3,Injury index,0.1
4,Injury time,180


### Table 3: Processed statistics

In [None]:
var_name3 = [r'Number of connections',
             r'Total number of connections',
             r'Mean active connectivity weight',
             r'Maximum objective',
             r'Rate of injury $\beta$',
             r'Injury index',
             r'Injury time']
var_value3 = [np.count_nonzero(W != 0.0), W.size, np.mean(W[W != 0]), np.max(gradObj), beta, injIndex, injTime*numIters]
table3 = pd.DataFrame({'Stat' : var_name3, 'Value': var_value3})

# Display
table3.style

### Set up arrays to plot

In [8]:
# Sample sizes
numTauInj = min(100, int(injIndex * N**2))
numTauNonInj = min(100, int((1-injIndex) * N**2))

# Sample indices to be plotted
injInds = np.where((isInj == 1) * (W != 0))
nonInjInds = np.where((isInj == 0) * (W != 0))

injSample = np.random.choice(injInds[0].size, numTauInj)
nonInjSample = np.random.choice(nonInjInds[0].size, numTauNonInj)

injInds_i = injInds[0][injSample]
injInds_j = injInds[1][injSample]
nonInjInds_i = nonInjInds[0][nonInjSample]
nonInjInds_j = nonInjInds[1][nonInjSample]

### Statistics

In [13]:
num_conns = np.count_nonzero(W != 0.0)
print(f'Number of active connections: {num_conns} out of {W.size}')

Number of active connections: 3726 out of 4624


In [14]:
np.max(gradObj), np.mean(W), np.mean(gamma)

(0.0035603749100393055, 0.03925114746600026, 0.4729467021539464)

In [15]:
s = np.linalg.norm(W)
t = np.max(W)
s, t
np.mean(W[W != 0]), np.mean(W) # Show statistics regarding mean connections + coincidence in the paper as a table.
# Non-negative connections? Number of connections?

(0.04871103217466055, 0.03925114746600026)