In [None]:
# --- 核心 PyTorch 库 ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# --- 数据处理库 ---
import scanpy as sc
import numpy as np
import pandas as pd
import anndata

# --- 辅助库 ---
import os
import math
import random
from tqdm.auto import tqdm
import warnings

# 忽略一些常见的警告
warnings.filterwarnings('ignore', category=UserWarning, module='scanpy')
warnings.filterwarnings('ignore', category=FutureWarning)

print(f"PyTorch 版本: {torch.__version__}")
print(f"Scanpy 版本: {sc.__version__}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

In [None]:
DATA_PATH = "/gpfs/hybrid/data/public/TEDD/link_cells/Tedd.19_Embryo.link.h5ad" 
# ------------------------------------------------Tedd.19_Embryo2
INPUT_DIM = 3000       # D (基因数, n_vars=3000)

# --- 2. "细胞袋" (Cell Bag) 超参数 ---
N_CELLS_PER_BAG = 511  # N (每个“句子”的细胞数)
TOTAL_SEQ_LEN = 512    # N_CELLS_PER_BAG + 1 (TIME Token)

# --- 3. 预训练超参数 ---
BATCH_SIZE = 128        # B (一次处理 16 个“细胞袋”)
N_EPOCHS = 50          # 训练的总轮数
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
N_SAMPLES_PER_EPOCH = 10000 # 每个 epoch 包含 10k 个随机生成的“细胞袋”

# --- 4. 静态掩码 (Static Masking) 超参数 ---
PROB_MASK_GENE = 0.10     # (MGM) 10% 的基因在整个“袋子”中被遮罩
PROB_MASK_CELL = 0.15     # (MFM) 15% 的细胞会被选中
PROB_MASK_FEATURE = 0.20  # (MFM) 对于被选中的细胞, 遮罩其 20% 的基因

# --- 5. Transformer 模型超参数 ---
MODEL_DIM = 256        # Transformer 内部维度 (D_model)
N_HEADS = 8            # 多头注意力头数
N_LAYERS = 6           # Transformer 编码器层数
FFN_DIM = MODEL_DIM * 4  # Transformer 前馈网络维度
DROPOUT = 0.1

# --- 6. 检查点 ---
CHECKPOINT_DIR = "./checkpoints_ChronoFormer_MoE"
MODEL_SAVE_PATH = os.path.join(CHECKPOINT_DIR, "chronoformer_moe_pretrained.pth")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
try:
    adata = sc.read_h5ad(DATA_PATH)
except FileNotFoundError:
    print(f"错误: 找不到数据文件 {DATA_PATH}")

print(adata)
assert adata.n_vars == INPUT_DIM, f"数据 n_vars ({adata.n_vars}) 与 INPUT_DIM ({INPUT_DIM}) 不匹配"

print("--- 2. 预处理: Log1p 和 Scale ---")
if 'log1p' not in adata.uns:
    sc.pp.log1p(adata)
    sc.pp.scale(adata, max_value=10)
    
    print("  将 scale() 产生的 NaN 替换为 0...")
    if isinstance(adata.X, np.ndarray):
        adata.X = np.nan_to_num(adata.X, nan=0.0, posinf=10.0, neginf=-10.0)
    else:
        adata.X.data = np.nan_to_num(adata.X.data, nan=0.0, posinf=10.0, neginf=-10.0)
    print("预处理完成。")
else:
    print("数据已预处理。")

# a. 基因表达张量
print(f"正在将 {adata.n_obs} x {adata.n_vars} 矩阵转换为密集张量...")
if isinstance(adata.X, np.ndarray):
    data_tensor = torch.tensor(adata.X, dtype=torch.float32)
else:
    data_tensor = torch.tensor(adata.X.toarray(), dtype=torch.float32)
print(f"data_tensor 形状: {data_tensor.shape}")

# b. 细胞类型标签
print("正在处理 'Celltype' 标签...")
adata.obs['Celltype'] = adata.obs['Celltype'].astype('category')
celltype_labels_tensor = torch.tensor(adata.obs['Celltype'].cat.codes.values, dtype=torch.long)
N_CELLTYPES = len(adata.obs['Celltype'].cat.categories)
print(f"找到 {N_CELLTYPES} 种细胞类型。")

# c. 时间戳 (用于 Δt)
print("正在处理 'time' 标签...")
time_tensor = torch.tensor(adata.obs['time'].values, dtype=torch.float32)

# --- 4. 提取采样器数据结构 ---

# a. 静态采样器 (按 'Timepoint' 分组)
print("构建 'Timepoint' 静态采样器...")
adata.obs['cell_index_int'] = np.arange(adata.n_obs)
static_sampler_map = {}
for tp, group_df in adata.obs.groupby('Timepoint'):
    static_sampler_map[tp] = group_df['cell_index_int'].values
print(f"静态采样器已为 {len(static_sampler_map)} 个时间点准备就绪。")

# b. 动态谱系采样器 
print("构建“谱系”动态采样器...")
dynamic_lineage_list = []
self_indices = adata.obs['cell_index_int'].values
prev_indices = adata.obs['prev_cell_id'].values
next_indices = adata.obs['next_cell_id'].values
n_total_obs = adata.n_obs

for i in tqdm(range(n_total_obs), total=n_total_obs):
    idx_prev = prev_indices[i]
    idx_next = next_indices[i]
    
    if (not pd.notna(idx_prev)) or (not pd.notna(idx_next)):
        continue
    try:
        idx_prev = int(idx_prev)
        idx_next = int(idx_next)
        idx_self = int(self_indices[i])
        
        # 检查索引是否在有效范围内
        if (0 <= idx_prev < n_total_obs) and (0 <= idx_next < n_total_obs):
            dynamic_lineage_list.append((idx_prev, idx_self, idx_next))
            
    except ValueError:
        continue

print(f"找到 {len(dynamic_lineage_list)} 个完整的谱系三元组 (Prev -> Self -> Next)")
if len(dynamic_lineage_list) == 0:
    print("\n*** 警告: 动态谱系列表为空。动态任务将无法训练。***")
else:
    print("成功！动态训练数据已准备就绪。")