# Parameter Recovery
---
This notebook is for generating figure 2 of: $\textit{A Hidden Markov Model with Drift Diffusion Model
Emissions (HMM-DDM) for the Analysis of Animal Choice Data}$

In [1]:
# load required libraries
using DriftDiffusionModels
using Random
using Distributions

## Shared Parameters

This set of code will define the parameters that are shared across the "high", "medium" and "low" separation experiments.

In [7]:
# Transition matrix and initial state distribution
A = [0.95 0.05; 0.2 0.8]
π₀ = [0.89, 0.11];

t_lengths = [10, 100, 1000, 1000, 10000, 100000]

# Starting conditions
A_guess = [0.5 0.5; 0.5 0.5]
π₀_guess = [0.5, 0.5];

# Initial Guesses
ddm_1_guess = DriftDiffusionModel(
    v = 1.0,
    B = 5.0,
    τ = 0.5,
)

ddm_2_guess = DriftDiffusionModel(
    v = 0.5,
    B = 2.5,
    τ = 0.25,
)

hmm_init = PriorHMM(
     π₀_guess,
     A_guess,
    [ddm_1_guess, ddm_2_guess],
    ones(2, 2) .*  2,
    ones(2)
)

PriorHMM{Float64, DriftDiffusionModel}([0.5, 0.5], [0.5 0.5; 0.5 0.5], DriftDiffusionModel[DriftDiffusionModel(5.0, 1.0, 0.5, 0.5, 1.0), DriftDiffusionModel(2.5, 0.5, 0.5, 0.25, 1.0)], [2.0 2.0; 2.0 2.0], [1.0, 1.0])

## High Separation
This section will assess how well the HMM-DDM where the two DDMs are "highly" separated i.e., the states are very distinct.

In [8]:
# this part will define a DDM with an "engaged" agent. In these models vis drift rate, B is boundary separation, and τ is non-decision time. We assume the initial state (defined as the relative fraction between correct and wrong choice is 0.5, i.e., unbiased), we also assume the noise coefficient σ = 1.0.
engaged_ddm = DriftDiffusionModel(
    v = 0.5,
    B = 3.0,
    τ = 0.3
)

disengaged_ddm = DriftDiffusionModel(
    v = 0.0,
    B = 0.2,
    τ = 0.0
)

# our hmm
high_hmm_true = PriorHMM(
    π₀,
    A,
    [engaged_ddm, disengaged_ddm],
    ones(2, 2) .*  2,
    ones(2)
)

PriorHMM{Float64, DriftDiffusionModel}([0.89, 0.11], [0.95 0.05; 0.2 0.8], DriftDiffusionModel[DriftDiffusionModel(3.0, 0.5, 0.5, 0.3, 1.0), DriftDiffusionModel(0.2, 0.0, 0.5, 0.0, 1.0)], [2.0 2.0; 2.0 2.0], [1.0, 1.0])

## Medium Separation
This section will assess how well the HMM-DDM where the two DDMs are somewhat separated i.e., the states are reasonably distinct.

In [9]:
less_engaged_ddm = DriftDiffusionModel(
    v = 0.2,
    B = 1.0,
    τ = 0.35
)

# our hmm
medium_hmm_true = PriorHMM(
    π₀,
    A,
    [engaged_ddm, less_engaged_ddm],
    ones(2, 2) .*  2,
    ones(2)
)

PriorHMM{Float64, DriftDiffusionModel}([0.89, 0.11], [0.95 0.05; 0.2 0.8], DriftDiffusionModel[DriftDiffusionModel(3.0, 0.5, 0.5, 0.3, 1.0), DriftDiffusionModel(1.0, 0.2, 0.5, 0.35, 1.0)], [2.0 2.0; 2.0 2.0], [1.0, 1.0])

## Low Separation
This section will assess how well the HMM-DDM where the two DDMs are barely separated i.e., the states are almost identical.

In [10]:
also_engaged_ddm = DriftDiffusionModel(
    v = 0.4,
    B = 2.8,
    τ = 0.25
)

# our hmm
low_hmm_true = PriorHMM(
    π₀,
    A,
    [engaged_ddm, also_engaged_ddm],
    ones(2, 2) .*  2,
    ones(2)
)

PriorHMM{Float64, DriftDiffusionModel}([0.89, 0.11], [0.95 0.05; 0.2 0.8], DriftDiffusionModel[DriftDiffusionModel(3.0, 0.5, 0.5, 0.3, 1.0), DriftDiffusionModel(2.8, 0.4, 0.5, 0.25, 1.0)], [2.0 2.0; 2.0 2.0], [1.0, 1.0])

## Generate data
---
In this cell we will generate data from all of the "true" HMMs

In [None]:
# create a dict-of-dicts to store the data
data_dict = Dict(
    "high" => Dict(),
    "medium" => Dict(),
    "low" => Dict()
)

for model in ["high", "medium", "low"]
    for t_length in t_lengths
        # generate data
        if model == "high"
            data_dict[model][t_length] = rand(high_hmm_true, t_length)
        elseif model == "medium"
            data_dict[model][t_length] = rand(medium_hmm_true, t_length)
        else
            data_dict[model][t_length] = rand(low_hmm_true, t_length)
        end
    end
end

In [None]:
learned_model_dict = Dict(
    "high" => Dict(),
    "medium" => Dict(),
    "low" => Dict()
)

for model in ["high", "medium", "low"]
    for t_length in t_lengths
        # fit the model to the data
        learned_model_dict[model][t_length] = baum_welch(hmm_init, data_dict[model][t_length])
    end
end