## 1. Preprocess ##

In [3]:
import os
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import math
# from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

In [11]:
# 获取项目根目录
project_root = Path(os.getcwd()).parent
sys.path.append(str(project_root))

# 导入预处理器
import importlib
import src.preprocess as preprocess
importlib.reload(preprocess)
from src.preprocess import Preprocessor_A, Preprocessor_B

# 初始化预处理器
preprocessor_a = Preprocessor_A()
preprocessor_b = Preprocessor_B()

# 导入配置数据
from config import body_length, features_range, canvas_settings

In [12]:
# Task1a, Task1b, Task3c preprocess
def preprocess_main_a(project_root, taskID, subIDs, features_range, canvas_settings, body_length, preprocessor):
    input_path = Path(project_root) / 'data' / 'raw' / taskID
    output_path = Path(project_root) / 'data' / 'processed' / taskID 
    os.makedirs(output_path, exist_ok=True)

    for subID in subIDs:
        if taskID == 'Task1a':
            feature_init = pd.DataFrame({
            'neck_length': [0.5], 'head_length': [0.5], 'leg_length': [0.5], 'tail_length': [0.5],
            'neck_angle': [0.5], 'head_angle': [0.5], 'leg_angle': [0.5], 'tail_angle': [0.5]
         })
        elif taskID == 'Task1b':
            stimulus_data = pd.read_csv(input_path / f'{taskID}_{subID}_sti.csv')
            feature_init = stimulus_data[stimulus_data['type'] == 2]
        else:
            feature_init = pd.read_csv(input_path / f'{taskID}_{subID}_sti.csv')
            
        mouse_trajactory = pd.read_csv(input_path / f'{taskID}_{subID}_mouse.csv')
        
        feature_trajactory = preprocessor.process(taskID, feature_init, mouse_trajactory, features_range, canvas_settings, body_length)
        feature_trajactory.to_csv(os.path.join(output_path, f'{taskID}_{subID}_feature.csv'), index=False)

# Task2, Task3a, Task3b preprocess
def preprocess_main_b(project_root, taskID, subIDs, preprocessor):
    input_path = Path(project_root) / 'data' / 'raw' / taskID
    output_path = Path(project_root) / 'data' / 'processed'
    os.makedirs(output_path, exist_ok=True)

    all_data = []
    for subID in subIDs:
        if taskID in ['Task2', 'Task3a']:
            stimulus_data = pd.read_csv(input_path / f'{taskID}_{subID}_sti.csv')
        elif taskID == 'Task3b':
            left_stimulus_data = pd.read_csv(input_path / f'{taskID}_{subID}_left.csv')
            right_stimulus_data = pd.read_csv(input_path / f'{taskID}_{subID}_right.csv')
            stimulus_data = pd.merge(left_stimulus_data, right_stimulus_data, on=['pairID'])

        behavior_data = pd.read_csv(input_path / f'{taskID}_{subID}_bhv.csv')

        combined_data = preprocessor.process(taskID, stimulus_data, behavior_data)
        combined_data.insert(0, 'iSub', subID)
        all_data.append(combined_data)

    processed_data = pd.concat(all_data, ignore_index=True)
    processed_data.to_csv(os.path.join(output_path, f'{taskID}_processed.csv'), index=False)

In [7]:
# Task1b, Task3c reconstruct
def preprocess_construct(project_root, taskID, subIDs):
    raw_path = Path(project_root) / 'data' / 'raw' / taskID
    processed_path = Path(project_root) / 'data' / 'processed' / taskID
    output_path = Path(project_root) / 'data' / 'processed'

    all_data = []
    for subID in subIDs:
        if taskID == 'Task1b':
            stimulus_data = pd.read_csv(raw_path / f'{taskID}_{subID}_sti.csv')
            stimulus_data = stimulus_data.drop(columns=['version', 'display_height', 'PairID'])
            stimulus_data['type'] = stimulus_data['type'].replace({1: 'target', 2: 'adjust_init'})

        elif taskID == 'Task3c':
            stimulus_data = pd.read_csv(raw_path / f'{taskID}_{subID}_sti.csv')
            stimulus_data.insert(0, 'type', 'adjust_init')

        feature_trajactory = pd.read_csv(processed_path / f'{taskID}_{subID}_feature.csv')
        adjust_after = feature_trajactory.groupby('iTrial').last().reset_index()

        new_rows = stimulus_data[stimulus_data['type'] == 'adjust_init'][['iTrial', 'body_ori']].copy()
        new_rows.insert(0, 'type', 'adjust_after')

        feature_columns = ['neck_length', 'head_length', 'leg_length', 'tail_length', 
                        'neck_angle', 'head_angle', 'leg_angle', 'tail_angle']
        new_rows = new_rows.merge(adjust_after[['iTrial'] + feature_columns], on='iTrial', how='left')

        combined_data = pd.concat([stimulus_data, new_rows], ignore_index=True)
        combined_data.insert(0, 'iSub', subID)
        all_data.append(combined_data)

    processed_data = pd.concat(all_data, ignore_index=True)
    processed_data.to_csv(os.path.join(output_path, f'{taskID}_processed.csv'), index=False)


In [13]:
subIDs = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25]
# preprocess_main_a(project_root, 'Task1a', subIDs, features_range, canvas_settings, body_length, preprocessor_a)
# preprocess_main_a(project_root, 'Task1b', subIDs, features_range, canvas_settings, body_length, preprocessor_a)
# preprocess_main_a(project_root, 'Task3c', subIDs, features_range, canvas_settings, body_length, preprocessor_a)

# preprocess_construct(project_root, 'Task1b', subIDs)
# preprocess_construct(project_root, 'Task3c', subIDs)

preprocess_main_b(project_root, 'Task2', subIDs, preprocessor_b)
# preprocess_main_b(project_root, 'Task3a', subIDs, preprocessor_b)

## 2. Perceptive Error Analysis

In [14]:
# 获取项目根目录
project_root = Path(os.getcwd()).parent
sys.path.append(str(project_root))

# 导入处理器
import importlib
import src.error_evaluation as error_evaluation
importlib.reload(error_evaluation)
from src.error_evaluation import Processor

# 初始化预处理器
processor = Processor()

In [None]:
processed_path = Path(project_root) / 'data' / 'processed'
processed_data = pd.read_csv(processed_path / f'Task1b_processed.csv')

error = processor.error_calculation(processed_data)
summary = processor.error_summary(error)

# 保存结果
result_path = Path(project_root) / 'results' / 'Raw'
os.makedirs(result_path, exist_ok=True)

# 使用示例：
processor.plot_error(error, "length")  # 绘制长度误差图
processor.plot_error(error, "angle")   # 绘制角度误差图
# 使用函数
processor.plot_error_by_feature(error)

## 3. Recording Analysis

In [7]:
# 获取项目根目录
project_root = Path(os.getcwd()).parent
sys.path.append(str(project_root))

# # 导入处理器
# import importlib
# import src.audio_coding as audio_coding
# importlib.reload(audio_coding)
# from src.audio_coding import Processor

# # 初始化预处理器
# processor = Processor()

In [31]:
import pandas as pd
import re

# Define body parts and their corresponding column names
BODY_PARTS = {
    '脖子': 'neck_value',
    '头': 'head_value',
    '腿': 'leg_value',
    '尾巴': 'tail_value'
}

# Define description keywords and their corresponding values
DESCRIPTIONS = {
    '长': 3,
    '短': 1,
    '中等': 2,
    '适中': 2
}

# Define possible modifiers between body parts and descriptions in "比" pattern
MODIFIERS = ['比较', '很', '等']

def extract_values(text):
    # Initialize the result dictionary with None
    result = {
        'invalid': 0,
        'noinfo': 0,
        'neck_value': None,
        'head_value': None,
        'leg_value': None,
        'tail_value': None
    }
    
    # Check if text is NaN or empty after stripping
    if pd.isna(text) or str(text).strip() == '':
        result['invalid'] = 1
        return result
    
    # Split the text by Chinese comma and remove any trailing punctuation
    items = re.split(r'[，,]', text)
    
    for item in items:
        item = item.strip('。.？?！!、')
        if not item:
            continue
        
        # Check if there's a comparison with "比" (but not "比较")
        has_comparison = "比" in item and "比较" not in item
        
        # Find all body parts mentioned in the item
        mentioned_parts = [part for part in BODY_PARTS.keys() if part in item]
        
        if not mentioned_parts:
            continue
            
        # Find description keywords, but ensure "长" is not part of another word
        descriptions_found = []
        for desc in DESCRIPTIONS.keys():
            if desc == '长':
                # Check if '长' exists but not as part of '长度'
                if '长' in item and '长度' not in item:
                    descriptions_found.append(desc)
            else:
                if desc in item:
                    descriptions_found.append(desc)
        
        if len(descriptions_found) >= 1:
            desc_value = DESCRIPTIONS[descriptions_found[0]]
            
            if has_comparison:
                # Find the parts before and after "比"
                parts_before = [part for part in mentioned_parts 
                              if item.find(part) < item.find('比')]
                parts_after = [part for part in mentioned_parts 
                             if item.find(part) > item.find('比')]
                
                # Assign opposite values for parts before "比"
                for part in parts_before:
                    result[BODY_PARTS[part]] = desc_value
                
                # Assign normal values for parts after "比"
                for part in parts_after:
                    result[BODY_PARTS[part]] = 4 - desc_value
            else:
                # No comparison, assign same value to all parts
                for part in mentioned_parts:
                    result[BODY_PARTS[part]] = desc_value

    # New Logic: Handle "其他" or "其余" with a description adjective
    for item in items:
        if any(keyword in item for keyword in ['其他', '其余']):
            # Find description adjectives in the item
            descriptions_found = []
            for desc in DESCRIPTIONS.keys():
                if desc == '长':
                    # Ensure '长' is not part of '长度'
                    if '长' in item and '长度' not in item:
                        descriptions_found.append(desc)
                else:
                    if desc in item:
                        descriptions_found.append(desc)
            
            if descriptions_found:
                # Use the first found description
                desc_value = DESCRIPTIONS[descriptions_found[0]]
                
                # Assign to all body parts that are still None
                for part, col in BODY_PARTS.items():
                    if result[col] is None:
                        result[col] = desc_value
                break  # Assuming only one "其他" or "其余" per text

    # Check if all body part values are still None
    body_values = [result[col] for col in BODY_PARTS.values()]
    if all(v is None for v in body_values):
        result['noinfo'] = 1
    else:
        result['noinfo'] = 0
    
    return result

def process_csv(input_file, output_file):
    # Read the CSV file
    df = pd.read_csv(input_file)
    
    # Initialize new columns
    df['invalid'] = 0
    df['noinfo'] = 0
    df['neck_value'] = None
    df['head_value'] = None
    df['leg_value'] = None
    df['tail_value'] = None
    
    # Apply the extraction function to each row
    extracted_data = df['text'].apply(extract_values)
    
    # Populate the new columns based on the extracted data
    df['invalid'] = extracted_data.apply(lambda x: x['invalid'])
    df['noinfo'] = extracted_data.apply(lambda x: x['noinfo'])
    df['neck_value'] = extracted_data.apply(lambda x: x['neck_value'])
    df['head_value'] = extracted_data.apply(lambda x: x['head_value'])
    df['leg_value'] = extracted_data.apply(lambda x: x['leg_value'])
    df['tail_value'] = extracted_data.apply(lambda x: x['tail_value'])
    
    # Save the processed DataFrame to a new CSV file
    df.to_csv(output_file, index=False, encoding='utf-8-sig')

# Example usage:
# process_csv('recording.csv', 'processed_recording.csv')


In [32]:
input_dir = Path(project_root) / 'data' / 'raw' / 'Task2'
output_dir = Path(project_root) / 'data' / 'processed' / 'Task2' 
os.makedirs(output_dir, exist_ok=True)

for filename in os.listdir(input_dir):
    if filename.endswith('rec.csv'):
        input_path = os.path.join(input_dir, filename)
        output_path = os.path.join(output_dir, filename)

        try:
            df = pd.read_csv(input_path, encoding='utf-8')
        except UnicodeDecodeError:
            df = pd.read_csv(input_path, encoding='gbk')
        
        process_csv(input_path, output_path)

## 4. Plot

In [217]:
import numpy as np
import pandas as pd
import os
import matplotlib.colors as mc
import colorsys
import joblib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.lines import Line2D

In [195]:
def read_data(input_rec_csv, input_bhv_csv, input_modelfitting):
    """
    读取并合并三个数据文件。
    
    Parameters:
        input_rec_csv (str): Task2_15_rec.csv 的路径。
        input_bhv_csv (str): Task2_15_bhv.csv 的路径。
        input_modelfitting (list): 包含模型拟合数据的列表，每个元素为 (k, center_dict)。
    
    Returns:
        pd.DataFrame: 合并后的数据框，包含human_feature和choice特征。
    """
    # 1. 读取CSV文件
    try:
        df_rec = pd.read_csv(input_rec_csv)
        df_bhv = pd.read_csv(input_bhv_csv)
    except FileNotFoundError as e:
        raise FileNotFoundError(f"无法找到文件: {e.filename}")
    
    # 2. 合并两个DataFrame，基于'iSession'和'iTrial'
    if not {'iSession', 'iTrial'}.issubset(df_rec.columns):
        raise ValueError("Task2_15_rec.csv中必须包含'iSession'和'iTrial'列。")
    if not {'iSession', 'iTrial', 'choice'}.issubset(df_bhv.columns):
        raise ValueError("Task2_15_bhv.csv中必须包含'iSession', 'iTrial'和'choice'列。")
    
    df = pd.merge(df_rec, df_bhv[['iSession', 'iTrial', 'choice']], on=['iSession', 'iTrial'], how='left')

    # 3. 处理四个value列
    value_columns = ['neck_value', 'head_value', 'leg_value', 'tail_value']
    for col in value_columns:
        if col not in df.columns:
            raise ValueError(f"列 '{col}' 在CSV文件中不存在。")

    # 4. 填充空值为2
    df[value_columns] = df[value_columns].fillna(2)

    # 5. 映射[1, 2, 3]到[0.25, 0.5, 0.75]
    mapping = {1: 0.25, 2: 0.5, 3: 0.75}
    df[value_columns] = df[value_columns].replace(mapping)

    # 6. 重命名列
    rename_mapping = {
        'head_value': 'human_feature_1',
        'leg_value': 'human_feature_2',
        'tail_value': 'human_feature_3',
        'neck_value': 'human_feature_4'
    }
    df1 = df.rename(columns=rename_mapping)

    # 定义列名
    columns = [f'choice_{choice}_feature_{feature}' 
            for choice in range(1, 5) 
            for feature in range(1, 5)]

    # 提取数据行
    rows = []
    for entry in input_modelfitting:
        k, center_dict = entry
        row = []
        for choice_key in range(4):  # 键 0 到 3
            features = center_dict.get(choice_key, (None,)*4)
            row.extend(features)
        rows.append(row)

    # 创建 DataFrame
    df2 = pd.DataFrame(rows, columns=columns)

    result = pd.concat([df1, df2], axis=1)

    return result

In [229]:
def draw_cube(ax):
    """
    在给定的轴上绘制一个立方体。
    
    Parameters:
        ax (Axes3D): 三维坐标轴对象。
    """
    # 定义立方体的8个顶点
    vertices = [
        (0, 0, 0),
        (1, 0, 0),
        (1, 1, 0),
        (0, 1, 0),
        (0, 0, 1),
        (1, 0, 1),
        (1, 1, 1),
        (0, 1, 1)
    ]
    
    # 定义立方体的12条边，连接顶点索引
    edges = [
        (0, 1), (1, 2), (2, 3), (3, 0),  # 底面
        (4, 5), (5, 6), (6, 7), (7, 4),  # 顶面
        (0, 4), (1, 5), (2, 6), (3, 7)   # 连接底面和顶面
    ]
    
    # 绘制边线
    for edge in edges:
        start, end = edge
        x_vals = [vertices[start][0], vertices[end][0]]
        y_vals = [vertices[start][1], vertices[end][1]]
        z_vals = [vertices[start][2], vertices[end][2]]
        ax.plot(x_vals, y_vals, z_vals, color='grey')

def draw_intersection_lines(ax):
    """
    在给定的轴上绘制灰色的三个平面的交线，共六条。
    
    Parameters:
        ax (Axes3D): 三维坐标轴对象。
    """
    # 定义平面位置
    plane_position = 0.5
    
    # # 绘制平面 feature1=0.5
    # y, z = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 10))
    # x = np.full_like(y, plane_position)
    # ax.plot_surface(x, y, z, color='grey', alpha=0.07)
    
    # # 绘制平面 feature2=0.5
    # x, z = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 10))
    # y = np.full_like(x, plane_position)
    # ax.plot_surface(x, y, z, color='grey', alpha=0.07)
    
    # # 绘制平面 feature3=0.5
    # x, y = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 10))
    # z = np.full_like(x, plane_position)
    # ax.plot_surface(x, y, z, color='grey', alpha=0.07)
    
    # 绘制交线（六条）
    # 1. feature1=0.5 与 feature2=0.5 的交线 (x=0.5, y=0.5, z从0到1)
    ax.plot([plane_position, plane_position], [plane_position, plane_position], [0, 1], color='grey', linestyle='--', linewidth=1)
    
    # 2. feature1=0.5 与 feature3=0.5 的交线 (x=0.5, y从0到1, z=0.5)
    ax.plot([plane_position, plane_position], [0, 1], [plane_position, plane_position], color='grey', linestyle='--', linewidth=1)
    
    # 3. feature2=0.5 与 feature3=0.5 的交线 (x从0到1, y=0.5, z=0.5)
    ax.plot([0, 1], [plane_position, plane_position], [plane_position, plane_position], color='grey', linestyle='--', linewidth=1)

def lighten_color(color, amount=0.5):
    """
    淡化颜色，使其更浅。

    Parameters:
        color (str): 原始颜色名称或RGB值。
        amount (float): 淡化程度，0表示不变，1表示白色。

    Returns:
        tuple: 淡化后的RGB颜色。
    """
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    # 淡化颜色
    new_color = colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
    return new_color

def plot_choice_graph(iSub, iSession, iTrial, choice, features_list, color_mapping, plots_dir, plot_side='both'):
    """
    绘制特定 choice 值的图像，并保存。
    
    Parameters:
        iSub (int/str): 受试者编号。
        iSession (int/str): 会话编号。
        iTrial (int/str): 试验编号。
        choice (int): 当前的 choice 值（1, 2, 3, 4）。
        features (list or tuple): 当前行的特征值 [feature1, feature2, feature3, feature4]。
        color_mapping (dict): 每个 choice 值对应的颜色映射。
        plots_dir (str): 图像保存的文件夹路径。
        plot_side (str): 绘制的子图类型，可选 'left', 'right', 'both'。
    """
    # 创建对应 choice 的子文件夹路径
    choice_folder = os.path.join(plots_dir, f"choice{choice}")
    if not os.path.exists(choice_folder):
        os.makedirs(choice_folder)
    
    # 创建图表
    fig = plt.figure(figsize=(12, 6) if plot_side == 'both' else (6, 6))
    
    # 添加主标题
    fig.suptitle(f"iSub={iSub}, iSession={iSession}, iTrial={iTrial}, Category={choice}", fontsize=16)

    # Prepare yellow point coordinates
    yellow_point_coords = {
        1: (0.25, 0.25, 0.5),
        2: (0.25, 0.75, 0.5),
        3: (0.75, 0.5, 0.25),
        4: (0.75, 0.5, 0.75)
    }

    # Extract all human and Bayesian learner features for trajectory
    human_x = [feat['human_feature_1'] for feat in features_list]
    human_y = [feat['human_feature_2'] for feat in features_list]
    human_z = [feat['human_feature_3'] for feat in features_list]

    bayesian_x = [feat[f'choice_{choice}_feature_1'] for feat in features_list]
    bayesian_y = [feat[f'choice_{choice}_feature_2'] for feat in features_list]
    bayesian_z = [feat[f'choice_{choice}_feature_3'] for feat in features_list]

    # 绘制左图（human_feature_1, human_feature_2, human_feature_3）
    if plot_side in ['left', 'both']:
        ax_left = fig.add_subplot(1, 2, 1, projection='3d') if plot_side == 'both' else fig.add_subplot(1, 1, 1, projection='3d')
        draw_cube(ax_left)

        # 淡化后的颜色
        lighter_color = lighten_color(color_mapping[choice], amount=0.7)

        # 目标点
        if choice in yellow_point_coords:
            y_point = yellow_point_coords[choice]
            ax_left.scatter(*y_point, color='yellow', s=200, alpha=0.7, edgecolors='k')

        # Plot trajectory line
        if len(features_list) > 1:
            ax_left.plot(human_x, human_y, human_z, color=lighter_color, linewidth=1)
        # Plot current point
        ax_left.scatter(human_x[-1], human_y[-1], human_z[-1], color=color_mapping[choice], s=100, alpha=0.8, edgecolors='w')
        # 设置坐标轴刻度
        ax_left.set_xticks([0, 0.5, 1])
        ax_left.set_yticks([0, 0.5, 1])
        ax_left.set_zticks([0, 0.5, 1])
        ax_left.set_xlim(0, 1)
        ax_left.set_ylim(0, 1)
        ax_left.set_zlim(0, 1)
        ax_left.set_xlabel('Feature 1')
        ax_left.set_ylabel('Feature 2')
        ax_left.set_zlabel('Feature 3')
        ax_left.view_init(elev=15., azim=30)  # 调整视角
        # 绘制平面和交线
        # draw_intersection_lines(ax_left)
        # 添加子图标题
        ax_left.set_title("Human")

    # 绘制右图（feature2, 3, 4）
    if plot_side in ['right', 'both']:
        if plot_side == 'both':
            ax_right = fig.add_subplot(1, 2, 2, projection='3d')
        else:
            ax_right = fig.add_subplot(1, 1, 1, projection='3d')
        draw_cube(ax_right)

        # 淡化后的颜色
        lighter_color = lighten_color(color_mapping[choice], amount=0.7)

        # 目标点
        if choice in yellow_point_coords:
            y_point = yellow_point_coords[choice]
            ax_right.scatter(*y_point, color='yellow', s=200, alpha=0.7, edgecolors='k')
        
        # Plot trajectory line
        if len(features_list) > 1:
            ax_right.plot(bayesian_x, bayesian_y, bayesian_z, color=lighter_color, linewidth=1)
        # Plot current point
        ax_right.scatter(bayesian_x[-1], bayesian_y[-1], bayesian_z[-1], color=color_mapping[choice], s=100, alpha=0.8, edgecolors='w')
        # 设置坐标轴刻度
        ax_right.set_xticks([0, 0.5, 1])
        ax_right.set_yticks([0, 0.5, 1])
        ax_right.set_zticks([0, 0.5, 1])
        ax_right.set_xlim(0, 1)
        ax_right.set_ylim(0, 1)
        ax_right.set_zlim(0, 1)
        ax_right.set_xlabel('Feature 1')
        ax_right.set_ylabel('Feature 2')
        ax_right.set_zlabel('Feature 3')
        ax_right.view_init(elev=15., azim=30)  # 调整视角
        # 绘制平面和交线
        # draw_intersection_lines(ax_right)
        # 添加子图标题
        ax_right.set_title("Bayesian learner")
    
    # 保存图表到对应的 choice 文件夹
    filename = f"{iSub}_{iSession}_{iTrial}_c{choice}.png"
    filepath = os.path.join(choice_folder, filename)
    plt.savefig(filepath)
    plt.close()

def process_and_plot(input_rec_csv, input_bhv_csv, input_modelfitting, output_csv, plots_dir, plot_side='both'):
    """
    处理数据并生成按 choice 分别绘制的图像。
    
    Parameters:
        input_rec_csv (str): Task2_15_rec.csv 的路径。
        input_bhv_csv (str): Task2_15_bhv.csv 的路径。
        input_modelfitting (list): 包含模型拟合数据的列表，每个元素为 (k, center_dict)。
        output_csv (str): 处理后的CSV文件保存路径。
        plots_dir (str): 图像保存的文件夹路径。
        plot_side (str): 绘制的子图类型，可选 'left', 'right', 'both'。
    """
    # 1. 读取并合并数据
    df = read_data(input_rec_csv, input_bhv_csv, input_modelfitting)
    
    # 3. 保存处理后的CSV
    df.to_csv(output_csv, index=False, encoding='utf-8-sig')
    
    # 4. 创建各个 choice 的子文件夹
    for choice in range(1, 5):
        folder_name = f"choice{choice}"
        folder_path = os.path.join(plots_dir, folder_name)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
    
    # 5. 定义颜色映射
    color_mapping = {
        1: 'darkgreen',
        2: 'darkgreen',
        3: 'darkred',
        4: 'darkred'
    }
    
    # 6. 初始化 last_known_features
    # last_known_features = {choice: [feature_dict1, feature_dict2, ...]}
    last_known_features = {1: [], 2: [], 3: [], 4: []}
    
    # 7. 迭代每一行数据，生成图表
    for index, row in df.iterrows():
        iSub = row.get('iSub', 'Unknown')  # 假设有 'iSub' 列
        iSession = row['iSession']
        iTrial = row['iTrial']
        current_choice = row['choice']
        
        if pd.isna(current_choice):
            print(f"第 {index} 行缺少 'choice' 数据，跳过绘图。")
            continue
        
        current_choice = int(current_choice)
        # 提取human_feature
        human_features = {
            'human_feature_1': row['human_feature_1'],
            'human_feature_2': row['human_feature_2'],
            'human_feature_3': row['human_feature_3'],
            'human_feature_4': row['human_feature_4']
        }
        # 提取choice特征
        choice_features = {}
        for choice in range(1, 5):
            choice_features[f'choice_{choice}_feature_1'] = row[f'choice_{choice}_feature_1']
            choice_features[f'choice_{choice}_feature_2'] = row[f'choice_{choice}_feature_2']
            choice_features[f'choice_{choice}_feature_3'] = row[f'choice_{choice}_feature_3']
            choice_features[f'choice_{choice}_feature_4'] = row[f'choice_{choice}_feature_4']
        
        # 更新 last_known_features for the current_choice
        feature_entry = {
            'human_feature_1': human_features['human_feature_1'],
            'human_feature_2': human_features['human_feature_2'],
            'human_feature_3': human_features['human_feature_3'],
            'human_feature_4': human_features['human_feature_4'],
            f'choice_{current_choice}_feature_1': choice_features[f'choice_{current_choice}_feature_1'],
            f'choice_{current_choice}_feature_2': choice_features[f'choice_{current_choice}_feature_2'],
            f'choice_{current_choice}_feature_3': choice_features[f'choice_{current_choice}_feature_3'],
            f'choice_{current_choice}_feature_4': choice_features[f'choice_{current_choice}_feature_4']
        }
        last_known_features[current_choice].append(feature_entry)
        
        # 绘制当前 choice 的图像
        plot_choice_graph(
            iSub=iSub,
            iSession=iSession,
            iTrial=iTrial,
            choice=current_choice,
            features_list=last_known_features[current_choice],
            color_mapping=color_mapping,
            plots_dir=plots_dir,
            plot_side=plot_side
        )
    
        # 绘制其他 choices 的图像，使用 last_known_features
        for choice in range(1, 5):
            if choice == current_choice:
                continue  # 已经绘制当前选择的 choice
            if last_known_features[choice]:
                # 绘制该 choice 的图像，使用上一次已知的特征值
                plot_choice_graph(
                    iSub=iSub,
                    iSession=iSession,
                    iTrial=iTrial,
                    choice=choice,
                    features_list=last_known_features[choice],
                    color_mapping=color_mapping,
                    plots_dir=plots_dir,
                    plot_side=plot_side
                )
            else:
                # 如果该 choice 之前没有数据，则跳过或使用默认图像
                print(f"Choice {choice} 在第 {index} 行之前没有数据，跳过生成图像。")

    print(f"处理完成，图表已分别保存到 '{plots_dir}/choice1', '{plots_dir}/choice2', '{plots_dir}/choice3', 和 '{plots_dir}/choice4' 文件夹中。")

In [172]:
result_path = Path(project_root) / 'results' / 'Bayesian'
fitting_results = joblib.load(result_path / 'M_Base_fitting_results.joblib')
fit_result = fitting_results[9]
step_results = fit_result['step_results']

# 导入分割方法
import src.Bayesian.utils.partition as partition
importlib.reload(partition)
from src.Bayesian.utils.partition import Partition

partition = Partition()
all_centers = partition.get_centers(4, 4)

input_modelfitting = [[step['k'], all_centers[step['k'] - 1][1]] for step in step_results]

In [230]:
raw_dir = Path(project_root) / 'data' / 'raw' / 'Task2' 
processed_dir = Path(project_root) / 'data' / 'processed' / 'Task2' 
input_bhv_csv = os.path.join(raw_dir, 'Task2_9_bhv.csv')
input_rec_csv = os.path.join(processed_dir, 'Task2_9_rec.csv')
output_csv = os.path.join(processed_dir, 'Task2_9_processed.csv')
plots_dir = Path(project_root) / 'results' / 'Plots'

process_and_plot(input_rec_csv, input_bhv_csv, input_modelfitting, output_csv, plots_dir, plot_side='both')

Choice 1 在第 0 行之前没有数据，跳过生成图像。
Choice 2 在第 0 行之前没有数据，跳过生成图像。
Choice 3 在第 0 行之前没有数据，跳过生成图像。
Choice 1 在第 1 行之前没有数据，跳过生成图像。
Choice 2 在第 1 行之前没有数据，跳过生成图像。
Choice 1 在第 2 行之前没有数据，跳过生成图像。
Choice 1 在第 3 行之前没有数据，跳过生成图像。
处理完成，图表已分别保存到 '/home/yangjiong/CategoryLearning/results/Plots/choice1', '/home/yangjiong/CategoryLearning/results/Plots/choice2', '/home/yangjiong/CategoryLearning/results/Plots/choice3', 和 '/home/yangjiong/CategoryLearning/results/Plots/choice4' 文件夹中。


In [223]:
import os
import re
import imageio

def extract_session_trial(filename, pattern):
    """
    从文件名中提取iSession和iTrial。

    参数:
    - filename (str): 文件名。
    - pattern (str): 用于匹配文件名的正则表达式模式。

    返回:
    - tuple: (iSession, iTrial) 如果匹配成功，否则 (None, None)。
    """
    match = re.match(pattern, filename)
    if match:
        iSession = int(match.group(1))
        iTrial = int(match.group(2))
        return iSession, iTrial
    else:
        return None, None
    
def create_sorted_gif(plots_dir, output_gif, pattern, duration=0.5):
    """
    将指定文件夹中的所有PNG图像按照iSession和iTrial的顺序合成为一个GIF文件。

    参数:
    - plots_dir (Path): 存放PNG图像的子文件夹路径。
    - output_gif (Path): 输出GIF文件的路径。
    - pattern (str): 用于匹配文件名的正则表达式模式。
    - duration (float): 每帧之间的时间间隔（秒）。
    """
    # 获取所有PNG文件
    all_files = [f for f in os.listdir(plots_dir) if f.endswith('.png')]

    # 提取iSession和iTrial，并过滤无效文件
    valid_files = []
    for f in all_files:
        iSession, iTrial = extract_session_trial(f, pattern)
        if iSession is not None and iTrial is not None:
            valid_files.append((iSession, iTrial, f))
        else:
            print(f"文件名 '{f}' 不符合预期格式，已跳过。")

    if not valid_files:
        print(f"在文件夹 '{plots_dir}' 中未找到符合格式的PNG图像。")
        return

    # 按iSession和iTrial排序
    sorted_files = sorted(valid_files, key=lambda x: (x[0], x[1]))

    # 读取图像
    images = []
    for iSession, iTrial, filename in sorted_files:
        filepath = plots_dir / filename
        try:
            images.append(imageio.imread(filepath))
        except Exception as e:
            print(f"读取文件 '{filepath}' 时出错: {e}")

    if not images:
        print(f"没有成功读取任何图像用于 '{output_gif}'。")
        return

    # 创建GIF
    try:
        imageio.mimsave(output_gif, images, duration=duration)
        print(f"GIF已成功创建并保存为 '{output_gif}'。")
    except Exception as e:
        print(f"保存GIF '{output_gif}' 时出错: {e}")

In [231]:
plots_parent_dir = Path(project_root) / 'results' / 'Plots' 

# 定义子文件夹及对应的文件名模式
choices = {
    'choice1': r'^\d+_(\d+)_(\d+)_c1\.png$',
    'choice2': r'^\d+_(\d+)_(\d+)_c2\.png$',
    'choice3': r'^\d+_(\d+)_(\d+)_c3\.png$',
    'choice4': r'^\d+_(\d+)_(\d+)_c4\.png$',
}

# 遍历每个子文件夹并生成GIF
for choice, pattern in choices.items():
    sub_dir = plots_parent_dir / choice
    if not sub_dir.exists() or not sub_dir.is_dir():
        print(f"子文件夹 '{sub_dir}' 不存在或不是一个文件夹，已跳过。")
        continue

    # 定义输出GIF的路径，保存到Plots父文件夹下
    output_gif = plots_parent_dir / f'{choice}_animation.gif'

    # 创建GIF
    create_sorted_gif(sub_dir, output_gif, pattern, duration=0.5)

  images.append(imageio.imread(filepath))


GIF已成功创建并保存为 '/home/yangjiong/CategoryLearning/results/Plots/choice1_animation.gif'。
GIF已成功创建并保存为 '/home/yangjiong/CategoryLearning/results/Plots/choice2_animation.gif'。
GIF已成功创建并保存为 '/home/yangjiong/CategoryLearning/results/Plots/choice3_animation.gif'。
GIF已成功创建并保存为 '/home/yangjiong/CategoryLearning/results/Plots/choice4_animation.gif'。
