In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymc3 as pm
import theano.tensor as tt

In [4]:
# 定义模型参数
class ModelParams:
    def __init__(self, k, beta):
        self.k = k
        self.beta = beta

# 定义类别中心点
def get_centers(k, condition):
    """根据 k 值和 condition 确定类别的中心点"""
    if condition == 1:
        centers = {
            1: ([0.25, 0.5, 0.5, 0.5], [0.75, 0.5, 0.5, 0.5]),  # x1 = 0.5 分割
            2: ([0.5, 0.25, 0.5, 0.5], [0.5, 0.75, 0.5, 0.5]),  # x2 = 0.5 分割
            3: ([0.5, 0.5, 0.25, 0.5], [0.5, 0.5, 0.75, 0.5]),  # x3 = 0.5 分割
            4: ([0.5, 0.5, 0.5, 0.25], [0.5, 0.5, 0.5, 0.75])   # x4 = 0.5 分割
        }
        return centers[int(k)]
    else:
        centers = {
            # 两个超平面分割
            1: ([0.25, 0.25, 0.5, 0.5], [0.25, 0.75, 0.5, 0.5], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),  # x1=0.5, x2=0.5
            2: ([0.25, 0.5, 0.25, 0.5], [0.25, 0.5, 0.75, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),  # x1=0.5, x3=0.5
            3: ([0.25, 0.5, 0.5, 0.25], [0.25, 0.5, 0.5, 0.75], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),  # x1=0.5, x4=0.5
            4: ([0.5, 0.25, 0.25, 0.5], [0.5, 0.25, 0.75, 0.5], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),  # x2=0.5, x3=0.5
            5: ([0.5, 0.25, 0.5, 0.25], [0.5, 0.25, 0.5, 0.75], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),  # x2=0.5, x4=0.5
            6: ([0.5, 0.5, 0.25, 0.25], [0.5, 0.5, 0.25, 0.75], [0.5, 0.5, 0.75, 0.25], [0.5, 0.5, 0.75, 0.75]),  # x3=0.5, x4=0.5
            # 三个超平面分割 (24种)
            7: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),   # M:x1, N1:x2, N2:x3
            8: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),   # M:x1, N1:x2, N2:x4
            9: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x3, N2:x4
            10: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x2, N2:x4
            11: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.75, 0.5, 0.25]), # M:x1, N1:x3, N2:x2
            12: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x1, N1:x3, N2:x4
            13: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.25, 0.75, 0.5]), # M:x1, N1:x4, N2:x3
            14: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),   # M:x1, N1:x4, N2:x2
            15: ([0.25, 0.5, 0.25, 0.25], [0.25, 0.5, 0.25, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.75, 0.25, 0.5]), # M:x1, N1:x4, N2:x2
            16: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.75, 0.25]), # M:x1, N1:x2, N2:x3
            17: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            18: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x3, N2:x4
            19: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x4, N2:x3
            20: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.25, 0.75]),  # M:x2, N1:x3, N2:x4
            21: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.25]),  # M:x2, N1:x4, N2:x3
            22: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            23: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            24: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x3, N1:x1, N2:x2
            25: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x2
            26: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            27: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            28: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            29: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x4, N1:x2, N2:x3
            30: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x4, N1:x1, N2:x3
        }
        return centers[int(k)]

# 定义似然函数
def likelihood(k, beta, data, condition):
    centers = get_centers(k, condition)
    choices = data['choice'].values
    stimuli = data['stimulus'].values
    
    distances = np.abs(stimuli[:, np.newaxis] - centers)
    min_distances = np.min(distances, axis=1)
    
    probs = tt.exp(-beta * distances) / tt.sum(tt.exp(-beta * distances), axis=1)
    return probs[np.arange(len(choices)), choices]

# 定义先验分布
def prior(params, condition):
    K = len(get_centers(1, condition))
    k_prior = 1/K if 1 <= params.k <= K else 0
    beta_prior = np.exp(-params.beta) if params.beta > 0 else 0
    return k_prior * beta_prior

# 定义后验分布
def posterior(k, beta, data):
    condition = data['condition'].iloc[0]
    params = ModelParams(k=k, beta=beta)
    log_prior = np.log(prior(params, condition))
    log_likelihood = np.sum(np.log(likelihood(params, data, condition)))
    return log_prior + log_likelihood  # 返回对数后验概率

# 优化后验分布
def fit_model_mcmc(data, num_samples=2000, num_chains=4):
    condition = data['condition'].iloc[0]
    K = len(get_centers(1, condition))
    
    with pm.Model() as model:
        # 先验
        k = pm.DiscreteUniform('k', lower=1, upper=K)
        beta = pm.Uniform('beta', lower=0, upper=10)
        
        # 似然
        likelihood_vals = likelihood(k, beta, data, condition)
        pm.Potential('likelihood', tt.sum(tt.log(likelihood_vals)))
        
        # 采样
        trace = pm.sample(num_samples, chains=num_chains, return_inferencedata=False)
    
    return trace

# 拟合模型
def fit_model_for_steps(data):
    num_trials = len(data)
    step_results = []
    
    for step in range(1, num_trials + 1):
        trial_data = data.iloc[:step]
        trace = fit_model_mcmc(trial_data)
        
        k_mean = np.mean(trace['k'])
        beta_mean = np.mean(trace['beta'])
        
        step_results.append({
            'k': k_mean,
            'beta': beta_mean,
            'trace': trace
        })
    
    return step_results


In [6]:
# 读取CSV文件
data = pd.read_csv('Task2.csv')

# 按被试编号分组并拟合模型
results = {}
for iSub, subject_data in data.groupby('iSub'):
    try:
        step_results = fit_model_for_steps(subject_data)
        results[iSub] = step_results
    except Exception as e:
        print(f"Error fitting model for subject {iSub}: {str(e)}")
        continue

# 绘制参数变化图
def plot_params_over_trials(subject_id, step_results):
    num_steps = len(step_results)
    k_values = [result['k'] for result in step_results]
    beta_values = [result['beta'] for result in step_results]
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15))
    
    # Plot k values
    ax1.plot(range(1, num_steps + 1), k_values, marker='o')
    ax1.set_title(f'MCMC: Subject {subject_id} - k value over trials')
    ax1.set_xlabel('Number of trials')
    ax1.set_ylabel('k value')
    ax1.grid(True)
    
    # Plot beta values
    ax2.plot(range(1, num_steps + 1), beta_values, marker='o')
    ax2.set_title(f'MCMC: Subject {subject_id} - beta value over trials')
    ax2.set_xlabel('Number of trials')
    ax2.set_ylabel('beta value')
    ax2.grid(True)
    
    # Plot posterior distributions for the final step
    final_trace = step_results[-1]['trace']
    pm.plot_posterior(final_trace, var_names=['k', 'beta'], ax=ax3)
    ax3.set_title(f'MCMC: Subject {subject_id} - Posterior distributions (final step)')
    
    plt.tight_layout()
    plt.savefig(f'MCMC: subject_{subject_id}_params_over_trials.png')
    plt.close()

# 为每个被试绘制参数变化图
for iSub, step_results in results.items():
    plot_params_over_trials(iSub, step_results)

# 打印每个被试的最终拟合结果
for iSub, step_results in results.items():
    final_result = step_results[-1]
    print(f"Subject {iSub}:")
    print(f"  Final Fitted k (mean): {final_result['k']}")
    print(f"  Final Fitted beta (mean): {final_result['beta']}")
    print(f"  95% Credible Interval for k: {pm.stats.hpd(final_result['trace']['k'])}")
    print(f"  95% Credible Interval for beta: {pm.stats.hpd(final_result['trace']['beta'])}")
    print()

Error fitting model for subject 1: __trunc__ returned non-Integral (type TensorVariable)
Error fitting model for subject 4: __trunc__ returned non-Integral (type TensorVariable)
Error fitting model for subject 6: __trunc__ returned non-Integral (type TensorVariable)
Error fitting model for subject 11: __trunc__ returned non-Integral (type TensorVariable)
Error fitting model for subject 21: __trunc__ returned non-Integral (type TensorVariable)
Error fitting model for subject 26: __trunc__ returned non-Integral (type TensorVariable)
Error fitting model for subject 27: __trunc__ returned non-Integral (type TensorVariable)
