In [1]:
import numpy as np
import csv
from NNengine import Run_NN, Multilayers

# Training Neural Network
### Very simple example of how NN would be trained on singular data set.

## Import simulation training values

In [2]:
fn = open('gt_1a.csv', 'r')
f = csv.DictReader(fn)

In [3]:
stim = []
voltage = []
for col in f:
    stim.append(float(col['stim']))
    voltage.append(float(col['voltage']))
stim = np.asarray(stim)
voltage = np.asarray(voltage)

#isolate values where stimulus present
start,stop = (np.where(stim!=0)[0][0], np.where(stim!=0)[0][-1])
test_stim = stim[start:stop]
test_voltage = voltage[start:stop]

# NN for a square wave
#### In (this) case where the pulse has the same amplitude for all values, we test our NN with a list of length 1 containing that value. We rescale the stimulus using our activation function for use in comparison within NN. Voltage data is normalized for NN accuracy.

In [4]:
assert len(set(test_stim)) == 1
test_stim = [test_stim[0]] #single value SQUARE WAVE
test_stim = [1/(1+np.exp(s)) for s in test_stim]
test_voltage = voltage[start:stop]
test_voltage = (test_voltage - min(test_voltage))/(max(test_voltage) - min(test_voltage))

In [5]:
nouts = [300, len(test_stim)] #single layer neural network with 300 nodes (there are 199 test_voltage values)
nin = len(test_voltage) #number of input nodes/values

### Initiate NN object and run

In [6]:
NN = Multilayers(nin, nouts) #initiate our NN class object
stim_pred,loss = Run_NN(NN, test_stim, test_voltage, iter_lim=50, error_thresh=.001)

loss: 0.025488307559917968
loss: 0.022383437199705918
loss: 0.0194417930082718
loss: 0.01674761625143805
loss: 0.014309637130990548
loss: 0.012129799876530463
loss: 0.010203612217631983
loss: 0.008520919525591223
loss: 0.007066972651243584
loss: 0.0058236570724440025
loss: 0.004770749470874841
loss: 0.003887087073695712
loss: 0.003151567591072604
loss: 0.002543934583211312
loss: 0.0020453370510335467
loss: 0.0016386781601623911
loss: 0.001308784496089472
loss: 0.0010424347308837438
loss: 0.0008282870691499974


In [7]:
stim_pred

array([0.15888845])

#### If we go back to the original stimulus scale, we can see that the output is in the ballpark of our original stimulus: 1.9. To improve accuracy, the error threshold and iteration limits would have to be adjusted. To test this, increase iter_lim and reduce error_thresh in Run_NN().

In [8]:
rescaled = np.log((1 / (stim_pred + 1e-8)) - 1) #should give you input stimulus at original scale
rescaled

array([1.66652182])