In [320]:
# train to detect a single input spike train
import math
import cmath
import numpy as np
from numpy import linalg

In [352]:
# the patterns are a list of vectors, each vector entry contains a time value between 0,T
#the different vectors are noisy versions of the same spike pattern
T = 1
V_target = 1.2
w = math.pi/T/2
regularizer = 0.01


def addRandomPatterns(complexPattern, numNoisy):
    l = len(complexPattern[0])
    for n in range(numNoisy):
        complexPattern.append(np.random.random(l))
    return complexPattern

def numbersToPhases(numbersVec):
    listComplex = [int(t>=0)*cmath.exp(t*w*1j) for t in numbersVec]
    return listComplex

def phasesToNumbers(complexVec):
    listPhases = [cmath.phase(c) for c in complexVec]
    listNumbers = [(ph%w)/w for ph in listPhases ]
    return listNumbers

def regressionWithPseudoInv(dataMtx, ouptutVec):
    ##Using the classical pseudoinverse
    pInvComplexData = linalg.pinv(dataMtx)
    resultW = np.matmul(pInvComplexData,ouptutVec)
    return resultW

def regressionWithRegularizer(dataMtx, ouptutVec):
    ##Using a regularizer (to prevent weight explosion)
    dataMtxT = np.transpose(dataMtx)
    regInvComplex = linalg.inv(np.matmul(dataMtxT, dataMtx) - regularizer*np.eye(numberNeurons))
    regInvProjComplex = np.matmul(regInvComplex,dataMtxT)
    resultW = np.matmul(regInvProjComplex,ouptutVec)
    return resultW
        
def recoverDelaysOnePattern(patternsToDetect):
    #Encode spikes as complex numbers
    countPatterns = len(patternsToDetect)
    countNoisePatterns = numberNoisePatterns
    complexPattern = [numbersToPhases(singlePattern) for singlePattern in patternsToDetect]
    
    #add random examples which do NOT correspond to the pattern
    complexPatternWithCounterExamples = addRandomPatterns(complexPattern,countNoisePatterns)
    
    detectPattern = [int(n<countPatterns)*V_target for n in range(0,countNoisePatterns + countPatterns)]
    
    #to numpy format
    complexData = numpy.array(complexPatternWithCounterExamples)
    outputPattern = numpy.array(detectPattern)
    
    #obtain the complex weights
    resultW = regressionWithPseudoInv(complexData, outputPattern)
    #Complex weights to delays and weights (weights = neuron reliability)
    delays = phasesToNumbers(resultW)
    weights = [abs(w) for w in resultW]
    
    return (delays,weights)

In [355]:
#Generate a toy example
numberNeurons = 10
numberExamples = 100 ##WARNING: numberNeurons < numberExamples => very high weights (but delays ok); degenerated inverse
numberNoisePatterns = 0#10*numberExamples

jitter = 0.05*T
noSpikeProb = 0.05


def noisedSpike(cleanTime):
    if np.random.rand()<noSpikeProb:
        t = 0
    else:
        t = min(max(0,cleanTime + numpy.random.rand()*jitter),T)
    return t

def noisedPattern(patternDelays):
    noisedSpikes = [noisedSpike(t) for t in patternDelays]
    return noisedSpikes

def generateInputSpikeTrains(patternDelays):
    inputPattern = [noisedPattern(patternDelays) for n in range(numberExamples)]
    return inputPattern

In [356]:
patternDelays = np.random.rand(numberNeurons)*T
complexDelays = numbersToPhases(patternDelays)

patternsToDetect = generateInputSpikeTrains(patternDelays)


recoveredDelays, weights = recoverDelaysOnePattern(patternsToDetect)


totalLatency = [rec + pat for (rec,pat) in zip(recoveredDelays,patternDelays)]


#Performance of the delay recovery:
print('The delays + spike timing should average to ' +str(T)+', and they add up to: '+str(sum(totalLatency)/numberNeurons))
print('The weights should average to '+ str(V_target)+ ' and they add up to '+str(sum(weights)))

The delays + spike timing should average to 1, and they add up to: 1.0517369308573552
The weights should average to 1.2 and they add up to 1.2114245086411495


1.4142135623730951


0.10000000000000026

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])