In [1]:
# these command make the cell width wider than default
from IPython.display import display, HTML
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

In [None]:
class Spike_MNIST_Nlayer(input_layer, hidden_layer):
    def __init__(self, task, idx_start_train, idx_start, idx_end, simulation_duration, epoch, previous_seg_name, sqrt_grp_size, test_option, hyperParams, debug, activate_input_spikemon, root_out):
        if debug==True:
            print('-------------------------------------------------')
            print('..........', task, '..............')
            print('-------------------------------------------------')
            
        self.rank, self.num_workers = get_rank_and_num_workers()  # rank of this worker and number of workers    
        self.sqrt_grp_size = sqrt_grp_size
        self.N_hidden = len(self.sqrt_grp_size)
        self.model_name = 'Spike_MNIST_'+str(1+self.N_hidden)+'_layer'  # 1--> input
        self.debug = debug
        self.activate_input_spikemon = activate_input_spikemon
        self.test_option = test_option
        self.root_out = root_out
        if self.test_option=='add_more_syn': # keep the same number of neurons but add more synapses
                print('--- doing test with option: add_more_syn ---')
                hyperParams['n_syn'] = self.num_workers        
        
        self.simulation_duration = simulation_duration
        self.idx_start_train = idx_start_train
        self.idx_start = idx_start
        self.idx_end = idx_end
        self.epoch = epoch
        self.task = task 
        self.previous_seg_name = previous_seg_name # previous segment for consecutive training

        self.model_success = self.manage_directories(hyperParams) # created all the directories

        # if a model is not successfully created, don't run
        if self.model_success == True: # call parent class after manage_directories is successful
            super().__init__(hyperParams)
            self.device_setting() # set device type.
            self.set_equations(hyperParams)
            self.build_neurons()
            self.build_synapses()
            self.set_monitors()
            self.set_syn_active_status()                
            self.collect()   
            self.prep_run()
    
    # set the device before starting anything to build the model
    def device_setting(self):
        if input_param.device_name == 'cpp_standalone' or input_param.device_name == 'cuda_standalone':
            directory = input_param.device_name+str(self.rank)   
            if not isinstance(brian2.get_device(), CPPStandaloneDevice) and not isinstance(brian2.get_device(), CUDAStandaloneDevice):
                brian2.set_device(input_param.device_name, directory=directory, build_on_run=False)
            else:
                brian2.device.reinit()
                brian2.device.activate(directory=directory, build_on_run=False)
            # too much memory consumed w/o this limit during cuda compilation 
            brian2.prefs['devices.cpp_standalone.extra_make_args_unix'] = ['-j8']
        else:
            sys.exit("!!!!! the option is not one of the following: 'cpp_standalone', 'cuda_standalone' !!!!!")

        # need this to finish the length of defined synases, 
        # https://brian.discourse.group/t/cant-modify-synaptic-group-attribute-using-run-args-in-standalone-mode/1219/13
        brian2.seed(0)
        
    # compile the model before run
    def prep_run(self):
        # namespace for different stdp_type
        if self.stdp_type == 0:
            self.run_var_namespace = {'d_Apre':  self.d_Apre,
                                 'd_Apost':  self.d_Apost,
                                 'taupre':  self.taupre,
                                 'taupost':  self.taupost, 
                                 'tau_adpt': self.tau_adpt}
        elif self.stdp_type == 1:
            self.run_var_namespace = {'nu_pre_ee': self.nu_pre_ee,      
                                 'nu_post_ee': self.nu_post_ee,
                                 'tc_pre_ee': self.tc_pre_ee, 
                                 'tc_post_1_ee': self.tc_post_1_ee,
                                 'tc_post_2_ee': self.tc_post_2_ee}
                        
        # run now
        if self.debug==True:
            self.col.run(self.simulation_duration, 
                         report='text', 
                         namespace = self.run_var_namespace, 
                         profile=self.debug)
        else:
            brian2.BrianLogger.log_level_error() # don't show INFO and WARN level message
            self.col.run(self.simulation_duration, 
                         report=None, 
                         namespace = self.run_var_namespace)
            
        brian2.device.build(run=False, directory=input_param.device_name+str(self.rank))   

        
    # restructure the data for model use
    def manage_data(self, input_data):
        num_samples = len(input_data['img'])
        img_shape = input_data['img'][0].shape
        img_array = np.empty((num_samples, img_shape[0] * img_shape[1]), dtype=input_data['img'][0].dtype)
        label_array = np.empty(num_samples, dtype=input_data['label'][0].dtype)
        for idx in range(num_samples):
            img_array[idx] = input_data['img'][idx].reshape(-1)
            label_array[idx] = input_data['label'][idx]

        if self.debug:
            print(' label: ', ', '.join([str(s) for s in label_array[:]]))

        return img_array, label_array
        
    # define input and output directories for synapse, spike and adaptive threshold
    def manage_directories(self, hyperParams):
        # name contains scan parameter information
        self.epoch_name = 'epoch_'+str(self.epoch)
        self.seg_name = 'seg_'+str(self.idx_start_train+self.idx_start)+'_'+str(self.idx_start_train+self.idx_end) # current segment   
        self.trial_dir = str(os.getcwd()) # different ray tune trial is in  different directory
        # store analysis root output file

        if self.root_out == True:
            self.root_file_dir = os.path.join(self.trial_dir, "root_file", self.epoch_name, self.seg_name)
            if os.path.isdir(self.root_file_dir) == False:
                os.makedirs(self.root_file_dir)
                
            self.root_file = ROOT.TFile(os.path.join(self.root_file_dir, f"eff_{self.task}_{self.seg_name}_rank{self.rank}.root"), "update")
            self.nt = self.root_file.Get("nt")
            if not self.nt:
                self.nt = ROOT.TNtuple("nt",  "", "sidx:label:digit:rate:match:n_max:rank")
            
        # directory to store neuron groups
        self.dir_neuron_group = os.path.join(self.trial_dir,  'neuron_group')
        if os.path.isdir(self.dir_neuron_group) == False: 
            os.makedirs(self.dir_neuron_group)

        syn_out_name = 'syn_out'
        
        # directory to store trained synapses
        self.dir_syn_out = os.path.join(self.trial_dir, syn_out_name, self.model_name, self.epoch_name, self.seg_name)
        if os.path.isdir(self.dir_syn_out) == False and self.task=='train':
            os.makedirs(self.dir_syn_out)
        
        # directory containing the final round of trained synapses
        self.dir_syn_final = os.path.join(self.trial_dir, syn_out_name, self.model_name, 'final')
                
        self.dir_syn_in = str()
        if self.task == 'train':                    
            if  self.idx_start == 0 and self.epoch != 0: # load synapses from the last sample in the epoch before
                self.dir_syn_in = os.path.join(self.trial_dir, syn_out_name, self.model_name, 'epoch_'+str(self.epoch-1), self.previous_seg_name)
            elif self.idx_start > 0: # load synapse and threshold from trained synapses from previous sample segment in the same epoch
                self.dir_syn_in = os.path.join(self.trial_dir, syn_out_name, self.model_name, self.epoch_name, self.previous_seg_name)
            if len(self.dir_syn_in) >0:
                if os.path.isdir(self.dir_syn_in) == False:
                    os.makedirs(self.dir_syn_in)

            # link the final directories
            if self.rank == 0: 
                final_file = os.path.join(self.trial_dir, self.dir_syn_final)
                file_exists = os.path.exists(final_file)
                if file_exists==True:
                    os.system(f'unlink {final_file}')
            
                lns_cmd = f'ln -s {self.dir_syn_out} {final_file}'
                os.system(lns_cmd)
                if self.debug==True:
                    print(' link the final trained synapses: ', lns_cmd)        

        else: # now testing or validing
            self.dir_syn_in = os.path.join(self.trial_dir, self.dir_syn_final) 

        if hyperParams is None: 
            return False
        else:
            return True
        
    def build_neurons(self):
        super().build_neurons()
                   
    def build_synapses(self):
        super().build_synapses()
        self.build_synapses_between_input_hidden()
        self.build_synapses_between_hidden()
        
        
    # synapse between input-hidden
    def build_synapses_between_input_hidden(self):
        # input neuron to excitatory
        self.S_input2e = []
        for ih in range(self.N_hidden):
            self.S_input2e.append(brian2.Synapses(self.input_neurons, self.exci_neurons[ih], self.eqs_syn, on_pre=self.STDP_pre_action,
                                                  on_post=self.STDP_post_action, name='s_input2e'+str(ih)))
        
        if self.idx_start==0 and self.epoch==0 and self.task=='train':
            if self.debug==True:
                print(' building (from scratch) synapses between input and hidden layers ...')
            
            for ih in range(len(self.S_input2e)):
                self.S_input2e[ih].connect(p = 1, n=self.n_syn, skip_if_invalid=True)
                self.S_input2e[ih].gmax = self.gmax_input2e[ih]
                self.S_input2e[ih].max_delay = self.max_delay_input2e[ih]
                self.S_input2e[ih].max_dendritic_delay = self.max_dendritic_delay
                self.S_input2e[ih].w = 'rand() * gmax'
                self.S_input2e[ih].pre_plastic.delay = 'rand()* max_delay'
                self.S_input2e[ih].post.delay = 'rand()* max_dendritic_delay' # dendritic delay
                self.S_input2e[ih].penalty = self.penalty_input2e[ih]
                self.S_input2e[ih].w_sat_scale = self.w_sat_scale
                self.S_input2e[ih].w_sat_shift = self.w_sat_shift
        else:
            self.load_synapses_between_input_hidden()

    # synapse between input-auxiliary
    def build_synapses_between_input_aux(self):
        # input neuron to excitatory
        self.S_input2aux = []
        for ih in range(self.N_hidden):
            self.S_input2aux.append(brian2.Synapses(self.input_neurons, self.aux_neurons[ih], self.eqs_syn, on_pre=self.STDP_pre_action,
                                                  on_post=self.STDP_post_action, name='s_input2aux'+str(ih)))
        
        if self.idx_start==0 and self.epoch==0 and self.task=='train':
            if self.debug==True:
                print(' building (from scratch) synapses between input and hidden layers ...')
                
            prob = 'exp(-((x_pre-x_post)**2 + (y_pre-y_post)**2)/(2*(sigma)**2))
            for ih in range(len(self.S_input2aux)):
                self.S_input2aux[ih].connect(p = prob, n=self.n_syn, skip_if_invalid=True, namespace={'sigma':self.Receptive_field[ih]})
                self.S_input2aux[ih].gmax = self.gmax_input2aux[ih]
                self.S_input2aux[ih].max_delay = self.max_delay_input2aux[ih]
                self.S_input2aux[ih].max_dendritic_delay = self.max_dendritic_delay
                self.S_input2aux[ih].w = 'rand() * gmax'
                self.S_input2aux[ih].pre_plastic.delay = 'rand()* max_delay'
                self.S_input2aux[ih].post.delay = 'rand()* max_dendritic_delay' # dendritic delay
                self.S_input2aux[ih].penalty = self.penalty_input2aux[ih]
                self.S_input2aux[ih].w_sat_scale = 0
                self.S_input2aux[ih].w_sat_shift = 0
        else:
            self.load_synapses_between_input_aux()


    # compile loaded syn data for testing 
    def compile_read_syn_data(self, syn_path, ih_pre = None, ih_post = None):
        if os.path.exists(syn_path) is False:
            return None
        #    
        if self.test_option==None:
            trained_syn = read_data(syn_path, self.rank)
        elif self.test_option == 'use_rank0_info':
            trained_syn = read_data(syn_path, 0)
        elif self.test_option=="add_more_syn": # add more syns between a pair of pre- and post-neurons
            if self.task != 'test' and self.task != 'advatk' and self.task != 'adv_make':
                sys.exit(f'!!!! self.test_option != None in task: {self.task}, but it is required to be in task: test or advatk or adv_make !!!')
            
            list_dict = [None]*self.n_syn
            for i in range(self.n_syn):
                list_dict[i] = read_data(syn_path, i)
                
            if not None in list_dict:
                trained_syn = concatenate_dict_arrays(list_dict)
            else:
                trained_syn = None
                if self.debug == True: 
                    print(f'!!!! trained_syn at {syn_path} is None !!!!')
        elif self.test_option != None and "add_more_neuron" in self.test_option:
            if self.task != 'test' and self.task != 'advatk' and self.task != 'adv_make':
                sys.exit(f'!!!! self.test_option != None in task: {self.task}, but it is required to be in task: test or advatk  or adv_make !!!')
            
            list_dict = []
            for ir in range(self.num_workers):
                syn = read_data(syn_path, ir)
                if not syn:
                    list_dict.append(syn)
                    break
                # change the index of pre-post neurons by i+rank*self.N_e[ih]
                if 'S_input2e' in syn_path: # The number of input neurons, which is the pre-neuron, is fixed. 
                    syn['j'] = syn['j'] + ir*self.N_e[ih_post]
                else: # both pre and post neurons are increased
                    syn['i'] = syn['i'] + ir*self.N_e[ih_pre]
                    syn['j'] = syn['j'] + ir*self.N_e[ih_post]
                
                list_dict.append(syn)
                
            if not None in list_dict:
                trained_syn = concatenate_dict_arrays(list_dict)
            else:
                trained_syn = None
                if self.debug == True: 
                    print(f'!!!! trained_syn at {syn_path} is None !!!!')
        else:
            sys.exit(f' !!!!  test option: {self.test_option} does not exist !!!!') 
        
        if self.debug==True and self.task!='train':
            write_data(pd.DataFrame.from_dict(trained_syn), syn_path, -1)
            
        return trained_syn
            
    # load synapses
    def load_synapses_between_input_hidden(self):
        for ih in range(len(self.S_input2e)):
            trained_S_input2e_path = os.path.join(self.dir_syn_in, 'S_input2e'+str(ih))
            trained_S_input2e = self.compile_read_syn_data(trained_S_input2e_path, ih_post = ih)
            
            if self.debug==True:
                print(f'--loading trained synapses between input and hiddenlayer {ih}: ', trained_S_input2e_path)
                
            if trained_S_input2e is not None:
                self.S_input2e[ih].connect(i=trained_S_input2e['i'], j=trained_S_input2e['j'])
                self.S_input2e[ih].w[:] = trained_S_input2e['w']
                self.S_input2e[ih].post.delay[:] = trained_S_input2e['delay_dendritic']*brian2.second
                if self.task == 'train':
                    self.S_input2e[ih].gmax = self.gmax_input2e[ih] # need to set these gmax value to avoid being cliped in pre-action
                    self.S_input2e[ih].pre_plastic.delay[:] = trained_S_input2e['delay']*brian2.second
                    self.S_input2e[ih].penalty = self.penalty_input2e[ih]
                    self.S_input2e[ih].w_sat_scale = self.w_sat_scale
                    self.S_input2e[ih].w_sat_shift = self.w_sat_shift
                else:
                    self.S_input2e[ih].pre_nonplastic.delay[:] = trained_S_input2e['delay']*brian2.second
            else:
                self.S_input2e[ih].connect(False)
            
    # load synapses
    def load_synapses_between_input_aux(self):
        for ih in range(len(self.S_input2aux)):
            trained_S_input2aux_path = os.path.join(self.dir_syn_in, 'S_input2aux'+str(ih))
            trained_S_input2aux = self.compile_read_syn_data(trained_S_input2aux_path, ih_post = ih)
            
            if self.debug==True:
                print(f'--loading trained synapses between input and hiddenlayer {ih}: ', trained_S_input2aux_path)
                
            if trained_S_input2aux is not None:
                self.S_input2aux[ih].connect(i=trained_S_input2aux['i'], j=trained_S_input2aux['j'])
                self.S_input2aux[ih].w[:] = trained_S_input2aux['w']
                self.S_input2aux[ih].post.delay[:] = trained_S_input2aux['delay_dendritic']*brian2.second
                if self.task == 'train':
                    self.S_input2aux[ih].gmax = self.gmax_input2e[ih] # need to set these gmax value to avoid being cliped in pre-action
                    self.S_input2aux[ih].pre_plastic.delay[:] = trained_S_input2aux['delay']*brian2.second
                    self.S_input2aux[ih].penalty = self.penalty_input2e[ih]
                    self.S_input2aux[ih].w_sat_scale = self.w_sat_scale
                    self.S_input2aux[ih].w_sat_shift = self.w_sat_shift
                else:
                    self.S_input2aux[ih].pre_nonplastic.delay[:] = trained_S_input2aux['delay']*brian2.second
            else:
                self.S_input2aux[ih].connect(False)
    
    def build_synapses_between_hidden(self):
        self.S_efe = [] # current layer excitatory to next layer excitatory neurons
        for ih in range(1, self.N_hidden):
            deep_connection = []  #the ih layer is connected to all earlier layer
            for jh in range(ih):
                deep_connection.append(brian2.Synapses(self.exci_neurons[jh], self.exci_neurons[ih], self.eqs_syn, on_pre=self.STDP_pre_action, 
                                                  on_post=self.STDP_post_action, name='s_efe'+str(jh)+"to"+str(ih)))
            self.S_efe.append(deep_connection)
                    
        if self.idx_start==0 and self.epoch==0 and self.task=='train':
            if self.debug==True:
                print(' building (from scratch) synapses between hidden layer ...')
            for i in range(len(self.S_efe)):
                for j in range(len(self.S_efe[i])):
                    self.S_efe[i][j].connect(condition='group_pre==group_post and rank_pre==rank_post')
                    self.S_efe[i][j].gmax = self.gmax_efe[i]
                    self.S_efe[i][j].max_delay = self.max_delay_efe[i]
                    self.S_efe[i][j].max_dendritic_delay = self.max_dendritic_delay
                    self.S_efe[i][j].w = 'rand() * gmax'
                    self.S_efe[i][j].pre_plastic.delay = 'rand()* max_delay'
                    self.S_efe[i][j].post.delay = 'rand()* max_dendritic_delay'
                    self.S_efe[i][j].penalty = self.penalty_efe[i]
                    self.S_efe[i][j].w_sat_scale = self.w_sat_scale
                    self.S_efe[i][j].w_sat_shift = self.w_sat_shift
        else: # synpases has been created from the previous round
            self.load_synapses_between_hidden()

    def build_synapses_between_aux(self):
        self.S_auxfaux = [] # current layer excitatory to next layer excitatory neurons
        for ih in range(1, self.N_hidden):
            deep_connection = []  #the ih layer is connected to all earlier layer
            for jh in range(ih):
                deep_connection.append(brian2.Synapses(self.aux_neurons[jh], self.aux_neurons[ih], self.eqs_syn, on_pre=self.STDP_pre_action, 
                                                  on_post=self.STDP_post_action, name='s_auxfaux'+str(jh)+"to"+str(ih)))
            self.S_auxfaux.append(deep_connection)
                    
        if self.idx_start==0 and self.epoch==0 and self.task=='train':
            if self.debug==True:
                print(' building (from scratch) synapses between auxiliary layer ...')
            for i in range(len(self.S_auxfaux)):
                for j in range(len(self.S_efe[i])):
                    self.S_auxfaux[i][j].connect(condition='group_pre==group_post and rank_pre==rank_post')
                    self.S_auxfaux[i][j].gmax = self.gmax_auxfaux[i]
                    self.S_auxfaux[i][j].max_delay = self.max_delay_auxfaux[i]
                    self.S_auxfaux[i][j].max_dendritic_delay = self.max_dendritic_delay
                    self.S_auxfaux[i][j].w = 'rand() * gmax'
                    self.S_auxfaux[i][j].pre_plastic.delay = 'rand()* max_delay'
                    self.S_auxfaux[i][j].post.delay = 'rand()* max_dendritic_delay'
                    self.S_auxfaux[i][j].penalty = self.penalty_auxfaux[i]
                    self.S_auxfaux[i][j].w_sat_scale = 0
                    self.S_auxfaux[i][j].w_sat_shift = 0
        else: # synpases has been created from the previous round
            self.load_synapses_between_aux()

    # ## load connections among hidden layers
    def load_synapses_between_hidden(self):
        for i in range(len(self.S_auxfaux)):
            for j in range(len(self.S_auxfaux[i])):
                trained_S_auxfaux_path = os.path.join(self.dir_syn_in, 'S_auxfaux'+str(j)+"to"+str(i+1))
                trained_S_auxfaux = self.compile_read_syn_data(trained_S_auxfaux_path, ih_pre=j, ih_post=i+1)

                if self.debug==True:
                    print('-- loading trained synapses between auxiliary layers {i} and {j}: ', trained_S_auxfaux)            
                
                if trained_S_auxfaux is not None:
                    self.S_auxfaux[i][j].connect(i=trained_S_auxfaux['i'], j=trained_S_auxfaux['j'])
                    self.S_auxfaux[i][j].w[:] = trained_S_auxfaux['w']
                    self.S_auxfaux[i][j].post.delay[:] = trained_S_auxfaux['delay_dendritic']*brian2.second
                    if self.task == 'train':
                        self.S_auxfaux[i][j].gmax = self.gmax_auxfaux[i]  # need this to avoid weight being cliped in pre-action
                        self.S_auxfaux[i][j].pre_plastic.delay[:] = trained_S_auxfaux['delay']*brian2.second
                        self.S_auxfaux[i][j].penalty = self.penalty_auxfaux[i]
                        self.S_auxfaux[i][j].w_sat_scale = 0
                        self.S_auxfaux[i][j].w_sat_shift =0
                    else:
                        self.S_auxfaux[i][j].pre_nonplastic.delay[:] = trained_S_auxfaux['delay']*brian2.second
                else:
                    self.S_auxfaux[i][j].connect(False)
                    
    # ## load connections among hidden layers
    def load_synapses_between_aux(self):
        for i in range(len(self.S_efe)):
            for j in range(len(self.S_efe[i])):
                trained_S_efe_path = os.path.join(self.dir_syn_in, 'S_efe'+str(j)+"to"+str(i+1))
                trained_S_efe = self.compile_read_syn_data(trained_S_efe_path, ih_pre=j, ih_post=i+1)

                if self.debug==True:
                    print('-- loading trained synapses between hidden layers {i} and {j}: ', trained_S_efe)            
                
                if trained_S_efe is not None:
                    self.S_efe[i][j].connect(i=trained_S_efe['i'], j=trained_S_efe['j'])
                    self.S_efe[i][j].w[:] = trained_S_efe['w']
                    self.S_efe[i][j].post.delay[:] = trained_S_efe['delay_dendritic']*brian2.second
                    if self.task == 'train':
                        self.S_efe[i][j].gmax = self.gmax_efe[i]  # need this to avoid weight being cliped in pre-action
                        self.S_efe[i][j].pre_plastic.delay[:] = trained_S_efe['delay']*brian2.second
                        self.S_efe[i][j].penalty = self.penalty_efe[i]
                        self.S_efe[i][j].w_sat_scale = self.w_sat_scale
                        self.S_efe[i][j].w_sat_shift = self.w_sat_shift
                    else:
                        self.S_efe[i][j].pre_nonplastic.delay[:] = trained_S_efe['delay']*brian2.second
                else:
                    self.S_efe[i][j].connect(False)
                    
                    
    # monitors
    def set_monitors(self):           
        super().set_monitors()
        record_list = [i for i in range(20, 30)]
        if self.debug==True:
            if self.task  == 'train':
                self.S_input2e_mon = brian2.StateMonitor(self.S_input2e[0], ['penalty'], record = [0])
    
    # need to collect all objects so that Brian2 know what to run. The magic option doesn't work with self.
    def collect(self):        
        super().collect()
        self.col.add(self.S_input2e)
        self.col.add(self.S_efe)
        if self.switch_aux == True:
            self.col.add(self.S_inputfaux)
            self.col.add(self.S_auxfaux)
        if self.task == 'train' and self.debug == True: # only appears in training
            self.col.add(self.S_input2e_mon)
                    
    # save the neurons. 
    def save_neurons(self):
        super().save_neurons()
                
    # save the trained synapses. 
    def save_synapses(self):
        super().save_synapses()
        
        if self.debug==True:
            print('saving to: ', self.dir_syn_out)
            
        for ih in range(len(self.S_input2e)):
            df = pd.DataFrame({'i': self.S_input2e[ih].i[:], 
                               'j': self.S_input2e[ih].j[:], 
                               'w': self.S_input2e[ih].w[:],
                               'delay': self.S_input2e[ih].pre_plastic.delay[:],
                               'delay_dendritic': self.S_input2e[ih].post.delay[:]})
            write_data(df, os.path.join(self.dir_syn_out, 'S_input2e'+str(ih)), self.rank)
        #
        for i in range(len(self.S_efe)):
            for j in range(len(self.S_efe[i])):
                df = pd.DataFrame({'i': self.S_efe[i][j].i[:], 
                                   'j': self.S_efe[i][j].j[:], 
                                   'w': self.S_efe[i][j].w[:],
                                   'delay': self.S_efe[i][j].pre_plastic.delay[:],
                                   'delay_dendritic': self.S_efe[i][j].post.delay[:]})
                write_data(df, os.path.join(self.dir_syn_out, 'S_efe'+str(j)+"to"+str(i+1)), self.rank)
            
    # plot the result, slow and mostly for debugging purpose
    def plot(self):
        super().plot()
        # print synapse connection for S_input2e
        for i in range(self.N_e[0]):
            exci_idx = np.random.randint(0, self.N_e[0])
            #print('igrp: ', igrp, 'x, y: ', self.exci_neurons[0].x[exci_idx], self.exci_neurons[0].y[exci_idx])
            #print('input: (x,y): ', *zip(self.input_neurons.x[self.S_input2e[0].i[:, exci_idx]], self.input_neurons.y[self.S_input2e[0].i[:, exci_idx]]))
        if self.task == 'train':
                brian2.plt.figure(figsize=(10, 10))
                brian2.plot(self.S_input2e_mon.t/brian2.ms, self.S_input2e_mon.penalty.T)
                brian2.plt.xlabel('Time (ms)')
                brian2.plt.ylabel('S_input2e penalty')
                
    # set active = True/False for pre and post in syn for training and testing
    def set_syn_active_status(self):
        super().set_syn_active_status()
        
        if self.task == 'train':
            for ih in range(len(self.S_input2e)):
                self.S_input2e[ih].pre_nonplastic.active = False  # fixed weight for testing only
                self.S_input2e[ih].pre_plastic.active = True  # weight changing, for training
                self.S_input2e[ih].post.active = True  # weight changing, for training
            
            for i in range(len(self.S_efe)):
                for j in range(len(self.S_efe[i])):
                    self.S_efe[i][j].pre_nonplastic.active = False  # fixed weight for testing only
                    self.S_efe[i][j].pre_plastic.active = True  # weight changing, for training
                    self.S_efe[i][j].post.active = True  # weight changing, for training
        else:
            for ih in range(len(self.S_input2e)):
                self.S_input2e[ih].pre_nonplastic.active = True # fixed weight for testing only
                self.S_input2e[ih].pre_plastic.active = False # weight changing, for training
                self.S_input2e[ih].post.active = False  # weight changing, for training
            
            for i in range(len(self.S_efe)):
                for j in range(len(self.S_efe[i])):
                    self.S_efe[i][j].pre_nonplastic.active = True  # fixed weight for testing only
                    self.S_efe[i][j].pre_plastic.active = False  # weight changing, for training
                    self.S_efe[i][j].post.active = False  # weight changing, for training

    # normalize weight after every sample
    def syn_weight_normalization(self):
        if self.debug==True:
            print('...rank: ', self.rank, ' normalizing weight of S_input2e')
        
        S_input2e_w = list()
        for ih in range(len(self.S_input2e)):
            S_input2e_w.append(get_matrix_from_synapse_and_normalize(self.S_input2e[ih], self.gmax_input2e[ih], self.norm_scale_S_input2e[ih]))

        if self.debug==True:
            print('...rank: ', self.rank, ' normalizing weight of syn between hidden layers')

        S_efe_w = []
        for i in range(len(self.S_efe)):
            S_efe_w.append([get_matrix_from_synapse_and_normalize(self.S_efe[i][j], self.gmax_efe[i], self.norm_scale_S_efe[i]) for j in range(len(self.S_efe[i]))])

        return S_input2e_w, S_efe_w
        
    # check if the last run for the sample meet the requirements
    def check_sample_status(self, img, label, stimulus_strength):
        # if any hidden layer has spikes < n_spike_sample_min, increase the image strength until max_repetitions is reached.
        counts_layers = [getattr(self.spike_counter, f'spike_counter_layer{ih}') for ih in range(self.N_hidden)]
        higher_than_min_spike_rate = all(count >= input_param.n_spike_sample_min for count in counts_layers) 
            
        if self.task == 'train':
            if label == -1: # this is the first events which is fake for the purpose of reading attribute before run
                return True, stimulus_strength
                
            counts_grps = [getattr(self.spike_counter, f'spike_counter_grp{ig}') for ig in range(input_param.num_classes)] 
            rate_correct_group = counts_grps[label] # rate of the group with id equal to label
            correct_identification = (rate_correct_group == max(counts_grps)) and (counts_grps.count(rate_correct_group) ==1)
            switch_stimulus = higher_than_min_spike_rate and correct_identification or (self.repetitions >=  input_param.max_repetitions)        
        else:
            switch_stimulus = higher_than_min_spike_rate or (self.repetitions >= input_param.max_repetitions)
        
        self.repetitions = int(not switch_stimulus)*(self.repetitions + 1)
        stimulus_strength = input_param.scale_img*int(switch_stimulus) + (stimulus_strength+input_param.d_scale_img)*int(not switch_stimulus)

        if self.debug==True:
            message = f'''
            higher_than_min_spike_rate: {higher_than_min_spike_rate} 
            repetitions: {self.repetitions}
            '''
            if self.task == 'train': 
                message += f'''
                rate_correct_group: {rate_correct_group}
                rate of_all_grps: {counts_grps}
                correct_identification: {correct_identification}
                '''
            print(message)
            
        return switch_stimulus, stimulus_strength
        
    # run the model 
    def run(self, input_data):
        # if a model is not successfully created, don't run
        if self.model_success == False:
            if self.debug==True:
                print('!!! model_success == False !!! ')
            return
        #
        img_array, label_array = self.manage_data(input_data)
        data_in = [[img, label] for img, label in zip(img_array, label_array)]
        list_counts_for_digit_last_layer = [None]*len(data_in) # n_spikes for each digit in the last h-layer from all samples
        data_in.insert(0, [0, -1]) # fake same to run so that one can normalize syn before running any new samples
        vt = [None]*self.N_hidden
        for stimulus_idx, (img, label) in enumerate(tqdm_ray.tqdm(data_in)):
            switch_stimulus = False
            stimulus_strength = input_param.scale_img # starting strength
            self.repetitions = 0 # num of repetitions for one sample
            while switch_stimulus == False:
                scaled_img = img*stimulus_strength
                arguments = {
                    self.input_neurons.rate: scaled_img*brian2.Hz, 
                    self.input_neurons.label: label, 
                    self.input_neurons.stimulus_idx: stimulus_idx, 
                    self.input_neurons.stimulus_strength:stimulus_strength, 
                }
                for ih in range(self.N_hidden):
                    arguments.update({
                        self.exci_neurons[ih].stimulus_idx: stimulus_idx,
                        self.exci_neurons[ih].stimulus_strength: stimulus_strength,
                        self.exci_neurons[ih].label: label,
                    })
                    
                if label!=-1: # skip the fake stimulus
                    arguments.update({self.exci_neurons[ih].vt: vt[ih][:]})
                    
                    if self.task == 'train': 
                        for ih in range(len(self.S_input2e)):
                            arguments.update({
                                self.S_input2e[ih].w: S_input2e_w[ih][:], 
                            })
                        for i in range(len(self.S_efe)):
                            for j in range(len(self.S_efe[i])):
                                arguments.update({
                                    self.S_efe[i][j].w: S_efe_w[i][j][:], 
                                })
                    
                brian2.device.run(run_args=arguments, with_output=False) # process the sample
                                
                for ih in range(self.N_hidden):
                    vt[ih] = self.exci_neurons[ih].vt[:]
                        
                if self.task == 'train' and self.switch_norm:
                    S_input2e_w, S_efe_w = self.syn_weight_normalization()

                switch_stimulus, stimulus_strength =  self.check_sample_status(img, label, stimulus_strength)
                if self.debug==True:
                    print(f'stimulus_idx: {stimulus_idx}, switch_stimulus: {switch_stimulus}, stimulus_strength: {stimulus_strength}, label: {label}')
                    
            # now the running of the stimulus meet the requirement, count the spikes for each digit
            if stimulus_idx != 0: # the first events is fake
                list_counts_for_digit_last_layer[stimulus_idx-1] = {'label': label, 'count':self.count_last_hlayer_spikes_from_spikemon()}

        # save synapse, incl. connection, weight, pre and post index for training
        if self.task == 'train':
            self.save_synapses()
            self.save_neurons()

        if self.debug==True:
            print(brian2.profiling_summary(self.col))
            print(self.col.scheduling_summary())

        eff_last_layer = 0  
        eff_mult_match_last_layer = 0
        for sidx, label_count in enumerate(list_counts_for_digit_last_layer):
            max_value = max(label_count['count'])
            label = label_count['label']
            match = (label_count['count'][label] == max_value)
            n_max = label_count['count'].count(max_value)
            if match == True:
                eff_last_layer +=1
                if n_max > 1:
                    eff_mult_match_last_layer +=1
                
            if self.root_out == True: # output ROOT files
                for digit, counts in enumerate(label_count['count']):
                    self.nt.Fill(sidx, label, digit, counts, float(match), n_max, self.rank)

        eff_last_layer /= len(list_counts_for_digit_last_layer)
        eff_mult_match_last_layer /= len(list_counts_for_digit_last_layer)

        if self.root_out == True: 
            self.nt.Write()
            self.root_file.Close()

        return list_counts_for_digit_last_layer, eff_last_layer, eff_mult_match_last_layer