# Comparing SSN with Non linear Hawkes, steady state rates at a fixed h, then for varying h

In [None]:
# Installation
using Revise
using LinearAlgebra,Statistics,StatsBase,Distributions
using Plots,NamedColors ; theme(:default)
using FFTW
using ProgressMeter
using Random
Random.seed!(0)
using HawkesSimulator; const global H = HawkesSimulator

In [None]:
# Initializing variables
# one E, one I population
v_rest = -70
k = 0.3
n = 2
tau = [20.0, 10.0]*1E-3   #time constant, [ E, I ]
w = [1.25 -0.65
     1.2 -0.5]
w_hawkes = w .* (k^(1/n))
time_step = 0.1*1E-2
tau_noise = 50*1E-3

In [None]:
function rate_powerlaw(v::Float64)
    diff = v - v_rest
    diff = (diff < 0) ? 0 : diff
    return k*(diff^n)
end

function plot_powerlaw()
    count = 90
    v_arr = zeros(count)
    v_arr[1] = -10
    for i in 2:count
        v_arr[i] = v_arr[i-1] +1
    end
    rates = rate_powerlaw.(v_arr)
    plot(v_arr, rates, xlabel="Voltage (mV)", ylabel="Rate (Hz)", label="rate = 0.3*([V]+)^2", legend=:bottomright, fmt=:png)
end

plot_powerlaw()

In [None]:
function plot_count(points)
    y = collect(1:length(points))
    plt = plot(xlabel="time (s)", ylabel="count")
    plot!(plt, points, y, label = "N(t)")
end

function plot_count(points_E, points_I)
    y = collect(1:length(points_E))
    plt = plot(xlabel="time (s)", ylabel="count")
    plot!(plt, points_E, y, label = "N_E(t)")
    y = collect(1:length(points_I))
    plot!(plt, points_I, y, label = "N_I(t)")
end

function plot_count(points_E, points_I, remean, rimean, count)
    y = collect(1:count)
    y1 = zeros(count)
    y2 = zeros(count)
    for i in 1:count
        y1[i] = remean*points_E[i]
        y2[i] = rimean*points_I[i]
    end
    plt = plot(xlabel="time (s)", ylabel="Spike Count", legend=:bottomright, fmt=:png)
    plot!(plt, points_E[1:count], y, label = "N_E(t)", color="blue")
    plot!(plt, points_I[1:count], y, label = "N_I(t)", color="red")
    plot!(plt, points_E[1:count],y1, label="E[N_E(t)]", color="dark blue")
    plot!(plt, points_I[1:count],y2, label="E[N_I(t)]", color="dark red")
end

In [None]:
function onedmat(x::Real)
  return cat(x;dims=2)
end

In [None]:
# Euler Method
function dv_ssn(v::Float64, h_i::Float64, noise_i::Float64, tau_i::Float64, w_arr::Vector{Float64}, rate_arr::Vector{Float64})
    num_neurons = size(w_arr, 1)
    dv = v_rest - v + h_i
    for i in 1:num_neurons
        dv += w_arr[i]*rate_arr[i]
    end
    return dv*time_step/tau_i
end

In [None]:
function simulate_ssn!(h, num_steps, t_arr, v_excite, v_inhibit, rate) # ignoring spikes, only comparing rates    
    v_excite[1] = v_rest
    v_inhibit[1] = v_rest
    
    for i in 2:num_steps
        rate[i,:] = [rate_powerlaw(v_excite[i-1]), rate_powerlaw(v_inhibit[i-1])]
        v_excite[i] = v_excite[i-1] + dv_ssn(v_excite[i-1], h, noise[1], tau[1], [w[1,1], w[1,2]], rate[i,:])
        v_inhibit[i] = v_inhibit[i-1] + dv_ssn(v_inhibit[i-1], h, noise[2], tau[2], [w[2,1], w[2,2]], rate[i,:])
        t_arr[i] = t_arr[i-1] + time_step        
    end
    
end

In [None]:
# Simulation

function simulate_hawkes!(networks, num_spikes)
    t_now = 0.0
    H.reset!.(networks) # clear spike trains etc
    for k in 1:num_spikes
        t_now = H.dynamics_step!(t_now, networks)
        if k%1_000 == 0
            H.clear_trains!(networks[1].postpops)
            H.clear_trains!(networks[2].postpops)
        end # clearing trains after every 1000 spikes
    end
    return t_now
end

In [None]:
function single_h_ssn_simulation()
    h = 5.0
    t_max = 1
    num_steps = Int(ceil(t_max/time_step))
    t_arr = zeros(num_steps)
    v_excite = zeros(num_steps)
    v_inhibit = zeros(num_steps)
    spike_trainE = Vector{Float64}()
    spike_trainI = Vector{Float64}()
    rate = zeros(num_steps, 2)
#     cov_noise = [alpha*h 0.0
#                 0.0 alpha*h]
    cov_noise = [0.0 0.0
                0.0 0.0]   # ignoring noise by setting cov_noise to 0
    simulate_ssn!(h, num_steps, t_arr, v_excite, v_inhibit, rate)
    return t_arr, v_excite, v_inhibit, rate
end

In [None]:
function single_h_nonlinear_hawkes_simulation()
    h = 5.0
    n_spikes = 500_000
    baseline_rate_e = [(k^(1/n))*h]
    baseline_rate_i = [(k^(1/n))*h]
    tau_E = 6.0
    tau_I = 2.0
    pop_E = H.PopulationExp(tau_E, H.NLRectifiedQuadratic())
    pop_I = H.PopulationExp(tau_I, H.NLRectifiedQuadratic())
    popstate_E = H.PopulationState(pop_E,baseline_rate_e)
    popstate_I = H.PopulationState(pop_I, baseline_rate_i)
    network_E = H.InputNetwork(popstate_E,[popstate_E, popstate_I],[onedmat(w_hawkes[1,1]),onedmat(w_hawkes[1,2])])
    network_I = H.InputNetwork(popstate_I,[popstate_E, popstate_I],[onedmat(w_hawkes[2,1]),onedmat(w_hawkes[2,2])])
    simulate_hawkes!([network_E, network_I],n_spikes)
    rate = [H.numerical_rate(popstate_E.trains_history[1][2000:end]), H.numerical_rate(popstate_I.trains_history[1][2000:end])]
    return rate, popstate_E.trains_history[1], popstate_I.trains_history[1], network_E, network_I
end

In [None]:
t_arr, v_excite, v_inhibit, rate = single_h_ssn_simulation()
println()

In [None]:
rate_hawkes,train_hawkes_E, train_hawkes_I, network_E, network_I = single_h_nonlinear_hawkes_simulation()
println()

In [None]:
println([v_excite[end], v_inhibit[end]])
println([rate[end,1], rate[end,2]])
println(rate_hawkes)
# the steady state rate of SSN is same as the mean spiking rate of the non linear Hawkes

In [None]:
# plot for Rate vs time
length = size(t_arr,1)
plt = plot(xlabel="time (s)", ylabel="Rate (Hz)", fmt = :png, legend=:bottomright)
plot!(plt, t_arr[1:length] , rate[1:length,1], label = "R_Linear_Excite", color="blue")
plot!(plt, t_arr[1:length] , rate[1:length,2], label = "R_Linear_Inhibit",color="red")
r_excite_hawkes = zeros(length)
r_inhibit_hawkes = zeros(length)
for i in 1:length
    r_excite_hawkes[i] = rate_hawkes[1]
    r_inhibit_hawkes[i] = rate_hawkes[2]
end
plot!(plt, t_arr[1:length], r_excite_hawkes, label = "R_Hawkes_Excite",linestyle=:dash,linewidth=2, color="dark blue")
plot!(plt, t_arr[1:length], r_inhibit_hawkes, label = "R_Hawkes_Inhibit",linestyle=:dash,linewidth=2, color="dark red")

In [None]:
# Extending the simulation for multiple h values

In [None]:
function run_ssn_multiple_h!(num_h, h_arr, rates_ssn)
    t_max = 1
    num_steps = Int(ceil(t_max/time_step))
    t_arr = zeros(num_steps)
    v_excite = zeros(num_steps)
    v_inhibit = zeros(num_steps)
    rate = zeros(num_steps, 2)
    @showprogress 1.0 "Running SSN multiple h..." for i in 1:num_h
        simulate_ssn!(h_arr[i], num_steps, t_arr, v_excite, v_inhibit, rate)
        rates_ssn[i,:] = rate[num_steps,:]
    end
end

In [None]:
function run_hawkes_multiple_h!(num_h, h_arr, rates_hawkes, pop_E, pop_I)
    n_spikes = 500_000
    @showprogress 1.0 "Running Hawkes multiple h..." for i in 1:num_h
        baseline_rate_e = [(k^(1/n))*(h_arr[i])]
        baseline_rate_i = [(k^(1/n))*(h_arr[i])]
        popstate_E = H.PopulationState(pop_E,baseline_rate_e)
        popstate_I = H.PopulationState(pop_I, baseline_rate_i)
        network_E = H.InputNetwork(popstate_E,[popstate_E, popstate_I],[onedmat(w_hawkes[1,1]),onedmat(w_hawkes[1,2])])
        network_I = H.InputNetwork(popstate_I,[popstate_E, popstate_I],[onedmat(w_hawkes[2,1]),onedmat(w_hawkes[2,2])])
        simulate_hawkes!([network_E, network_I],n_spikes)
        rates_hawkes[i,:] = [H.numerical_rate(popstate_E.trains_history[1][2000:end]), H.numerical_rate(popstate_I.trains_history[1][2000:end])]
        # rates_hawkes[i,:] = [H.numerical_rates(popstate_E)[1], H.numerical_rates(popstate_I)[1]]
    end
end

In [None]:
h_max = 70.0
step = 1
h_min = 1.0
num_h = Int(ceil((h_max-h_min+1)/step))
h_arr = zeros(num_h)
h_arr[1] = h_min
rates_ssn = zeros(num_h,2)
rates_hawkes_nonlinear = zeros(num_h,2)

function compare()
    for i in 2:num_h
        h_arr[i] = h_arr[i-1] + step
    end
    tau_E = 6.0
    tau_I = 2.0
    pop_E = H.PopulationExp(tau_E, H.NLRectifiedQuadratic())
    pop_I = H.PopulationExp(tau_I, H.NLRectifiedQuadratic())
    run_hawkes_multiple_h!(num_h, h_arr, rates_hawkes_nonlinear, pop_E, pop_I)
    run_ssn_multiple_h!(num_h, h_arr, rates_ssn)
end

In [None]:
compare()

plots for Rate vs h

In [None]:
length = num_h
plt = plot(xlabel="h (mV)", ylabel="Rate (Hz)", fmt = :png, legend=:bottomright)
plot!(plt, h_arr[1:length] , rates_ssn[1:length,1], label = "R_SSN_Excite", color="blue")
plot!(plt, h_arr[1:length] , rates_hawkes_nonlinear[1:length,1], label = "R_Hawkes_Excite", color="dark blue")

In [None]:
length = num_h

plt = plot(xlabel="h (mV)", ylabel="Rate (Hz)", fmt = :png, legend=:bottomright)
plot!(plt, h_arr[1:length] , rates_ssn[1:length,2], label = "R_SSN_Inhibit",color="red")
plot!(plt, h_arr[1:length] , rates_hawkes_nonlinear[1:length,2], label = "R_Hawkes_Inhibit",color="dark red")

In [None]:
y_e = zeros(num_h)
y_i = zeros(num_h)
for i in 2:num_h
    y_e[i] = abs(rates_hawkes_nonlinear[i,1] - rates_ssn[i,1])
    y_i[i] = abs(rates_hawkes_nonlinear[i,2] - rates_ssn[i,2])
end
plt = plot(xlabel = "h (mV)", ylabel="absolute error", fmt = :png, legend=:bottomright)
plot!(plt, h_arr, y_e, label="rateE", color="blue")
plot!(plt, h_arr, y_i, label="rateI", color="red")

#n_spikes = 500_000

In [None]:
y_e = zeros(num_h)
y_i = zeros(num_h)
for i in 2:num_h
    y_e[i] = 2*abs(rates_hawkes_nonlinear[i,1] - rates_ssn[i,1])/(rates_hawkes_nonlinear[i,1] + rates_ssn[i,1])
    y_i[i] = abs(rates_hawkes_nonlinear[i,2] - rates_ssn[i,2])/(rates_hawkes_nonlinear[i,2] + rates_ssn[i,2])
end
plt = plot(xlabel = "h (mV)", ylabel="Relative error", fmt = :png, legend=:topleft)
plot!(plt, h_arr, y_e, label="rateE", color="blue")
plot!(plt, h_arr, y_i, label="rateI", color="red")

#n_spikes = 500_000

Modifying above code to extract spiking

In [None]:
function spiking(probability)
    return rand(Float64) < probability
end

In [None]:
# modelling noise as a Multivariate Ornstein-Uhlenbeck process
function dnoise!(noise::Vector{Float64}, wiener::Vector{Float64}, cov_noise::Matrix{Float64})
    wiener = wiener + sqrt(time_step)*[rand(Normal(0,0.5)),rand(Normal(0,0.5))]
    dn = - noise .* time_step + sqrt(2*tau_noise*cov_noise)*wiener
    return dn ./ tau_noise
end

In [None]:
function simulate_ssn_with_spikes!(h, num_steps, t_arr, v_excite, v_inhibit, rate, spike_trainE, spike_trainI, cov_noise)
    
    v_excite[1] = v_rest
    v_inhibit[1] = v_rest
    rate[1,:] = [rate_powerlaw(v_excite[1]), rate_powerlaw(v_inhibit[1])]
    noise = zeros(2)
    wiener = zeros(2)
    @showprogress 1.0 "Running SSN..." for i in 2:num_steps
        t_arr[i] = t_arr[i-1] + time_step
        noise[:] = noise[:] .+ dnoise!(noise, wiener, cov_noise)
        v_excite[i] = v_excite[i-1] + dv_ssn(v_excite[i-1], h, noise[1], tau[1], [w[1,1], w[1,2]], rate[i-1,:])
        v_inhibit[i] = v_inhibit[i-1] + dv_ssn(v_inhibit[i-1], h, noise[2], tau[2], [w[2,1], w[2,2]], rate[i-1,:])
        rate[i,:] = [rate_powerlaw(v_excite[i]), rate_powerlaw(v_inhibit[i])]
        
        spike_probability = rate[i,:] .* (time_step)
        if spiking(spike_probability[1])
            push!(spike_trainE, t_arr[i])
        end
        if spiking(spike_probability[2])
            push!(spike_trainI, t_arr[i])
        end
    end
end

In [None]:
function get_ssn_spikes_multiple_h!(num_h, h_arr, rates_ssn, spikes_E, spikes_I)
    t_max = 1
    num_steps = Int(ceil(t_max/time_step))
    t_arr = zeros(num_steps)
    v_excite = zeros(num_steps)
    v_inhibit = zeros(num_steps)
    rate = zeros(num_steps, 2)
    @showprogress 1.0 "Running SSN multiple h..." for i in 1:num_h
        spike_trainE = Vector{Float64}()
        spike_trainI = Vector{Float64}()
        cov_noise = [k*h_arr[i] 0.0
                    0.0 k*h_arr[i]]
        simulate_ssn_with_spikes!(h_arr[i], num_steps, t_arr, v_excite, v_inhibit, rate, spike_trainE, spike_trainI, cov_noise)
        rates_ssn[i,:] = rate[num_steps,:]
        push!(spikes_E, spike_trainE)
        push!(spikes_I, spike_trainI)
    end
end

In [None]:
function get_hawkes_spikes_multiple_h!(num_h, h_arr, rates_hawkes, pop_E, pop_I, spikes_E, spikes_I)
    n_spikes = 500_000
    @showprogress 1.0 "Running Hawkes multiple h..." for i in 1:num_h
        baseline_rate_e = [(k^(1/n))*(h_arr[i])]
        baseline_rate_i = [(k^(1/n))*(h_arr[i])]
        popstate_E = H.PopulationState(pop_E,baseline_rate_e)
        popstate_I = H.PopulationState(pop_I, baseline_rate_i)
        network_E = H.InputNetwork(popstate_E,[popstate_E, popstate_I],[onedmat(w_hawkes[1,1]),onedmat(w_hawkes[1,2])])
        network_I = H.InputNetwork(popstate_I,[popstate_E, popstate_I],[onedmat(w_hawkes[2,1]),onedmat(w_hawkes[2,2])])
        simulate_hawkes!([network_E, network_I],n_spikes)
        push!(spikes_E, popstate_E.trains_history[1][:])
        push!(spikes_I, popstate_I.trains_history[1][:])
        rates_hawkes[i,:] = [H.numerical_rate(popstate_E.trains_history[1][2000:end]), H.numerical_rate(popstate_I.trains_history[1][2000:end])]
#         rates_hawkes[i,:] = [H.numerical_rates(popstate_E)[1], H.numerical_rates(popstate_I)[1]]
    end
end

In [None]:
num_h = 2
h_arr = zeros(num_h)
h_arr[1] = 9
h_arr[2] = 28

rates_ssn = zeros(num_h,2)
rates_hawkes_nonlinear = zeros(num_h,2)
spikes_hawkesE = Vector{Vector{Float64}}()
spikes_hawkesI = Vector{Vector{Float64}}()
spikes_ssnE = Vector{Vector{Float64}}()
spikes_ssnI = Vector{Vector{Float64}}()

function compare2()
    tau_E = 4.0
    tau_I = 2.0
    pop_E = H.PopulationExp(tau_E, H.NLRectifiedQuadratic())
    pop_I = H.PopulationExp(tau_I, H.NLRectifiedQuadratic())
    get_hawkes_spikes_multiple_h!(num_h, h_arr, rates_hawkes_nonlinear, pop_E, pop_I, spikes_hawkesE, spikes_hawkesI)
    get_ssn_spikes_multiple_h!(num_h, h_arr, rates_ssn, spikes_ssnE, spikes_ssnI)
end

In [None]:
compare2()

In [None]:
println([sizeof(spikes_hawkesE[1]), sizeof(spikes_hawkesE[1])])
println([sizeof(spikes_ssnE[1]), sizeof(spikes_ssnE[1])])

In [None]:
println([sizeof(spikes_hawkesE[2]), sizeof(spikes_hawkesE[2])])
println([sizeof(spikes_ssnE[2]), sizeof(spikes_ssnE[2])])

In [None]:
println(rates_hawkes_nonlinear[1,1])
println(rates_hawkes_nonlinear[1,2])
println(rates_hawkes_nonlinear[2,1])
println(rates_hawkes_nonlinear[2,2])

In [None]:
function rasterplot(spikes_E, spikes_I, tlims = (2000.,2005.) )
  _trainE = spikes_E
  plt=plot()
  trainE = filter(t-> tlims[1]< t < tlims[2],_trainE)
  nspk = size(trainE,1)
  scatter!(plt,trainE,fill(2,nspk),markersize=35, markercolor=:black,markershape=:vline,leg=false)
  _trainI = spikes_I
    trainI = filter(t-> tlims[1]< t < tlims[2],_trainI)
  nspk = size(trainI,1)
  scatter!(plt,trainI,fill(1,nspk),markersize=35, markercolor=:blue,markershape=:vline,leg=false)
  plot!(plt,ylims=(0,3),xlabel="time (s), h = 9 mV",fmt=:png)
end
rasterplot(spikes_hawkesE[1], spikes_hawkesI[1])

In [None]:
function rasterplot(spikes_E, spikes_I, tlims = (2000.,2005.) )
  _trainE = spikes_E
  plt=plot()
  trainE = filter(t-> tlims[1]< t < tlims[2],_trainE)
  nspk = size(trainE,1)
  scatter!(plt,trainE,fill(2,nspk),markersize=35, markercolor=:black,markershape=:vline,leg=false)
  _trainI = spikes_I
    trainI = filter(t-> tlims[1]< t < tlims[2],_trainI)
  nspk = size(trainI,1)
  scatter!(plt,trainI,fill(1,nspk),markersize=35, markercolor=:blue,markershape=:vline,leg=false)
  plot!(plt,ylims=(0,3),xlabel="time (s), h = 28 mV",fmt=:png)
end

rasterplot(spikes_hawkesE[1], spikes_hawkesI[1])