In [13]:
import os
import json
import numpy as np
import umap
import plotly.graph_objects as go
import plotly.express as px
from tqdm import tqdm
from copy import deepcopy


In [14]:

# 降维参数
n_neighbors = 20  # UMAP的邻居数量
min_dist = 0.2    # UMAP的最小距离
n_components_2d = 2  # 二维降维
n_components_3d = 3  # 三维降维


def load_json_file(file_path):
    """加载JSON文件并返回数据"""
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
        return data
    except Exception as e:
        print(f"加载文件 {file_path} 出错: {e}")
        return None

def process_data_for_umap(data):
    """
    处理数据以用于UMAP降维
    将所有rule的时间维度拼接起来
    """
    if not data:
        return None, None
    
    # 获取所有规则名称
    rules = list(data.keys())
    
    # 检查所有矩阵的形状是否一致
    first_rule = rules[0]
    num_neurons, num_times = np.array(data[first_rule]).shape
    
    for rule in rules[1:]:
        r_num_neurons, r_num_times = np.array(data[rule]).shape
        if r_num_neurons != num_neurons or r_num_times != num_times:
            print(f"警告: 规则 {rule} 的矩阵形状与其他规则不一致")
            return None, None
    
    # 拼接所有规则的时间维度
    concatenated_data = []
    for rule in rules:
        rule_data = np.array(data[rule])
        concatenated_data.append(rule_data)
    
    # 拼接后的形状: (num_neurons, total_times) 其中 total_times = num_rules * num_times
    concatenated_data = np.concatenate(concatenated_data, axis=1)
    
    return concatenated_data, (num_neurons, num_times, len(rules))


def perform_umap(data, n_components=2):
    """对数据进行UMAP降维"""
    if data is None:
        return None
    
    # 初始化UMAP模型
    reducer = umap.UMAP(
        n_jobs=1,  # 使用单线程
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        random_state=42
    )
    
    data_reshaped = data.T  # 变为 (total_times, num_neurons)
    embedding = reducer.fit_transform(data_reshaped)
    
    return embedding


def create_visualization(embedding, num_times, num_rules, 
                         rules, n_components=2, title="UMAP可视化"):
    """创建Plotly可视化"""
    if embedding is None:
        return None
    
    fig = go.Figure()
    
        # 为每个规则创建轨迹段
    for rule_idx in range(num_rules):
        start_idx = rule_idx * num_times
        end_idx = (rule_idx + 1) * num_times
        
        # 获取该规则下的所有时间点
        if n_components == 2:
            x = embedding[start_idx:end_idx, 0]
            y = embedding[start_idx:end_idx, 1]
            z = None
        else:  # 3D
            x = embedding[start_idx:end_idx, 0]
            y = embedding[start_idx:end_idx, 1]
            z = embedding[start_idx:end_idx, 2]
        
        # 创建轨迹
        trace = go.Scatter3d(
            x=x, y=y, z=z,
            mode='markers',
            name=f'{rules[rule_idx]}',
            marker=dict(
                size=2,
                color=np.arange(end_idx - start_idx),  # 用颜色表示时间
                colorscale='Viridis',
                colorbar=dict(title='时间点') if (rule_idx == 0) else None
            ),
            line=dict(width=2)
        ) if n_components == 3 else go.Scatter(
            x=x, y=y,
            mode='markers',
            name=f'{rules[rule_idx]}',
            marker=dict(
                size=2,
                color=np.arange(end_idx - start_idx),  # 用颜色表示时间
                colorscale='Viridis',
                colorbar=dict(title='时间点') if (rule_idx == 0) else None
            ),
            line=dict(width=2)
        )
        
        fig.add_trace(trace)
    
    # 设置标题和轴标签
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2',
            zaxis_title='UMAP 3'
        ) if n_components == 3 else dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2'
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    return fig


def process_single_file(file_path, output_dir, generate_2d=True, generate_3d=True):
    """处理单个JSON文件并生成可视化结果"""
    # 加载数据
    data = load_json_file(file_path)
    if not data:
        return False
    
    rules = list(data.keys())
    if not rules:
        print(f"文件 {file_path} 不包含任何规则数据")
        return False
    
    # 处理数据
    concatenated_data, shapes = process_data_for_umap(data)
    if concatenated_data is None or shapes is None:
        return False
    
    num_neurons, num_times, num_rules = shapes
    
    # 获取文件名（不含扩展名）
    file_name = os.path.splitext(os.path.basename(file_path))[0]
    
    # 生成2D可视化
    if generate_2d:
        embedding_2d = perform_umap(concatenated_data, n_components=n_components_2d)
        if embedding_2d is not None:
            fig_2d = create_visualization(
                embedding_2d, num_times, num_rules, rules,
                n_components=2, title=f"{file_name} 的UMAP 2D可视化"
            )
            if fig_2d:
                output_file_2d = os.path.join(output_dir, f"{file_name}_2d.html")
                fig_2d.write_html(output_file_2d)
                print(f"已生成2D可视化: {output_file_2d}")
    
    # 生成3D可视化
    if generate_3d:
        embedding_3d = perform_umap(concatenated_data, n_components=n_components_3d)
        if embedding_3d is not None:
            fig_3d = create_visualization(
                embedding_3d, num_times, num_rules, rules,
                n_components=3, title=f"{file_name} 的UMAP 3D可视化"
            )
            if fig_3d:
                output_file_3d = os.path.join(output_dir, f"{file_name}_3d.html")
                fig_3d.write_html(output_file_3d)
                print(f"已生成3D可视化: {output_file_3d}")
    
    return True


def process_folder(input_dir, output_dir, generate_2d=True, generate_3d=True):

    os.makedirs(output_dir, exist_ok=True)
    
    for root, dirs, files in os.walk(input_dir):
        
        relative_path = os.path.relpath(root, input_dir)
        output_subdir = os.path.join(output_dir, relative_path)
        os.makedirs(output_subdir, exist_ok=True)
        
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                print(f"处理文件: {file_path}")
                process_single_file(file_path, output_subdir, generate_2d, generate_3d)


In [15]:

input_folder = "../results/json_temp/ump/flexible_shift/RDH01-PFC"
output_folder = "../results/html_final/umap/flexible_shift/RDH01-PFC" 
generate_2d = False
generate_3d = True

print(f"开始处理文件夹: {input_folder}")
print(f"结果将保存到: {output_folder}")
process_folder(input_folder, output_folder, generate_2d, generate_3d)
print("处理完成!")

开始处理文件夹: ../results/json_temp/ump/flexible_shift/RDH01-PFC
结果将保存到: ../results/html_final/umap/flexible_shift/RDH01-PFC
处理文件: ../results/json_temp/ump/flexible_shift/RDH01-PFC\alignTrack\correct_all_rules.json
已生成3D可视化: ../results/html_final/umap/flexible_shift/RDH01-PFC\alignTrack\correct_all_rules_3d.html
处理完成!
