In [39]:
import pandas as pd
import numpy as np
from scipy.stats import gamma
from scipy.optimize import minimize
import matplotlib.pyplot as plt

In [41]:
# 定义模型参数
class ModelParams:
    def __init__(self, k, beta):
        self.k = k
        self.beta = beta

# 定义类别中心点
def get_centers(k, condition):
    """根据 k 值和 condition 确定类别的中心点"""
    if condition == 1:
        centers = {
            1: ([0.25, 0.5, 0.5, 0.5], [0.75, 0.5, 0.5, 0.5]),  # x1 = 0.5 分割
            2: ([0.5, 0.25, 0.5, 0.5], [0.5, 0.75, 0.5, 0.5]),  # x2 = 0.5 分割
            3: ([0.5, 0.5, 0.25, 0.5], [0.5, 0.5, 0.75, 0.5]),  # x3 = 0.5 分割
            4: ([0.5, 0.5, 0.5, 0.25], [0.5, 0.5, 0.5, 0.75])   # x4 = 0.5 分割
        }
        return centers[int(k)]
    else:
        centers = {
            # 两个超平面分割
            1: ([0.25, 0.25, 0.5, 0.5], [0.25, 0.75, 0.5, 0.5], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),  # x1=0.5, x2=0.5
            2: ([0.25, 0.5, 0.25, 0.5], [0.25, 0.5, 0.75, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),  # x1=0.5, x3=0.5
            3: ([0.25, 0.5, 0.5, 0.25], [0.25, 0.5, 0.5, 0.75], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),  # x1=0.5, x4=0.5
            4: ([0.5, 0.25, 0.25, 0.5], [0.5, 0.25, 0.75, 0.5], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),  # x2=0.5, x3=0.5
            5: ([0.5, 0.25, 0.5, 0.25], [0.5, 0.25, 0.5, 0.75], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),  # x2=0.5, x4=0.5
            6: ([0.5, 0.5, 0.25, 0.25], [0.5, 0.5, 0.25, 0.75], [0.5, 0.5, 0.75, 0.25], [0.5, 0.5, 0.75, 0.75]),  # x3=0.5, x4=0.5
            # 三个超平面分割 (24种)
            7: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),   # M:x1, N1:x2, N2:x3
            8: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),   # M:x1, N1:x2, N2:x4
            9: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x3, N2:x4
            10: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x2, N2:x4
            11: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.75, 0.5, 0.25]), # M:x1, N1:x3, N2:x2
            12: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x1, N1:x3, N2:x4
            13: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.25, 0.75, 0.5]), # M:x1, N1:x4, N2:x3
            14: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),   # M:x1, N1:x4, N2:x2
            15: ([0.25, 0.5, 0.25, 0.25], [0.25, 0.5, 0.25, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.75, 0.25, 0.5]), # M:x1, N1:x4, N2:x2
            16: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.75, 0.25]), # M:x1, N1:x2, N2:x3
            17: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            18: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x3, N2:x4
            19: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x4, N2:x3
            20: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.25, 0.75]),  # M:x2, N1:x3, N2:x4
            21: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.25]),  # M:x2, N1:x4, N2:x3
            22: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            23: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            24: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x3, N1:x1, N2:x2
            25: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x2
            26: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            27: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            28: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            29: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x4, N1:x2, N2:x3
            30: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x4, N1:x1, N2:x3
        }
        return centers[int(k)]

# 定义似然函数
def likelihood(params, data, condition):
    k, beta = params.k, params.beta
    
    x = data[['feature1', 'feature2', 'feature3', 'feature4']].values
    c = data['choice'].values
    r = data['feedback'].values

    # 计算到类别中心的距离
    centers = get_centers(k, condition)
    distances = np.array([np.linalg.norm(x - np.array(center), axis=1) for center in centers])

    # 计算选择概率
    probs = np.exp(-beta * distances)
    probs /= np.sum(probs, axis=0, keepdims=True)
    p_c = probs[c - 1, np.arange(len(c))]
    
    return np.where(r == 1, p_c, 1 - p_c)

# 定义先验分布
def prior(params, condition):
    K = len(get_centers(1, condition))
    k_prior = 1/K if 1 <= params.k <= K else 0
    beta_prior = np.exp(-params.beta) if params.beta > 0 else 0
    return k_prior * beta_prior

# 定义后验分布
def posterior(params, data):
    condition = data['condition'].iloc[0]
    log_prior = np.log(prior(params, condition))
    log_likelihood = np.sum(np.log(likelihood(params, data, condition)))
    return -(log_prior + log_likelihood)  # 负对数后验概率

# 优化后验分布
def fit_model(data):
    iterations = []
    objective_values = []
    
    def objective(params):
        k, beta = params
        obj_value = posterior(ModelParams(k=int(k), beta=beta), data)
        iterations.append((k, beta))
        objective_values.append(obj_value)
        return obj_value
    
    initial_params = [1, 1.0]  # k 初始值为 2，beta 初始值为 1.0
    condition = data['condition'].iloc[0]
    K = len(get_centers(1, condition))
    bounds = [(1, K), (0, 10)]  # k 的范围是 [1, 4]，beta 的范围是 [0, 10]
    
    result = minimize(
        objective,
        initial_params,
        method='L-BFGS-B',
        bounds=bounds
    )
    
    return ModelParams(k=int(round(result.x[0])), beta=result.x[1]), iterations, objective_values

In [42]:
# 读取CSV文件
data = pd.read_csv('Task2.csv')

# 按被试编号分组并拟合模型
results = {}
for iSub, subject_data in data.groupby('iSub'):
    try:
        fitted_params, iterations, objective_values = fit_model(subject_data)
        results[iSub] = {
            'k': fitted_params.k,
            'beta': fitted_params.beta,
            'condition': subject_data['condition'].iloc[0],
            'iterations': iterations,
            'objective_values': objective_values
        }
    except Exception as e:
        print(f"Error fitting model for subject {iSub}: {str(e)}")
        continue

# 绘制拟合过程图
def plot_fitting_process(subject_id, params):
    iterations = params['iterations']
    k_values, beta_values = zip(*iterations)
    objective_values = params['objective_values']
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15))
    
    # Plot k values
    ax1.plot(range(len(k_values)), k_values, marker='o')
    ax1.set_title(f'Subject {subject_id} - k value convergence')
    ax1.set_xlabel('Iteration')
    ax1.set_ylabel('k value')
    ax1.grid(True)
    
    # Plot beta values
    ax2.plot(range(len(beta_values)), beta_values, marker='o')
    ax2.set_title(f'Subject {subject_id} - beta value convergence')
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('beta value')
    ax2.grid(True)
    
    # Plot objective function values
    ax3.plot(range(len(objective_values)), objective_values, marker='o')
    ax3.set_title(f'Subject {subject_id} - Objective function convergence')
    ax3.set_xlabel('Iteration')
    ax3.set_ylabel('Negative log posterior')
    ax3.set_yscale('log')  # 使用对数刻度以更好地显示变化
    ax3.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'subject_{subject_id}_fitting_process.png')
    plt.close()

# 为每个被试绘制拟合过程图
for iSub, params in results.items():
    plot_fitting_process(iSub, params)

# 打印结果
for iSub, params in results.items():
    print(f"Subject {iSub}:")
    print(f"  Condition: {params['condition']}")
    print(f"  Fitted k: {params['k']}")
    print(f"  Fitted beta: {params['beta']}")
    print(f"  Number of iterations: {len(params['iterations'])}")
    print(f"  Final objective value: {params['objective_values'][-1]}")
    print()

Subject 1:
  Condition: 1
  Fitted k: 1
  Fitted beta: 10.0
  Number of iterations: 6
  Final objective value: 31.50827173668597

Subject 4:
  Condition: 1
  Fitted k: 1
  Fitted beta: 10.0
  Number of iterations: 6
  Final objective value: 33.07012007317838

Subject 6:
  Condition: 3
  Fitted k: 1
  Fitted beta: 6.899509921430229
  Number of iterations: 24
  Final objective value: 158.10064295160242

Subject 11:
  Condition: 2
  Fitted k: 1
  Fitted beta: 6.474644375878146
  Number of iterations: 24
  Final objective value: 163.58496714298488

Subject 21:
  Condition: 3
  Fitted k: 1
  Fitted beta: 7.020761259903179
  Number of iterations: 24
  Final objective value: 499.0634306583489

Subject 26:
  Condition: 2
  Fitted k: 1
  Fitted beta: 8.201095490888317
  Number of iterations: 21
  Final objective value: 177.4258085417328

Subject 27:
  Condition: 3
  Fitted k: 1
  Fitted beta: 6.95372637188597
  Number of iterations: 24
  Final objective value: 424.4214072711751

