In [797]:
#Problem 2 Retinal Ganglion Vertical Line Detector Network
#Notes: 
#using LIF neuronal representation for ganglion cells, 
#supervised hebbian learning (SHL)
#9x9 images for tractability
import numpy as np
from sklearn import preprocessing
import math

In [798]:
#LIF neuron
class LIF():
    def __init__(self, imageSeg, onCenter):
        '''
        Param imageSeg: image segment number (for a 9x9 image, segment 0 is the 3x3 upper left corner and segment 8 is
        the bottom right 3x3 corner)
        Param onCenter: boolean indicating whether or not the ganglion is on-center off surround; the cell is implicitly
        off-center on-surround if the value is 'False'
        '''
        self.r = 1
        self.c = 10
        self.tau = self.r * self.c
        self.v = -65
        self.rest = -65
        self.i = 10
        self.thresh = -60
        self.imageSegXStart = (imageSeg * 3) % 9
        self.imageSegXEnd = self.imageSegXStart + 2
        if(imageSeg <= 2):
            self.imageSegYStart = 0
            self.imageSegYEnd = 2
        elif(imageSeg > 2 and imageSeg <= 5):
            self.imageSegYStart = 3
            self.imageSegYEnd = 5
        else:
            self.imageSegYStart = 6
            self.imageSegYEnd = 8
        self.onCenter = onCenter
        self.spikeTime = -1
def simulate(ganglion, v, rest, thresh, tau, r, i):
    time = 0
    while(ganglion.spikeTime == -1):
        v += (1/tau) * (-(v - rest) + r * i)    # Membrane potential equation
        if v >= thresh:         #Spike is generated
            ganglion.spikeTime = time
            v = rest
        else:                   #Spike is not generated
            time += 1
    return time

In [799]:
#Images we will use for both training and testing, pixels part of the line will have value 1,
#the rest of the pixels will be 0
#vertical line positive images (positives)
positives = list()
negatives = list()
print("Formatted print of vertical line positive images\n")
for a in range(9):
    vlp_a = list()
    for i in range(9):
        vlp_i = [int(j==a) for j in range(9)]
        print(vlp_i)
        vlp_a.append(vlp_i)
    positives.append(vlp_a)
    print()
#horizontal line images (negatives)
print("Formatted print of horizontal line negative images\n")
for a in range(9):
    hln_a = list()
    for i in range(9):
        hln_i = [int(i==a) for j in range(9)]
        print(hln_i)
        hln_a.append(hln_i)
    negatives.append(hln_a)
    print()
#diagonal line images (negatives)
print("Formatted print of forward slash diagonal line negative images")
for a in range(17):
    dln_a = list()
    for i in range(9):
        dln_i = [int((i+j)==a) for j in range(9)]
        print(dln_i)
        dln_a.append(dln_i)
    negatives.append(dln_a)
    print()
print("Formatted print of back slash diagonal line negative images")
for a in range(17):
    dln_a = list()
    for i in range(9):
        dln_i = [int((8+i-j)==a) for j in range(9)]
        print(dln_i)
        dln_a.append(dln_i)
    negatives.append(dln_a)
    print()
    
custom_negatives = list()
cn_1 = [[0, 0, 1, 0, 0, 0, 0, 0, 0]
       ,[0, 0, 1, 0, 0, 0, 0, 0, 0]
       ,[0, 0, 0, 1, 0, 0, 0, 0, 0]
       ,[0, 0, 0, 1, 0, 0, 0, 0 ,0]
       ,[0, 0, 0, 0, 1, 0, 0, 0, 0]
       ,[0, 0, 0, 0, 1, 0 ,0 ,0 ,0]
       ,[0, 0, 0, 0, 0, 1, 0, 0, 0]
       ,[0, 0, 0, 0, 0, 1, 0, 0, 0]
       ,[0, 0, 0, 0, 0, 0, 1, 0, 0]]
custom_negatives.append(cn_1)
cn_2 = list()
for i in range(9):
    #reverse image
    cn_2.append(cn_1[8-i])
custom_negatives.append(cn_2)
cn_3 = [[0, 0, 0, 1, 0, 0, 0, 0, 0]
       ,[0, 0, 0, 1, 0, 0, 0, 0, 0]
       ,[0, 0, 0, 1, 0, 0, 0, 0, 0]
       ,[0, 0, 0, 0, 1, 0, 0, 0, 0]
       ,[0, 0, 0, 0, 1, 0, 0, 0, 0]
       ,[0, 0, 0, 0, 1, 0, 0, 0, 0]
       ,[0, 0, 0, 0, 0, 1, 0, 0, 0]
       ,[0, 0, 0, 0, 0, 1, 0, 0, 0]
       ,[0, 0, 0, 0, 0, 1, 0, 0, 0]]
custom_negatives.append(cn_3)
cn_4 = list()
for i in range(9):
    cn_4.append(cn_3[8-i])
custom_negatives.append(cn_4)

Formatted print of vertical line positive images

[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0]

[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0]

[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0]

[0, 0, 0, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 0, 0, 0, 0

In [800]:
#ganglions in the first layer layered sequentially row-wise
ganglions = list()
for i in range(9):
    if(i%2 == 0):
        ganglions.append(LIF(i, True))
    else:
        ganglions.append(LIF(i, False))

i = 0
for ganglion in ganglions:
    print('Ganglion ' + str(i) + ' Information:')
    if(ganglion.onCenter):
        print('on center, off surround')
    else:
        print('off center, on surround')
    print('image segment x start: ' + str(ganglion.imageSegXStart))
    print('image segment x end: ' + str(ganglion.imageSegXEnd))
    print('image segment y start: ' + str(ganglion.imageSegYStart))
    print('image segment y end: ' + str(ganglion.imageSegYEnd))
    print()
    i += 1

Ganglion 0 Information:
on center, off surround
image segment x start: 0
image segment x end: 2
image segment y start: 0
image segment y end: 2

Ganglion 1 Information:
off center, on surround
image segment x start: 3
image segment x end: 5
image segment y start: 0
image segment y end: 2

Ganglion 2 Information:
on center, off surround
image segment x start: 6
image segment x end: 8
image segment y start: 0
image segment y end: 2

Ganglion 3 Information:
off center, on surround
image segment x start: 0
image segment x end: 2
image segment y start: 3
image segment y end: 5

Ganglion 4 Information:
on center, off surround
image segment x start: 3
image segment x end: 5
image segment y start: 3
image segment y end: 5

Ganglion 5 Information:
off center, on surround
image segment x start: 6
image segment x end: 8
image segment y start: 3
image segment y end: 5

Ganglion 6 Information:
on center, off surround
image segment x start: 0
image segment x end: 2
image segment y start: 6
image seg

In [801]:
#learning cycle that implements parallel SHL referenced in COMPARISON OF SUPERVISED LEARNING METHODS FOR SPIKE TIME
#CODING IN SPIKING NEURAL NETWORKS Section 2.6 (Kasinski et al.)
def learn_cycle(image, eta, td, ganglions, w0):
    '''
    First integrates spatial excitability of a ganglion into the current based on image used, then starts the learning cycle
    which consists of simulating the ganglions until they spike, then updating weights using the parallel algorithm. This method
    is to be used iteratively to update weights of all synapses in parallel. 
    
    NOTE: The learning cycles converge weights to an optimal weights vector notated 'wd' in the paper cited above, the parallel 
    weight algorithm does not converge to the desired postsynaptic firing time
    
    Param image: training image to use, one learning cycle corresponds to one image used
    Param eta: learning rate
    Param td: desired spike time of postsynaptic neuron
    Param ganglions: retinal ganglion cells
    Param w0: random initial asymmetric weights vector
    '''
    #integrate currents into ganglion cells
    integrateCurrents(image, ganglions)
    #now that currents have been initialized for each ganglion, simulate them and get spike times
    i=0
    spike_times = list()
    for ganglion in ganglions:
        simulate(ganglion, ganglion.v, ganglion.rest, ganglion.thresh, ganglion.tau, ganglion.r, ganglion.i)
        spike_times.append(ganglion.spikeTime)
        #update ganglion weight
        w0[i][0] += eta*(td - ganglion.spikeTime)
        i+=1
    #normalize weight vector
    w0 = preprocessing.normalize(w0, axis=0)
    #reset current of ganglions
    resetCurrent(ganglions)
    return w0
def integrateCurrents(image, ganglions):
    '''
    Integrates currents into ganglion cells
    Param ganglions: ganglion cells
    '''
    #integrate currents into ganglion cells
    for ganglion in ganglions:
        if(ganglion.onCenter):
            ganglion.i += 5 * image[int((ganglion.imageSegYStart + ganglion.imageSegYEnd)/2)][int((ganglion.imageSegXStart  + ganglion.imageSegXEnd)/2)]
        else:
            subArea = image[ganglion.imageSegYStart:ganglion.imageSegYEnd + 1]
            for i in range(len(subArea)):
                jStart = ganglion.imageSegXStart
                middle = int(ganglion.imageSegXStart + ganglion.imageSegXEnd/2)
                jEnd = ganglion.imageSegXEnd
                while(jStart <= jEnd):
                    if(jStart != middle):
                        ganglion.i += 5 * subArea[i][jStart]
                    jStart += 1
                    
def resetCurrent(ganglions):
    '''
    Convenience function to reset the current of ganglion cells after learning cycles or input simulations
    '''
    for ganglion in ganglions:
        ganglion.i = 10

def simulateNetwork(ganglions, image, weights):
    '''
    Simulates the network given an image and returns spiking time of the output neuron
    Param ganglions: retinal ganglion cells
    Param image: validation image to be used
    Param weights: synaptic efficacies of ganglions to output neuron
    '''
    integrateCurrents(image, ganglions)
    #now that currents have been initialized for each ganglion, simulate them and get spike times
    i=0
    spike_times = list()
    for ganglion in ganglions:
        simulate(ganglion, ganglion.v, ganglion.rest, ganglion.thresh, ganglion.tau, ganglion.r, ganglion.i)
        spike_times.append(ganglion.spikeTime)
    #output neuron will integrate inputs from ganglions as a weighted average of current
    output_neuron = LIF(0, None) #LIF neuron as output (not a retinal ganglion cell)
    currents = list()
    i=0
    for ganglion in ganglions:
        currents.append(4 * ganglion.i * weights[i][0]) #4 is a scalar here to bump up the current so that output neuron fires quicker
        i+=1
    avg_i = np.average(currents)
    output_neuron.i = avg_i
    time_step = spike_times[np.argmax(spike_times, axis=0)]
    time = simulate(output_neuron, output_neuron.v, output_neuron.rest, output_neuron.thresh, output_neuron.tau,
                   output_neuron.r, output_neuron.i)
    resetCurrent(ganglions)
    return time_step + time

In [802]:
#initialize weights
weights = np.random.rand(9, 1)
print(weights)

[[0.94070315]
 [0.94044255]
 [0.89185725]
 [0.91808665]
 [0.96818267]
 [0.0894343 ]
 [0.24688604]
 [0.78864783]
 [0.16825261]]


In [803]:
#run 1000 learning cycles
import random
for i in range(1000):
    weights = learn_cycle(random.choice(positives), 0.3, 11, ganglions, weights)
weights

array([[0.45958799],
       [0.28724249],
       [0.28724249],
       [0.28724249],
       [0.28724249],
       [0.28724249],
       [0.45958799],
       [0.28724249],
       [0.28724249]])

In [804]:
#aggregate spike time for positives
agg_p = list()
for i in range(9):
    agg_p.append(simulateNetwork(ganglions, positives[i], weights))
agg_p = np.average(agg_p, axis=0)
print(agg_p)

9.11111111111111


In [805]:
#aggregate spike times for negatives
agg_n = list()
for i in range(len(negatives)):
    agg_n.append(simulateNetwork(ganglions, negatives[i], weights))
agg_n = np.average(agg_n, axis=0)
print(agg_n)

9.604651162790697


In [811]:
#model evaluation 100 cases (positive + negative)
accuracy = 0
for i in range(100):
    choice = np.random.randint(0, 2)
    if(choice == 0):
        result = simulateNetwork(ganglions, random.choice(negatives), weights)
        accuracy += int(abs(result - agg_n) < abs(result - agg_p))
    else:
        result = simulateNetwork(ganglions, random.choice(positives), weights)
        accuracy += int(abs(result - agg_p) < abs(result - agg_n))
        
print('Accuracy of model: ' + '~' + str((accuracy/100) * 100) + '%')

Accuracy of model: ~82.0%


In [853]:
#model evaluation 10 cases (positive)
accuracy = 0
for i in range(9):
    result = simulateNetwork(ganglions, random.choice(positives), weights)
    if(abs(result - agg_n) < abs(result - agg_p)):
        print('Not a vertical line')
    else:
        print('vertical line')
        accuracy+=1
        
print('Accuracy of model: ' + '~' + str((accuracy/10) * 100) + '%')

vertical line
vertical line
vertical line
vertical line
vertical line
vertical line
vertical line
vertical line
vertical line
Accuracy of model: ~90.0%
