In [1]:
import pandas as pd
import numpy as np
from scipy.stats import gamma
from scipy.optimize import minimize
import matplotlib.pyplot as plt

In [2]:
# 定义模型参数
class ModelParams:
    def __init__(self, k, beta):
        self.k = k
        self.beta = beta

# 定义类别中心点
def get_centers(k, condition):
    """根据 k 值和 condition 确定类别的中心点"""
    if condition == 1:
        centers = {
            1: ([0.25, 0.5, 0.5, 0.5], [0.75, 0.5, 0.5, 0.5]),  # x1 = 0.5 分割
            2: ([0.5, 0.25, 0.5, 0.5], [0.5, 0.75, 0.5, 0.5]),  # x2 = 0.5 分割
            3: ([0.5, 0.5, 0.25, 0.5], [0.5, 0.5, 0.75, 0.5]),  # x3 = 0.5 分割
            4: ([0.5, 0.5, 0.5, 0.25], [0.5, 0.5, 0.5, 0.75])   # x4 = 0.5 分割
        }
        return centers[int(k)]
    else:
        centers = {
            # 两个超平面分割
            1: ([0.25, 0.25, 0.5, 0.5], [0.25, 0.75, 0.5, 0.5], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),  # x1=0.5, x2=0.5
            2: ([0.25, 0.5, 0.25, 0.5], [0.25, 0.5, 0.75, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),  # x1=0.5, x3=0.5
            3: ([0.25, 0.5, 0.5, 0.25], [0.25, 0.5, 0.5, 0.75], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),  # x1=0.5, x4=0.5
            4: ([0.5, 0.25, 0.25, 0.5], [0.5, 0.25, 0.75, 0.5], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),  # x2=0.5, x3=0.5
            5: ([0.5, 0.25, 0.5, 0.25], [0.5, 0.25, 0.5, 0.75], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),  # x2=0.5, x4=0.5
            6: ([0.5, 0.5, 0.25, 0.25], [0.5, 0.5, 0.25, 0.75], [0.5, 0.5, 0.75, 0.25], [0.5, 0.5, 0.75, 0.75]),  # x3=0.5, x4=0.5
            # 三个超平面分割 (24种)
            7: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),   # M:x1, N1:x2, N2:x3
            8: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),   # M:x1, N1:x2, N2:x4
            9: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x3, N2:x4
            10: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x2, N2:x4
            11: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.75, 0.5, 0.25]), # M:x1, N1:x3, N2:x2
            12: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x1, N1:x3, N2:x4
            13: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.25, 0.75, 0.5]), # M:x1, N1:x4, N2:x3
            14: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),   # M:x1, N1:x4, N2:x2
            15: ([0.25, 0.5, 0.25, 0.25], [0.25, 0.5, 0.25, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.75, 0.25, 0.5]), # M:x1, N1:x4, N2:x2
            16: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.75, 0.25]), # M:x1, N1:x2, N2:x3
            17: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            18: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x3, N2:x4
            19: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x4, N2:x3
            20: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.25, 0.75]),  # M:x2, N1:x3, N2:x4
            21: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.25]),  # M:x2, N1:x4, N2:x3
            22: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            23: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            24: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x3, N1:x1, N2:x2
            25: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x2
            26: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            27: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            28: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            29: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x4, N1:x2, N2:x3
            30: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x4, N1:x1, N2:x3
        }
        return centers[int(k)]

# 定义似然函数
def likelihood(params, data, condition):
    k, beta = params.k, params.beta
    
    x = data[['feature1', 'feature2', 'feature3', 'feature4']].values
    c = data['choice'].values
    r = data['feedback'].values

    # 计算到类别中心的距离
    centers = get_centers(k, condition)
    distances = np.array([np.linalg.norm(x - np.array(center), axis=1) for center in centers])

    # 计算选择概率
    probs = np.exp(-beta * distances)
    probs /= np.sum(probs, axis=0, keepdims=True)
    p_c = probs[c - 1, np.arange(len(c))]
    
    return np.where(r == 1, p_c, 1 - p_c)

# 定义先验分布
def prior(params, condition):
    K = len(get_centers(1, condition))
    k_prior = 1/K if 1 <= params.k <= K else 0
    beta_prior = np.exp(-params.beta) if params.beta > 0 else 0
    return k_prior * beta_prior

# 定义后验分布
def posterior(k, data):
    condition = data['condition'].iloc[0]
    params = ModelParams(k=k, beta=1)  # beta 固定为 1
    log_prior = np.log(prior(params, condition))
    log_likelihood = np.sum(np.log(likelihood(params, data, condition)))
    return log_prior + log_likelihood  # 返回对数后验概率

# 优化后验分布
def fit_model(data):
    condition = data['condition'].iloc[0]
    K = len(get_centers(1, condition))
    
    # 对所有可能的 k 值计算后验概率
    posteriors = [posterior(k, data) for k in range(1, K+1)]
    
    # 找到最大后验概率对应的 k 值
    best_k = np.argmax(posteriors) + 1
    
    return ModelParams(k=best_k, beta=1), posteriors

# 拟合模型
def fit_model_for_steps(data):
    num_trials = len(data)
    step_results = []
    
    for step in range(1, num_trials + 1):
        trial_data = data.iloc[:step]
        fitted_params, posteriors = fit_model(trial_data)
        step_results.append({
            'k': fitted_params.k,
            'beta': fitted_params.beta,
            'posteriors': posteriors
        })
    return step_results


In [3]:
# 读取CSV文件
data = pd.read_csv('Task2.csv')

# 按被试编号分组并拟合模型
results = {}
for iSub, subject_data in data.groupby('iSub'):
    try:
        step_results = fit_model_for_steps(subject_data)
        results[iSub] = step_results
    except Exception as e:
        print(f"Error fitting model for subject {iSub}: {str(e)}")
        continue

# 绘制参数变化图
def plot_params_over_trials(subject_id, step_results):
    num_steps = len(step_results)
    k_values = [result['k'] for result in step_results]
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
    
    # Plot k values
    ax1.plot(range(1, num_steps + 1), k_values, marker='o')
    ax1.set_title(f'Subject {subject_id} - k value over trials')
    ax1.set_xlabel('Number of trials')
    ax1.set_ylabel('k value')
    ax1.grid(True)
    
    # Plot posterior probabilities for the final step
    final_posteriors = step_results[-1]['posteriors']
    ax2.bar(range(1, len(final_posteriors) + 1), np.exp(final_posteriors))
    ax2.set_title(f'Subject {subject_id} - Posterior probabilities (final step)')
    ax2.set_xlabel('k value')
    ax2.set_ylabel('Posterior probability')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'subject_{subject_id}_params_over_trials.png')
    plt.close()

# 为每个被试绘制参数变化图
for iSub, step_results in results.items():
    plot_params_over_trials(iSub, step_results)

# 打印每个被试的最终拟合结果
for iSub, step_results in results.items():
    final_result = step_results[-1]
    print(f"Subject {iSub}:")
    print(f"  Final Fitted k: {final_result['k']}")
    print(f"  Beta (fixed): {final_result['beta']}")
    print(f"  Final log posterior: {max(final_result['posteriors'])}")
    print()

Subject 1:
  Final Fitted k: 1
  Beta (fixed): 1
  Final log posterior: -77.08097428971425

Subject 4:
  Final Fitted k: 1
  Beta (fixed): 1
  Final log posterior: -77.95848884404784

Subject 6:
  Final Fitted k: 1
  Beta (fixed): 1
  Final log posterior: -215.11356605694493

Subject 11:
  Final Fitted k: 2
  Beta (fixed): 1
  Final log posterior: -207.47136419793424

Subject 21:
  Final Fitted k: 2
  Beta (fixed): 1
  Final log posterior: -691.0371353160598

Subject 26:
  Final Fitted k: 1
  Beta (fixed): 1
  Final log posterior: -270.4022088033614

Subject 27:
  Final Fitted k: 1
  Beta (fixed): 1
  Final log posterior: -578.812307912142

