In [None]:
import numpy as np
import scipy.special
    
def genCorrelatedSpikeTrains2(spike_opts = None,varargin = None): 
    # function S = genCorrelatedSpikeTrains(spike_opts)
    
    # Function to generage a set of correlated spike trains by using a Hawkes
# process, or a discrete approximation of a Hawkes process.
    
    # 2017 - Adam Charles
    
    ###########################################################################
## Input parsing
    
    spike_opts = check_spike_opts(spike_opts)
    
    tmax = spike_opts.dt * spike_opts.nt
    
    N_node = spike_opts.K
    
    N_bg = spike_opts.N_bg
    
    if len(varargin) > 1:
        batch_sz = varargin[0]
        if len(batch_sz)==0:
            batch_sz = N_node
    else:
        batch_sz = N_node
    
    if len(varargin) > 2:
        n_locs = varargin[2]
    else:
        n_locs = []
    
    if len(varargin) > 3:
        discrete_flag = varargin[3]
    else:
        discrete_flag = True
    
    if len(discrete_flag)==0:
        discrete_flag = True
    
    if len(varargin) > 4:
        verbose = varargin[4]
    else:
        verbose = False
    
    if (not isfield(spike_opts,'selfact') ) or len(spike_opts.selfact)==0:
        selfAct = 1.26
    else:
        selfAct = spike_opts.selfact
    
    batch_num = np.ceil(N_node / batch_sz)
    
    batch_bg = np.ceil(N_bg / batch_num)
    
    ###########################################################################
## Sample connectivity for Hawkes process
# spike_opts.rate = 0.2;
    ascale = 4
    bscale = 2
    if verbose:
        print('Generating small world connectivity...' % ())
    
    MU = cell(batch_num,1)
    
    A = cell(batch_num,1)
    
    B = cell(batch_num,1)
    
    for kk in np.arange(1,batch_num+1).reshape(-1):
        N_now = np.amin(batch_sz,N_node - (kk - 1) * batch_sz)
        N_nowB = np.amin(batch_bg,N_bg - (kk - 1) * batch_bg)
        if discrete_flag:
            A[kk] = sampSmallWorldMat(np.array([N_now,N_nowB]),10,0.3,0.9,spike_opts.burst_mean,n_locs)
        else:
            if not discrete_flag :
                A[kk] = sampSmallWorldMat(np.array([N_now,N_nowB]),10,0.3,0.9,spike_opts.burst_mean,n_locs)
        #     A{kk}  = 0.98*A{kk}/max(abs(eig(A{kk})));                              # Normalize excitation matrix
        A[kk] = ascale * A[kk] / mean(sum(A[kk]))
        MU[kk] = np.array([[gamrnd(1,spike_opts.rate,np.array([N_now,1]))],[gamrnd(1,spike_opts.rate,np.array([N_nowB,1]))]])
        B[kk] = np.array([[gamrnd(3,bscale,np.array([N_now,1]))],[gamrnd(3,bscale,np.array([N_nowB,1]))]])
        if verbose:
            print('.' % ())
    
    if verbose:
        print('done.\n' % ())
    
    # B{1} = 0*B{1}+bscale/2;
    A[0][logical[np.eye[len[A[0]]]]] = selfAct * B[kk]
    # A{1}(logical(eye(length(A{1})))) = 1.26*B{kk};
###########################################################################
## Set up and run marked point-process
    
    evt = cell(batch_num,1)
    
    evm = cell(batch_num,1)
    
    # extSc = 3;
# inbSc = 1;
# extSc = max(1,3+randn(N_now,1));
    extSc = np.amax(0.3,1 + 0.3 * np.random.randn(N_now,1))
    # extSc = 4.5;
# inbSc = 1+0.1*randn(N_now,1);
    inbSc = extSc / 2
    # inbSc = extSc./(2+0.2*randn(N_now,1));
    gamma = lambda t = None: np.exp(- t)
    
    gammalen = 10
    
    wrand = lambda w = None: sum(np.array([[0],[cumsum(w)]]) < np.random.rand(1) * sum(w))
    
    genMKF = lambda m = None,a = None: lambda t = None,ht = None,hm = None: wrand(m + a(:,hm) * scipy.special.gamma(t - ht))
    genCIF = lambda summ = None,suma = None: lambda tcurr = None,ht = None,hm = None: summ + suma(hm(:,1)) * scipy.special.gamma(tcurr - ht)
    if verbose:
        print('Simulating Hawkes process...' % ())
    
    alpha = 3
    rectfun = lambda z = None: np.log(1 + np.exp(alpha * z))
    for kk in np.arange(1,batch_num+1).reshape(-1):
        if discrete_flag:
            yt = np.zeros((MU[kk].shape,MU[kk].shape))
            zt = np.zeros((MU[kk].shape,MU[kk].shape))
            evt[kk] = []
            evm[kk] = []
            yt = yt + 5
            for tt in np.arange(1,spike_opts.nt+1).reshape(-1):
                #             xt = rand(size(MU{kk}))<1-exp(-rectfun(((zt-yt)*mean(MU{kk})+1).*MU{kk}).*spike_opts.dt);    # See which neurons fired currently
                xt = np.random.rand(MU[kk].shape) < 1 - np.exp(np.multiply(np.multiply(- rectfun((zt - yt + 1)),MU[kk]),spike_opts.dt))
                zt = np.multiply(scipy.special.gamma(extSc * spike_opts.dt),zt) + A[kk] * (xt)
                yt = np.multiply(scipy.special.gamma(inbSc * spike_opts.dt),yt) + np.multiply(B[kk],xt)
                evt[kk] = cat(1,evt[kk],(tt - 1 + np.random.rand(sum(xt),1)) * spike_opts.dt)
                evm[kk] = cat(1,evm[kk],vec(find(xt)))
        else:
            if not discrete_flag :
                MUs = sum(MU[kk])
                As = np.sum(A[kk], 1-1)
                CIF = genCIF(MUs,As)
                MKF = genMKF(MU[kk],A[kk])
                evt[kk],evm[kk] = markpointproc(CIF,[],MKF,tmax,inf,1,gammalen)
        if verbose:
            print('.' % ())
    
    if verbose:
        print('done.\n' % ())
    
    ###########################################################################
## Concatenate all events
    
    evt_cell = []
    evm_cell = []
    evt_bg = []
    evm_bg = []
    if verbose:
        print('Concatenating events...' % ())
    
    for kk in np.arange(1,batch_num+1).reshape(-1):
        N_now = np.amin(batch_sz,N_node - (kk - 1) * batch_sz)
        evt_cell = cat(1,evt_cell,evt[kk](evm[kk] <= N_now))
        evm_cell = cat(1,evm_cell,(kk - 1) * batch_num + evm[kk](evm[kk] <= N_now))
        evt_bg = cat(1,evt_bg,evt[kk](evm[kk] > N_now))
        evm_bg = cat(1,evm_bg,(kk - 1) * batch_bg + evm[kk](evm[kk] > N_now) - N_now)
        if verbose:
            print('.' % ())
    
    if verbose:
        print('done\n' % ())
    
    ###########################################################################
## Partition events into bins
    
    if verbose:
        print('Turning events to bin counts...' % ())
    
    S.soma = binSpikeTrains(evt_cell,evm_cell,N_node,spike_opts.dt,spike_opts.nt)
    
    S.bg = binSpikeTrains(evt_bg,evm_bg,N_bg,spike_opts.dt,spike_opts.nt)
    
    if verbose:
        print('done.\n' % ())
    
    ###########################################################################
## Output parsing
    
    if nargout > 1:
        net_params.A = A
        net_params.mu = MU
        varargout[0] = net_params
    
    if nargout > 2:
        ev.evt = evt
        ev.evm = evm
        varargout[2] = ev
    
    return S,varargout
    
    ###########################################################################
###########################################################################
    return S,varargout