In [1]:
using Turing
using Distributions
using LinearAlgebra
using StatsFuns
using FastGaussQuadrature
using Bijectors
using HiddenMarkovModels
using DataFrames
using CSV
using Dates
using StatsAPI
using DensityInterface
using StatsPlots
using ReverseDiff
using Random
using MCMCChains

In [2]:
# --- load ---
df = CSV.read("../data/mouse_df.csv", DataFrame)

# --- pick mouse with most trials ---
trial_counts = combine(groupby(df, :name), nrow => :trial_count)
sort!(trial_counts, :trial_count, rev=true)
moi = trial_counts[1, :name]
mouse_df = df[df.name .== moi, :]

# --- drop omissions ---
if :outcome ∈ names(mouse_df)
    mouse_df = filter(:outcome => x -> x != "omission", mouse_df)
end

# --- drop invalid RTs (e.g., "NAN") ---
valid_rt_mask = .!(ismissing.(mouse_df.rt) .| (uppercase.(string.(mouse_df.rt)) .== "NAN"))
mouse_df = mouse_df[valid_rt_mask, :]

# coerce RTs to Float64
RTs = Float64.(mouse_df.rt)
choices = mouse_df.choice
outcomes = mouse_df.outcome
correct_side = mouse_df.correct_side

# coerce choices to 1 and 0
choices = Vector{Int}(map(x -> x == "right" ? 1 : 0, choices))
outcomes = Vector{Int}(map(x -> x == "correct" ? 1 : 0, outcomes))
correct_side = Vector{Int}(map(x -> x == "right" ? 1 : -1, correct_side))

12311-element Vector{Int64}:
 -1
  1
 -1
  1
  1
 -1
 -1
 -1
  1
  1
  ⋮
  1
 -1
  1
 -1
 -1
 -1
 -1
  1
 -1

In [3]:
@model function hierarchical_ar_p(df::DataFrame, p::Int)
    # Pre-process data outside the model for efficiency
    processed_data = prepare_ar_data(df, p)
    
    subjects = processed_data.subjects
    n_subjects = length(subjects)
    y = processed_data.y  # log RTs for valid observations
    X = processed_data.X  # AR design matrix [n_obs × p]
    subject_indices = processed_data.subject_indices  # which subject each obs belongs to
    
    # Group level priors/hyperpriors
    β_μ ~ MvNormal(zeros(p), 1.0*I)
    β_σ ~ filldist(Exponential(0.5), p)
    σ_μ ~ Exponential(0.1)
    σ_σ ~ Exponential(0.1)
    
    # Individual level parameters (vectorized)
    β_raw ~ filldist(MvNormal(zeros(p), I), n_subjects)  # Non-centered parameterization
    β = [β_μ + β_σ .* β_raw[s] for s in 1:n_subjects]    # Transform to centered
    
    σ ~ filldist(LogNormal(log(σ_μ), σ_σ), n_subjects)
    
    # Vectorized likelihood - no loops or conditionals
    μ = [dot(X[i, :], β[subject_indices[i]]) for i in 1:length(y)]
    σ_vec = σ[subject_indices]  # Broadcast subject-specific sigmas
    
    y ~ MvNormal(μ, Diagonal(σ_vec.^2))
end

# Helper function to prepare data (call once before sampling)
function prepare_ar_data(df::DataFrame, p::Int)
    # Sort by subject and time if not already sorted
    df_sorted = sort(df, [:name, :trial])  # Assuming you have a trial/time column
    
    subjects = unique(df_sorted.name)
    subject_to_idx = Dict(subj => i for (i, subj) in enumerate(subjects))
    
    log_rt = log.(df_sorted.rt)
    valid_indices = Int[]
    y_vals = Float64[]
    X_rows = Vector{Float64}[]
    subj_indices = Int[]
    
    # Process each subject separately
    for subj in subjects
        subj_mask = df_sorted.name .== subj
        subj_log_rt = log_rt[subj_mask]
        subj_idx = subject_to_idx[subj]
        
        # For this subject, create AR design matrix
        for t in (p+1):length(subj_log_rt)
            push!(y_vals, subj_log_rt[t])
            push!(X_rows, subj_log_rt[t-p:t-1])  # Lagged values
            push!(subj_indices, subj_idx)
        end
    end
    
    # Convert to matrix
    X = reduce(hcat, X_rows)'  # [n_obs × p]
    
    return (
        subjects = subjects,
        y = y_vals,
        X = X,
        subject_indices = subj_indices
    )
end

prepare_ar_data (generic function with 1 method)

In [None]:
p = 3

# Call once to prepare data
processed = prepare_ar_data(df, p)

# Then sample normally
chain = sample(hierarchical_ar_p(df, p), NUTS(), 1000)

[32mSampling   0%|█                                         |  ETA: N/A[39m
