In [None]:
# --- 1. 导入必要的库并设置项目路径 ---
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D # For custom legends
from tqdm import tqdm
from omegaconf import OmegaConf
import torch_geometric.data
import umap

# 假设本 Notebook 位于项目的 'test/' 目录下
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)
print(f"项目根目录已添加到系统路径: {PROJECT_ROOT}")

from src.models.combined_model import CombinedModel
from src.utils import set_seed, get_device
from src.data_loader import load_gene_graph_data, create_cells_by_day_mapping
from src.evaluate import _get_point_estimate_from_params, _reshape_params_for_loss_or_eval

In [None]:
# ## 2. 配置
# --- 用户可配置参数 ---

# 指向您的主配置文件
CONFIG_PATH = os.path.join(PROJECT_ROOT, "configs/main_config.yaml")

# 指向您已经训练好的、想要评估的【联合模型】检查点
CHECKPOINT_PATH = os.path.join(PROJECT_ROOT, "results/experiment_stable_start/checkpoints/joint_train_best.pt")

# 要评估的时间点将从数据中自动获取

# --- 可视化参数 ---
NUM_SAMPLES_PER_DAY = 200 # 为每个时间点采样的细胞数量
UMAP_N_NEIGHBORS = 15
UMAP_MIN_DIST = 0.1
UMAP_METRIC = 'euclidean'
UMAP_RANDOM_STATE = 42

In [None]:
# ## 3. 加载配置、数据和模型
# --- 加载配置 ---
if not os.path.exists(CONFIG_PATH):
    raise FileNotFoundError(f"配置文件未找到: {CONFIG_PATH}")
config = OmegaConf.load(CONFIG_PATH)

set_seed(config.seed)
device = get_device(config.training_params.device)
print(f"正在使用设备: {device}")

# --- 加载数据 ---
print("正在加载数据...")
actual_data_dir = config.data_params.data_dir
if not os.path.isabs(actual_data_dir):
    actual_data_dir = os.path.join(PROJECT_ROOT, actual_data_dir)
if not os.path.exists(actual_data_dir):
    raise FileNotFoundError(f"数据目录未找到: {actual_data_dir}")
X_all_np, shared_edge_index, shared_edge_weight, gene_names, cell_names, meta_df = \
    load_gene_graph_data(actual_data_dir, config)
cells_by_day_indices = create_cells_by_day_mapping(meta_df, cell_names)
print("数据加载完成。")

# --- 加载模型 ---
print(f"正在从检查点加载模型: {CHECKPOINT_PATH}")
if not os.path.exists(CHECKPOINT_PATH):
    raise FileNotFoundError(f"检查点文件未找到: {CHECKPOINT_PATH}")
# 加载 CombinedModel
model = CombinedModel(config)
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

state_dict = checkpoint['model_state_dict']
if all(key.startswith('module.') for key in state_dict.keys()):
    print("检查点来自DDP模型，正在移除 'module.' 前缀。")
    state_dict = {k[7:]: v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
print("模型加载完成。")

# --- 获取模型和数据参数 ---
is_variational = bool(config.model_params.encoder.get("is_variational", False))
decoder_distribution = config.model_params.decoder.get("distribution", "gaussian")
num_dist_params = config.model_params.decoder.get("num_dist_params", 2)
num_genes = int(config.model_params.num_genes)

# 自动确定所有要评估的时间点
DAYS_TO_EVALUATE = sorted(cells_by_day_indices.keys())
print(f"将自动评估以下所有时间点: {DAYS_TO_EVALUATE}")

In [None]:
# ## 4. 生成重构数据 (只通过VGAE部分)
# --- 初始化数据收集列表 ---
all_expressions_list = []
all_latent_list = []
all_type_labels_list = [] 
all_day_labels_list = [] 

with torch.no_grad():
    for day in tqdm(DAYS_TO_EVALUATE, desc="处理各个时间点"):
        if day not in cells_by_day_indices or len(cells_by_day_indices[day]) == 0:
            print(f"警告: 时间点 {day} 没有有效细胞，跳过。")
            continue
        
        day_indices = np.random.choice(cells_by_day_indices[day], min(NUM_SAMPLES_PER_DAY, len(cells_by_day_indices[day])), replace=False)
        real_expressions = X_all_np[day_indices]
        
        day_data_list = [torch_geometric.data.Data(x=torch.tensor(expr, dtype=torch.float32).unsqueeze(-1), edge_index=shared_edge_index) for expr in real_expressions]
        day_batch_pyg = torch_geometric.data.Batch.from_data_list(day_data_list).to(device)
        
        # --- 通过VGAE进行编码和解码 ---
        gae_outputs = model.graph_autoencoder(day_batch_pyg) 
        
        # --- 收集潜空间数据 ---
        real_mu_nodes = gae_outputs.get("mu_nodes")
        real_z_batch = gae_outputs.get("z_batch")
        
        real_latent_repr = real_mu_nodes if is_variational and real_mu_nodes is not None else gae_outputs.get("sampled_z_nodes")
        if real_latent_repr is not None and real_latent_repr.numel() > 0:
            all_latent_list.append(global_mean_pool(real_latent_repr, real_z_batch).detach().cpu().numpy())
        
        # --- 收集表达空间数据 ---
        reconstructed_params = gae_outputs.get("reconstructed_params")
        reshaped_params = _reshape_params_for_loss_or_eval(reconstructed_params, (len(day_indices), num_genes), num_dist_params)
        recon_expressions_est = _get_point_estimate_from_params(reshaped_params, decoder_distribution, device)
        recon_expressions_np = recon_expressions_est.detach().cpu().numpy()
        all_expressions_list.append(real_expressions)
        all_expressions_list.append(recon_expressions_np)

        # --- 重新编码重构后的表达，以评估潜空间重构 ---
        recon_data_list = [torch_geometric.data.Data(x=torch.tensor(expr, dtype=torch.float32).unsqueeze(-1), edge_index=shared_edge_index) for expr in recon_expressions_np]
        recon_batch_pyg = torch_geometric.data.Batch.from_data_list(recon_data_list).to(device)
        _, recon_mu_nodes, _, recon_batch, _ = model.graph_autoencoder.encode(recon_batch_pyg, return_pooling_details=False)
        
        recon_latent_repr = recon_mu_nodes if is_variational and recon_mu_nodes is not None else _
        if recon_latent_repr is not None and recon_latent_repr.numel() > 0:
            all_latent_list.append(global_mean_pool(recon_latent_repr, recon_batch).detach().cpu().numpy())
        
        # --- 收集标签 ---
        all_day_labels_list.extend([day] * len(real_expressions))
        all_day_labels_list.extend([day] * len(recon_expressions_np))
        all_type_labels_list.extend(["Real"] * len(real_expressions))
        all_type_labels_list.extend(["Recon"] * len(recon_expressions_np))


# --- 合并所有数据 ---
if not all_expressions_list:
    raise ValueError("没有为UMAP收集到任何数据。请检查 DAYS_TO_EVALUATE 是否有效。")
    
combined_expressions_np = np.concatenate(all_expressions_list, axis=0)
combined_latent_np = np.concatenate(all_latent_list, axis=0)
day_labels_np = np.array(all_day_labels_list)
type_labels_np = np.array(all_type_labels_list)

print(f"为UMAP准备的总数据点数量: {combined_expressions_np.shape[0]}")

In [None]:
# ## 5. UMAP降维与可视化
def plot_space(embedding, day_labels, type_labels, title, output_filename):
    fig, ax = plt.subplots(figsize=(16, 12))
    unique_days = np.unique(day_labels)
    colors = plt.get_cmap('viridis', len(unique_days))
    day_to_color = {day: colors(i) for i, day in enumerate(unique_days)}

    for day in unique_days:
        real_idx = (day_labels == day) & (type_labels == "Real")
        recon_idx = (day_labels == day) & (type_labels == "Recon")
        color = day_to_color[day]
        
        ax.scatter(embedding[real_idx, 0], embedding[real_idx, 1], color=color, marker='o', s=30, alpha=0.1)
        ax.scatter(embedding[recon_idx, 0], embedding[recon_idx, 1], color=color, marker='x', s=40, alpha=1)

    day_legend_elements = [Line2D([0], [0], color=day_to_color[day], lw=4, label=f'Day {day:.2f}') for day in unique_days]
    marker_legend_elements = [
        Line2D([0], [0], marker='o', color='gray', label='Real Data', linestyle='None', markersize=10, alpha=0.1),
        Line2D([0], [0], marker='x', color='gray', label='Reconstructed Data', linestyle='None', markersize=10, alpha=1)
    ]
    ax.legend(handles=day_legend_elements + marker_legend_elements, loc='best', bbox_to_anchor=(1.2, 1), fontsize=10)
    ax.set_title(title, fontsize=16)
    ax.set_xlabel('UMAP 1', fontsize=14)
    ax.set_ylabel('UMAP 2', fontsize=14)
    ax.grid(True, linestyle='--', alpha=0.6)
    fig.tight_layout(rect=[0, 0, 0.85, 1])

    output_dir_eval = config.evaluation_params.get("output_dir", "results/eval_output")
    os.makedirs(output_dir_eval, exist_ok=True)
    plot_path = os.path.join(output_dir_eval, output_filename)
    plt.savefig(plot_path, dpi=300)
    print(f"UMAP图已保存到: {plot_path}")
    plt.show()

# --- 表达空间 ---
print("\n正在处理表达空间UMAP...")
embedding_expr = umap.UMAP(n_neighbors=UMAP_N_NEIGHBORS, min_dist=UMAP_MIN_DIST, metric=UMAP_METRIC, random_state=UMAP_RANDOM_STATE).fit_transform(combined_expressions_np)
plot_space(embedding_expr, day_labels_np, type_labels_np, 
           'UMAP of Real vs. Reconstructed Gene Expression by Day',
           "umap_vgae_reconstruction_expression_space.png")

# --- 潜空间 ---
print("\n正在处理潜空间UMAP...")
embedding_latent = umap.UMAP(n_neighbors=UMAP_N_NEIGHBORS, min_dist=UMAP_MIN_DIST, metric='cosine', random_state=UMAP_RANDOM_STATE).fit_transform(combined_latent_np)
plot_space(embedding_latent, day_labels_np, type_labels_np,
           'UMAP of Real vs. Reconstructed Latent Representations by Day',
           "umap_vgae_reconstruction_latent_space.png")