In [1]:
# these command make the cell width wider than default
from IPython.display import display, HTML
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

In [2]:
class spike():
    def __init__(self, hyperPar):
        self.num_classes = input_param.num_classes # number of classies for classification, i.e. 10 for MNIST data
        self.switch_norm = info_hyperparams('switch_norm',hyperPar) # normalize weight every run
        self.n_syn = info_hyperparams('n_syn',hyperPar) # number of synapses per pre-post pair. now it's a dummy variable.
        self.stdp_type = info_hyperparams('stdp_type',hyperPar) # stdp type
        self.N_e = [int(math.pow(self.sqrt_grp_size[ih],2)*input_param.num_classes) for ih in range(len(self.sqrt_grp_size))]
        self.max_delay_input2e = info_hyperparams('max_delay_input2e', hyperPar, brian2.ms) # max axonic delay
        self.max_delay_efe = info_hyperparams('max_delay_efe', hyperPar, brian2.ms) # max axonic delay
        self.max_dendritic_delay = info_hyperparams('max_dendritic_delay', hyperPar, brian2.ms) # max dendritic delay
        self.tau_adpt = info_hyperparams('tau_adpt', hyperPar, brian2.ms)  # adaptive threshold time constant
        self.delta_vt = info_hyperparams('delta_vt', hyperPar, brian2.mV)  # increment for adaptive threshold
        self.tau_membrane_exci = info_hyperparams('tau_membrane_exci', hyperPar, brian2.ms)
        self.tau_ge = info_hyperparams('tau_ge', hyperPar, brian2.ms)
        self.tau_gi = info_hyperparams('tau_gi', hyperPar, brian2.ms) 
        self.sigma_noise = info_hyperparams('sigma_noise', hyperPar, brian2.mV) #  sigma for Ornstein-Unlenbeck noise
        self.gmax_input2e = info_hyperparams('gmax_input2e', hyperPar)  # for syn between input and hidden layers
        self.gmax_efe = info_hyperparams('gmax_efe', hyperPar)  # for syn between different hidden layers in forward direction
        self.norm_scale_S_input2e = info_hyperparams('norm_scale_S_input2e', hyperPar) # average weight for S_input2e
        self.norm_scale_S_efe = info_hyperparams('norm_scale_S_efe', hyperPar) # average weight for S_efe
        self.penalty_input2e = info_hyperparams('penalty_input2e', hyperPar)
        self.penalty_efe = info_hyperparams('penalty_efe', hyperPar)
        self.dW_e2e = info_hyperparams('dW_e2e', hyperPar)
        self.w_sat_scale = info_hyperparams('w_sat_scale', hyperPar) # range->[0,1], determine how fast the weight saturate
        self.w_sat_shift = info_hyperparams('w_sat_shift', hyperPar) # range->[0,1], shift the drop-own part of the tanh function to higher or lower value 
        self.vt_sat_scale = info_hyperparams('vt_sat_scale', hyperPar) # range->[0,1], same function as w_sat_scale but for determining how fast the vt saturate
        self.vt_sat_shift = info_hyperparams('vt_sat_shift', hyperPar) # range->[0,1], shift the drop-own part of the tanh function to higher or lower value 
                                             
        # neuron parameters 
        self.v_thres_exci = info_hyperparams('v_thres_exci', hyperPar, brian2.mV) 
        self.v_reversal_e_exci = info_hyperparams('v_reversal_e_exci', hyperPar, brian2.mV)
        self.v_reversal_i_exci = info_hyperparams('v_reversal_i_exci', hyperPar, brian2.mV)
        self.v_rest_exci = info_hyperparams('v_rest_exci', hyperPar, brian2.mV)
        self.v_reset_exci = info_hyperparams('v_reset_exci', hyperPar, brian2.mV)
        self.refrac_time_exci = info_hyperparams('refrac_time_exci', hyperPar, brian2.ms)

        # network to collect all information before run
        self.col = brian2.Network()
        
        #
        if self.debug == True:
            print('self.switch_norm: ', self.switch_norm)
            print('self.N_hidden: ', self.N_hidden)
            print('self.N_e: ', self.N_e)
            print('self.max_delay_input2e: ', self.max_delay_input2e)
            print('self.max_delay_efe: ', self.max_delay_efe)
            print('self.n_syn: ', self.n_syn)
            print('self.tau_adpt: ', self.tau_adpt)
            print('self.delta_vt: ', self.delta_vt)
            print('self.vt_sat_scale: ', self.vt_sat_scale)
            print('self.vt_sat_shift: ', self.vt_sat_shift)
            print('self.tau_membrane_exci: ', self.tau_membrane_exci)
            print('self.tau_ge: ', self.tau_ge)
            print('self.tau_gi: ', self.tau_gi)
            print('self.sigma_noise: ', self.sigma_noise)
            print('self.gmax_input2e: ', self.gmax_input2e)
            print('self.gmax_efe: ', self.gmax_efe)
            print('self.norm_scale_S_input2e: ', self.norm_scale_S_input2e)
            print('self.norm_scale_S_efe: ', self.norm_scale_S_efe)
            print('self.stdp_type:', self.stdp_type)
            print('self.penalty_input2e: ',  self.penalty_input2e)
            print('self.penalty_efe: ', self.penalty_efe)
            print('v_thres_exci: ', self.v_thres_exci) 
            print('v_reversal_e_exci: ', self.v_reversal_e_exci) 
            print('v_reversal_i_exci: ', self.v_reversal_i_exci) 
            print('v_rest_exci: ', self.v_rest_exci)
            print('v_reset_exci: ', self.v_reset_exci)
            print('refrac_time_exci: ', self.refrac_time_exci)
            print('self.dW_e2e: ', self.dW_e2e)
            print('self.w_sat_scale: ', self.w_sat_scale)
            print('self.w_sat_shift: ', self.w_sat_shift)
                
    # common equations used for neuron, synapse and STDP groups
    def set_equations(self, hyperParams):        
        # neuron equation
        self.eqs_neuron = '''
            dv/dt = (ge*(Ee-v) + gi*(Ei-v)+ El - v)/tau_m + sigma_noise*xi*tau_m**-0.5: volt (unless refractory)
            dge/dt = -ge/tau_ge : 1 (unless refractory)
            dgi/dt = -gi/tau_gi : 1 (unless refractory)
            dvt/dt = (v_thres-vt)/tau_adpt : volt (unless refractory) # adapt threshold
            delta_vt: volt (constant)
            vt_sat_scale: 1 (constant)
            vt_sat_shift: 1 (constant)
            tau_adpt: second (constant)
            v_rest: volt (constant)
            v_thres: volt (constant)
            v_reset: volt (constant)
            refrac_time: second (constant)
            x: 1
            y: 1
            x_max: 1 (constant)
            y_max: 1 (constant)
            group: integer # which digit group does it belong to
            label: integer # image label; can't be (shared) type for some reason
            stimulus_idx : integer  # Index of the stimulus to show
            stimulus_strength : 1  # Factor to scale up stimulus
            tau_m: second (constant)
            tau_ge: second (constant)
            tau_gi: second (constant)
            Ee: volt (constant)
            Ei: volt (constant)
            El: volt (constant)
            sigma_noise: volt (constant) # random noise
            rank: integer (constant) # rank in distributed training. Needed for S_e2e.connect() to avoid all-to-all connect in testing when test_option=add_more_neuron*
        '''
            
        self.reset = '''
            v = v_reset
            vt += delta_vt*(0.5-0.5*tanh((-2*(vt-v_thres*(vt_sat_shift-0.5))/v_thres+1)/vt_sat_scale)) # gradually approach max 
        '''
    
        if self.debug==True:
            print(' train with adaptive threshold ....')
            print('eqs_neuron: ', self.eqs_neuron)

        # STDP equations of different types
        if self.stdp_type == 0: #synapse equation from Song, Miller and Abbott (2000) and Song and Abbott (2001)
            if self.debug==True:
                print(' using synapse equation from Song, Miller and Abbott (2000) and Song and Abbott (2001) ----- ')
            
            self.taupre = hyperParams['taupre']*brian2.ms
            self.taupost = hyperParams['taupost']*brian2.ms
            self.d_Apre = hyperParams['d_Apre']
            self.d_Apost = -self.d_Apre * self.taupre / self.taupost * hyperParams['d_Apost_scale']
            
            self.eqs_syn = '''
                w : 1
                dApre/dt = -Apre / taupre : 1 (event-driven)
                dApost/dt = -Apost / taupost : 1 (event-driven)
                gmax: 1 (constant)
                max_delay: second (constant)
                max_dendritic_delay: second (constant)
                normalized_ave: 1 (shared)
                w_sat_scale: 1 (constant)
                w_sat_shift: 1 (constant)
            '''
            
            # don't add penalty to the pre_action, which trim unnecessary branches when label!=group because every pre-spike will reduce the weight
            # However, when label==group, don't trim because post-neuron may not spike because of too high adaptive threshold. Just keep it as is.
            self.STDP_pre_action = {'pre_nonplastic': 'ge += w',
                                    'pre_plastic': ''' ge += w
                                                       Apre += d_Apre
                                                       w = clip(w + Apost, 0, gmax)'''}
            self.STDP_post_action = '''Apost += d_Apost
                                       w = clip(w + Apre, 0, gmax)'''
            if self.task == 'train':
                self.eqs_syn += '\n penalty: 1 (shared)'
                self.STDP_post_action = '''Apost += d_Apost
                                           w = clip(w + Apre*(0.5-0.5*tanh((2*(w+gmax*(w_sat_shift-0.5))/gmax-1)/w_sat_scale))*(1- penalty*int(group_post!=label_post)), 0, gmax)'''
        elif self.stdp_type ==1: 
            # minimum model of triplet equation for visual cortex, i.e. A2+ = A3- = 0, from Pfister and Gerstner, 2006 as implemented in Diehl and Cook (2015) and his github
            # the actual parameter used in Diehl is different from the original paper. Can be further tuned if needed
            if self.debug==True:    
                print('---- using triplet equation from Pfister and Gerstner, 2006 as implemented in Diehl and Cook (2015) and his github ----- ')
                
            self.nu_pre_ee =  hyperParams['nu_pre_ee']      # A2- of the minimum triplet model
            self.nu_post_ee = hyperParams['nu_post_ee']      # A3+ of the minimum triplet model
            self.tc_pre_ee = hyperParams['tc_pre_ee']*brian2.ms
            self.tc_post_1_ee = hyperParams['tc_post_1_ee']*brian2.ms
            self.tc_post_2_ee = hyperParams['tc_post_2_ee']*brian2.ms
                    
            self.eqs_syn = '''
                w : 1
                post2before                            : 1
                dpre/dt   =   -pre/(tc_pre_ee)         : 1 (event-driven)
                dpost1/dt  = -post1/(tc_post_1_ee)     : 1 (event-driven)
                dpost2/dt  = -post2/(tc_post_2_ee)     : 1 (event-driven)
                gmax: 1 (constant) 
                max_delay: second (constant)
                max_dendritic_delay: second (constant)
                normalized_ave: 1 (shared)
                w_sat_scale: 1 (constant)
                w_sat_shift: 1 (constant)
            '''
            # don't add penalty to the pre_action, which trim unnecessary branches when label!=group because every pre-spike will reduce the weight
            # However, when label==group, don't trim because post-neuron may not spike because of too high adaptive threshold. Just keep it as is. 
            self.STDP_pre_action = {'pre_nonplastic': 'ge += w',
                                    'pre_plastic': '''ge += w 
                                                      pre = 1. 
                                                      w = clip(w - nu_pre_ee * post1, 0, gmax)'''}
            self.STDP_post_action = '''post2before = post2 
                                       w = clip(w + nu_post_ee * pre * post2before, 0, gmax) 
                                       post1 = 1. 
                                       post2 = 1.'''        
            if self.task == 'train':
                self.eqs_syn += '\n penalty: 1 (shared)'
                self.STDP_post_action = '''post2before = post2 
                                           w = clip(w + nu_post_ee * pre * post2before *(0.5-0.5*tanh((2*(w+gmax*(w_sat_shift-0.5))/gmax-1)/w_sat_scale))*(1-penalty*int(group_post!=label_post)), 0, gmax) 
                                           post1 = 1. 
                                           post2 = 1.'''
            
    # add one set of regular neurons
    def add_neurons(self, ih, n, name):
        if self.test_option!=None and 'add_more_neuron' in self.test_option:
            n = n*self.num_workers
            if self.task != 'test' and self.task != 'advatk' and self.task != 'adv_make':
                sys.exit(f'!!!! self.task={self.task}, need to be doing testing when with test_option: {self.test_option} !!!!')
            if self.debug == True:
                print(f' ---  add_neuron with n = {n} for test option: {self.test_option} -----')
                
        neurons = brian2.NeuronGroup(n, self.eqs_neuron, threshold='v>vt', reset=self.reset, refractory='refrac_time', method='euler', name=name)
        
        neuron_dir_path = os.path.join(self.dir_neuron_group, f'neuron_groups_exci_{ih}')    
        if self.test_option == 'use_rank0_info':
            rank = 0
        else:
            rank = self.rank
            
        if  find_file(f'rank{rank}.{input_param.data_format}', neuron_dir_path) is None: 
            neurons.ge = 0
            neurons.gi = 0
            neurons.v = 'v_rest + rand()*(v_thres - v_rest)'
            neurons.tau_m = self.tau_membrane_exci[ih]
            neurons.tau_ge = self.tau_ge[ih]
            neurons.tau_gi = self.tau_gi[ih]
            neurons.sigma_noise = self.sigma_noise
            neurons.tau_adpt = self.tau_adpt[ih]
            neurons.v_rest = self.v_rest_exci
            neurons.v_thres = self.v_thres_exci
            neurons.v_reset = self.v_reset_exci
            neurons.refrac_time = self.refrac_time_exci
            neurons.Ee = self.v_reversal_e_exci
            neurons.Ei = self.v_reversal_i_exci
            neurons.El = self.v_rest_exci
            neurons.delta_vt = self.delta_vt[ih]
            neurons.vt_sat_scale = self.vt_sat_scale
            neurons.vt_sat_shift = self.vt_sat_shift
            neurons.vt = self.v_thres_exci  # inital value of vt
            neurons.v = self.v_rest_exci # inital value of v
            
            neurons.group, neurons.x, neurons.y, neurons.x_max, neurons.y_max = assign_element_of_array_to_random_groups(n, input_param.num_classes)
            # now scale X/Y to match input layer size, in which case, the features of earlier layers and input layer are just
            # directly added on top of each other, virtually forming a single layer with more features, as the input to the current hidden layer
            seg_size = input_param.input_rows/self.sqrt_grp_size[ih] # divide input layer into segments of this size for each group
            neurons.x = (neurons.x+0.5)*seg_size
            neurons.y = (neurons.y+0.5)*seg_size
            neurons.rank = rank
        else:
            if self.debug==True:
                print('loading excitatory neurons: ', neuron_dir_path)
            if self.test_option != None and 'add_more_neuron' in self.test_option:
                list_dict = []
                for ir in range(self.num_workers):
                    list_dict.append(read_data(neuron_dir_path, ir))
                if not None in list_dict:
                    neuron_exci = concatenate_dict_arrays(list_dict)
                else:
                    neuron_exci = None
                    sys.exit(f'!!!! neuron_exci at {neuron_dir_path} is None !!!!')
            else:
                neuron_exci = read_data(neuron_dir_path, rank)
            
            if self.debug==True and self.task!='train':
                write_data(pd.DataFrame.from_dict(neuron_exci), neuron_dir_path, -1)
            
            # excitatory neuron information
            if len(neuron_exci['ge']) >0: 
                neurons.ge[:] = neuron_exci['ge']
                neurons.gi[:] = neuron_exci['gi']
                neurons.v[:] = neuron_exci['v']*brian2.volt
                neurons.tau_m[:] = neuron_exci['tau_m']*brian2.second
                neurons.tau_ge[:] = neuron_exci['tau_ge']*brian2.second
                neurons.tau_gi[:] = neuron_exci['tau_gi']*brian2.second
                neurons.sigma_noise[:] = neuron_exci['sigma_noise']*brian2.volt
                neurons.tau_adpt[:] = neuron_exci['tau_adpt']*brian2.second
                neurons.v_rest[:] = neuron_exci['v_rest']*brian2.volt
                neurons.v_thres[:] = neuron_exci['v_thres']*brian2.volt
                neurons.v_reset[:] = neuron_exci['v_reset']*brian2.volt
                neurons.refrac_time[:] = neuron_exci['refrac_time']*brian2.second
                neurons.Ee[:] = neuron_exci['Ee']*brian2.volt
                neurons.Ei[:] = neuron_exci['Ei']*brian2.volt
                neurons.El[:] = neuron_exci['El']*brian2.volt
                neurons.delta_vt[:] = neuron_exci['delta_vt']*brian2.volt
                neurons.vt_sat_scale[:] = neuron_exci['vt_sat_scale']
                neurons.vt_sat_shift[:] = neuron_exci['vt_sat_shift']
                neurons.vt[:] = neuron_exci['vt']*brian2.volt 
                neurons.group[:] = neuron_exci['group']
                neurons.x[:] = neuron_exci['x']
                neurons.y[:] = neuron_exci['y']
                neurons.x_max[:] = neuron_exci['x_max']
                neurons.y_max[:] = neuron_exci['y_max']
                neurons.rank[:] = neuron_exci['rank']
            else:
                sys.exit(f'!!!! neuron does not exist for rank {rank} and layer {ih}  !!!') 
            
        return neurons

    ##### the following are templates for multi-inheritance 
    def build_neurons(self):
        pass
    def build_synapses(self):
        pass
    def collect(self):
        pass 
    def set_monitors(self):
        pass
    def save_neurons(self):
        pass
    def save_synapses(self):
        pass
    def set_syn_active_status(self):
        pass
    def save_neuron_coord(self):
        pass
    def plot(self):
        pass

In [3]:
class input_layer(spike): 
    def __init__(self, hyperParams):
        super().__init__(hyperParams)
        if self.debug==True:
            print('--- initializing input_layer------')
        
    def build_neurons(self):
        super().build_neurons()  # first parent of models, need this super() to call the same function in the 2nd parents of the model.
        # input neuron with Poisson spikes #
        if self.debug==True:
            print(' building input neuron ...')
        
        self.input_neurons = brian2.NeuronGroup(input_param.Num_input_neuron, 
                                         '''x: 1
                                            y: 1
                                            x_max: 1 (constant)
                                            y_max: 1 (constant)
                                            switch_stimulus: boolean
                                            correct_identification: boolean
                                            higher_than_min_spike_rate: boolean
                                            rate: Hz
                                            label: integer
                                            start_t : second  # Start time of the current trial
                                            stimulus_idx : integer  # Index of the stimulus to show
                                            stimulus_strength : 1  # Factor to scale up stimulus
                                            repetitions : integer  # Number of times the stimulus has been presented
                                            tot_repetitions : integer  # sum(repetitions)
                                            ''', threshold='rand()<rate*dt', name='input_neurons',  
                                            namespace = {'rows':input_param.input_rows, 'cols':input_param.input_cols})
        
        self.input_neurons.x = 'i%cols'  
        self.input_neurons.y = 'rows - i//rows' # // is a floored division
        self.input_neurons.x_max = 'rows-1'
        self.input_neurons.y_max = 'cols-1'
     
    # build synapse
    def build_synapses(self):
        super().build_synapses()  # call the 2nd parent of the model
        
    # input and output neuron related containers are common to all models
    def collect(self):
        super().collect()
        self.col.add(self.input_spikemon)
        self.col.add(self.input_neurons)
    
    def set_monitors(self):
        super().set_monitors()
        if self.debug==True:
            print('setup input layer monitors ...')
        self.input_spikemon = brian2.SpikeMonitor(self.input_neurons, ['stimulus_idx', 'label', 'stimulus_strength', 'repetitions', 'tot_repetitions', 
                                                                       'correct_identification', 'higher_than_min_spike_rate'], name='input_spike_mon')
        self.input_spikemon.active = self.activate_input_spikemon
    
    # plot input spikes
    def plot(self):
        super().plot()

        fig0, axs0 = brian2.plt.subplots(1, 2, figsize=(12, 3))
        axs0[0].scatter(self.input_neurons.x, self.input_neurons.y, c=self.input_spikemon.count)
        axs0[0].set(xlabel='x', ylabel='y')
        axs0[1].plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.i, marker='.', linestyle='', markersize=0.7)
        axs0[1].set(ylabel='nid', xlabel='Time (ms)')
        
        fig1, axs1 = brian2.plt.subplots(6, 1, sharex=True, figsize=(10, 15))
        axs1[0].plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.repetitions, color='C0')
        axs1[0].set(ylabel='repetitions')
        
        axs1[1].plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.tot_repetitions, color='C1')
        axs1[1].set(ylabel='tot_repetitions')
        
        axs1[2].plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.label, color='C2')
        axs1[2].set(ylabel='input label')

        axs1[3].plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.stimulus_idx, color='C3')
        axs1[3].set(ylabel='stimulus index')
        twin_ax = axs1[3].twinx()
        twin_ax.plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.stimulus_strength, color='C4')
        twin_ax.set(ylabel='stimulus strength')
        
        axs1[4].plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.correct_identification, color='C5')
        axs1[4].set(ylabel='correct_identification')
        
        axs1[5].plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.higher_than_min_spike_rate, color='C6')
        axs1[5].set(ylabel='higher_than_min_spike_rate')



In [4]:
class hidden_layer(spike):
    def __init__(self, hyperPar):
        super().__init__(hyperPar) 
        if self.debug==True:
            print('initializing hidden_layer------')
    
    # build input and output neurons, common for all models
    def build_neurons(self):
        super().build_neurons()  # 2nd parent of models, need this super() to call the same function in the 3rd parents of the model.
        if self.debug==True:
            print(' building neurons in hidden layers ...')
        self.exci_neurons = []
        for ih in range(self.N_hidden):
            self.exci_neurons.append(super().add_neurons(ih, self.N_e[ih], 'exci_neurons'+str(ih)))
        
        # Used to keep track of the total number of spikes in all hidden layers
        spike_counter_attribute = 'spike_counter : integer' +'\n' # total number of spikes in all layers
        for ig in range(input_param.num_classes):
            spike_counter_attribute += 'spike_counter_grp'+str(ig)+' : integer' + '\n'  # n_spikes in each group for all layers
        for ih in range(self.N_hidden):
            spike_counter_attribute += 'spike_counter_layer'+str(ih) + ' : integer' + '\n' # evey layer has a spike counter
                
        if self.debug == True:
            print('spike_counter_attribute --> \n', spike_counter_attribute)
            
        self.spike_counter = brian2.NeuronGroup(1, spike_counter_attribute, name='spike_counter')
    
    # save the neurons. 
    def save_neurons(self):
        super().save_neurons()
        if self.debug==True:
            print(' saving the neuron information in: ', self.dir_neuron_group)

        for ih in range(self.N_hidden):
            df = pd.DataFrame({'ge': self.exci_neurons[ih].ge[:],
                               'gi': self.exci_neurons[ih].gi[:],
                               'v': self.exci_neurons[ih].v[:],
                               'tau_m': self.exci_neurons[ih].tau_m[:],
                               'tau_ge': self.exci_neurons[ih].tau_ge[:],
                               'tau_gi': self.exci_neurons[ih].tau_gi[:],
                               'sigma_noise': self.exci_neurons[ih].sigma_noise[:],
                               'tau_adpt': self.exci_neurons[ih].tau_adpt[:],
                               'v_rest': self.exci_neurons[ih].v_rest[:],
                               'v_thres': self.exci_neurons[ih].v_thres[:],
                               'v_reset': self.exci_neurons[ih].v_reset[:],
                               'refrac_time': self.exci_neurons[ih].refrac_time[:],
                               'Ee': self.exci_neurons[ih].Ee[:],
                               'Ei': self.exci_neurons[ih].Ei[:],
                               'El': self.exci_neurons[ih].El[:],
                               'delta_vt': self.exci_neurons[ih].delta_vt[:],
                               'vt_sat_scale': self.exci_neurons[ih].vt_sat_scale[:],
                               'vt_sat_shift': self.exci_neurons[ih].vt_sat_shift[:],
                               'vt': self.exci_neurons[ih].vt[:],  # inital value of vt
                               'v': self.exci_neurons[ih].v[:], # inital value of v
                               'group': self.exci_neurons[ih].group[:],
                               'x': self.exci_neurons[ih].x[:],
                               'y': self.exci_neurons[ih].y[:],
                               'x_max': self.exci_neurons[ih].x_max[:],
                               'y_max': self.exci_neurons[ih].y_max[:],
                               'rank': self.exci_neurons[ih].rank[:]})
            
            write_data(df, os.path.join(self.dir_neuron_group, f'neuron_groups_exci_{ih}'), self.rank)
                        
    # build synapse
    def build_synapses(self):
        super().build_synapses() 
        
        # dummy synapses to keep track of the total number of spikes in all hidden layers
        self.counter_synapse=[]
        for ih in range(self.N_hidden):
            counter_syn_action = 'spike_counter += 1' + '\n'
            counter_syn_action +='spike_counter_layer' + str(ih) + '+=1' + '\n'
            for ig in range(input_param.num_classes): # n_spikes in each group for all layers
                counter_syn_action += 'spike_counter_grp'+str(ig) + '+=1*int('+str(ig)+'==group_pre)' +'\n'
            
            # when n_spikes<n_spike_sample_min, delta_vt = -fixed_delta_vt. This is better than delta_vt=0 which can't lower the threshold to
            if self.task == 'train':
                counter_syn_action +='delta_vt_pre = fixed_delta_vt*int(spike_counter_layer' + str(ih) + '>=n_spike_sample_min) - fixed_delta_vt*int(spike_counter_layer' + str(ih) + '<n_spike_sample_min)' +'\n'
            
            if self.debug==True:
                print('counter_syn_action-->\n', counter_syn_action)
            
            self.counter_synapse.append(brian2.Synapses(self.exci_neurons[ih], self.spike_counter,  on_pre=counter_syn_action, name='counter_synapse'+str(ih), 
                                                       namespace={'n_spike_sample_min':input_param.n_spike_sample_min, 'fixed_delta_vt':self.delta_vt[ih]}))
            self.counter_synapse[ih].connect()
                        
        # syn within individual hidden layer
        if self.task!='train':
            self.S_e2e = [None]*self.N_hidden # excitatory to excitatory between different group w/o STDP
            for ih in range(self.N_hidden):
                self.S_e2e[ih] = brian2.Synapses(self.exci_neurons[ih], self.exci_neurons[ih], 'dW_e2e: 1(constant)', 
                                                            on_pre='gi += dW_e2e', name='s_e2e'+str(ih))
                self.S_e2e[ih].connect(condition='group_pre!=group_post and x_pre==x_post and y_pre==y_post and rank_pre==rank_post')
                self.S_e2e[ih].dW_e2e = self.dW_e2e[ih]
            
    # save the synapses. 
    def save_synapses(self):
        super().save_synapses()
        if self.debug==True and self.task!='train': 
            for ih in range(self.N_hidden):
                df = pd.DataFrame({'i': self.S_e2e[ih].i[:], 'j': self.S_e2e[ih].j[:]})
                write_data(df, os.path.join(self.dir_syn_in, 'S_e2e_hid'+str(ih)), self.rank)
        
    # input and output neuron related containers are common to all models        
    def set_monitors(self):
        super().set_monitors()
        if self.debug==True:
            print('setup hidden layer monitors...')
        self.exci_statemon =[]
        self.exci_spikemon = []
        record_list = [ih for ih in range(0, 9)]
        for ih in range(self.N_hidden):
            self.exci_statemon.append(brian2.StateMonitor(self.exci_neurons[ih], True, record=record_list)) # for the case of N_neuron, e.g. [400, 6]
            self.exci_statemon[ih].active = self.debug
            
            self.exci_spikemon.append(brian2.SpikeMonitor(self.exci_neurons[ih], ['label', 'group', 'stimulus_idx', 'stimulus_strength'], name='exci_spikemon'+str(ih)))
            self.exci_spikemon[ih].active = True
        
    # input and output neuron related containers are common to all models
    def collect(self):
        super().collect()
        self.col.add(self.exci_neurons)
        self.col.add(self.exci_spikemon)
        self.col.add(self.exci_statemon)
        self.col.add(self.spike_counter)
        self.col.add(self.counter_synapse)
        if self.task != 'train':
            self.col.add(self.S_e2e)
        
    # save spikes from a certain spikemon
    def count_last_hlayer_spikes_from_spikemon(self):
        count = [0]*input_param.num_classes #spike counts in each digit
        for i in range(len(self.exci_spikemon[self.N_hidden-1].all_values()['group'])): # loop over each neuron in the last h-layer
            for g in self.exci_spikemon[self.N_hidden-1].all_values()['group'][i]:
                count[g] += 1
        
        if self.debug==True:
            print('number of input spikes : ', self.input_spikemon.num_spikes)
            for ih in range(self.N_hidden):            
                print('number of exci spikes for layer'+str(ih)+': ', self.exci_spikemon[ih].num_spikes)
        
        return count
                            
    # set active = True/False for pre and post in syn for training and testing
    def set_syn_active_status(self):
        super().set_syn_active_status()
                                
    # what to plot. 
    def what_to_plot(self, sub, state_mon):
            sub[0].plot(state_mon.t/brian2.ms, state_mon.v.T)
            sub[0].set(xlabel='Time (s)', ylabel='v(mV)')
            sub[0].set_xlim([0, 2000])
            sub[0].legend(loc="upper left")
            
            sub[1].plot(state_mon.t/brian2.ms, state_mon.ge.T)            
            sub[1].set(xlabel='Time (s)', ylabel='ge')
            sub[1].set_xlim([0, 2000])
            sub[1].legend(loc="upper left")
            
            sub[2].plot(state_mon.t/brian2.ms, state_mon.delta_vt.T)
            sub[2].set(xlabel='Time (s)', ylabel='delta_vt')
            sub[2].legend(loc="upper left")
            
            sub[3].plot(state_mon.t/brian2.ms, state_mon.vt.T)
            sub[3].set(xlabel='Time (s)', ylabel='vt')
            sub[3].legend(loc="upper left")

        
    # plot the result, slow and mostly for debugging purpose
    def plot(self):
        super().plot()
        # i vs. t for input layer and all hidden layers
        brian2.plt.figure(figsize=(15, 10))
        brian2.plot(self.input_spikemon.t/brian2.ms, self.input_spikemon.i, label='input layer', marker='.', linestyle='', markersize=1)
        for ih in range(self.N_hidden):
            brian2.plot(self.exci_spikemon[ih].t/brian2.ms, self.exci_spikemon[ih].i, label='layer'+str(ih), marker='o', linestyle='', fillstyle='none', markersize=5)
        brian2.plt.legend(loc="upper left")
        #brian2.plt.xlim(0, 2000)
        
        if self.task == 'train':
            brian2.plt.figure(figsize=(5, 5))
            brian2.plot(self.exci_spikemon[ih].t/brian2.ms, self.exci_spikemon[ih].label, label='input label in exci_neuron', color='C1')
            brian2.plt.legend(loc="upper left")
            
        # n_cycle and n_spikes vs. t
        brian2.plt.figure(figsize=(15, 15))
        #brian2.plt.xlim(0, 2000)
        for ih in range(self.N_hidden):
            brian2.plot(self.exci_statemon[ih].t/brian2.ms, self.exci_statemon[ih].v.T*(-100)) # x-100 to make it visible
            brian2.plot(self.exci_statemon[ih].t/brian2.ms, self.exci_statemon[ih].ge.T)
        brian2.plt.xlabel('Time (s)')
        brian2.plt.ylabel('n_cycle/n_spikes/v/ge')
        brian2.plt.legend(loc="upper left")
            
        # V vs. t & ge vs. t for hidden layer 
        fig, axs = brian2.plt.subplots(self.N_hidden, 4, figsize=(12,3))
        if self.N_hidden ==1:
            self.what_to_plot(axs, self.exci_statemon[0])
        else:
            for ih in range(self.N_hidden):
                self.what_to_plot(axs[ih], self.exci_statemon[ih])