In [1]:
%%capture
!pip install pystan
!pip install pandas

In [2]:
import nest_asyncio
nest_asyncio.apply()

In [3]:
import stan
import pandas as pd
import numpy as np

In [4]:
def preprocess_data(data):
    """Preprocess data for RLWM model fitting"""
    df_clean = data[(data['choice'] != -1) &
                    (data['key'] != -1) &
                    (data['cor'] != -1) &
                    (data['rew'] != -1) &
                    (data['rt'] > 0.15)]  # Match MATLAB rt threshold
    return df_clean

In [5]:
def select(pr):
    """Implement random choice selection based on probability distribution"""
    cumsum = np.cumsum(pr)
    return np.searchsorted(cumsum, np.random.random())

In [6]:
def compute_delay(stimuli, rewards):
    """Compute delay since last correct for each stimulus"""
    delay = np.full_like(stimuli, np.nan, dtype=float)
    for s in np.unique(stimuli):
        T = np.where(stimuli == s)[0]
        for t in T[1:]:
            y = np.where((stimuli[:t] == s) & (rewards[:t] == 1))[0]
            if len(y) > 0:
                delay[t] = t - y[-1]
    return delay

In [7]:
def compute_pcor(stimuli, rewards):
    """Compute number of previous correct responses"""
    pcor = np.zeros_like(stimuli, dtype=float)
    for s in np.unique(stimuli):
        T = np.where(stimuli == s)[0]
        pcor[T] = np.cumsum(rewards[T]) - rewards[T]
    return pcor

In [8]:
rl_wm_model = """
data {
    int<lower=1> N;                     // Number of trials
    int<lower=1> S;                     // Number of stimuli
    int<lower=1> A;                     // Number of actions (3)
    array[N] int<lower=1, upper=S> stimuli;
    array[N] int<lower=1, upper=A> actions;
    array[N] int<lower=0, upper=1> rewards;
    array[N] int<lower=1> set_sizes;    // Set size for each trial
    array[N] int<lower=1> iterations;    // Trial iteration number
}

parameters {
    real<lower=0, upper=1> alpha_bg;    // BG learning rate (α_BG in paper)
    real<lower=0, upper=100> beta_bg;   // BG inverse temperature (β_BG in paper)
    real<lower=0, upper=100> beta_wm;   // WM inverse temperature (β_WM in paper)
    real<lower=0, upper=1> w_0;         // Initial WM weight (w_0 in paper)
    real<lower=0, upper=1> epsilon;     // Lapse rate (ϵ in paper)
    real<lower=0, upper=6> C;           // WM capacity (C in paper)
    real<lower=0, upper=1> forget;      // WM decay (not in paper results)
    real<lower=0, upper=1> stick;       // Motor perseveration
}

model {
    // Priors matching reported values
    alpha_bg ~ normal(0.16, 0.03);
    beta_bg ~ normal(26.6, 3.3);
    beta_wm ~ normal(45.0, 4.5);
    w_0 ~ normal(0.81, 0.02);
    epsilon ~ normal(0.23, 0.02);
    C ~ normal(3.7, 0.14);
    forget ~ beta(2, 2);
    stick ~ beta(2, 2);

    // Initialize value matrices
    matrix[S, A] Q;
    matrix[S, A] WM;
    vector[A] prev_choice = rep_vector(0, A);

    // Set initial values
    for (s in 1:S) {
        for (a in 1:A) {
            Q[s, a] = 1.0 / A;
            WM[s, a] = 1.0 / A;
        }
    }

    // Trial by trial learning
    for (t in 1:N) {
        int s = stimuli[t];
        int a = actions[t];
        real w = w_0 * fmin(1.0, C / set_sizes[t]);
        vector[A] policy;

        // Only model choice if not first iteration
        if (iterations[t] > 1) {
            // Combine RL and WM policies with stick
            vector[A] Q_policy = softmax(beta_bg * (to_vector(Q[s]) + stick * prev_choice));
            vector[A] WM_policy = softmax(beta_wm * (to_vector(WM[s]) + stick * prev_choice));

            // Add lapse and combine policies
            policy = w * ((1 - epsilon) * WM_policy + epsilon / A) +
                    (1 - w) * ((1 - epsilon) * Q_policy + epsilon / A);

            // Choice likelihood
            actions[t] ~ categorical(policy);
        }

        // Update values after observing reward
        real r = rewards[t];

        // RL update
        Q[s, a] = Q[s, a] + alpha_bg * (r - Q[s, a]);

        // WM update with decay
        WM = WM + forget * (rep_matrix(1.0 / A, S, A) - WM);
        WM[s, a] = WM[s, a] + (r - WM[s, a]);

        // Update previous choice
        prev_choice = rep_vector(0, A);
        prev_choice[a] = 1;
    }
}
"""

In [9]:
def fit_rlwm_model(data, num_chains=4, num_samples=2000):
    """Fit RLWM model to data"""
    # Prepare data for Stan - using column indices matching MATLAB
    stan_data = {
        'N': len(data),
        'S': int(data.iloc[:, 4].max()),  # Convert to int for Stan
        'A': 3,  # Number of possible actions
        'stimuli': data.iloc[:, 4].values.astype(int),
        'actions': data.iloc[:, 9].values.astype(int),
        'rewards': data.iloc[:, 11].values.astype(int),
        'set_sizes': data.iloc[:, 2].values.astype(int),
        'iterations': data.iloc[:, 7].values.astype(int)
    }

    # Build and sample with initial values matching paper
    rl_wm_sm = stan.build(rl_wm_model, data=stan_data)
    fit = rl_wm_sm.sample(
        num_chains=num_chains,
        num_samples=num_samples,
        init=[{
            'alpha_bg': 0.16,
            'beta_bg': 26.6,
            'beta_wm': 45.0,
            'w_0': 0.81,
            'epsilon': 0.23,
            'C': 3.7,
            'forget': 0.3,
            'stick': 0.3
        } for _ in range(num_chains)]
    )

    return fit

In [10]:
def analyze_results(fit):
    """Extract and analyze model fitting results"""
    params = {
        'alpha_bg': fit['alpha_bg'].mean(),
        'beta_bg': fit['beta_bg'].mean(),
        'beta_wm': fit['beta_wm'].mean(),
        'w_0': fit['w_0'].mean(),
        'epsilon': fit['epsilon'].mean(),
        'C': fit['C'].mean(),
        'forget': fit['forget'].mean(),
        'stick': fit['stick'].mean()
    }

    return params

In [12]:
if __name__ == "__main__":
    data = pd.read_csv('expe_data.csv')
    df_clean = preprocess_data(data)

    # Fit model for each subject-block pair
    i=1
    results = []
    for (subject, block), group_data in df_clean.groupby(['subno', 'block']):
        print(f"Fitting subject {subject}, block {block}")

        fit = fit_rlwm_model(group_data)
        params = analyze_results(fit)

        results.append({
            'subject': subject,
            'block': block,
            **params
        })
        if (i==2):
            break
        i += 1

    results_df = pd.DataFrame(results)
    # results_df.to_csv('rlwm_results.csv', index=False)

# df = pd.read_csv('/content/rlwm_results.csv')


Fitting subject 1, block 1
Building...



Building: 40.1s, done.Messages from stanc:
    0.02 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    0.02 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    45.0 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    26.6 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    0.03 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    bound in its declaration. These hard constraints are not recommended, for
    two reasons: (a) Except when there are logical or physical constraints,
    it is very unusual for you to be sure that a parameter will fall inside a
    specified range, and (b) The infinite gr

Fitting subject 1, block 2
Building...



Building: found in cache, done.Messages from stanc:
    0.02 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    0.02 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    45.0 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    26.6 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    0.03 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
    bound in its declaration. These hard constraints are not recommended, for
    two reasons: (a) Except when there are logical or physical constraints,
    it is very unusual for you to be sure that a parameter will fall inside a
    specified range, and (b) The in

In [13]:
results_df.head()

Unnamed: 0,subject,block,alpha_bg,beta_bg,beta_wm,w_0,epsilon,C,forget,stick
0,1,1,0.161235,26.620413,44.944144,0.808295,0.23131,3.702278,0.217121,0.199772
1,1,2,0.163086,26.61891,45.198363,0.810843,0.222568,3.700077,0.201742,0.121771
