# SOLUTIONS

In [None]:
;;
#require "pkp"

open Owl

open Gp

module A = Pkp.Balanced_net

open A

Number of input synapses per neuron, used throughout the notebook (cf lecture):

In [None]:
let k = 100

# 1. Poisson neurons → they will provide input to our network

In [None]:
let input = Array.init k (fun _ -> poisson 5.0) (* 5.0 is the firing rate in Hz *)

In [None]:
let net = { neurons = [ input ]; connections = [] }

In [None]:
let _ = simulate ~duration:1.0 net

In [None]:
let _ =
  let spikes = input |> Array.map spikes |> raster in
  let figure (module P : Plot) =
    P.plot
      (A spikes)
      ~style:"p pt 7 lc 8 ps 0.4"
      [ barebone
      ; borders [ `bottom ]
      ; xtics `auto ~o:"nomirror"
      ; offsets [ `bottom (`graph 0.1) ]
      ; xlabel "time [s]"
      ; ylabel "neurons"
      ]
  in
  Juplot.draw ~fmt:`svg ~size:(600, 300) figure

In [None]:
let _ =
  let firing_rates =
    let duration = 100. in
    simulate ~duration net;
    input |> Array.map spikes |> Array.map (fun x -> float (List.length x) /. duration)
  in
  (* convert to owl array type, do some stats and plot *)
  let firing_rates = Mat.of_array firing_rates 1 (-1) in
  let mu = Mat.mean' firing_rates in
  let sem = Mat.std' firing_rates /. sqrt (float (Array.length input)) in
  let figure (module P : Plot) =
    P.plot
      (A firing_rates)
      ~style:"p pt 7 lc 8 ps 0.5"
      [ barebone
      ; borders [ `left ]
      ; ylabel "firing rate"
      ; ytics `auto ~o:"out nomirror"
      ; title (Printf.sprintf "mean = %.3f ± %.3f" mu sem)
      ]
  in
  Juplot.draw ~fmt:`svg ~size:(600, 300) figure

# 2. A single LIF neuron receiving Poisson input

In [None]:
let make_simple_net w =
  let single = lif () in
  let net =
    { neurons = [ input; [| single |] ]
    ; connections = [ all_to_all_connections ~from:input ~onto:[| single |] ~w:(w /. float k) ]
    }
  in
  single, net

In [None]:
let plot_response ~duration x =
  let figure (module P : Plot) =
    P.plots
      [ plottable_voltage ~duration x; plottable_spikes x ]
      [ barebone
      ; ytics `auto ~o:"out nomirror"
      ; margins [ `bottom 0.2 ]
      ; set
          "object 1 rectangle from first 0, graph -0.1 rto first 0.2, graph -0.02 fs \
           solid 1.0 noborder fc rgb 'black' noclip"
      ]
  in
  Juplot.draw ~fmt:`svg ~size:(500, 200) figure

In [None]:
let () =
  let single, net = make_simple_net 6.0 in
  let duration = 1.5 in
  simulate ~duration net;
  plot_response ~duration single

In [None]:
let firing_rate w =
  let duration = 100.0 in
  let single, net = make_simple_net w in
  simulate ~duration net;
  let count = List.length (spikes single) in
  float count /. duration

In [None]:
let () =
  let open Owl in
  let ws = Mat.linspace 0.0 25.0 20 in
  let rates = Mat.map firing_rate ws in
  let figure (module P : Plot) =
    P.plot
      (L [ ws; rates ])
      ~style:"lp pt 7 lc 8 ps 0.5"
      [ barebone
      ; borders [ `bottom; `left ]
      ; xtics (`regular [ 0.; 2. ]) ~o:"out nomirror"
      ; ytics `auto ~o:"out nomirror"
      ; xlabel "weight w"
      ; ylabel "firing rate"
      ]
  in
  Juplot.draw ~fmt:`svg ~size:(300, 200) figure

# 3. Single neuron receiving balanced E and I inputs

In [None]:
let make_simple_ei_net w =
  let rate = 5.0 in
  let w = w /. sqrt (float k) in
  let single = lif () in
  let input_e = Array.init k (fun _ -> poisson rate) in
  let input_i = Array.init k (fun _ -> poisson rate) in
  let net =
    { neurons = [ input_e; input_i; [| single |] ]
    ; connections =
        [ all_to_all_connections ~from:input_e ~onto:[| single |] ~w
        ; all_to_all_connections ~from:input_i ~onto:[| single |] ~w:(-.w)
        ]
    }
  in
  single, net

In [None]:
let () =
  let single, net = make_simple_ei_net 1.5 in
  let duration = 2.0 in
  simulate ~duration net;
  plot_response ~duration single

# 4. Full balanced network

In [None]:
type weights =
  { ex : float
  ; ix : float
  ; ee : float
  ; ei : float
  ; ie : float
  ; ii : float
  }

In [None]:
let default_weights =
  { ex = 1.0; ix = 0.8; ee = 1.0; ei = -2.0; ie = 1.0; ii = -1.8 }

In [None]:
let make_full_net w =
  let n = 2000 in
  let popE = Array.init n (fun i -> lif ~log_voltage:(i=0) ()) in
  let popI = Array.init n (fun i -> lif ~log_voltage:(i=0) ()) in
  let popX = Array.init n (fun _ -> poisson 5.0) in
  (* normalization function for the weights *)
  let f w = w /. sqrt (float k) in
  let connections =
    [ random_connections ~from:popE ~onto:popE ~k ~w:(f w.ee)
    ; random_connections ~from:popE ~onto:popI ~k ~w:(f w.ie)
    ; random_connections ~from:popI ~onto:popE ~k ~w:(f w.ei)
    ; random_connections ~from:popI ~onto:popI ~k ~w:(f w.ii)
    ; random_connections ~from:popX ~onto:popE ~k ~w:(f w.ex)
    ; random_connections ~from:popX ~onto:popI ~k ~w:(f w.ix)
    ]
  in
  let net = { neurons = [ popE; popI; popX ]; connections } in
  (popE, popI, popX), net

In [None]:
let plot_network_output ~duration (popE, popI, popX) =
  let keep = 100 in
  let popE = Array.sub popE 0 keep in
  let popI = Array.sub popI 0 keep in
  let popX = Array.sub popX 0 keep in
  let figure (module P : Plot) =
    let common =
      [ barebone; xrange (0.0, duration); margins [ `left 0.2; `right 0.95 ] ]
    in
    let plot_raster ~tm ~bm (pop, name, color) =
      P.plot
        (A (popX |> Array.map spikes |> raster))
        ~style:(Printf.sprintf "p pt 7 lc rgb '%s' ps 0.4" color)
        (common @ [ margins [ `top tm; `bottom bm ]; ylabel name ])
    in
    plot_raster ~tm:0.9 ~bm:0.7 (popX, "X neurons", "black");
    plot_raster ~tm:0.68 ~bm:0.48 (popE, "E neurons", "#e51e10");
    plot_raster ~tm:0.46 ~bm:0.26 (popI, "I neurons", "#56b4e9");
    P.plots
      [ plottable_voltage ~duration popE.(0); plottable_spikes popE.(0) ]
      (common
      @ [ margins [ `top 0.24; `bottom 0.1 ]
        ; borders [ `bottom ]
        ; xtics (`regular [ 0.; 1. ])
        ; xrange (0., 2.)
        ; offsets [ `bottom (`graph 0.1) ]
        ; xlabel "time"
        ; ylabel "V_m"
        ])
  in
  Juplot.draw ~fmt:`svg ~size:(400, 600) figure

In [None]:
let _ =
  let pops, net = make_full_net default_weights in
  let duration = 2.0 in
  simulate ~duration net;
  plot_network_output ~duration pops

In [None]:
let (popE, popI, popX), net = make_full_net default_weights

In [None]:
let _ = simulate ~duration:2. net

In [None]:
let rate ~duration pop =
  pop
  |> Array.map spikes
  |> Array.map List.length
  |> Array.map float
  |> fun m -> Mat.of_array m 1 (-1) |> fun m -> Mat.(m /$ duration) |> Mat.mean'

In [None]:
rate ~duration:2.0 popI

Mathematical calculations show that the firing rates $r_E$ and $r_I$ should solve:
$$ w_{EE} \times r_E + w_{EI} \times r_I + w_{EX} \times r_X = 0$$
$$ w_{IE} \times r_E + w_{II} \times r_I + w_{IX} \times r_X = 0$$


In [None]:
let w_mat w = Mat.of_arrays [| [| w.ee; w.ei |]; [| w.ie; w.ii |] |]

In [None]:
let solve w rx =
  let h = Mat.of_arrays [| [| w.ex |]; [| w.ix |] |] in
  let sol = Mat.(inv (w_mat w) *@ neg (rx $* h)) in
  sol

In [None]:
;;
solve default_weights 6.