In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scvelo as scv
from scvi.nn import Encoder, FCLayers
import anndata
import pandas as pd
import scipy.sparse as sp
import scanpy as sc

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [30]:
adata = anndata.read_h5ad('../../test_simulated_data_continuous_0.h5ad')
print(adata)
adata.var['n_isoforms']

AnnData object with n_obs × n_vars = 2000 × 1000
    obs: 'cell_state', 'pseudotime'
    var: 'gene_category', 'n_isoforms'
    uns: 'proportions_ground_truth', 'proportions_observed', 'spliced_isoform_counts'
    obsm: 'isoform_counts', 'proportion'
    layers: 'alpha', 'beta', 'gamma', 'spliced', 'unspliced'


gene0      5
gene1      3
gene2      3
gene3      4
gene4      5
          ..
gene995    2
gene996    4
gene997    2
gene998    3
gene999    5
Name: n_isoforms, Length: 1000, dtype: int64

In [4]:
class IsoDataset(Dataset):
    def __init__(self, adata):
        # Get Unspliced counts (G genes)
        U = adata.layers["unspliced"]
        if sp.issparse(U):
            U = U.toarray()
        
        # Get Isoform counts (I isoforms)
        # Assume that adata.obsm["isoform_counts"] includes spliced isoform count
        I = adata.obsm["isoform_counts"]
        if hasattr(I, "values"): 
            I = I.values
        if sp.issparse(I):        
            I = I.toarray()
            
        # Merge: Cells x (Genes + Isoforms)
        # Note: the input of the Encoder should be G + I
        X = np.hstack([U, I]).astype(np.float32) 
        self.X = torch.from_numpy(X)
        self.n_cells = self.X.shape[0]

    def __len__(self):
        return self.n_cells

    def __getitem__(self, idx):
        # Get the data and index
        return self.X[idx], idx

In [5]:
class IsoveloEncoder(Encoder):
    """
    Encodes U (unspliced) and Isoform counts into a latent cell embedding.
    Inherits from scvi.nn.Encoder to leverage its VAE structure.
    """
    def __init__(self, 
                 input_dim:int, 
                 hidden_dim=32, 
                 latent_dim = 128, 
                 n_layers=2, 
                 dropout_rate=0.1, 
                 distribution="normal", 
                 use_batch_norm=True, 
                 use_layer_norm=False,
                 var_activation=nn.Softplus(),
                 activation_fn=nn.ReLU,
                 **kwargs):
        super().__init__(n_input=input_dim, 
                         n_output=latent_dim,
                         n_layers=n_layers,
                         n_hidden=hidden_dim,
                         dropout_rate=dropout_rate,
                         distribution=distribution,
                         use_batch_norm=use_batch_norm,
                         use_layer_norm=use_layer_norm,
                         var_activation=var_activation,
                         activation_fn=activation_fn,
                         **kwargs
                         )

    def forward(self, x: torch.Tensor, *cat_list: torch.Tensor):
        """
        Forward pass.
        :param x: Concatenated tensor of [U, Isoforms]
        :return: A dictionary with 'qz_m', 'qz_v', 'z'
        """
        # Encode x to get latent parameters
        qz_m, qz_v, z = super().forward(x, *cat_list)
        
        return {"qz_m": qz_m, "qz_v": qz_v, "z": z}

In [6]:
def build_encoder_from_adata(adata, **enc_kwargs):
    U = adata.layers["unspliced"]
    g = U.shape[1]
    i = adata.obsm["isoform_counts"].shape[1]
    enc = IsoveloEncoder(input_dim=g + i, **enc_kwargs)
    return enc

In [7]:
class IsoveloDecoder(nn.Module):
    def __init__(self,
                 n_cells: int,
                 n_genes: int,
                 n_isoforms: int,
                 latent_dim: int = 128,
                 hidden_dim: int = 256,
                 # Initialization values provided by scVelo
                 init_time: np.ndarray = None,      # shape: (n_cells, )
                 init_alpha: np.ndarray = None,     # shape: (n_genes, )
                 init_beta_iso: np.ndarray = None,  # shape: (n_isoforms, ) <- (beta_gene * proportion)
                 init_gamma: np.ndarray = None,     # shape: (n_isoforms, ) <- (gamma_gene * proportion)
                 device = torch.device):
        
        super().__init__()

        def inverse_softplus(x_np):
            x_safe = np.maximum(x_np, 1e-6)
            return np.log(np.exp(x_safe)-1)

        self.n_genes = n_genes
        self.n_isoforms = n_isoforms
        self.device = device

        # --- A. Cell Time (t) ---
        # independent parameters, not rely on z (cell * 1)
        self.cell_time = nn.Parameter(torch.randn(n_cells, 1)) 
        if init_time is not None:
            # Initialize with the scvelo output
            self.cell_time.data.copy_(torch.from_numpy(init_time).float().unsqueeze(1))

        # --- B. Gamma (γ) ---
        # independent parameters, not rely on z (1 * isoform)
        self.gamma = nn.Parameter(torch.ones(1, n_isoforms))
        if init_gamma is not None:
            # Initialize with the scvelo output
            inv_gamma = inverse_softplus(init_gamma)
            self.gamma.data.copy_(torch.from_numpy(inv_gamma).float().unsqueeze(0))

        # --- C. Alpha Network (α) ---
        # Input: z -> Output: Alpha (Cell * Gene)
        self.alpha_fc1 = nn.Linear(latent_dim, hidden_dim)
        self.alpha_fc2 = nn.Linear(hidden_dim, n_genes) # 最后一层
        
        # Initialize the bias of the alpha
        if init_alpha is not None:
            # Initial weights should be small, the initial values should be determined mainly by the scvelo bias
            nn.init.xavier_normal_(self.alpha_fc2.weight, gain=0.01)
            # Set the bias to scvelo output (usually need to get log or inverse softplus, depending on the activation function)
            # Assume that using Softplus activation. For simplicity, set to scvelo output, training will approximate these values
            inv_alpha = inverse_softplus(init_alpha)
            self.alpha_fc2.bias.data.copy_(torch.from_numpy(inv_alpha).float())

        # --- D. Beta Network (β) ---
        # Input: z -> Output: Beta (Cell * Isoform)
        self.beta_fc1 = nn.Linear(latent_dim, hidden_dim)
        self.beta_fc2 = nn.Linear(hidden_dim, n_isoforms) 
        
        # Initialize the bias of the beta
        if init_beta_iso is not None:
            nn.init.xavier_normal_(self.beta_fc2.weight, gain=0.01)
            # Similar with alpha, here it should be isoform level beta
            inv_beta = inverse_softplus(init_beta_iso)
            self.beta_fc2.bias.data.copy_(torch.from_numpy(inv_beta).float())

    def forward(self, z: torch.Tensor, cell_indices = None):
        """
        z: [Batch, Latent_dim]
        cell_indices: [Batch]
        """
        # 1. Get cell time
        if cell_indices is not None:
            t = self.cell_time[cell_indices] # [Batch, 1]
        else:
            t = self.cell_time

        # 2. Get alpha (non-negative)
        h_alpha = F.relu(self.alpha_fc1(z))
        alpha = F.softplus(self.alpha_fc2(h_alpha)) # [Batch, n_genes]

        # 3. Get beta (non-negative)
        h_beta = F.relu(self.beta_fc1(z))
        beta = F.softplus(self.beta_fc2(h_beta))   # [Batch, n_isoforms]

        # 4. Get gamma (non-negative)
        gamma = F.softplus(self.gamma)             # [1, n_isoforms]

        return {"cell_time": t, "gene_alpha": alpha, "isoform_beta": beta, "isoform_gamma":gamma}

In [17]:
from scvelo.tools import latent_time


model = build_encoder_from_adata(adata, hidden_dim=32, latent_dim = 128).to(device)
dataset = IsoDataset(adata)
outputs = model(dataset.X.to(device))
test_time = np.random.rand(20)
test_alpha = np.random.rand(10)
test_beta = np.random.rand(15)
test_gamma = np.random.rand(15)
test_z = np.random.rand(20,128)
test_z = torch.from_numpy(test_z).float().to(device)
decoder = IsoveloDecoder(n_cells=20, n_genes=10, n_isoforms=15, latent_dim=128, hidden_dim=256, init_time=test_time, init_alpha=test_alpha, init_beta_iso=test_beta, init_gamma=test_gamma, device = device).to(device)
decoder.cell_time.detach().cpu().numpy()
decoder(test_z)

{'cell_time': Parameter containing:
 tensor([[0.3647],
         [0.9789],
         [0.8849],
         [0.9412],
         [0.7309],
         [0.7280],
         [0.5466],
         [0.4591],
         [0.2717],
         [0.7356],
         [0.2825],
         [0.4853],
         [0.8373],
         [0.4981],
         [0.4323],
         [0.3400],
         [0.4265],
         [0.9966],
         [0.1515],
         [0.1439]], device='cuda:0', requires_grad=True),
 'gene_alpha': tensor([[0.5238, 0.0322, 0.2024, 0.6503, 0.8556, 0.5879, 0.7142, 0.2903, 0.6915,
          0.3397],
         [0.5224, 0.0321, 0.2030, 0.6509, 0.8557, 0.5864, 0.7158, 0.2911, 0.6913,
          0.3395],
         [0.5228, 0.0321, 0.2024, 0.6513, 0.8538, 0.5879, 0.7132, 0.2909, 0.6909,
          0.3404],
         [0.5229, 0.0322, 0.2023, 0.6506, 0.8556, 0.5897, 0.7131, 0.2906, 0.6900,
          0.3397],
         [0.5230, 0.0322, 0.2032, 0.6507, 0.8555, 0.5883, 0.7140, 0.2905, 0.6890,
          0.3400],
         [0.5239, 0.0321, 

In [None]:

loader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,          # Shuffle each epoch
    num_workers=0,         # Depends on the number of cpus
    pin_memory=True        # If trained with GPU
)




# After training, store all intermediate parameters
model.eval()
X_tensor = dataset.X.to(device) 

with torch.no_grad():
    results = model(X_tensor)

z = results["z"].cpu().numpy()
qz_m = results["qz_m"].cpu().numpy()
qz_v = results["qz_v"].cpu().numpy()

adata.obsm["X_isovelo_z"] = z
adata.obsm["latent_qz_m"] = qz_m 
adata.obsm["latent_qz_v"] = qz_v

Test

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class IsoveloDecoder(nn.Module):
    def __init__(self, 
                 n_cells: int, 
                 n_genes: int, 
                 n_isoforms: int, 
                 g2i_mask: torch.Tensor, 
                 latent_dim: int = 128, 
                 hidden_dim: int = 256,
                 n_steps: int = 10,  # 把时间切成10份来积分
                 init_time: np.ndarray = None,       
                 init_alpha: np.ndarray = None,      
                 init_beta_iso: np.ndarray = None,   
                 init_gamma: np.ndarray = None,      
                 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        
        super().__init__()
        self.n_genes = n_genes
        self.n_isoforms = n_isoforms
        self.n_steps = n_steps # 积分步数
        self.device = device
        
        self.register_buffer('g2i_mask', g2i_mask.float().to(device)) 

        # --- Parameters ---
        # 1. Cell Time (每个细胞独立的时间)
        if init_time is not None:
            t_init = torch.from_numpy(init_time).float().view(-1, 1)
        else:
            t_init = torch.rand(n_cells, 1) * 5.0
        self.cell_time_param = nn.Parameter(t_init)

        # 2. Gamma (常数)
        self.gamma_param = nn.Parameter(torch.zeros(1, n_isoforms))
        if init_gamma is not None:
            self.gamma_param.data.copy_(self._inverse_softplus(torch.from_numpy(init_gamma).float().unsqueeze(0)))

        # 3. Networks for Alpha/Beta (z -> parameter)
        self.alpha_fc = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, n_genes)
        )
        self.beta_fc = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, n_isoforms)
        )
        
        # Initialization logic (omitted for brevity, same as before)
        if init_alpha is not None:
            inv_alpha = self._inverse_softplus(torch.from_numpy(init_alpha).float())
            self.alpha_fc[-1].bias.data.copy_(inv_alpha)
        if init_beta_iso is not None:
            inv_beta = self._inverse_softplus(torch.from_numpy(init_beta_iso).float())
            self.beta_fc[-1].bias.data.copy_(inv_beta)

    @staticmethod
    def _inverse_softplus(x):
        return torch.log(torch.exp(x) - 1.0 + 1e-6)

    def forward(self, z_final, cell_indices):
        """
        z_final: [Batch, Latent] 细胞当前的潜在状态
        cell_indices: [Batch] 用于取时间
        """
        batch_size = z_final.shape[0]
        
        # 1. 获取这个 Batch 中每个细胞的总时间 T
        T = F.softplus(self.cell_time_param[cell_indices]) # [Batch, 1]
        
        # 2. 构造虚拟历史 (Virtual History)
        # 我们假设细胞是从 z=0 演化到 z=z_final 的
        # 我们生成 n_steps 个时间点，代表 0% T, 10% T, ... 90% T
        
        # 生成插值系数: [0, 0.1, 0.2, ..., 0.9] (假设 n_steps=10)
        steps_ratio = torch.linspace(0, 1 - 1/self.n_steps, self.n_steps, device=self.device)
        
        # 扩展 z: [Batch, Steps, Latent]
        # z_history[b, s, :] = z_final[b, :] * steps_ratio[s]
        # 这就模拟了细胞从不成熟(0)到成熟(z_final)的过程
        z_history = torch.einsum('bl,s->bsl', z_final, steps_ratio)
        
        # 3. 并行计算历史时刻的 Alpha 和 Beta
        # 输入维度: [Batch * Steps, Latent] -> 输出: [Batch * Steps, Genes]
        # 这样 alpha 就随状态(也就是随时间)变化了！
        flat_z = z_history.reshape(-1, z_final.shape[1])
        
        alpha_flat = F.softplus(self.alpha_fc(flat_z))
        beta_iso_flat = F.softplus(self.beta_fc(flat_z))
        
        # Reshape 回 [Batch, Steps, Features]
        alpha_seq = alpha_flat.reshape(batch_size, self.n_steps, self.n_genes)
        beta_iso_seq = beta_iso_flat.reshape(batch_size, self.n_steps, self.n_isoforms)
        
        # 计算 Gene level beta
        beta_gene_seq = torch.einsum('bsi,gi->bsg', beta_iso_seq, self.g2i_mask)
        
        # 获取 Gamma (constant)
        gamma = F.softplus(self.gamma_param)
        
        # 4. 数值积分 (Euler Method, 最简单直观的 delta time 累加)
        # 类似于你说的 delta time * alpha
        
        dt = T / self.n_steps # [Batch, 1] 每个 step 的时长
        
        # 初始化 u, s 为 0
        u = torch.zeros(batch_size, self.n_genes, device=self.device)
        s = torch.zeros(batch_size, self.n_isoforms, device=self.device)
        
        for i in range(self.n_steps):
            # 当前时刻的参数
            alpha_t = alpha_seq[:, i, :]
            beta_gene_t = beta_gene_seq[:, i, :]
            beta_iso_t = beta_iso_seq[:, i, :]
            
            # --- 物理方程 (Euler更新) ---
            # 这一步完全符合你的想法： u_new = u_old + dt * rate
            
            # dU = Production - Degradation
            du = alpha_t - beta_gene_t * u
            
            # dS = Splicing - Degradation
            # 注意: S的来源是 beta_iso * u_gene
            u_expanded = torch.matmul(u, self.g2i_mask)
            ds = beta_iso_t * u_expanded - gamma * s
            
            # 更新状态
            u = u + dt * du
            s = s + dt * ds
            
            # 保证非负
            u = F.relu(u)
            s = F.relu(s)
            
        return {
            "u_hat": u, 
            "s_hat": s,
            "t": T,
            "alpha": alpha_seq[:, -1, :], # 返回最后一步的alpha供参考
            "beta": beta_iso_seq[:, -1, :],
            "gamma": gamma
        }

In [11]:
import torch
import numpy as np

# 1. 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on: {device}")

# 2. 定义维度 (模拟小规模数据)
n_cells = 20
n_genes = 10
n_isoforms = 15
latent_dim = 128
hidden_dim = 256

# 3. 构造必要的模拟数据
# (A) Gene-to-Isoform Mask (必须项)
# 逻辑：创建一个 [Genes, Isoforms] 的矩阵，每一列(isoform)只有一个1(属于某个gene)
mask_np = np.zeros((n_genes, n_isoforms))
# 简单起见，随机给每个 isoform 分配一个 gene
for i in range(n_isoforms):
    g_idx = np.random.randint(0, n_genes)
    mask_np[g_idx, i] = 1
g2i_mask = torch.from_numpy(mask_np)

# (B) 初始化参数 (scvelo 模拟结果)
# 注意形状：Time是(n_cells,), Alpha是(n_genes,), Beta/Gamma是(n_isoforms,)
test_time = np.random.rand(n_cells) * 5.0 # 模拟时间 0-5
test_alpha = np.random.rand(n_genes)
test_beta_iso = np.random.rand(n_isoforms)
test_gamma = np.random.rand(n_isoforms)

# (C) 模拟输入 Latent z
test_z = torch.randn(n_cells, latent_dim).to(device) # [Batch, 128]

# (D) 模拟 Batch Indices (如果是全量测试，就是 range(n_cells))
# 如果是 DataLoader 里的一个 batch，这里就是该 batch 对应的索引
test_indices = torch.arange(n_cells).to(device)

# 4. 实例化 Decoder
# 注意：你需要确保你的 IsoveloDecoder 类定义在前面已经运行过
decoder = IsoveloDecoder(
    n_cells=n_cells,
    n_genes=n_genes,
    n_isoforms=n_isoforms,
    g2i_mask=g2i_mask,       # <--- 关键新增
    latent_dim=latent_dim,
    hidden_dim=hidden_dim,
    n_steps=10,              # 数值积分步数
    init_time=test_time,
    init_alpha=test_alpha,
    init_beta_iso=test_beta_iso,
    init_gamma=test_gamma,
    device=device
).to(device)

# 5. 运行 Forward 测试
print("--- Starting Forward Pass ---")
try:
    # Forward 需要传入 indices 以获取对应的 Time
    outputs = decoder(test_z, test_indices)
    
    # 6. 检查输出结果
    u_hat = outputs["u_hat"]
    s_hat = outputs["s_hat"]
    pred_time = outputs["t"]
    
    print("\n✅ Forward pass successful!")
    print(f"Reconstructed U shape: {u_hat.shape} (Expected: {n_cells}, {n_genes})")
    print(f"Reconstructed S shape: {s_hat.shape} (Expected: {n_cells}, {n_isoforms})")
    print(f"Inferred Time shape:   {pred_time.shape} (Expected: {n_cells}, 1)")
    
    # 检查是否有 NaN (数值积分常见问题)
    if torch.isnan(u_hat).any() or torch.isnan(s_hat).any():
        print("⚠️ Warning: Output contains NaNs. Check initialization or learning rates.")
    else:
        print("数值检查通过: No NaNs detected.")
        
    # 查看一下实际的数据示例（确保非负）
    print(f"\nExample U_hat (first cell):\n{u_hat[0, :5].detach().cpu().numpy()}")
    
except Exception as e:
    print(f"\n❌ Error during forward pass: {e}")
    import traceback
    traceback.print_exc()

Running on: cuda
--- Starting Forward Pass ---

✅ Forward pass successful!
Reconstructed U shape: torch.Size([20, 10]) (Expected: 20, 10)
Reconstructed S shape: torch.Size([20, 15]) (Expected: 20, 15)
Inferred Time shape:   torch.Size([20, 1]) (Expected: 20, 1)
数值检查通过: No NaNs detected.

Example U_hat (first cell):
[0.98614293 0.6515614  0.70129627 1.1449945  0.72132283]


In [12]:
outputs

{'u_hat': tensor([[0.9861, 0.6516, 0.7013, 1.1450, 0.7213, 0.1793, 0.6613, 0.4312, 0.1813,
          1.5473],
         [0.6498, 0.4466, 0.5104, 0.7735, 0.5817, 0.1442, 0.5040, 0.4521, 0.1635,
          1.0480],
         [2.6094, 1.1398, 0.7589, 3.1038, 1.0989, 0.2259, 1.7275, 0.4090, 0.1932,
          3.4930],
         [1.2076, 0.7583, 0.6639, 1.5062, 0.9735, 0.2168, 0.7520, 0.4659, 0.1817,
          1.7877],
         [2.3933, 0.9522, 0.6311, 3.2229, 0.9823, 0.1757, 1.4936, 0.5275, 0.1734,
          3.6760],
         [1.1300, 0.6371, 0.7708, 1.1434, 0.6233, 0.2014, 0.6648, 0.3928, 0.1545,
          1.6125],
         [0.7795, 0.5358, 0.5507, 1.0127, 0.6285, 0.1635, 0.4810, 0.4568, 0.1901,
          1.2098],
         [1.6265, 1.0227, 0.6674, 2.0116, 0.8331, 0.2312, 1.0964, 0.5128, 0.1816,
          2.3464],
         [2.4238, 1.1571, 0.7659, 2.8902, 0.8844, 0.1796, 1.6040, 0.5028, 0.1894,
          3.5638],
         [2.2764, 1.2852, 0.6608, 2.6655, 1.0902, 0.1464, 1.7238, 0.4555, 0.1604,


In [98]:
prop = adata.obsm['proportion']
prop_keep = prop >= 0.02
iso_df = adata.obsm['isoform_counts']
iso_sums = iso_df.sum(axis=0)
keep_mask_count = iso_sums >= 500
cells_passing_count = prop_keep.sum(axis=0)
keep_passing = cells_passing_count >= 10
keep_isoforms = keep_mask_count & keep_passing
filtered_iso_df = iso_df.loc[:, keep_isoforms]


isoform_names = iso_df.columns
try:
    gene_map = pd.Series([x.rsplit('_', 1)[0] for x in isoform_names], index=isoform_names)
except Exception as e:
    raise ValueError(f"Error parsing isoform with names '_'. Error: {e}")

remaining_isoforms = gene_map[keep_isoforms.values]
new_counts = pd.Series(remaining_isoforms).value_counts(sort=False)

adata.var['filtered_n_isoforms'] = 0
genes_to_update = new_counts.index.intersection(adata.var_names)
adata.var.loc[genes_to_update, 'filtered_n_isoforms'] = new_counts[genes_to_update]

adata_hvg = adata.copy()

sc.pp.normalize_total(adata_hvg)
sc.pp.log1p(adata_hvg)
sc.pp.highly_variable_genes(adata_hvg, n_top_genes=500, flavor='seurat')
hvg_genes = set(adata_hvg.var_names[adata_hvg.var['highly_variable']])


relevant_genes = gene_map[filtered_iso_df.columns].unique()

gene_to_iso_cols = {}
current_iso_cols = filtered_iso_df.columns
temp_gene_map = gene_map[current_iso_cols]
    
for iso, gene in temp_gene_map.items():
    if gene not in gene_to_iso_cols:
        gene_to_iso_cols[gene] = []
    gene_to_iso_cols[gene].append(iso)

In [99]:
gene_to_iso_cols

{'gene0': ['gene0_isoform3'],
 'gene1': ['gene1_isoform2'],
 'gene2': ['gene2_isoform0', 'gene2_isoform1'],
 'gene3': ['gene3_isoform1'],
 'gene4': ['gene4_isoform4'],
 'gene5': ['gene5_isoform2', 'gene5_isoform3'],
 'gene6': ['gene6_isoform1'],
 'gene7': ['gene7_isoform0', 'gene7_isoform1', 'gene7_isoform2'],
 'gene8': ['gene8_isoform3'],
 'gene9': ['gene9_isoform1'],
 'gene10': ['gene10_isoform1'],
 'gene11': ['gene11_isoform2'],
 'gene12': ['gene12_isoform2'],
 'gene13': ['gene13_isoform3'],
 'gene14': ['gene14_isoform2'],
 'gene15': ['gene15_isoform3'],
 'gene16': ['gene16_isoform1'],
 'gene17': ['gene17_isoform0', 'gene17_isoform1'],
 'gene18': ['gene18_isoform2'],
 'gene19': ['gene19_isoform3'],
 'gene20': ['gene20_isoform0'],
 'gene21': ['gene21_isoform0', 'gene21_isoform1'],
 'gene22': ['gene22_isoform0'],
 'gene23': ['gene23_isoform0'],
 'gene24': ['gene24_isoform0', 'gene24_isoform1'],
 'gene25': ['gene25_isoform0'],
 'gene26': ['gene26_isoform1'],
 'gene27': ['gene27_isoform

In [None]:
from pickle import FALSE
import scanpy as sc
import scvelo as scv
import numpy as np
import pandas as pd
import scipy.sparse as sp
from tqdm import tqdm

def preprocess_and_initialize_scvelo(
    adata, 
    isoform_key="isoform_counts", 
    proportion_key = "proportion",
    min_isoform_counts=10, 
    min_cell_counts = 10,
    min_isoform_prop=0.05, 
    n_top_genes=800,
    n_top_splicing = 500,
    min_cells_spanning = 5,
    isoform_delimiter="_",
    normalized = FALSE
):
    """
    Prefilter Isoforms and Genes.
    1. Filter cells, remove low total counts cells.
    2. Filter Isoforms: Remove low expression and low global proportion isoforms.
    3. Filter Genes, keep highly variable genes and isoform proportion variable genes.
    4. Run scVelo dynamical mode.
    5. Return initialization parameters for VAE.
    
    Parameters:
    adata: including layers['unspliced'] and obsm[isoform_key]
    isoform_key: key of isoform count in adata.obsm
    """

    if adata.X is None:
        adata.X = adata.layers['spliced'] + adata.layers['unspliced']

    # 1. Filter cells, remove low total counts cells.
    initial_cell_count = adata.n_obs
    sc.pp.filter_cells(adata, min_counts=min_cell_counts)
    print(f"Filtered cells from {initial_cell_count} to {adata.n_obs} (min_counts={min_cell_total_counts})")

    # 2. Filter Isoforms: Remove low expression and low global proportion isoforms.
    iso_df = adata.obsm[isoform_key]
    iso_df = iso_df.loc[adata.obs_names]
    isoform_names = iso_df.columns
    try:
        gene_map = pd.Series([x.rsplit(isoform_delimiter, 1)[0] for x in isoform_names], index=isoform_names)
    except Exception as e:
        raise ValueError(f"Error parsing isoform with names '{isoform_delimiter}'. Error: {e}")

    iso_sum = iso_df.sum(axis=0)
    keep_mask_count = iso_sum >= min_isoform_counts

    iso_prop = adata.obsm[proportion_key]
    high_prop_mask = iso_prop >= min_isoform_prop
    cells_passing_count = high_prop_mask.sum(axis=0)
    keep_mask_prop = cells_passing_count >= min_cells_spanning
    keep_isoforms = keep_mask_count & keep_mask_prop

    filtered_iso_df = iso_df.loc[:, keep_isoforms]
    adata.obsm[isoform_key] = filtered_iso_df
    filtered_prop_df = iso_prop.loc[:, keep_isoforms]
    adata.obsm[proportion_key] = filtered_prop_df

    print(f"Filtered isoforms from {iso_df.shape[1]} to {filtered_iso_df.shape[1]} based on isoform expression and global proportion.")

    remaining_isoforms = gene_map[keep_isoforms.values]
    new_counts = pd.Series(remaining_isoforms).value_counts(sort=False)
    adata.var['filtered_n_isoforms'] = 0
    genes_to_update = new_counts.index.intersection(adata.var_names)
    adata.var.loc[genes_to_update, 'filtered_n_isoforms'] = new_counts[genes_to_update]

    # 3. Filter Genes, keep highly variable genes and isoform proportion variable genes.
    adata_hvg = adata.copy()
    if not normalized:
        sc.pp.normalize_total(adata_hvg)
        sc.pp.log1p(adata_hvg)
        sc.pp.highly_variable_genes(adata_hvg, n_top_genes=n_top_genes, flavor='seurat')
        hvg_genes = set(adata_hvg.var_names[adata_hvg.var['highly_variable']])
    else:
        sc.pp.highly_variable_genes(adata_hvg, n_top_genes=n_top_genes, flavor='seurat')
        hvg_genes = set(adata_hvg.var_names[adata_hvg.var['highly_variable']])
    
    print("Detecting isoform proportion variance...")
    relevant_genes = gene_map[filtered_iso_df.columns].unique()
    high_var_iso_genes = set()

    gene_to_iso_cols = {}
    current_iso_cols = filtered_iso_df.columns
    temp_gene_map = gene_map[current_iso_cols]
    
    for iso, gene in temp_gene_map.items():
        if gene not in gene_to_iso_cols:
            gene_to_iso_cols[gene] = []
        gene_to_iso_cols[gene].append(iso)
    splicing_scores = {}

    for gene, isos in tqdm(gene_to_iso_cols.items(), desc="Processing Splicing Variance"):
        if len(isos) < 2:
            continue # Because we only care proportion for n_isoform>=2
        
        sub_df = filtered_iso_df[isos]
        mat = sub_df.values
        
        row_sums = mat.sum(axis=1, keepdims=True) + 1e-6 
        props = mat / row_sums # (Cells x Isoforms) 
    
        var_props = np.var(props, axis=0).mean()
        splicing_scores[gene] = var_props

    # 选取 Splicing Variance 最高的基因
    if splicing_scores:
        sorted_genes = sorted(splicing_scores, key=splicing_scores.get, reverse=True)
        high_var_iso_genes = set(sorted_genes[:n_top_splicing])


    
    








    # 3. Aggregate Isoforms to get Gene Spliced Counts.
    S_gene_mat = iso_df.values @ g2i_mask.T
    
    
    
    
    print("--- 1. Mapping Isoforms to Genes ---")
    # 假设 isoform 的名字格式为 "GeneName-IsoformID" 或者你有办法从 adata.uns 中获取
    # 这里我们假设 adata.obsm[isoform_key] 是一个 DataFrame，列名是 Isoform 名字
    # 如果是 numpy array，你需要额外传入 isoform_names
    
    if isinstance(adata.obsm[isoform_key], pd.DataFrame):
        isoform_df = adata.obsm[isoform_key]
    else:
        # 如果是 numpy/sparse，必须有对应的名字，这里假设存在 adata.uns['isoform_names']
        isoform_df = pd.DataFrame(
            adata.obsm[isoform_key].toarray() if sp.issparse(adata.obsm[isoform_key]) else adata.obsm[isoform_key],
            index=adata.obs_names,
            columns=adata.uns.get('isoform_names', [f"Iso_{i}" for i in range(adata.obsm[isoform_key].shape[1])])
        )

    # 建立映射: 假设 isoform 名字包含 gene 名字 (e.g., "TP53-001")
    # 如果你的数据结构不同，请修改这里的 mapping 逻辑
    # 比如: gene_map = {iso: gene for iso, gene in ...}
    # 这里为了通用性，我们假设 isoform name 能够 parse 出 gene name，或者我们先简单处理：
    # **重要**：实际项目中你需要一个确定的 Gene-Isoform 对应表。
    # 这里我们做一个假设示例：假设我们已经有了 g2i_mapping
    # 为了代码能跑，这里我需要你提供 mapping，或者我帮你生成一个 dummy 的
    # 实际使用请替换为: gene_of_isoform = parse_gene_from_isoform(iso_name)
    
    # 暂时跳过复杂的 name parsing，假设 adata.var_names 涵盖了所有基因
    # 我们需要构建一个 Mask [n_genes, n_isoforms]
    isoform_names = isoform_df.columns
    gene_names = adata.var_names
    
    # [用户需自定义] 这里是一个占位符，请替换为真实的映射逻辑
    # 比如: gene_name = iso_name.split('-')[0]
    # 下面代码假设 isoform_names 里包含 gene 信息
    isoform_to_gene = {} 
    for iso in isoform_names:
        # ⚠️ 请修改这里：根据你的数据格式提取 Gene Name
        # 示例：如果是 "GeneA_Iso1"，则 split('_')[0]
        # 这里尝试直接在 adata.var_names 里找匹配 (较慢，仅作演示)
        found = False
        for g in gene_names:
            if g in iso: # 这是一个很弱的匹配，请务必用精确逻辑替换
                isoform_to_gene[iso] = g
                found = True
                break
        if not found:
            isoform_to_gene[iso] = None # 标记为删除

    print("--- 2. Filtering Isoforms ---")
    # A. 表达量过滤
    total_iso_counts = isoform_df.sum(axis=0)
    valid_iso_mask = total_iso_counts > min_isoform_counts
    
    # B. 比例过滤 (Proportion)
    # 先计算 Gene total counts
    gene_spliced_sum = pd.DataFrame(0.0, index=adata.obs_names, columns=gene_names)
    
    # 加速聚合 (实际应使用矩阵乘法，这里用 pandas 演示逻辑)
    # 构建矩阵映射
    g2i_mat = np.zeros((len(gene_names), len(isoform_names)))
    for i, iso in enumerate(isoform_names):
        g_name = isoform_to_gene.get(iso)
        if g_name is not None:
            g_idx = adata.var_names.get_loc(g_name)
            g2i_mat[g_idx, i] = 1
            
    # 计算 Gene Spliced Counts (用于 scvelo)
    # S_gene = I_iso @ Mask.T
    I_mat = isoform_df.values
    S_gene_mat = I_mat @ g2i_mat.T
    
    # 计算每个 isoform 的全局平均 proportion
    # Prop_i = Sum(I_i) / Sum(S_gene_of_i)
    avg_proportions = []
    isoforms_to_keep = []
    
    for i, iso in enumerate(isoform_names):
        if not valid_iso_mask[i]: continue
        
        g_name = isoform_to_gene.get(iso)
        if g_name is None: continue
        
        g_idx = adata.var_names.get_loc(g_name)
        
        # 计算比例
        iso_sum = I_mat[:, i].sum()
        gene_sum = S_gene_mat[:, g_idx].sum()
        
        if gene_sum == 0: prop = 0
        else: prop = iso_sum / gene_sum
        
        if prop > min_isoform_prop:
            isoforms_to_keep.append(iso)
            avg_proportions.append(prop)
            
    print(f"Kept {len(isoforms_to_keep)} / {len(isoform_names)} isoforms.")
    
    # 更新 adata.obsm
    adata.obsm[isoform_key] = isoform_df[isoforms_to_keep]
    # 更新 Mask (G x I_new)
    # 重建 g2i_mask 对应新的 isoforms
    final_isoforms = isoforms_to_keep
    final_g2i_mask = np.zeros((len(gene_names), len(final_isoforms)))
    
    isoform_gene_map_final = [isoform_to_gene[iso] for iso in final_isoforms]
    
    for i, g_name in enumerate(isoform_gene_map_final):
        g_idx = adata.var_names.get_loc(g_name)
        final_g2i_mask[g_idx, i] = 1
        
    final_g2i_mask_tensor = torch.from_numpy(final_g2i_mask).float()

    print("--- 3. Constructing Gene-Level Data for scVelo ---")
    # scVelo 需要 layers['spliced'] 和 layers['unspliced']
    # Spliced 来自 Isoform 的聚合
    S_gene_final = adata.obsm[isoform_key].values @ final_g2i_mask.T
    adata.layers['spliced'] = sp.csr_matrix(S_gene_final)
    
    # 确保 unspliced 也是 sparse
    if not sp.issparse(adata.layers['unspliced']):
        adata.layers['unspliced'] = sp.csr_matrix(adata.layers['unspliced'])

    print("--- 4. Filtering Genes (Expression Var OR Isoform Var) ---")
    # 1. 常规 HVG (基于 Count)
    scv.pp.filter_genes(adata, min_shared_counts=20)
    scv.pp.normalize_per_cell(adata)
    scv.pp.filter_genes_dispersion(adata, n_top_genes=n_top_genes)
    hvg_genes = set(adata.var_names[adata.var['highly_variable']])
    
    # 2. Isoform Switch Genes (基于 Proportion Variance)
    # 计算每个细胞中，每个 Gene 下 Isoform 分布的熵或方差，这里用简单的 proportion 方差
    # 如果一个基因的 Isoform 比例在不同细胞间变化大，则保留该基因
    # 为了简化，我们只保留那些已经在 isoform 过滤步骤中留下了至少 2 个 isoform 的基因
    iso_counts_per_gene = final_g2i_mask.sum(axis=1) # [n_genes]
    multi_iso_genes = adata.var_names[iso_counts_per_gene > 1]
    
    # 合并保留列表
    genes_to_keep = list(hvg_genes.union(set(multi_iso_genes)))
    adata = adata[:, genes_to_keep].copy()
    
    # 同时需要切分 g2i_mask 对应剩下的 genes
    # 这是一个痛点：adata 切分后 var_names 变了，mask 也要变
    # 重新构建 mask (快速版)
    # 更好的方法是在 Dataset init 里做，但这里需要给 scvelo 跑
    
    print(f"Final Genes: {adata.n_vars}, Final Isoforms: {len(final_isoforms)}")

    print("--- 5. Running scVelo (Dynamical Mode) ---")
    scv.pp.moments(adata, n_pcs=30, n_neighbors=30)
    
    # 核心：使用 dynamical model 恢复参数
    scv.tl.recover_dynamics(adata, var_names='all', n_jobs=8)
    
    # 计算 velocity 和 latent time
    scv.tl.velocity(adata, mode='dynamical')
    scv.tl.latent_time(adata)
    
    print("--- 6. Extracting Parameters for Initialization ---")
    # 提取参数
    # scVelo 的 fit_alpha, fit_beta, fit_gamma 是 Gene level 的
    # 它们可能包含 NaN (如果 fitting 失败)，需要填充
    
    def get_param(key, default=1.0):
        val = adata.var[key].values
        # 填充 NaN
        val = np.nan_to_num(val, nan=default)
        return val.astype(np.float32)

    init_alpha_g = get_param('fit_alpha')
    init_beta_g = get_param('fit_beta')
    init_gamma_g = get_param('fit_gamma')
    
    init_time = adata.obs['latent_time'].values.astype(np.float32)
    # scVelo latent time 是 [0, 1]，可能需要 rescale 到 [0, 20] 左右以配合 ODE
    init_time = init_time * 20.0 

    # 计算 Isoform Level 的 Beta 和 Gamma
    # 逻辑：Beta_iso = Beta_gene * avg_proportion
    # 逻辑：Gamma_iso = Gamma_gene (假设降解率近似) 或者也是 * proportion?
    # 通常降解是分子特有的。但如果没有 isoform specific 数据，假设 Gamma_iso ≈ Gamma_gene 是合理的起点，
    # 或者假设 Beta_iso * U = S_iso * Gamma_iso (稳态假设) -> Gamma_iso = Beta_iso * U / S_iso
    # 既然你要求 "gamma同理" (isoform proportion * gene gamma)，我们按你的要求做：
    
    # 我们需要知道当前 adata 的 genes 对应的 isoforms 的全局比例
    # 注意：此时 adata 已经 filter 过了，mask 也要对应切片
    
    # 1. 对齐 Mask 到当前的 adata.var_names
    current_genes = adata.var_names
    # 重新映射 final_isoforms 到 current_genes
    subset_mask = np.zeros((len(current_genes), len(final_isoforms)))
    for i, iso in enumerate(final_isoforms):
        g_name = isoform_gene_map_final[i]
        if g_name in current_genes:
            g_idx = adata.var_names.get_loc(g_name)
            subset_mask[g_idx, i] = 1
    
    subset_mask_tensor = torch.from_numpy(subset_mask).float()
    
    # 2. 计算平均 Proportion (Global)
    # I: [Cells, Isoforms], S_g: [Cells, Genes]
    # 这是一个全局常数向量 [n_isoforms]
    I_val = adata.obsm[isoform_key].values # 注意：这里要是 subset 后的 isoforms
    # 上面代码没有 inplace 更新 adata.obsm 的列，这里需要注意
    # 修正：上面 adata.obsm[isoform_key] 已经是 filtered isoforms 了
    
    total_I = I_val.sum(axis=0) # [n_isoforms]
    
    # 映射 Gene Sum 到 Isoform 维度
    # total_G[j] = Sum of gene counts for gene corresponding to isoform j
    gene_counts = (I_val @ subset_mask.T).sum(axis=0) # [n_genes]
    
    # 扩展 gene_counts 到 isoform 维度
    # iso_gene_total = gene_counts[gene_index_of_isoform]
    # 利用矩阵乘法: [n_genes] @ [n_genes, n_isoforms] -> [n_isoforms]
    # 但 mask 是 0/1，可以直接乘
    iso_gene_total = gene_counts @ subset_mask
    
    iso_proportions = np.divide(total_I, iso_gene_total, out=np.zeros_like(total_I), where=iso_gene_total!=0)
    
    # 3. 计算 Init Values
    # init_beta_g: [n_genes]
    # 扩展到 isoform: [n_isoforms]
    beta_g_expanded = init_beta_g @ subset_mask
    gamma_g_expanded = init_gamma_g @ subset_mask
    
    init_beta_iso = beta_g_expanded * iso_proportions
    init_gamma_iso = gamma_g_expanded * iso_proportions # 按你要求的逻辑
    
    return {
        "adata": adata,
        "g2i_mask": subset_mask_tensor,
        "init_time": init_time,
        "init_alpha": init_alpha_g,
        "init_beta_iso": init_beta_iso,
        "init_gamma": init_gamma_iso,
        "final_isoforms": final_isoforms
    }

In [14]:
# 假设 adata 已经加载，并且 adata.obsm['isoform_counts'] 存在
# 重要：你需要确保知道 isoform 对应的 gene，这里需要你根据实际情况修改 isoform_to_gene 的逻辑

results = preprocess_and_initialize_scvelo(
    adata, 
    isoform_key="isoform_counts",
    min_isoform_counts=20,
    min_isoform_prop=0.05
)

processed_adata = results['adata']
g2i_mask = results['g2i_mask']

# 实例化你的 VAE Decoder
decoder = IsoveloDecoder(
    n_cells=processed_adata.n_obs,
    n_genes=processed_adata.n_vars,
    n_isoforms=len(results['final_isoforms']),
    g2i_mask=g2i_mask,
    init_time=results['init_time'],
    init_alpha=results['init_alpha'],
    init_beta_iso=results['init_beta_iso'],
    init_gamma=results['init_gamma'],
    device=device
).to(device)

--- 1. Mapping Isoforms to Genes ---
--- 2. Filtering Isoforms ---
Kept 1 / 3468 isoforms.
--- 3. Constructing Gene-Level Data for scVelo ---
--- 4. Filtering Genes (Expression Var OR Isoform Var) ---
Filtered out 999 genes that are detected 20 counts (shared).
Normalized count data: X, spliced, unspliced.
Skip filtering by dispersion since number of variables are less than `n_top_genes`.


KeyError: 'highly_variable'