In [None]:
;;
#require "pkp"

In [None]:
open Owl

open Gp

# Random dot kinematograms

In [None]:
type dot =
  [ `random of float * float | `coherent of float * float ] * int * (float * float)

In [None]:
let norm (dx, dy) = Maths.(sqrt (sqr dx +. sqr dy))

In [None]:
let is_within_disk ?(radius = 1.) (a, b) = norm (a, b) < radius

In [None]:
let left_motion v = -.v, 0.
and right_motion v = v, 0.

and random_motion v =
  let theta = Random.float Const.pi2 in
  v *. cos theta, v *. sin theta

In [None]:
let plottable (dots : dot list) =
  dots
  |> List.filter (fun (_, _, d) -> is_within_disk d)
  |> Array.of_list
  |> Array.map (function _, _, (a, b) -> [| a; b |])
  |> Mat.of_arrays

In [None]:
let rec sample_position () =
  let radius = 2. in
  let a, b = -1. +. Random.float 2., -1. +. Random.float 2. in
  let a = radius *. a
  and b = radius *. b in
  if is_within_disk ~radius (a, b) then a, b else sample_position ()

In [None]:
let sample_lifetime () = 1 + int_of_float (Owl_stats.exponential_rvs ~lambda:1.)

In [None]:
let rec fresh_dot typ =
  let a, b = sample_position () in
  let lt = sample_lifetime () in
  let typ =
    match typ with
    | `random v -> `random (random_motion (norm v))
    | `coherent v -> `coherent v
  in
  (typ, lt, (a, b) : dot)

In [None]:
let update ((typ, lt, (a, b)) : dot) =
  if lt = 0
  then fresh_dot typ
  else (
    match typ with
    | `random (dx, dy) -> typ, lt - 1, (a +. dx, b +. dy)
    | `coherent (dx, dy) -> typ, lt - 1, (a +. dx, b +. dy))

In [None]:
let rdm_trial ~c ~motion n_steps =
  let ph = Jupyter_notebook.display "text/html" "" in
  let size = 300, 300 in
  let props =
    [ barebone
    ; margins [ `top 0.9; `right 0.9; `bottom 0.1; `left 0.1 ]
    ; xrange (-1.2, 1.2)
    ; yrange (-1.2, 1.2)
    ]
  in
  let rec iterate k dots =
    if k < n_steps
    then (
      let fig (module P : Plot) =
        P.plot (A (plottable dots)) ~style:"p pt 7 lc 8 ps 0.3" props
      in
      Juplot.draw ~display_id:ph ~size fig;
      iterate (k + 1) List.(rev_map update dots))
  in
  let fixation (module P : Plot) =
    P.plot
      (S "-2")
      (props
      @ [ set "arrow 1 from first -0.1, first 0 to first 0.1, first 0 nohead lc 8"
        ; set "arrow 2 from first 0, first -0.1 to first 0, 0.1 nohead lc 8"
        ])
  in
  Juplot.draw ~display_id:ph ~size fixation;
  Unix.sleepf 1.0;
  iterate
    0
    List.(
      init 300 (fun _ ->
          fresh_dot
            (if Random.float 1. < c
            then `coherent motion
            else `random (random_motion (norm motion)))));
  Juplot.draw ~display_id:ph ~size fixation

In [None]:
let _ =
  let m = if Random.bool () then `left else `right in
  rdm_trial
    ~c:0.1
    ~motion:
      ((match m with
       | `left -> left_motion
       | `right -> right_motion)
         0.05)
    20;
  m

# Evidence accumulation

In [None]:
let drift_diffusion ~mu ~sigma =
  let good_decision = if mu > 0. then `left else `right in
  let dt = 1E-3 in
  let unpack list =
    list |> List.rev |> Array.of_list |> fun m -> Mat.of_array m (-1) 1
  in
  let rec iter t accu x =
    if x > 1.
    then `left, good_decision, dt *. float t, unpack accu
    else if x < -1.
    then `right, good_decision, dt *. float t, unpack accu
    else
      iter
        (t + 1)
        (x :: accu)
        (x +. (dt *. mu) +. (sqrt dt *. Owl_stats.gaussian_rvs ~mu:0. ~sigma))
  in
  iter 0 [] 0.

In [None]:
let plot_trials x =
  let x = List.sort (fun (_, _, t1, _) (_, _, t2, _) -> compare t1 t2) x in
  let n = List.length x in
  let fig (module P : Plot) =
    P.plots
      List.(
        mapi
          (fun i (_, gd, _, timecourse) ->
            item
              (A timecourse)
              ~using:"(0.001*$0):1"
              ~style:
                (Printf.sprintf
                   "l lc palette cb %i"
                   (match gd with
                   | `left -> n - i
                   | `right -> -(n - i))))
          x)
      [ barebone
      ; borders [ `left ]
      ; xlabel "time (s)"
      ; ylabel "decision variable"
      ; set "arrow 1 from graph 0, first 1 to graph 1, first 1 nohead lc 8 dt 2 front"
      ; set "arrow 2 from graph 0, first -1 to graph 1, first -1 nohead lc 8 dt 2 front"
      ; cbrange (-.float n, float n)
      ; yrange (-1.1, 1.1)
      ; xtics (`regular [ 0.; 0.5 ])
      ]
  in
  Juplot.draw ~size:(400, 300) fig

### A few sample trials

In [None]:
let _ =
  let drift = 1.0 in
  plot_trials
    (List.init 10 (fun _ ->
         let mu = if Random.bool () then drift else -.drift in
         drift_diffusion ~mu ~sigma:0.5))

### Psychometric curve

In [None]:
let pct_correct ~mu ~sigma =
  Array.init 5000 (fun _ ->
      let mu = if Random.bool () then mu else -.mu in
      let decision, good_decision, _, _ = drift_diffusion ~mu ~sigma in
      if decision = good_decision then 1. else 0.)
  |> Stats.mean

In [None]:
let _ =
  let c = Mat.logspace ~base:2. (-6.) 0. 10 in
  let pc = (Pkp.Misc.with_indicator Mat.map) (fun c -> pct_correct ~mu:c ~sigma:0.5) c in
  let fig (module P : Plot) =
    P.plot
      (L [ c; pc ])
      ~style:"lp pt 7 lc 8"
      (default_props
      @ [ set "log x"
        ; xlabel "evidence strength"
        ; ylabel "% correct"
        ; yrange (0.5, 1.)
        ; ytics (`regular [ 0.5; 0.1 ])
        ])
  in
  Juplot.draw ~size:(300, 200) fig

### Distributions of reaction times

In [None]:
let reaction_times =
  Mat.init 1 10000 (fun _ ->
      let _, _, t_dec, _ = drift_diffusion ~mu:2. ~sigma:0.2 in
      t_dec)

In [None]:
let _ =
  let open Gp in
  let fig (module P : Plot) =
    P.plot
      (A Pkp.Misc.(hist ~n_bins:50 reaction_times))
      ~style:"boxes fs solid 0.5 lc 8"
      [ barebone; borders [ `bottom ]; xtics `auto; xlabel "reaction time" ]
  in
  Juplot.draw ~size:(400, 200) fig