# Code for figures 4 and 4S  from "Balanced Excitation and Inhibition are Required for High-Capacity, Noise-Robust Neuronal Selectivity"

## Ran Rubin, Larry Abbott and Haim Sompolinsky

Code by Ran Rubin.


In [None]:
%pylab nbagg

# Balanced vs. Unbalanced Solutions With Spiking Neurons

We whish to use weights learned for the perceptron and test their performance on a spiking neuron.

We define the neuron's input-output as follows:
* Convert static input $x_i$ to rate: $r_i=Ax_i$. We use $A=30\mathrm{Hz}$ and  $\bar{x}_\mathrm{exc}=\sigma_\mathrm{exc}=1$.
* Input spike trains are drawn randomly from poisson process with rate $r_i$ for duration $T=200\mathrm{ms}$. 
* The neuronal dynamics are a standard LIF neuron with current based, exponentially decaying synaptic inputs. We use a membrane time constant of $\tau_m=30\mathrm{ms}$, and a synaptic time constant of $\tau_s=10\mathrm{ms}$.

## Functions to Generate Perceptron Input and desired Output

In [None]:
#Distribution of excitatory input activities
Pex = lambda size: np.random.exponential(size=size)

#Distribution of inhibitory input activities
class Pinh(object):
    def __init__(self,A,k,dxbar):
        self.A=A
        self.k=k
        self.dxbar=dxbar
    def __call__(self,size):
        return self.A*np.random.gamma(self.k,size=size)+self.dxbar


def gen_random_patterns(N,f_out,g_ex,alpha,Pex,Pinh):
    '''Generates input patterns and desired output.
    Input parameters:
    -----------------
    N - number of inputs
    f_out - propability of an input pattern to belong to the 'plus' catagory
    g_ex - fraction of excitatory inputs
    alpha - Number of patterns to create as a fraction of N
    Pex - A callable object to create excitatory activity
    Pinh - A callable object to create inhibitory activity
    
    Pex and Pinh be callable with signature P(size=(N,P))  
    and return an N by P array of input activities.
    
    Returns:
    --------
    X - an NxP array of input activity patterns
    y - a P vector of desired labels (+-1) with an avg. number of +1 of P*f_out
    P - the number of patterns created
    g - an N vector of the input affarent type: 
        +1 for an excitatory input and -1 for an inhibitory input.    
    '''
    #This leads to phi=lambda=sqrt(k)
    P=int(alpha*N)
    N_ex=int(g_ex*N)
    N_inh=N-N_ex
    X_ex=Pex(size=(N_ex,P))
    X_inh=Pinh(size=(N_inh,P))
    X=np.vstack((X_ex,X_inh))
    y=2.*((np.random.rand(P,1)<f_out).astype(np.float))-1.
    g=np.ones((N,1))
    g[N_ex:]=-1.
    return X,y,P,g


## Functions to Generate Perceptron Solutions

We formulate the problem as a conic programing problem and use a standard solver (CVXOPT, http://cvxopt.org/) to find the solutions. 

In [None]:
sys.path.append('../QuadProg/')
from max_kappa import sign_constrained_perceptron_max_kappa_out, \
                      sign_constrained_perceptron_max_kappa_in

# IpyParallel

We use Ipython Parallel (https://github.com/ipython/ipyparallel) to parallelize the simulations on a cluster. 
However, the same code will run locally on a multi core computer just fine. One just need to set up the cluster.

In [None]:
from ipyparallel import Client
import ipyparallel as parallel

ipyp_profile='default'

In [None]:
#Connecting a client to the cluster (which you should set up before)
c = Client(profile=ipyp_profile)
dview=c[:]
lbview=c.load_balanced_view()
print(len(lbview)) #number of connected engines

## Remote imports

In [None]:
import os
cwd=os.getcwd()
dview['cwd']=cwd

In [None]:
%%px
import numpy as np
from numpy import *

# For LIF simulation I use my own C++ extension library that implements a
# highly efficient event based simulation of the LIF neuron developed by Robert Gutig 
# and myself.
# The library was written to implement learning in spiking neurons (See my previouse work) 
# and contains several dynamic and learning models. 
# However, here we only use it to simply generate spiking inputs from rate inputs and simulate the LIF dynamics
import sys
#sys.path.append('path to were the SpikingTempotron package resides')
sys.path.append(cwd)

import SpikingTempotron as ST

## Random seeds

In [None]:
#Setting different random seeds for each engine. 
#Note that using the same seed will not guaranty exactly the same results 
#since we do not expilicitly control which job will be sent to which engine  
import time
numpy_seed=23549264
cpp_seed=98982967
np.random.seed(numpy_seed)
dview.scatter('numpy_seed',numpy_seed+arange(1,len(lbview)+1))
dview.scatter('cpp_seed',cpp_seed+arange(1,len(lbview)+1))
r1=dview.execute('np.random.seed(numpy_seed[0])')
r2=dview.execute('ST.setRNGenSeed(int(cpp_seed[0]))') #this is the RNG inside the C++ extension library
while not(r1.done() and r2.done()):
    time.sleep(1)
#A (True,True) output implies success
(array(r1.status)=='ok').all(),(array(r2.status)=='ok').all()

## Local parameters

In [None]:
N=1000 # Number of neurons
g_ex=0.8 # Fraction of excitatory inputs
f_out=0.1 # Fraction of 'plus' patterns
#Parameters for inhibitory activity distribution
k=2. 
dxbar=.0
A_inh=sqrt(2.)

alpha=1. # Load
P=1000 # Number of patterns
Gamma = 1.5 # Maximal norm of weights

## Remote defs

In [None]:
#sending N and P to all engines
dview['N']=N
dview['P']=P

In [None]:
%%px 
#Neuron's parameters:
#time in sec
tau_m = 0.03
tau_s = 0.01
T = 0.2
LIFtheta = 1.

#PSP_Kernel (Double exponential kernel to be used inside the LIF neuron)
K = ST.PSP_Kernel(tau_m,tau_s)
K.T = T

#Units of x:
A = 30 #in Hz

#noise parameters
#Parameters for low noise conditions
N_noise = 1
sigma_noise_w = 0./sqrt(N_noise)

#An LIF neuron that will test the max kappa_in solution 
tempo_in = ST.SpikingTempotron(N+N_noise,1.,0.,LIFtheta)
tempo_in.K = K

#An LIF neuron that will test the max kappa_out solution
tempo_out = ST.SpikingTempotron(N+N_noise,1.,0.,LIFtheta)
tempo_out.K = K

#A funtion to set the weights of the LIF neurons
def set_weights():
    tempo_in.w = hstack([w_in,sigma_noise_w*random.randn(N_noise)])
    tempo_out.w = hstack([w_out,sigma_noise_w*random.randn(N_noise)])

## Functions to activate LIF neurons with input pattern $\mu$

In [None]:
#These functions are defined locally but will be sent to remote engines.
#Thus they refer to global varaibles that are defined on remote engines.

def activate_in(mu):
    '''
    Activate the max kappa_in LIF neuron with input pattern mu
    Returns the number of output spikes.
    '''
    # Create Poisson Spike-trains from input pattern
    spk_train = ST.CPPgeneratePoissonSpikeTrian(X[:,mu]*A,T,N,N_noise) 
    # Pre-calculate exponents of input spike times 
    spk_train.calcExponents(K)
    #activates the max kappa_in LIF neuron with created spike train 
    tempo=tempo_in
    tempo.setPattern(spk_train) 
    tempo.restart()
    tempo.activate_no_teacher() #actual event base LIF simulation
    return len(tempo.crossings()) # tempo.crossings() is an array with output spike times following the activation
def activate_out(mu):
    '''
    Activate the max kappa_out LIF neuron with input pattern mu
    Returns the number of output spikes.
    '''
    spk_train = ST.CPPgeneratePoissonSpikeTrian(X[:,mu]*A,T,N,N_noise)
    spk_train.calcExponents(K)
    tempo=tempo_out
    tempo.setPattern(spk_train)
    tempo.restart()
    tempo.activate_no_teacher()
    return len(tempo.crossings())

## Functions to perform simulation

In [None]:
#Suppress detailed output of optimization
import cvxopt
cvxopt.solvers.options['show_progress']=False
import time

def get_sol(X,y,g,Gamma):
    '''
    Returns the max kappa_in and max kappa_out solutions
    Input parameters:
    -----------------
    X - an NxP input patterns array
    y - a P vector of desired labels (+-1)
    g - an N vector of the input affarent type: 
        +1 for an excitatory input and -1 for an inhibitory input.    
    Gamma - Max norm of weight vector
    
    Returns:
    w_in - max kappa_in solution
    w_out - max kapa_out solution
    '''
    #find max kappa_out solution
    w,theta,tau,sol,converged_to_solution = \
    sign_constrained_perceptron_max_kappa_out(X,y,g,Gamma)
    #calculate 'physical' weights
    w_out=(w/theta).flatten()
    
    #find max kappa_in solution
    w,theta,tau,sol,converged_to_solution = \
        sign_constrained_perceptron_max_kappa_in(X,y,g,Gamma)
    #calculate 'physical' weights
    w_in=(w/theta).flatten()
    
    return w_in,w_out
    
def run_sim(nsim=1,n_rep=100):
    '''
    Runs nsim simulation on the IPyCluster
    n_rep - number of times to present each pattern
    
    Uses global variables defined in 'Local Parameters'
    '''
    n_spikes_out=[]
    n_spikes_in=[]
    y_vec=[]
    
    #Generate patterns
    X,y,P,g = gen_random_patterns(N,f_out,g_ex,alpha,Pex,Pinh(A_inh,k,dxbar))  
    #save desired labels of patterns to list y_vec
    y_vec.append(y)
    #find solutions
    w_in,w_out=get_sol(X,y,g,Gamma)
    
    #Send data to engines and waits to completion
    dview.push({'X':X,'y':y,'w_in':w_in,'w_out':w_out},block=True)
    
    #Sets the weights of remote LIF neurons and waits for completion
    dview.execute('set_weights()',block=True)
     
    for sim in range(nsim):
        tt=time.time()
        #send jobs to engiens
        #testing the max kappa_out solution
        amr_out=[lbview.map(activate_out,range(P),chunksize=500) for n in range(n_rep)]
        #testing the max kappa_in solution
        amr_in=[lbview.map(activate_in,range(P),chunksize=500) for n in range(n_rep)]
        
        #while engiens test solutions generate new solutions
        if sim+1 < nsim:
            #Generate patterns
            X,y,P,g = gen_random_patterns(N,f_out,g_ex,alpha,Pex,Pinh(A_inh,k,dxbar))  
            #save desired labels of patterns to lis y_vec
            y_vec.append(y)
            #find solutions
            w_in,w_out=get_sol(X,y,g,Gamma)
        
        #gather results
        #each result is an n_rep x P array of the number of output spikes 
        #for each pattern presentation
        n_spikes_out.append(array([a.result() for a in amr_out]))
        n_spikes_in.append(array([a.result() for a in amr_in]))
        
        #send new patterns to engines
        if sim+1 < nsim:
            dview.push({'X':X,'y':y,'w_in':w_in,'w_out':w_out},block=True)
            dview.execute('set_weights()',block=True)
        print("Sim {} complete in {} sec".format(sim+1,time.time()-tt))
    #return results
    return n_spikes_out,n_spikes_in,y_vec

# Perform no output noise condition simulation

In [None]:
n_spikes_out,n_spikes_in,y_vec=run_sim(10,500)

In [None]:
#Save results to a dictionary
res={'low_output_noise':{'n_spikes_in':n_spikes_in,\
                         'n_spikes_out':n_spikes_out,\
                         'y_vec':y_vec}
    }

# Perform high output noise condition simulation

## Remote defs

In [None]:
%%px 
#noise parameters for high output noise condition
N_noise = 30000
sigma_noise_w = 2./sqrt(N_noise)

#New LIF Neurons for the high output noise cndition
tempo_in = ST.SpikingTempotron(N+N_noise,1.,0.,LIFtheta)
tempo_in.K = K


tempo_out = ST.SpikingTempotron(N+N_noise,1.,0.,LIFtheta)
tempo_out.K = K

## Simulations

In [None]:
n_spikes_out,n_spikes_in,y_vec=run_sim(40,100)

In [None]:
#save result to dict
res.update({'high_output_noise':{'n_spikes_in':n_spikes_in,\
                                 'n_spikes_out':n_spikes_out,\
                                 'y_vec':y_vec}
    })

In [None]:
#Save res dict to disk in your favorite way
##....

In [None]:
#shut down engiens if you wish
c.shutdown()

# Plotting Figures From Results

In [None]:
#Load res dict 
#....

In [None]:
def format_ax(ax):
    ax.spines['top'].set_visible(0)
    ax.tick_params(top=False,right=False,which='both')
    ax.spines['right'].set_visible(0)
    

In [None]:
mpl.style.use('classic')

# Figure 4

In [None]:
close()
rcParams['font.size']=9
title_pos=[-0.35,1.0]
title_fs=10
lbl_fs=8

f=figure(figsize=[4.5,4.5])

from matplotlib import gridspec

gs=gridspec.GridSpec(10,10)

#########################################################
################## low output noise #####################
#########################################################
globals().update(res['low_output_noise'])

arr_n_spikes_out=hstack(n_spikes_out[:])
arr_n_spikes_in=hstack(n_spikes_in[:])

#indexes for 'plus' and 'minus' patterns
plus_ind = find(vstack(y_vec[:])==1)
minus_ind = find(vstack(y_vec[:])==-1)
#########################################################
ax=f.add_subplot(gs[1:5,1:4])
format_ax(ax)
ax.set_title('(a)',fontsize=title_fs,position=title_pos)

hp_out=hist(arr_n_spikes_out[:,plus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="`plus' patterns")
hm_out=hist(arr_n_spikes_out[:,minus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="`minus' patterns",ec='r')

xlabel('Spike Count',fontsize=lbl_fs)
ylabel('Probability',fontsize=lbl_fs)
legend(loc=(.9,-0.77),fontsize=6)
#title(r'Max. $\kappa_\mathrm{out}$ solution')
xlim(-1,30)
xticks([0,10,20,30])
ylim(0,0.5)
yticks([0,0.2,0.4])
####################################################################################33
ax=f.add_subplot(gs[1:5,4:7])
format_ax(ax)
ax.set_title('(b)',fontsize=title_fs,position=title_pos)

hp_in=hist(arr_n_spikes_in[:,plus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="`plus' patterns")
hm_in=hist(arr_n_spikes_in[:,minus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="`minus' patterns",ec='r')

xlabel('Spike Count',fontsize=lbl_fs)
ylabel('Probability',fontsize=lbl_fs)
#legend(loc='upper right',fontsize=9)
#title(r'Max. $\kappa_\mathrm{in}$ solution')
xlim(-1,16)
xticks([0,5,10,15])
ylim(0,0.5)
yticks([0,0.2,0.4])
#######################################################################################333
ax=f.add_subplot(gs[1:5,7:])
format_ax(ax)
ax.set_title('(c)',fontsize=title_fs,position=title_pos)

plot(1-cumsum(hp_out[0]),cumsum(hm_out[0]),lw=2,color='k',label=r'Max. $\kappa_\mathrm{out}$')
plot(1-cumsum(hp_in[0]),cumsum(hm_in[0]),'--o',color='0.5',lw=2,label=r'Max. $\kappa_\mathrm{in}$',mec='none')
ylim(0,1.05)
yticks([0,0.5,1])
xlim(0,1.05)
xticks([0,0.5,1])

xlabel("'plus' fraction\ncorrect",fontsize=lbl_fs)
ylabel("'minus' fraction\ncorrect",fontsize=lbl_fs)
#ax.set_aspect(1)
legend(loc=(0.075,-0.77),fontsize=6,numpoints=1)
#########################################################
################## high output noise #####################
#########################################################

globals().update(res['high_output_noise'])

arr_n_spikes_out=hstack(n_spikes_out[:])
arr_n_spikes_in=hstack(n_spikes_in[:])

plus_ind = find(vstack(y_vec[:])==1)
minus_ind = find(vstack(y_vec[:])==-1)
#########################################################################################
ax=f.add_subplot(gs[6:,1:4])
format_ax(ax)
ax.set_title('(d)',fontsize=title_fs,position=title_pos)

hp_out=hist(arr_n_spikes_out[:,plus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="'plus' patterns")
hm_out=hist(arr_n_spikes_out[:,minus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="'minus' patterns",ec='r')

xlabel('Spike Count',fontsize=lbl_fs)
ylabel('Probability',fontsize=lbl_fs)
#legend(loc='upper right',fontsize=9)
#title(r'Max. $\kappa_\mathrm{out}$ solution')
xlim(-1,30)
xticks([0,10,20,30])
ylim(0,0.5)
yticks([0,0.2,0.4])
###################################################################################333
ax=f.add_subplot(gs[6:,4:7])
format_ax(ax)
ax.set_title('(e)',fontsize=title_fs,position=title_pos)


hp_in=hist(arr_n_spikes_in[:,plus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="`plus' patterns")
hm_in=hist(arr_n_spikes_in[:,minus_ind].flatten(),arange(60)-0.5,normed=True,\
           histtype='step',label="`minus' patterns",ec='r')

xlabel('Spike Count',fontsize=lbl_fs)
ylabel('Probability',fontsize=lbl_fs)
#legend(loc='upper right',fontsize=9)
#title(r'Max. $\kappa_\mathrm{in}$ solution')
xlim(-1,16)
xticks([0,5,10,15])
ylim(0,0.5)
yticks([0,0.2,0.4])

######################################################################################33
ax=f.add_subplot(gs[6:,7:])
format_ax(ax)
ax.set_title('(f)',fontsize=title_fs,position=title_pos)

plot(1-cumsum(hp_out[0]),cumsum(hm_out[0]),lw=2,color='k')
plot(1-cumsum(hp_in[0]),cumsum(hm_in[0]),'--o',color='0.5',lw=2,mec='none')
ylim(0,1.05)
yticks([0,0.5,1])
xlim(0,1.05)
xticks([0,0.5,1])
xlabel("'plus' fraction\ncorrect",fontsize=lbl_fs)
ylabel("'minus' fraction\ncorrect",fontsize=lbl_fs)
#ax.set_aspect(1)

###########################################################################
f.text(0.00,0.85,"No Output Noise",fontsize=10,rotation='vertical')
f.text(0.00,0.375,"High Output Noise",fontsize=10,rotation='vertical')
f.text(0.1,0.92,"Balanced\n"+r"max. $\kappa_\mathrm{out}$ Solution",\
       fontsize=10,multialignment='center')
f.text(0.425,0.92,"Unbalanced\n"+r"max. $\kappa_\mathrm{in}$ Solution",\
       fontsize=10,multialignment='center')

tight_layout()

# Figure S4 

In [None]:
globals().update(res['low_output_noise'])

arr_n_spikes_out=hstack(n_spikes_out[:])
arr_n_spikes_in=hstack(n_spikes_in[:])

plus_ind = find(vstack(y_vec[:])==1)
minus_ind = find(vstack(y_vec[:])==-1)

close()
mpl.style.use('classic')
rcParams['font.size']=8

f=figure(1,figsize=[3.42,2.1])
title_pos=[-0.35,1.]
title_fs=9
latex_fs=12
ax=subplot(121)
format_ax(ax)

r=arange(-0.25,60,0.5)

hist(mean(arr_n_spikes_out[:,plus_ind],0),r,normed=True,histtype='step',label="'plus' pat.")
hist(mean(arr_n_spikes_out[:,minus_ind],0),r,normed=True,histtype='step',label="'minus' pat.",ec='r')
xlabel('Mean Spike Count')
ylabel('Probability Density')
legend(loc='upper right',fontsize=6)
title('Balanced\n'+r'max. $\kappa_\mathrm{out}$ solution',fontsize=title_fs)
xlim(-1,20)
ylim(0,.7)
yticks([0,0.3,0.6])
#ylim(0,0.55)

ax=subplot(122)
format_ax(ax)

hist(mean(arr_n_spikes_in[:,plus_ind],0),r,normed=True,histtype='step',label="'plus' pat.",lw=1)
hist(mean(arr_n_spikes_in[:,minus_ind],0),r,normed=True,histtype='step',label="'minus' pat.",ec='r',lw=1)
xlabel('Mean Spike Count')
ylabel('Probabilty Density')
legend(loc='upper right',fontsize=6)
title('Unbalanced\n'+r'max. $\kappa_\mathrm{in}$ solution',fontsize=title_fs)
xlim(-1,10)
ylim(0,2.2)
yticks([0,1,2.])
tight_layout()
