In [None]:
#  This program run nback model
import numpy as np
import stan
import pickle
from time import time
import matplotlib.pyplot as plt

In [None]:
choice = np.loadtxt("demo_files/action.txt", dtype=int)
rewards = np.loadtxt("demo_files/rewards.txt", dtype=int)
T = 320
init_Ps = np.full(2, 0.5)
# pack the data
data_dict = {
    'T': T,
    'choice' :list(choice),#这里应该把u_t/resp改为reward/action,其中action对应的就是category 的联结，reward即±1
    'reward':list(rewards),
    'init_Ps':init_Ps,
}


In [None]:
hmm = '''
data {
    real mini; // to avoid p = 0 or 1-p=0,
    int<lower=0> T; // 试次数量
    array[T] int<lower=-1, upper=1> reward;
    array[T] int<lower=1, upper=2> choice;
    vector[2] init_Ps; // initial prob/belief of the two states
}

parameters {
    real<lower=0, upper=1> gamma_pr;
    real<lower=0.5, upper=1> c_pr;
    real<lower=0.5, upper=1> d_pr;
}

transformed parameters {
    real gamma;
    real c;
    real d;
    
    gamma = Phi_approx(gamma_pr);      // 将先验参数进行转换和归一化，以确保其满足所需的范围和约束条件。将输入值通过逆标准正态累积分布函数进行转换。范围（0-1）
    c = Phi_approx(c_pr) * 0.5 + 0.5;  // 范围(0.5-1)
    d = Phi_approx(d_pr) * 0.5 + 0.5;  // 范围(0.5-1)
}

model {
    // prior of parameters
    gamma_pr ~ normal(0, 1);
    c_pr ~ normal(0.5, 1);
    d_pr ~ normal(0.5, 1);

    vector[2] Ps;       // prob of the states, 1 - yellow-cat(blue-dog), 2 - yellow-dog(blue-cat)
    real P_O_S1;        // p(O|S1) - O = {A,R} given yellow-cat
    real P_O_S2;        // p(O|S2) - O = {A,R} given yellow-dog
    
    // *******************compute ******************************
    for (t in 1:T) {
        // 使用 transition matrix 进行State update
        // S[t-1] -- S[t]，obs之前
        if (t == 1) {
            Ps = init_Ps;
        } else {
            Ps[1] = Ps[1] * (1 - gamma) + Ps[2] * gamma;
            Ps[2] = 1 - Ps[1];
        }
        
        choice[t] ~ categorical(Ps);
        
        // renew of emission prob: p(O|S1) p(O|S2); O 为 A(ction) 和 R(eward)的组合
        // --> the probability of actually observing this outcome
        if (reward[t] == 1) {
            P_O_S1 = 0.5 * ((choice[t] == 1) ? c : (1 - c));
            P_O_S2 = 0.5 * ((choice[t] == 2) ? c : (1 - c));
        } else if (reward[t] == -1) {
            P_O_S1 = 0.5 * ((choice[t] == 1) ? (1 - d) : d);
            P_O_S2 = 0.5 * ((choice[t] == 2) ? (1 - d) : d);
        }

        // State belief update
        real prior1 = Ps[1];
        Ps[1] = (P_O_S1 * prior1) / (P_O_S1 * prior1 + P_O_S2 * (1 - prior1));
        Ps[2] = 1 - Ps[1];
    }
    }
'''

In [None]:
sm = stan.build(hmm, data=data_dict)  # specify the model file

In [None]:

nChain = 4  # How Many Chains, depend on the # of cores of your computer. 4 seems good
nWarmup = 10000  # How Many Burn-in Samples?
nIter = 10000  # How Many Recorded Samples, including nWarmup samples
save = True # whether save the fitting result or not
watch_list = ['gamma', 'c', 'd']

In [None]:
t = time()

fit = sm.sample(num_samples=nIter, num_warmup=nWarmup, num_chains=nChain)

t = time()-t
print(f'{t} sec')

if save:
   with open('test1.pkl', 'wb') as f:
        pickle.dump(fit, f)

In [None]:
with open('test1.pkl', 'rb') as f:
    fit = pickle.load(f)

# Extract the samples for the desired parameters
gamma_samples = fit['gamma']
c_samples = fit['c']
d_samples = fit['d']

# Plot and save the parameter samples
fig, ax = plt.subplots(3, 1, figsize=(8, 12))
ax[0].plot(gamma_samples)
ax[0].set_ylabel('gamma')
ax[1].plot(c_samples)
ax[1].set_ylabel('c')
ax[2].plot(d_samples)
ax[2].set_ylabel('d')

# Adjust the layout to avoid overlapping labels
fig.tight_layout()

# Save the plot as an image file
plt.savefig('parameter_samples.png')