In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [3]:
# 定义信念空间参数
class ModelParams:
    def __init__(self, k, beta):
        self.k = k
        self.beta = beta

# 定义类别中心点
def get_centers(k, condition):
    """根据 k 值和 condition 确定类别的中心点"""
    if condition == 1:
        centers = {
            1: ([0.25, 0.5, 0.5, 0.5], [0.75, 0.5, 0.5, 0.5]),  # x1 = 0.5 分割
            2: ([0.5, 0.25, 0.5, 0.5], [0.5, 0.75, 0.5, 0.5]),  # x2 = 0.5 分割
            3: ([0.5, 0.5, 0.25, 0.5], [0.5, 0.5, 0.75, 0.5]),  # x3 = 0.5 分割
            4: ([0.5, 0.5, 0.5, 0.25], [0.5, 0.5, 0.5, 0.75])   # x4 = 0.5 分割
        }
        return centers[int(k)]
    else:
        centers = {
            # 两个超平面分割
            1: ([0.25, 0.25, 0.5, 0.5], [0.25, 0.75, 0.5, 0.5], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),  # x1=0.5, x2=0.5
            2: ([0.25, 0.5, 0.25, 0.5], [0.25, 0.5, 0.75, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),  # x1=0.5, x3=0.5
            3: ([0.25, 0.5, 0.5, 0.25], [0.25, 0.5, 0.5, 0.75], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),  # x1=0.5, x4=0.5
            4: ([0.5, 0.25, 0.25, 0.5], [0.5, 0.25, 0.75, 0.5], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),  # x2=0.5, x3=0.5
            5: ([0.5, 0.25, 0.5, 0.25], [0.5, 0.25, 0.5, 0.75], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),  # x2=0.5, x4=0.5
            6: ([0.5, 0.5, 0.25, 0.25], [0.5, 0.5, 0.25, 0.75], [0.5, 0.5, 0.75, 0.25], [0.5, 0.5, 0.75, 0.75]),  # x3=0.5, x4=0.5
            # 三个超平面分割 (24种)
            7: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.25, 0.5], [0.75, 0.5, 0.75, 0.5]),   # M:x1, N1:x2, N2:x3
            8: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.5, 0.5, 0.25], [0.75, 0.5, 0.5, 0.75]),   # M:x1, N1:x2, N2:x4
            9: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x3, N2:x4
            10: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.75, 0.25, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.75]), # M:x1, N1:x2, N2:x4
            11: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.25, 0.5, 0.25], [0.75, 0.75, 0.5, 0.25]), # M:x1, N1:x3, N2:x2
            12: ([0.25, 0.25, 0.25, 0.5], [0.25, 0.25, 0.75, 0.5], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x1, N1:x3, N2:x4
            13: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.25, 0.75, 0.5]), # M:x1, N1:x4, N2:x3
            14: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.25, 0.5, 0.75], [0.75, 0.25, 0.5, 0.5], [0.75, 0.75, 0.5, 0.5]),   # M:x1, N1:x4, N2:x2
            15: ([0.25, 0.5, 0.25, 0.25], [0.25, 0.5, 0.25, 0.75], [0.75, 0.25, 0.25, 0.5], [0.75, 0.75, 0.25, 0.5]), # M:x1, N1:x4, N2:x2
            16: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.75, 0.25]), # M:x1, N1:x2, N2:x3
            17: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            18: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x3, N2:x4
            19: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.5]),   # M:x2, N1:x4, N2:x3
            20: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.25, 0.75]),  # M:x2, N1:x3, N2:x4
            21: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.5], [0.5, 0.75, 0.75, 0.25]),  # M:x2, N1:x4, N2:x3
            22: ([0.5, 0.25, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.5, 0.25], [0.5, 0.75, 0.5, 0.75]),   # M:x2, N1:x3, N2:x4
            23: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            24: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.75, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x3, N1:x1, N2:x2
            25: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x2
            26: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            27: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.25, 0.75]), # M:x3, N1:x1, N2:x4
            28: ([0.25, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.25], [0.5, 0.25, 0.25, 0.75], [0.5, 0.25, 0.75, 0.75]), # M:x3, N1:x1, N2:x4
            29: ([0.25, 0.25, 0.5, 0.25], [0.25, 0.75, 0.5, 0.25], [0.75, 0.5, 0.25, 0.25], [0.75, 0.5, 0.25, 0.75]), # M:x4, N1:x2, N2:x3
            30: ([0.25, 0.25, 0.5, 0.25], [0.75, 0.25, 0.5, 0.25], [0.5, 0.75, 0.25, 0.25], [0.5, 0.75, 0.75, 0.25]), # M:x4, N1:x1, N2:x3
        }
        return centers[int(k)]

# 生成单个试次的选择
def generate_choice(params, features, condition):
    k, beta = params.k, params.beta
    centers = get_centers(k, condition)
    
    # 计算到类别中心的距离
    distances = np.array([np.linalg.norm(features - np.array(center)) for center in centers])
    
    # 计算选择概率
    probs = np.exp(-beta * distances)
    probs /= np.sum(probs)
    
    # 根据概率选择类别
    choice = np.random.choice(len(centers), p=probs) + 1
    
    return choice

# 处理整个数据集
def process_data(data, params):
    processed_data = data.copy()
    
    # 为每个试次生成选择
    processed_data['model_choice'] = processed_data.apply(
        lambda row: generate_choice(params, 
                                    row[['feature1', 'feature2', 'feature3', 'feature4']].values, 
                                    row['condition']), 
        axis=1
    )
    
    # 计算模型反馈
    processed_data['model_feedback'] = (processed_data['model_choice'] == processed_data['category']).astype(int)
    
    return processed_data

# 分析处理后的数据
def analyze_processed_data(data):
    results = {}
    for iSub, subject_data in data.groupby('iSub'):
        condition = subject_data['condition'].iloc[0]
        model_choices = subject_data['model_choice'].value_counts().sort_index()
        actual_choices = subject_data['category'].value_counts().sort_index()
        model_feedback = subject_data['model_feedback'].mean()
        actual_feedback = subject_data['feedback'].mean()
        
        results[iSub] = {
            'condition': condition,
            'model_choice_distribution': model_choices,
            'actual_choice_distribution': actual_choices,
            'model_average_feedback': model_feedback,
            'actual_average_feedback': actual_feedback
        }
    
    return results

In [5]:
# 读取CSV文件
data = pd.read_csv('Task2.csv')

# 设置模型参数 (这里使用固定值，您可以根据需要修改)
params = ModelParams(k=1, beta=5)

# 处理数据
processed_data = process_data(data, params)

# 分析处理后的数据
analysis_results = analyze_processed_data(processed_data)

# 打印分析结果
for iSub, result in analysis_results.items():
    print(f"Subject {iSub}:")
    print(f"  Condition: {result['condition']}")
    print(f"  Model choice distribution: {result['model_choice_distribution'].to_dict()}")
    print(f"  Actual choice distribution: {result['actual_choice_distribution'].to_dict()}")
    print(f"  Model average feedback: {result['model_average_feedback']:.2f}")
    print(f"  Actual average feedback: {result['actual_average_feedback']:.2f}")
    print()

# 绘制选择分布对比图
def plot_choice_distribution_comparison(analysis_results):
    num_subjects = len(analysis_results)
    fig, axes = plt.subplots(num_subjects, 2, figsize=(15, 5*num_subjects))
    
    for i, (iSub, result) in enumerate(analysis_results.items()):
        # 模型选择分布
        axes[i, 0].bar(result['model_choice_distribution'].index, result['model_choice_distribution'].values)
        axes[i, 0].set_title(f"Subject {iSub} - Model Choices (Condition {result['condition']})")
        axes[i, 0].set_xlabel('Choice')
        axes[i, 0].set_ylabel('Frequency')
        
        # 实际选择分布
        axes[i, 1].bar(result['actual_choice_distribution'].index, result['actual_choice_distribution'].values)
        axes[i, 1].set_title(f"Subject {iSub} - Actual Choices (Condition {result['condition']})")
        axes[i, 1].set_xlabel('Choice')
        axes[i, 1].set_ylabel('Frequency')

    plt.tight_layout()
    plt.savefig('choice_distribution_comparison.png')
    plt.close()

# 绘制选择分布对比图
plot_choice_distribution_comparison(analysis_results)

print("数据处理和分析完成。选择分布对比图已保存为 'choice_distribution_comparison.png'。")

Subject 1:
  Condition: 1
  Model choice distribution: {1: 55, 2: 73}
  Actual choice distribution: {1: 32, 2: 32, 3: 32, 4: 32}
  Model average feedback: 0.27
  Actual average feedback: 0.87

Subject 4:
  Condition: 1
  Model choice distribution: {1: 61, 2: 67}
  Actual choice distribution: {1: 32, 2: 32, 3: 32, 4: 32}
  Model average feedback: 0.20
  Actual average feedback: 0.94

Subject 6:
  Condition: 3
  Model choice distribution: {1: 38, 2: 49, 3: 46, 4: 59}
  Actual choice distribution: {1: 43, 2: 44, 3: 48, 4: 57}
  Model average feedback: 0.45
  Actual average feedback: 0.90

Subject 11:
  Condition: 2
  Model choice distribution: {1: 71, 2: 51, 3: 67, 4: 67}
  Actual choice distribution: {1: 65, 2: 61, 3: 64, 4: 66}
  Model average feedback: 0.54
  Actual average feedback: 0.55

Subject 21:
  Condition: 3
  Model choice distribution: {1: 209, 2: 186, 3: 178, 4: 178}
  Actual choice distribution: {1: 185, 2: 196, 3: 181, 4: 189}
  Model average feedback: 0.48
  Actual average