# SETUP

In [None]:
;;
#require "pkp"
open Owl
open Gp

In [None]:
module A = Pkp.Balanced_net (* module alias to simplify outputs *)

open A

In [None]:
let k = 100 (* number of (exc. or inh.) input synapses per neuron, cf lecture *)

# 1. Poisson neurons → they will provide input to our network

☝ Create a population of `n_neurons` Poisson neurons firing at 5 Hz -- cf. `poisson` function.

In [None]:
let input = ... 

☝ Now create a network (a simple record -- cf. `network` type), and simulate it for `duration = 3.0` s.

In [None]:
let net = ...

In [None]:
let _ = simulate ~duration:3.0 net

☝ Inspect the spiking behaviour of this input population, by plotting a spike raster:

In [None]:
let _ =
  let spikes = input |> Array.map spikes |> raster in
  let figure (module P : Plot) =
    P.plot
      (A spikes)
      ~style:"p pt 7 lc 8 ps 0.4"
      [ barebone; borders [ `bottom ]; xtics `auto; set "offsets graph 0.1, 0, 0, 0" ]
  in
  Juplot.draw ~fmt:`svg ~size:(600, 300) figure

☝ Check that neurons fire an average of `r_x` spikes per second, as they should.

In [None]:
let _ =
  let firing_rates =
    ...
    (* fill this in; should be a `float array` of [n_neurons] firing rates;
       NB: you will need to simulate for much longer than 1 s,
       so you can accurately estimate each neuron's firing rate *)
  in
  (* convert to owl array type, do some stats and plot *)
  let firing_rates = Mat.of_array firing_rates 1 (-1) in
  let mu = Mat.mean' firing_rates in
  let sem = Mat.std' firing_rates /. sqrt (float (Array.length input)) in
  let figure (module P : Plot) =
    P.plot
      (A firing_rates)
      ~style:"p pt 7 lc 8 ps 0.5"
      [ barebone
      ; borders [ `left ]
      ; ylabel "firing rate"
      ; ytics `auto ~o:"out nomirror"
      ; title (Printf.sprintf "mean = %.3f ± %.3f" mu sem)
      ]
  in
  Juplot.draw ~fmt:`svg ~size:(600, 300) figure

# 2. A single LIF neuron receiving Poisson input

☝ Now, create a function of type `float → neuron * network`, which takes a weight parameter `w` and:
   1. creates a single LIF neuron (function `lif`)
   2. creates connections from `input` (the Poisson neurons above) to the LIF neuron, with weight `w/k` (function `all_to_all_connections`)
   3. returns a `A.neuron * A.network`: (single LIF, network):

In [None]:
let make_simple_net w = 
  ... (* fill this in! *)

Here is a helper function to plot the activity (voltage timecourse + spikes) of the LIF neuron:

In [None]:
let plot_response ~duration x =
  let figure (module P : Plot) =
    P.plots
      [ plottable_voltage ~duration x; plottable_spikes x ]
      [ barebone
      ; ytics `auto ~o:"out nomirror"
      ; margins [ `bottom 0.2 ]
      ; set
          "object 1 rectangle from first 0, graph -0.1 rto first 0.2, graph -0.02 fs \
           solid 1.0 noborder fc rgb 'black' noclip"
      ]
  in
  Juplot.draw ~fmt:`svg ~size:(500, 200) figure

☝ Using the above function, together with your `make_simple_net` and `simulate ~duration:1.0` functions, explore the behaviour of this mini network. Start with `w=5.0` and increase it until you find that the LIF neuron's voltage goes above threshold.

In [None]:
...

☝ How does the firing rate of the LIF neuron depend on `w`? Plot this dependence.

In [None]:
...

(* you might want to use these plot properties for a plot of firing rate vs. w: *)

let props =
  [ barebone
  ; borders [ `bottom; `left ]
  ; xtics `auto ~o:"out nomirror"
  ; ytics `auto ~o:"out nomirror"
  ; xlabel "weight w * n_{neurons}"
  ; ylabel "firing rate"
  ]

# 3. Single neuron receiving balanced E and I inputs

Now we are going to simulate a single neuron, receiving:
1. excitatory input from `k` Poisson neurons (5 Hz), with weight $+w/\sqrt{K}$, and
2. inhibitory input from `k` Poisson neurons (5 Hz; another, independent set), with weight $-w/\sqrt{K}$

☝ Begin by writing a function of type `float → neuron * network` (similar to `make_simple_net` above) that sets up the whole network given the parameter `w`.

In [None]:
let make_simple_ei_net w = ...

☝ Now, repeat the analysis of the previous section:
1. Plot the voltage+spike response of your LIF neuron, and explore the effect of `w`
2. Plot the firing rate of the LIF neuron as a function of `w`.

You might want to reuse some of your previous code. 

In [None]:
...

# 4. Full balanced network

We are now ready to simulate the full network.

To begin with, let's define a custom record type to hold all our weight parameters (will come in handy later):

In [None]:
(* cf lecture slides: e.g. "ex" means "from x to e" *)
type weights =
  { ex : float
  ; ix : float
  ; ee : float
  ; ei : float
  ; ie : float
  ; ii : float
  }

This is a good set of default parameters you might want to use later:

In [None]:
let default_weights =
  { ex = 1.0; ix = 0.8; ee = 1.0; ei = -2.0; ie = 1.0; ii = -1.8 }

☝ Now, write a function of type `weights → (neuron array * neuron array * neuron array) * network` which:
1. creates a population of $N=1000$ Poisson neurons (5 Hz rate) -- call this `popX`
2. creates a population of $N$ (excitatory) LIF neurons -- call this `popE`
3. creates a population of $N$ (inhibitory) LIF neurons -- call this `popI`
4. sets up random connections as discussed in the lecture: each neuron in each of the {X, E, and I} population makes a connection onto `k` randomly chosen neurons in both `popE` and `popI`. You will want to use the `random_connections` function provided in `module A`. Set the connection weights appropriately, don't forget the $1/\sqrt{K}$ factor!
5. return a tuple with the 3 populations, along with the full network

In [None]:
let make_full_net weights =
  let n = 1000 in
  let popE = Array.init n (fun _ -> lif ()) in
  let popI = Array.init n (fun _ -> lif ()) in
  let ext = Array.init n_neurons (fun _ -> poisson rx) in
  let k = 100 in
  let normalized w = w /. sqrt (float k) in
  let connections =
    [ random_connections ~from:exc ~onto:exc ~k ~w:(normalized weights.wee)
    ; random_connections ~from:exc ~onto:inh ~k ~w:(normalized weights.wie)
    ; random_connections ~from:inh ~onto:exc ~k ~w:(normalized weights.wei)
    ; random_connections ~from:inh ~onto:inh ~k ~w:(normalized weights.wii)
    ; random_connections ~from:ext ~onto:exc ~k ~w:(normalized weights.wex)
    ; random_connections ~from:ext ~onto:inh ~k ~w:(normalized weights.wix)
    ]
  in
  let net = { neurons = [ exc; inh; ext ]; connections } in
  exc, inh, ext, net

In [None]:
let plot_network_output ~exc ~inh ~ext =
  let keep = 100 in
  let figure (module P : Plot) =
    P.plot
      (A (raster (Array.map spikes (Array.sub ext 0 keep))))
      ~style:"p pt 7 lc 8 ps 0.4"
      [ barebone; margins [ `left 0.2; `right 0.95; `top 0.9; `bottom 0.7 ] ];
    P.plot
      (A (raster (Array.map spikes (Array.sub exc 0 keep))))
      ~style:"p pt 7 lc 7 ps 0.4"
      [ barebone; margins [ `left 0.2; `right 0.95; `top 0.68; `bottom 0.48 ] ];
    P.plot
      (A (raster (Array.map spikes (Array.sub inh 0 keep))))
      ~style:"p pt 7 lc 3 ps 0.4"
      [ barebone; margins [ `left 0.2; `right 0.95; `top 0.46; `bottom 0.26 ] ]
  in
  Juplot.draw ~fmt:`svg ~size:(300, 600) figure

In [None]:
let _ =
  let exc, inh, ext, net = make_net ~rx:10. default_weights in
  simulate net 1.0;
  (* plot_network_output ~exc ~inh ~ext *) ()

Mathematical calculations show that the firing rates $r_E$ and $r_I$ should solve:
$$ 1.0 \times r_E -2 \times r_I + 1.0 \times r_X = 0$$
$$ 1.5 \times r_E -2 \times r_I + 0.9 \times r_X = 0 $$

---

In [None]:
let input = Array.init 1000 (fun _ -> poisson 10.)

let output = Array.init 1 (fun _ -> lif ())

let connections = [ random_connections ~from:input ~onto:output ~k:1 ~w:0.005 ]

let net = { neurons = [ input; output ]; connections }

let duration = 1.0 (* in seconds *)

let _ = simulate ~duration net

(* simplest possible analysis: what's the output firing rate? *)
let output_rate = float (List.length (spikes output.(0))) /. duration