In [20]:
# Brian-setup
from brian2 import BrianLogger, start_scope, NeuronGroup, Synapses, network_operation, run, StateMonitor, SpikeMonitor, stop, rand
from brian2.units import ms, volt, mV, amp, mA, siemens, Hz, ohm
from brian2 import figure, grid, axhline, xlabel, ylabel, plot, legend

from brian2 import prefs, BrianLogger
prefs.logging.file_log_level = 'WARNING'
BrianLogger.file_handler.setLevel(prefs.logging.file_log_level)

# Scientific libs:
import numpy as np
%matplotlib inline

In [2]:
def perform_network_firing_rate_test(a, b, ws=10, n=500, tau_g=5*ms, I_ext=0.1*mA, verbose=True, do_plot=True,
                                    debug=False):
    start_scope()

    tau = 1*ms;
    # mfarad = 10^-3* s / ohm = ms/ohm  => I_tot / mfarad = I_tot * ohm/ms
    eqs_exc = '''
    dv/dt = ((0.04*v**2)/mV + 5*v + 140*mV - u + I_tot_e*ohm) / tau : volt
    du/dt = a * (b*v - u)/tau : volt
    c : volt
    d : volt
    
    I_tot_e = I_tot_e_e + I_tot_i_e : amp
    
    I_tot_e_e : amp
    I_tot_i_e : amp
    
    oper_ctr : 1
    '''
    
    c_inh = -65*mV; d_inh = 2*mV;
    eqs_inh = '''
    dv/dt = ((0.04*v**2)/mV + 5*v + 140*mV - u + I_tot_i*ohm) / tau : volt
    du/dt = a * (b*v - u)/tau : volt
    
    I_tot_i = I_tot_e_i + I_tot_i_i : amp
    
    I_tot_e_i : amp
    I_tot_i_i : amp
    '''
    
    n_excit = int(0.8 * n)
    n_inhib = n-n_excit
    if(debug):
        print('n_excit', n_excit)
        print('n_inhib', n_inhib)
    
    G_excit = NeuronGroup(n_excit, eqs_exc, threshold='v>30*mV', 
                    reset='''v=c; u=u+d;''', 
                    method='euler')
    G_excit.c = '-65*mV + 15*(rand()**2)*mV'
    G_excit.d = '8*mV - 6*(rand()**2)*mV'
    G_excit.v = G_excit.c
    
    G_inhib = NeuronGroup(n_inhib, eqs_inh, threshold='v>30*mV', 
                    reset='''v=c_inh; u=u+d_inh;''', 
                    method='euler')
    G_inhib.v = c_inh
    
    # variable for simple assert-test
    G_excit.oper_ctr[0] = 0
    
    # synapses: 4 groups
    eqs_syn_e_e = '''
    dg1/dt = -g1/(tau_g) : siemens (clock-driven)
    I_tot_e_e_post = ws*mV * g1 : amp (summed)
    '''
    eqs_syn_e_i = '''
    dg2/dt = -g2/(tau_g) : siemens (clock-driven)
    I_tot_e_i_post = ws*mV * g2 : amp (summed)
    '''
    eqs_syn_i_e = '''
    dg3/dt = -g3/(tau_g) : siemens (clock-driven)
    I_tot_i_e_post = ws*mV * g3 : amp (summed)
    '''
    eqs_syn_i_i = '''
    dg4/dt = -g4/(tau_g) : siemens (clock-driven)
    I_tot_i_i_post = ws*mV * g4 : amp (summed)
    '''

    S1 = Synapses(G_excit, G_excit, eqs_syn_e_e, on_pre='g1 = 1*siemens', method='euler')
    S2 = Synapses(G_excit, G_inhib, eqs_syn_e_i, on_pre='g2 = 1*siemens', method='euler')
    S3 = Synapses(G_inhib, G_excit, eqs_syn_i_e, on_pre='g3 = -1*siemens', method='euler')
    S4 = Synapses(G_inhib, G_inhib, eqs_syn_i_i, on_pre='g4 = -1*siemens', method='euler')
    
    # probabilistic method: S.connect(condition='i!=j', p=p_conn)    
    # Paper method; hardcode wiring to 10 (random) neurons, for each neuron
    for ind in range(0, n_excit):
        for _ in range(0, 10):
            target_index = int(rand() * n)
            while(target_index==ind):
                target_index = int(rand() * n)
            if(target_index < n_excit):
                S1.connect(i=ind, j=target_index)
            else:
                S2.connect(i=ind, j=target_index%n_excit)
    
    for ind in range(0, n_inhib):
        for _ in range(0, 10):
            target_index = int(rand() * n)
            while(target_index==ind):
                target_index = int(rand() * n)
            if(target_index < n_excit):
                S3.connect(i=ind, j=target_index)
            else:
                S4.connect(i=ind, j=target_index%n_excit)
    
    
    # random neuron excitation at 100 micro-Ampere each timestep
    @network_operation(dt=tau)
    def excite_one_random_neuron():
        idx = int(rand()*(n-1))
        if(idx<n_excit):
            G_excit.I_tot_e_e[idx] += I_ext
        else:
            G_inhib.I_tot_e_i[idx%n_excit] += I_ext

        G_excit.oper_ctr[0] += 1
    
    
    def assert_network_op_post_run(t):
        assert(G_excit.oper_ctr[0] == t)
        
        if(debug):
            g_active = []
            for g_i in G_excit.g:
                if g_i>0.0001*siemens:
                    g_active += [g_i]
            for g_i in G_inhib.g:
                if g_i<-0.0001*siemens:
                    g_active += [g_i]
            print('t='+str(t)+', # recently spiking neurons: ', len(g_active))
            if(len(g_active) < 10):
                print(g_active)

    # run for 1 second
    t1=1000
    run(t1*tau)
    assert_network_op_post_run(t1)

    # start recording spikes
    statemon_excit = StateMonitor(G_excit[:], 'v', record=True)
    spikemon_excit = SpikeMonitor(G_excit[:], variables='v')
    statemon_inhib = StateMonitor(G_inhib[:], 'v', record=True)
    spikemon_inhib = SpikeMonitor(G_inhib[:], variables='v')

    t2=8000
    run(t2*tau)
    assert_network_op_post_run(t1+t2)
    
    avg_neuron_firing_rate_excit = spikemon_excit.num_spikes/(n*t2*tau)
    avg_neuron_firing_rate_inhib = spikemon_inhib.num_spikes/(n*t2*tau)
    
    chaotic_behaviour = avg_neuron_firing_rate_excit > 1000.*Hz or avg_neuron_firing_rate_inhib > 1000.*Hz
    if(do_plot and not chaotic_behaviour):
        # create spike plots
        figure(figsize=(9, 4))
        grid(True)
        axhline(30, ls='-', c='lightgray', lw=3)
        plot(statemon_excit.t/ms, statemon_excit.v.T/mV, '-')
        plot(statemon_inhib.t/ms, statemon_inhib.v.T/mV, '-')
        xlabel('Time (ms)')
        ylabel('v (mV)');
        
    if(debug):
        print('#spikes; exc: {}, inh: {}'.format(spikemon_excit.num_spikes, spikemon_inhib.num_spikes))
        
    stop()
        
    return [avg_neuron_firing_rate_excit, avg_neuron_firing_rate_inhib]

In [4]:
def avg_network_firing_rate_over_N_runs(a, b, ws=10, n=500, tau_g=5*ms, I_ext=0.1*mA, verbose=False, do_plot=False, N=30,
                                       debug=False, log_fname_postfix=''):
    parameters = [a, b, ws, n, tau_g, I_ext]
    avg_excits = []
    avg_inhibs = []
    
    log_str = 'Running experiment over N={} runs for parametrisation: a={}, b={}, s={}, n={}, tau_g={}. with I={}'.format(N, a, b, ws, n, tau_g, I_ext)
    if(verbose):
        print(log_str)
    write_to_logfile(parameters, log_str, log_fname_postfix)
    
    for ctr in range(0, N):
        [cur_avg_e, cur_avg_i] = perform_network_firing_rate_test(a, b, ws, n, tau_g, I_ext, verbose, do_plot, debug)
        avg_excits += [cur_avg_e]
        avg_inhibs += [cur_avg_i]
        
        log_str = 'run #{}, cur_avg_e_rate: {}, cur_avg_i: {}'.format(ctr, cur_avg_e, cur_avg_i)
        write_to_logfile(parameters, log_str, log_fname_postfix)
        if(verbose):
            print(log_str)
    
    mean_e = np.mean(avg_excits)
    std_e = np.std(avg_excits)
    mean_i = np.mean(avg_inhibs)
    std_i = np.std(avg_inhibs)
    
    log_str = 'mean_e: {} std_e: {} mean_i: {} std_i {}'.format(mean_e, std_e, mean_i, std_i)
    write_to_logfile(parameters, log_str, log_fname_postfix)
    print('parameters: {}, log_str: {}'.format(parameters, log_str))
    return [mean_e, std_e, mean_i, std_i]

In [None]:
# ================================= ANALYSIS: =========================================

In [1]:
import datetime as dt
from brian2 import figure, grid, axhline, xlabel, ylabel, plot, legend, errorbar, title

In [2]:
def write_to_logfile(params, log_str, opt_fname_postfix='', version='v6'):
    fname = './results/'+version
    if(opt_fname_postfix!=''):
        fname += '_' + opt_fname_postfix
    fname += '.txt'
    
    prefix = '[{}] (a={}, b={}, ws={}, n={}, tau_g={}, I={})'.format(dt.datetime.now(), params[0], params[1], params[2], params[3], params[4], params[5])
    full_str = prefix + ' ' + log_str + '\n'
    with open(fname, 'a') as f:
        f.write(full_str)

In [3]:
def unwrap_var_value_from_param_str(var_str, params_str):
    tmp = params_str.split(',')
    tmp = list(filter(lambda s: var_str in s, tmp))[0]
    val_str = tmp.split('=')[1]
    return float(val_str)

def parser_helper_get_values_for(var_str, parameters):
    vals = []
    for params_arr in parameters:
        vals += [unwrap_var_value_from_param_str(var_str, params_arr[1])]
        
    return vals

def get_description_from_str_without(var_str, params_str):
    tmp = params_str.split(',')
    tmp = list(filter(lambda s: var_str not in s, tmp))
    res = ''
    for substr in tmp:
        res += substr + ', '
    return res[:-2]

In [4]:
def read_from_logfile(fname, const_value, const_str='a', read_all=False):
    with open(fname) as file:
        file_contents = file.read()

    tmp = file_contents.split('\n')
    tmp = list(filter(lambda s: 'mean_e' in s, tmp))
    
    freq_exc = []
    std_exc = []
    freq_inh = []
    std_inh = []
    parameters = []
    parsed_str_ctr = 0
    for res_str in tmp:
        res_arr_composite = res_str.split(' mean_e: ')
        res_arr = res_arr_composite[1].split(' ')
        parameters_arr = res_arr_composite[0].split('] (')
        datetime_str = parameters_arr[0][1:]
        parametrisation_str = parameters_arr[1][:-1]
        
        if(read_all or const_value == unwrap_var_value_from_param_str(const_str, parametrisation_str)):
            parameters += [[datetime_str, parametrisation_str]]

            freq_exc += [float(res_arr[0])]
            std_exc += [float(res_arr[2])]
            freq_inh += [float(res_arr[4])]
            std_inh += [float(res_arr[6])]

            parsed_str_ctr += 1
    
    assert(len(freq_exc) == len(std_exc))
    assert(len(std_exc) == len(freq_inh))
    assert(len(freq_inh) == len(std_inh))
    
    avg_rate = []
    avg_std = []
    for i in range(0, len(freq_exc)):
        avg_rate += [(freq_exc[i]+freq_inh[i])/2.]
        avg_std += [(std_exc[i]+std_inh[i])/2.]
    
    return [parameters, freq_exc, std_exc, freq_inh, std_inh, avg_rate, avg_std]

In [21]:
def plot_parsed_data(data, var_str='b', const_str='a', const_val=0.02):
    [parameters, freq_exc, std_exc, freq_inh, std_inh, avg_rate, avg_std] = data
    vals = parser_helper_get_values_for(var_str, parameters)
    
    figure(figsize=(9, 4))
    ylabel('Average neuronal firing rate ($Hz$)')
    xlabel('$'+var_str+'$-value, ($'+const_str+'='+"{:3.2f}".format(const_val)+'$)')
    grid(True)

    errorbar(vals, freq_exc, yerr=std_exc, fmt='b--*')
    errorbar(vals, freq_inh, yerr=std_inh, fmt='g-o')
    errorbar(vals, avg_rate, yerr=avg_std, fmt='r--s')

    legend(['Excitatory neurons', 'Inhibitory neurons', 'Population average'])
    title('effect of $'+var_str+'$ on average network firing rate,\n' + get_description_from_str_without(var_str, parameters[0][1]));

In [24]:
from mpl_toolkits.mplot3d import Axes3D

def plot_3d(data, var1_str, var2_str):
    [parameters, freq_exc, std_exc, freq_inh, std_inh, avg_rate, avg_std] = data
    vals1 = parser_helper_get_values_for(var1_str, parameters)
    vals2 = parser_helper_get_values_for(var2_str, parameters)
    
    fig = figure(figsize=(9, 4))
    ax = fig.gca(projection='3d')
    ax.set_zlabel('Avg. firing rate ($Hz$)')
    ax.set_xlabel('$'+var1_str+'$-value')
    ax.set_ylabel('$'+var2_str+'$-value')
    grid(True)
    
    ax.scatter(vals1, vals2, freq_exc)
    ax.scatter(vals1, vals2, freq_inh)
    ax.scatter(vals1, vals2, avg_rate)
    
    legend(['Excitatory neurons', 'Inhibitory neurons', 'Population average'])
    description = get_description_from_str_without(var1_str, parameters[0][1])
    description = get_description_from_str_without(var2_str, description)
    title('effect of $'+var1_str+'$ and $'+var2_str+'$ on average network firing rate,\n' + description);