In [1]:
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import scipy
import arviz as az
import pymc.backends as pmbackends
import sys
import matplotlib.pyplot as plt

sys.setrecursionlimit(1000000) 

In [2]:
#这里需要传入data
u_t = np.loadtxt("demo_files/u_t.txt", dtype=int)
resp = np.loadtxt("demo_files/resp.txt", dtype=int)

In [4]:
with pm.Model() as model: 
    #prior 需要被采样的参数
    omega2 = pm.Normal('omega2', mu=-3, sigma=16, initval=-3.7)
    omega3 = pm.Normal('omega3',mu=-6,sigma=16, initval=-5)
    kappa = pm.HalfNormal('kappa', sigma=16, initval=0.2)
    #get the right shapes
    
    #initialize
    mu_hats = pt.zeros([3,])s
    pi = pt.zeros([3,])
   
    #update mu_hats
    def update(mem,#每个trial会变化的参数，mu_hat(3),pi(3)
               ur,#输入数据
               ):
        u_t= pt.as_tensor_variable(u_t,name='u_t',ndim=1)
        #unpack
        mu_hats, pi = mem
        u_t,omega2,omega3,kappa = ur
        mu2_t2 = mu_hats[1]
        mu3_t2 = mu_hats[2]       
        pi2_t2 = pi[1]
        pi3_t2 = pi[2]
        u_t = ur[0]
        
        ### 这里要把omega2,omega3和kappa也传进来
        omega2 = ur[1]
        omega3 = ur[2]
        kappa = ur[3]
        
        #forward
        mu_hat2_t1 = mu2_t2
        mu_hat1_t1 = 1 / ( 1 + pt.exp(-mu_hat2_t1))

        #update  更新参数值     
        pi_hat1_t1 = 1 / (mu_hat1_t1 * (1 - mu_hat1_t1))    
        
        mu1_t1_scalar = u_t
        da1_t1 = mu1_t1_scalar - mu_hat1_t1
        
        v2_t1 = pt.exp(kappa * mu3_t2 + omega2)
        pi_hat2_t1 = 1/((1/pi2_t2) + v2_t1 )
        # updates
        pi2_t1 = pi_hat2_t1 + 1 / pi_hat1_t1
        mu2_t1 = mu_hat2_t1 + (1 / pi2_t1) * da1_t1
        # Volatility prediction error
        da2_t1 = (1 / pi2_t1 + ((mu2_t1 - mu_hat2_t1) ** 2)) * pi_hat2_t1 - 1

        mu_hat3_t1 = mu3_t2
        pi_hat3_t1 = 1 / ((1/pi3_t2) + pt.exp(omega3))          
        w2_t1 = v2_t1 * pi_hat2_t1

        pi3_t1 = pi_hat3_t1 + 0.5 * (kappa ** 2) * w2_t1 * ((w2_t1 + (2 * w2_t1 - 1) * da2_t1))
        mu3_t1 = mu_hat3_t1 + 0.5 * (1 / pi3_t1 ) * kappa * w2_t1 * da2_t1
        
        mu_hat = pt.set_subtensor(mu_hat, mu_hat1_t1)
        mu_hats = pt.set_subtensor(mu_hats[0],mu_hat1_t1)
        mu_hats = pt.set_subtensor(mu_hats[1],mu2_t1)
        mu_hats = pt.set_subtensor(mu_hats[2],mu3_t1)
        pi = pt.set_subtensor(pi[1],pi2_t1)
        pi = pt.set_subtensor(pi[2],pi3_t1)   
       
        return (mu_hats,pi), mu_hat1_t1
    
    #input初始化参数列表
    mu_hats = pt.set_subtensor(mu_hats[0],0)
    mu_hats = pt.set_subtensor(mu_hats[1],0)
    mu_hats = pt.set_subtensor(mu_hats[2],1)
    pi = pt.set_subtensor(pi[0],0)
    pi = pt.set_subtensor(pi[1],10)
    pi = pt.set_subtensor(pi[2],1)  
    mem0 = (mu_hats,pi)
    
    #get the results for the whole task(update mu_hats in each trial)
    (mu_hats,pi),update= pytensor.scan(fn=update,
                         outputs_info = [mem0],
                         sequences = [u_t],
                         non_sequences=[omega2,omega3,kappa],
                         n_steps = 320)


    action_prob = pm.Bernoulli('action_prob',p=mu_hats,observed=resp)

In [None]:
with model:
    try:
        trace = pm.sample(draws=500,tune=500,progressbar=True,chains=4)
    except pm.SamplingError as error:
        print(error)
        model.debug()
       

In [None]:
with model:
    pm.traceplot(trace)
    plt.savefig('trace_plt.png')
    model_summary = pm.summary(trace)