In [1]:
import pandas as pd
import numpy as np
from scipy.stats import gamma
from scipy.optimize import minimize
import matplotlib.pyplot as plt

In [7]:
# 定义模型参数
class ModelParams:
    def __init__(self, k, beta):
        self.k = k
        self.beta = beta

# 定义类别中心点
def get_centers(k, condition):
    """根据 k 值和 condition 确定类别的中心点"""
    if condition == 1:
        centers = {
            1: ([0.25, 0.5, 0.5, 0.5], [0.75, 0.5, 0.5, 0.5]),  # x1 = 0.5 分割
            2: ([0.5, 0.25, 0.5, 0.5], [0.5, 0.75, 0.5, 0.5]),  # x2 = 0.5 分割
            3: ([0.5, 0.5, 0.25, 0.5], [0.5, 0.5, 0.75, 0.5]),  # x3 = 0.5 分割
            4: ([0.5, 0.5, 0.5, 0.25], [0.5, 0.5, 0.5, 0.75])   # x4 = 0.5 分割
        }
        return centers[int(k)]
    else:
        centers = {
            # 两个超平面分割
            1: ([0.25, 0.25, 0.5, 0.5], [0.25, 0.75, 0.5, 0.5], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),  # x1=0.5, x2=0.5
            2: ([0.25, 0.5, 0.25, 0.5], [0.25, 0.5, 0.75, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),  # x1=0.5, x3=0.5
            3: ([0.25, 0.5, 0.5, 0.25], [0.25, 0.5, 0.5, 0.75], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),  # x1=0.5, x4=0.5
            4: ([0.5, 0.25, 0.25, 0.5], [0.5, 0.25, 0.75, 0.5], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),  # x2=0.5, x3=0.5
            5: ([0.5, 0.25, 0.5, 0.25], [0.5, 0.25, 0.5, 0.75], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),  # x2=0.5, x4=0.5
            6: ([0.5, 0.5, 0.25, 0.25], [0.5, 0.5, 0.25, 0.75], [0.5, 0.5, 0.75, 0.25], [0.5, 0.5, 0.75, 0.75]),  # x3=0.5, x4=0.5
            # 三个超平面分割 (24种)
            7: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),   # M:x1, N1:x2, N2:x3
            8: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),   # M:x1, N1:x2, N2:x4
            9: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x3, N2:x4
            10: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x2, N2:x4
            11: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.75, 0.5, 0.25]), # M:x1, N1:x3, N2:x2
            12: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x1, N1:x3, N2:x4
            13: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.25, 0.75, 0.5]), # M:x1, N1:x4, N2:x3
            14: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),   # M:x1, N1:x4, N2:x2
            15: ([0.25, 0.5, 0.25, 0.25], [0.25, 0.5, 0.25, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.75, 0.25, 0.5]), # M:x1, N1:x4, N2:x2
            16: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.75, 0.25]), # M:x1, N1:x2, N2:x3
            17: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            18: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x3, N2:x4
            19: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x4, N2:x3
            20: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.25, 0.75]),  # M:x2, N1:x3, N2:x4
            21: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.25]),  # M:x2, N1:x4, N2:x3
            22: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            23: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            24: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x3, N1:x1, N2:x2
            25: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x2
            26: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            27: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            28: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            29: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x4, N1:x2, N2:x3
            30: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x4, N1:x1, N2:x3
        }
        return centers[int(k)]

# 定义似然函数
def likelihood(params, data, condition):
    k, beta = params.k, params.beta
    
    x = data[['feature1', 'feature2', 'feature3', 'feature4']].values
    c = data['choice'].values
    r = data['feedback'].values

    # 计算到类别中心的距离
    centers = get_centers(k, condition)
    distances = np.array([np.linalg.norm(x - np.array(center), axis=1) for center in centers])

    # 计算选择概率
    probs = np.exp(-beta * distances)
    probs /= np.sum(probs, axis=0, keepdims=True)
    p_c = probs[c - 1, np.arange(len(c))]
    
    return np.where(r == 1, p_c, 1 - p_c)

# 定义先验分布
def prior(params, condition):
    K = len(get_centers(1, condition))
    k_prior = 1/K if 1 <= params.k <= K else 0
    beta_prior = np.exp(-params.beta) if params.beta > 0 else 0
    return k_prior * beta_prior

# 定义后验分布
def posterior(params, data):
    condition = data['condition'].iloc[0]
    log_prior = np.log(prior(params, condition))
    log_likelihood = np.sum(np.log(likelihood(params, data, condition)))
    return -(log_prior + log_likelihood)  # 负对数后验概率

# 优化后验分布
def fit_model(data):
    iterations = []
    objective_values = []
    
    def objective(params):
        k, beta = params
        obj_value = posterior(ModelParams(k=int(k), beta=beta), data)
        
        # 引入随机噪声，以鼓励探索不同的k值
        obj_value += np.random.normal(0, 0.01)  
        
        iterations.append((k, beta))
        objective_values.append(obj_value)
        return obj_value
    
    condition = data['condition'].iloc[0]  # 获取 condition
    K = len(get_centers(1, condition))     # 计算 K
    
    initial_k = np.random.randint(1, K + 1)  # 使用随机初始k值
    initial_params = [initial_k, 1.0]  # beta 初始值为1.0
    bounds = [(1, K), (0, 10)]  # k 的范围是 [1, K]，beta 的范围是 [0, 10]
    
    result = minimize(
        objective,
        initial_params,
        method='L-BFGS-B',
        bounds=bounds
    )
    
    return ModelParams(k=int(round(result.x[0])), beta=result.x[1]), iterations, objective_values

# 拟合模型
def fit_model_for_steps(data):
    num_trials = len(data)
    step_results = []
    
    # 逐步增加试次进行拟合
    for step in range(1, num_trials + 1):
        trial_data = data.iloc[:step]  # 只用前 step 个试次的数据
        fitted_params, iterations, objective_values = fit_model(trial_data)
        step_results.append({
            'k': fitted_params.k,
            'beta': fitted_params.beta,
            'iterations': iterations,
            'objective_values': objective_values
        })
    return step_results


In [21]:
# 读取CSV文件
data = pd.read_csv('Task2.csv')

# 按被试编号分组并拟合模型
results = {}
for iSub, subject_data in data.groupby('iSub'):
    try:
        step_results = fit_model_for_steps(subject_data)
        results[iSub] = step_results
    except Exception as e:
        print(f"Error fitting model for subject {iSub}: {str(e)}")
        continue

# 绘制参数变化图
def plot_params_over_trials(subject_id, step_results):
    num_steps = len(step_results)
    k_values = [result['k'] for result in step_results]
    beta_values = [result['beta'] for result in step_results]
    objective_values = [result['objective_values'][-1] for result in step_results]  # 每一步的最终目标函数值
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15))
    
    # Plot k values
    ax1.plot(range(1, num_steps + 1), k_values, marker='o')
    ax1.set_title(f'Subject {subject_id} - k value over trials')
    ax1.set_xlabel('Number of trials')
    ax1.set_ylabel('k value')
    ax1.grid(True)
    
    # Plot beta values
    ax2.plot(range(1, num_steps + 1), beta_values, marker='o')
    ax2.set_title(f'Subject {subject_id} - beta value over trials')
    ax2.set_xlabel('Number of trials')
    ax2.set_ylabel('beta value')
    ax2.grid(True)
    
    # Plot objective function values
    ax3.plot(range(1, num_steps + 1), objective_values, marker='o')
    ax3.set_title(f'Subject {subject_id} - Objective function over trials')
    ax3.set_xlabel('Number of trials')
    ax3.set_ylabel('Negative log posterior')
    ax3.set_yscale('log')  # 使用对数刻度
    ax3.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'subject_{subject_id}_params_over_trials.png')
    plt.close()

# 为每个被试绘制参数变化图
for iSub, step_results in results.items():
    plot_params_over_trials(iSub, step_results)

# 打印每个被试的最终拟合结果
for iSub, step_results in results.items():
    final_result = step_results[-1]  # 取最后一次拟合的结果
    print(f"Subject {iSub}:")
    print(f"  Final Fitted k: {final_result['k']}")
    print(f"  Final Fitted beta: {final_result['beta']}")
    print(f"  Number of iterations: {len(final_result['iterations'])}")
    print(f"  Final objective value: {final_result['objective_values'][-1]}")
    print()

  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  df = fun(x) - f0
  log_prior = np.log(prior(params, condition))
  log_prior = np.log(prior(params, condition))
  

Subject 1:
  Final Fitted k: 2
  Final Fitted beta: 1.0
  Number of iterations: 9
  Final objective value: 91.19589834343725

Subject 4:
  Final Fitted k: 2
  Final Fitted beta: 1.0
  Number of iterations: 9
  Final objective value: 91.20574528852654

Subject 6:
  Final Fitted k: 2
  Final Fitted beta: 1.0
  Number of iterations: 9
  Final objective value: 215.64882909586578

Subject 11:
  Final Fitted k: 1
  Final Fitted beta: 1.0
  Number of iterations: 15
  Final objective value: 210.48708341805465

Subject 21:
  Final Fitted k: 4
  Final Fitted beta: 1.000016552444491
  Number of iterations: 72
  Final objective value: 725.8714877903125

Subject 26:
  Final Fitted k: 3
  Final Fitted beta: 1.0050276670791407
  Number of iterations: 57
  Final objective value: 272.4952536768067

Subject 27:
  Final Fitted k: 1
  Final Fitted beta: 1.0000070435379882
  Number of iterations: 96
  Final objective value: 578.8245431960031

