## MMD-MA

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import rbf_kernel
import umap
import os


class MMDLoss(nn.Module):
    """Maximum Mean Discrepancy Loss"""
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        """计算高斯核矩阵"""
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        
        L2_distance = ((total0-total1)**2).sum(2)
        
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def forward(self, source, target):
        batch_size = int(source.size()[0])
        kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, 
                                     kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
        
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        
        loss = torch.mean(XX + YY - XY - YX)
        return loss


class FeatureEncoder(nn.Module):
    """特征编码器网络"""
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.2):
        super(FeatureEncoder, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        
        self.encoder = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.encoder(x)


class MMDDAIntegrator:
    """MMD-DA多模态数据整合器"""
    
    def __init__(self, gex_dim, morpho_dim, latent_dim=128, hidden_dims=None, 
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.latent_dim = latent_dim
        
        if hidden_dims is None:
            hidden_dims = [256, 128]
        
        # 创建编码器
        self.gex_encoder = FeatureEncoder(gex_dim, hidden_dims, latent_dim).to(device)
        self.morpho_encoder = FeatureEncoder(morpho_dim, hidden_dims, latent_dim).to(device)
        
        # MMD损失函数
        self.mmd_loss = MMDLoss().to(device)
        
        # 优化器
        self.optimizer = optim.Adam(
            list(self.gex_encoder.parameters()) + list(self.morpho_encoder.parameters()),
            lr=0.001, weight_decay=1e-5
        )
        
        # 标准化器
        self.gex_scaler = StandardScaler()
        self.morpho_scaler = StandardScaler()
        
    def train(self, gex_data, morpho_data, epochs=100, batch_size=128, lambda_mmd=1.0, 
              verbose=True, reconstruction_loss=True):
        """训练MMD-DA模型"""
        
        # 数据标准化
        gex_data_norm = self.gex_scaler.fit_transform(gex_data)
        morpho_data_norm = self.morpho_scaler.fit_transform(morpho_data)
        
        # 转换为张量
        gex_tensor = torch.FloatTensor(gex_data_norm).to(self.device)
        morpho_tensor = torch.FloatTensor(morpho_data_norm).to(self.device)
        
        # 创建数据加载器
        gex_dataset = TensorDataset(gex_tensor)
        morpho_dataset = TensorDataset(morpho_tensor)
        
        gex_loader = DataLoader(gex_dataset, batch_size=batch_size, shuffle=True)
        morpho_loader = DataLoader(morpho_dataset, batch_size=batch_size, shuffle=True)
        
        # 训练循环
        losses = []
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            num_batches = 0
            
            # 获取数据迭代器
            gex_iter = iter(gex_loader)
            morpho_iter = iter(morpho_loader)
            
            while True:
                try:
                    gex_batch = next(gex_iter)[0]
                except StopIteration:
                    gex_iter = iter(gex_loader)
                    gex_batch = next(gex_iter)[0]
                
                try:
                    morpho_batch = next(morpho_iter)[0]
                except StopIteration:
                    break
                
                # 确保批次大小一致
                min_size = min(gex_batch.size(0), morpho_batch.size(0))
                gex_batch = gex_batch[:min_size]
                morpho_batch = morpho_batch[:min_size]
                
                # 前向传播
                gex_encoded = self.gex_encoder(gex_batch)
                morpho_encoded = self.morpho_encoder(morpho_batch)
                
                # 计算MMD损失
                mmd_loss = self.mmd_loss(gex_encoded, morpho_encoded)
                
                total_loss = lambda_mmd * mmd_loss
                
                # 反向传播
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()
                
                epoch_loss += total_loss.item()
                num_batches += 1
            
            avg_loss = epoch_loss / num_batches
            losses.append(avg_loss)
            
            if verbose and (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.6f}')
        
        return losses
    
    def transform(self, gex_data, morpho_data):
        """将数据转换到共同的潜在空间"""
        self.gex_encoder.eval()
        self.morpho_encoder.eval()
        
        with torch.no_grad():
            # 标准化
            gex_data_norm = self.gex_scaler.transform(gex_data)
            morpho_data_norm = self.morpho_scaler.transform(morpho_data)
            
            # 转换为张量
            gex_tensor = torch.FloatTensor(gex_data_norm).to(self.device)
            morpho_tensor = torch.FloatTensor(morpho_data_norm).to(self.device)
            
            # 编码
            gex_encoded = self.gex_encoder(gex_tensor).cpu().numpy()
            morpho_encoded = self.morpho_encoder(morpho_tensor).cpu().numpy()
            
        return gex_encoded, morpho_encoded


def calculate_celltype_accuracy(gex_encoded, morpho_encoded, rna_family_labels, k=1):
    """计算细胞类型匹配准确率"""
    
    # 计算从Morpho到GEX的匹配率
    nbrs_gex = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(gex_encoded)
    distances_m2g, indices_m2g = nbrs_gex.kneighbors(morpho_encoded)
    
    # 获取最近邻的标签（排除自身）
    nearest_gex_labels = rna_family_labels[indices_m2g[:, 1:]]  # 排除第一个（自身）
    morpho_labels = rna_family_labels
    
    # 计算匹配数量
    matches_m2g = 0
    for i in range(len(morpho_labels)):
        if morpho_labels[i] in nearest_gex_labels[i]:
            matches_m2g += 1
    
    morpho_to_gex_accuracy = matches_m2g / len(morpho_labels)
    
    # 计算从GEX到Morpho的匹配率
    nbrs_morpho = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(morpho_encoded)
    distances_g2m, indices_g2m = nbrs_morpho.kneighbors(gex_encoded)
    
    nearest_morpho_labels = rna_family_labels[indices_g2m[:, 1:]]  # 排除第一个（自身）
    gex_labels = rna_family_labels
    
    matches_g2m = 0
    for i in range(len(gex_labels)):
        if gex_labels[i] in nearest_morpho_labels[i]:
            matches_g2m += 1
    
    gex_to_morpho_accuracy = matches_g2m / len(gex_labels)
    
    # 计算平均准确率
    average_accuracy = (morpho_to_gex_accuracy + gex_to_morpho_accuracy) / 2
    
    return morpho_to_gex_accuracy, gex_to_morpho_accuracy, average_accuracy


def main():
    """主函数：完整的MMD-DA整合流程"""
    
    # 数据路径
    gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
    morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
    rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
    
    # 输出路径
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/MMDDA/"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading data...")
    
    # 加载基因表达数据
    gex_df = pd.read_csv(gene_expression_path, header=None)
    gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
    
    # 加载形态学数据
    morpho_df = pd.read_csv(morphology_path, header=0)
    morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
    
    # 加载RNA family标签
    try:
        rna_df = pd.read_csv(rna_family_path, header=0)
        if rna_df.shape[1] == 1:
            rna_family_labels = rna_df.iloc[:, 0].values
        else:
            rna_family_labels = rna_df.iloc[:, 1].values
        
        # 确保所有数据长度一致
        min_samples = min(len(gex_data), len(morpho_data), len(rna_family_labels))
        gex_data = gex_data[:min_samples]
        morpho_data = morpho_data[:min_samples]
        rna_family_labels = rna_family_labels[:min_samples]
        
        print(f"Data loaded successfully:")
        print(f"  - GEX data shape: {gex_data.shape}")
        print(f"  - Morpho data shape: {morpho_data.shape}")
        print(f"  - RNA family labels: {len(np.unique(rna_family_labels))} unique types")
        
    except Exception as e:
        print(f"Error loading RNA family labels: {e}")
        return
    
    # 创建并训练MMD-DA模型
    print("\nInitializing MMD-DA model...")
    integrator = MMDDAIntegrator(
        gex_dim=gex_data.shape[1],
        morpho_dim=morpho_data.shape[1],
        latent_dim=128,
        hidden_dims=[256, 128]
    )
    
    print("Training MMD-DA model...")
    losses = integrator.train(
        gex_data, morpho_data, 
        epochs=200, 
        batch_size=64, 
        lambda_mmd=1.0,
        verbose=True
    )
    
    # 转换数据到共同空间
    print("\nTransforming data to common latent space...")
    gex_encoded, morpho_encoded = integrator.transform(gex_data, morpho_data)
    
    print(f"Encoded data shapes:")
    print(f"  - GEX encoded: {gex_encoded.shape}")
    print(f"  - Morpho encoded: {morpho_encoded.shape}")
    
    # 计算细胞类型匹配准确率
    print("\nCalculating celltype accuracy...")
    morpho_to_gex_acc, gex_to_morpho_acc, avg_acc = calculate_celltype_accuracy(
        gex_encoded, morpho_encoded, rna_family_labels
    )
    
    # 输出结果
    print(f"\n" + "="*50)
    print(f"CELLTYPE ACCURACY RESULTS:")
    print(f"="*50)
    print(f"Morpho to GEX accuracy: {morpho_to_gex_acc:.4f}")
    print(f"GEX to Morpho accuracy: {gex_to_morpho_acc:.4f}")
    print(f"Average accuracy: {avg_acc:.4f}")
    print(f"="*50)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(gex_encoded)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(morpho_encoded)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")
    
    return {
        'morpho_to_gex_accuracy': morpho_to_gex_acc,
        'gex_to_morpho_accuracy': gex_to_morpho_acc,
        'average_accuracy': avg_acc,
        'gex_encoded': gex_encoded,
        'morpho_encoded': morpho_encoded,
        'integrator': integrator
    }


if __name__ == "__main__":
    results = main()

  from .autonotebook import tqdm as notebook_tqdm


Loading data...
Data loaded successfully:
  - GEX data shape: (645, 2000)
  - Morpho data shape: (645, 645)
  - RNA family labels: 10 unique types

Initializing MMD-DA model...
Training MMD-DA model...
Epoch [10/200], Loss: 0.149540
Epoch [20/200], Loss: 0.119339
Epoch [30/200], Loss: 0.124792
Epoch [40/200], Loss: 0.099399
Epoch [50/200], Loss: 0.115688
Epoch [60/200], Loss: 0.072340
Epoch [70/200], Loss: 0.091670
Epoch [80/200], Loss: 0.088127
Epoch [90/200], Loss: 0.078209
Epoch [100/200], Loss: 0.091072
Epoch [110/200], Loss: 0.100726
Epoch [120/200], Loss: 0.100471
Epoch [130/200], Loss: 0.085039
Epoch [140/200], Loss: 0.074465
Epoch [150/200], Loss: 0.057716
Epoch [160/200], Loss: 0.075180
Epoch [170/200], Loss: 0.063125
Epoch [180/200], Loss: 0.062496
Epoch [190/200], Loss: 0.078186
Epoch [200/200], Loss: 0.068741

Transforming data to common latent space...
Encoded data shapes:
  - GEX encoded: (645, 128)
  - Morpho encoded: (645, 128)

Calculating celltype accuracy...

CELLTYP

  warn(


  - GEX UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/MMDDA/gex_umap.csv
  - Computing Morpho UMAP...
  - Morpho UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/MMDDA/morpho_umap.csv

Done!


## Cycel GAN

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score
import umap
from tqdm import tqdm
import os

class Generator(nn.Module):
    """Generator network for CycleGAN"""
    def __init__(self, input_dim, output_dim, hidden_dim=512):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.network(x)

class Discriminator(nn.Module):
    """Discriminator network for CycleGAN"""
    def __init__(self, input_dim, hidden_dim=512):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 4, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.network(x)

class CycleGANIntegrator:
    def __init__(self, gex_dim=2000, morpho_dim=645, latent_dim=256, device='cuda'):
        self.device = device if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # 初始化生成器和判别器
        self.G_gex2morpho = Generator(gex_dim, latent_dim).to(self.device)
        self.G_morpho2gex = Generator(morpho_dim, latent_dim).to(self.device)
        self.D_gex = Discriminator(latent_dim).to(self.device)
        self.D_morpho = Discriminator(latent_dim).to(self.device)
        
        # 重建生成器（用于cycle consistency）
        self.G_gex_recon = Generator(latent_dim, gex_dim).to(self.device)
        self.G_morpho_recon = Generator(latent_dim, morpho_dim).to(self.device)
        
        # 损失函数
        self.adversarial_loss = nn.BCELoss()
        self.cycle_loss = nn.L1Loss()
        self.identity_loss = nn.L1Loss()
        
        # 优化器
        self.optimizer_G = optim.Adam(
            list(self.G_gex2morpho.parameters()) + 
            list(self.G_morpho2gex.parameters()) +
            list(self.G_gex_recon.parameters()) + 
            list(self.G_morpho_recon.parameters()),
            lr=0.0002, betas=(0.5, 0.999)
        )
        self.optimizer_D = optim.Adam(
            list(self.D_gex.parameters()) + list(self.D_morpho.parameters()),
            lr=0.0002, betas=(0.5, 0.999)
        )
        
    def load_data(self):
        """加载数据"""
        # 基因表达数据
        gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
        gex_df = pd.read_csv(gene_expression_path, header=None)
        self.gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # 形态学数据
        morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
        morpho_df = pd.read_csv(morphology_path, header=0)
        self.morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # RNA family标签
        rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
        try:
            rna_df = pd.read_csv(rna_family_path, header=0)
            if rna_df.shape[1] == 1:
                self.rna_family_labels = rna_df.iloc[:, 0].values
            else:
                self.rna_family_labels = rna_df.iloc[:, 1].values
            
            min_samples = min(len(self.gex_data), len(self.morpho_data))
            self.rna_family_labels = self.rna_family_labels[:min_samples]
            self.gex_data = self.gex_data[:min_samples]
            self.morpho_data = self.morpho_data[:min_samples]
            
            print(f"RNA family labels loaded: {len(np.unique(self.rna_family_labels))} unique types")
            print(f"Data shapes - GEX: {self.gex_data.shape}, Morpho: {self.morpho_data.shape}")
        except Exception as e:
            print(f"Warning: Could not load RNA family labels: {e}")
            self.rna_family_labels = None
        
        # 数据标准化
        self.gex_data = (self.gex_data - np.mean(self.gex_data, axis=0)) / (np.std(self.gex_data, axis=0) + 1e-8)
        self.morpho_data = (self.morpho_data - np.mean(self.morpho_data, axis=0)) / (np.std(self.morpho_data, axis=0) + 1e-8)
        
        # 转换为tensor
        self.gex_tensor = torch.FloatTensor(self.gex_data).to(self.device)
        self.morpho_tensor = torch.FloatTensor(self.morpho_data).to(self.device)
        
    def train_epoch(self, epoch):
        """训练一个epoch"""
        batch_size = 64
        n_batches = len(self.gex_data) // batch_size
        
        total_g_loss = 0
        total_d_loss = 0
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(self.gex_data))
            
            real_gex = self.gex_tensor[start_idx:end_idx]
            real_morpho = self.morpho_tensor[start_idx:end_idx]
            
            batch_size_actual = real_gex.size(0)
            valid = torch.ones(batch_size_actual, 1).to(self.device)
            fake = torch.zeros(batch_size_actual, 1).to(self.device)
            
            # =======================
            # 训练生成器
            # =======================
            self.optimizer_G.zero_grad()
            
            # GEX -> 潜在空间 -> Morpho -> 潜在空间 -> GEX (cycle)
            gex_latent = self.G_gex2morpho(real_gex)
            morpho_recon = self.G_morpho_recon(gex_latent)
            gex_cycle = self.G_gex_recon(self.G_morpho2gex(morpho_recon))
            
            # Morpho -> 潜在空间 -> GEX -> 潜在空间 -> Morpho (cycle)
            morpho_latent = self.G_morpho2gex(real_morpho)
            gex_recon = self.G_gex_recon(morpho_latent)
            morpho_cycle = self.G_morpho_recon(self.G_gex2morpho(gex_recon))
            
            # 对抗损失
            d_gex_fake = self.D_gex(morpho_latent)
            d_morpho_fake = self.D_morpho(gex_latent)
            
            loss_GAN_gex = self.adversarial_loss(d_gex_fake, valid)
            loss_GAN_morpho = self.adversarial_loss(d_morpho_fake, valid)
            
            # Cycle consistency损失
            loss_cycle_gex = self.cycle_loss(gex_cycle, real_gex)
            loss_cycle_morpho = self.cycle_loss(morpho_cycle, real_morpho)
            
            # Identity损失（可选，帮助保持数据特性）
            identity_gex = self.G_gex_recon(self.G_gex2morpho(real_gex))
            identity_morpho = self.G_morpho_recon(self.G_morpho2gex(real_morpho))
            loss_identity_gex = self.identity_loss(identity_gex, real_gex)
            loss_identity_morpho = self.identity_loss(identity_morpho, real_morpho)
            
            # 总生成器损失
            loss_G = (loss_GAN_gex + loss_GAN_morpho + 
                     10 * (loss_cycle_gex + loss_cycle_morpho) + 
                     0.5 * (loss_identity_gex + loss_identity_morpho))
            
            loss_G.backward()
            self.optimizer_G.step()
            
            # =======================
            # 训练判别器
            # =======================
            self.optimizer_D.zero_grad()
            
            # 真实数据
            d_gex_real = self.D_gex(self.G_gex2morpho(real_gex).detach())
            d_morpho_real = self.D_morpho(self.G_morpho2gex(real_morpho).detach())
            
            loss_D_real_gex = self.adversarial_loss(d_gex_real, valid)
            loss_D_real_morpho = self.adversarial_loss(d_morpho_real, valid)
            
            # 假数据
            d_gex_fake = self.D_gex(morpho_latent.detach())
            d_morpho_fake = self.D_morpho(gex_latent.detach())
            
            loss_D_fake_gex = self.adversarial_loss(d_gex_fake, fake)
            loss_D_fake_morpho = self.adversarial_loss(d_morpho_fake, fake)
            
            # 总判别器损失
            loss_D = ((loss_D_real_gex + loss_D_fake_gex) + 
                     (loss_D_real_morpho + loss_D_fake_morpho)) / 2
            
            loss_D.backward()
            self.optimizer_D.step()
            
            total_g_loss += loss_G.item()
            total_d_loss += loss_D.item()
        
        return total_g_loss / n_batches, total_d_loss / n_batches
    
    def get_integrated_embeddings(self):
        """获取整合后的嵌入"""
        self.G_gex2morpho.eval()
        self.G_morpho2gex.eval()
        
        with torch.no_grad():
            gex_embeddings = self.G_gex2morpho(self.gex_tensor).cpu().numpy()
            morpho_embeddings = self.G_morpho2gex(self.morpho_tensor).cpu().numpy()
        
        return gex_embeddings, morpho_embeddings
    
    def calculate_celltype_accuracy(self, gex_embeddings, morpho_embeddings):
        """计算细胞类型匹配准确率"""
        if self.rna_family_labels is None:
            print("No RNA family labels available for accuracy calculation")
            return None, None, None
        
        # 使用最近邻搜索
        nbrs_morpho = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(morpho_embeddings)
        nbrs_gex = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(gex_embeddings)
        
        # GEX到Morpho的匹配率
        _, indices_gex2morpho = nbrs_morpho.kneighbors(gex_embeddings)
        gex2morpho_matches = 0
        for i, nearest_idx in enumerate(indices_gex2morpho.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                gex2morpho_matches += 1
        gex2morpho_accuracy = gex2morpho_matches / len(self.rna_family_labels)
        
        # Morpho到GEX的匹配率
        _, indices_morpho2gex = nbrs_gex.kneighbors(morpho_embeddings)
        morpho2gex_matches = 0
        for i, nearest_idx in enumerate(indices_morpho2gex.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                morpho2gex_matches += 1
        morpho2gex_accuracy = morpho2gex_matches / len(self.rna_family_labels)
        
        # 平均匹配率
        average_accuracy = (gex2morpho_accuracy + morpho2gex_accuracy) / 2
        
        return gex2morpho_accuracy, morpho2gex_accuracy, average_accuracy
    
    def train(self, epochs=200):
        """训练CycleGAN模型"""
        print("Starting CycleGAN training...")
        
        for epoch in tqdm(range(epochs), desc="Training"):
            g_loss, d_loss = self.train_epoch(epoch)
            
            # 每10个epoch打印一次
            if (epoch + 1) % 10 == 0:
                gex_emb, morpho_emb = self.get_integrated_embeddings()
                gex2morpho_acc, morpho2gex_acc, avg_acc = self.calculate_celltype_accuracy(gex_emb, morpho_emb)
                
                print(f"Epoch {epoch+1}/{epochs}")
                print(f"  G Loss: {g_loss:.4f}, D Loss: {d_loss:.4f}")
                if avg_acc is not None:
                    print(f"  GEX->Morpho Acc: {gex2morpho_acc:.4f}")
                    print(f"  Morpho->GEX Acc: {morpho2gex_acc:.4f}")
                    print(f"  Average Acc: {avg_acc:.4f}")
                print("-" * 50)

# 使用示例
if __name__ == "__main__":
    # 创建输出目录
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/CycleGAN/"
    os.makedirs(output_dir, exist_ok=True)
    
    # 初始化模型
    integrator = CycleGANIntegrator(gex_dim=2000, morpho_dim=645, latent_dim=256)
    
    # 加载数据
    integrator.load_data()
    
    # 训练模型
    integrator.train(epochs=200)
    
    # 获取最终的整合嵌入
    print("\nGetting latent space embeddings...")
    final_gex_embeddings, final_morpho_embeddings = integrator.get_integrated_embeddings()
    
    # 计算最终准确率
    final_gex2morpho_acc, final_morpho2gex_acc, final_avg_acc = integrator.calculate_celltype_accuracy(
        final_gex_embeddings, final_morpho_embeddings
    )
    
    print("\n" + "="*60)
    print("FINAL RESULTS:")
    print("="*60)
    print(f"GEX -> Morpho Accuracy: {final_gex2morpho_acc:.4f}")
    print(f"Morpho -> GEX Accuracy: {final_morpho2gex_acc:.4f}")
    print(f"Average Accuracy: {final_avg_acc:.4f}")
    print("="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(final_gex_embeddings)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(final_morpho_embeddings)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")

Using device: cpu
RNA family labels loaded: 10 unique types
Data shapes - GEX: (645, 2000), Morpho: (645, 645)
Starting CycleGAN training...


Training:   5%|▌         | 10/200 [00:51<17:55,  5.66s/it]

Epoch 10/200
  G Loss: 12.3280, D Loss: 1.2921
  GEX->Morpho Acc: 0.1907
  Morpho->GEX Acc: 0.0899
  Average Acc: 0.1403
--------------------------------------------------


Training:  10%|█         | 20/200 [01:29<12:09,  4.05s/it]

Epoch 20/200
  G Loss: 11.5809, D Loss: 1.1592
  GEX->Morpho Acc: 0.1752
  Morpho->GEX Acc: 0.1891
  Average Acc: 0.1822
--------------------------------------------------


Training:  15%|█▌        | 30/200 [02:09<12:28,  4.40s/it]

Epoch 30/200
  G Loss: 12.1029, D Loss: 1.0447
  GEX->Morpho Acc: 0.1597
  Morpho->GEX Acc: 0.1054
  Average Acc: 0.1326
--------------------------------------------------


Training:  20%|██        | 40/200 [03:06<12:48,  4.80s/it]

Epoch 40/200
  G Loss: 12.4366, D Loss: 1.2524
  GEX->Morpho Acc: 0.1674
  Morpho->GEX Acc: 0.1070
  Average Acc: 0.1372
--------------------------------------------------


Training:  25%|██▌       | 50/200 [03:50<09:34,  3.83s/it]

Epoch 50/200
  G Loss: 11.1682, D Loss: 1.5045
  GEX->Morpho Acc: 0.1690
  Morpho->GEX Acc: 0.1054
  Average Acc: 0.1372
--------------------------------------------------


Training:  30%|███       | 60/200 [04:32<09:50,  4.22s/it]

Epoch 60/200
  G Loss: 11.2318, D Loss: 0.9618
  GEX->Morpho Acc: 0.1612
  Morpho->GEX Acc: 0.1814
  Average Acc: 0.1713
--------------------------------------------------


Training:  35%|███▌      | 70/200 [05:37<18:08,  8.38s/it]

Epoch 70/200
  G Loss: 11.0012, D Loss: 1.1030
  GEX->Morpho Acc: 0.1271
  Morpho->GEX Acc: 0.1147
  Average Acc: 0.1209
--------------------------------------------------


Training:  40%|████      | 80/200 [06:17<08:05,  4.05s/it]

Epoch 80/200
  G Loss: 10.9131, D Loss: 1.0765
  GEX->Morpho Acc: 0.1349
  Morpho->GEX Acc: 0.1023
  Average Acc: 0.1186
--------------------------------------------------


Training:  45%|████▌     | 90/200 [07:03<09:20,  5.10s/it]

Epoch 90/200
  G Loss: 11.0908, D Loss: 1.2711
  GEX->Morpho Acc: 0.2357
  Morpho->GEX Acc: 0.1349
  Average Acc: 0.1853
--------------------------------------------------


Training:  50%|█████     | 100/200 [07:48<06:44,  4.04s/it]

Epoch 100/200
  G Loss: 11.7458, D Loss: 0.8237
  GEX->Morpho Acc: 0.1442
  Morpho->GEX Acc: 0.1178
  Average Acc: 0.1310
--------------------------------------------------


Training:  55%|█████▌    | 110/200 [08:30<05:51,  3.90s/it]

Epoch 110/200
  G Loss: 11.6305, D Loss: 1.1920
  GEX->Morpho Acc: 0.1628
  Morpho->GEX Acc: 0.1535
  Average Acc: 0.1581
--------------------------------------------------


Training:  60%|██████    | 120/200 [09:16<06:53,  5.17s/it]

Epoch 120/200
  G Loss: 10.9671, D Loss: 1.1249
  GEX->Morpho Acc: 0.2202
  Morpho->GEX Acc: 0.1008
  Average Acc: 0.1605
--------------------------------------------------


Training:  65%|██████▌   | 130/200 [09:58<04:35,  3.94s/it]

Epoch 130/200
  G Loss: 12.1283, D Loss: 1.0639
  GEX->Morpho Acc: 0.1287
  Morpho->GEX Acc: 0.1860
  Average Acc: 0.1574
--------------------------------------------------


Training:  70%|███████   | 140/200 [10:37<04:08,  4.15s/it]

Epoch 140/200
  G Loss: 11.5284, D Loss: 0.9027
  GEX->Morpho Acc: 0.1628
  Morpho->GEX Acc: 0.1054
  Average Acc: 0.1341
--------------------------------------------------


Training:  75%|███████▌  | 150/200 [11:22<03:23,  4.06s/it]

Epoch 150/200
  G Loss: 13.1879, D Loss: 0.8193
  GEX->Morpho Acc: 0.1767
  Morpho->GEX Acc: 0.2357
  Average Acc: 0.2062
--------------------------------------------------


Training:  80%|████████  | 160/200 [12:54<11:34, 17.37s/it]

Epoch 160/200
  G Loss: 12.7722, D Loss: 0.5358
  GEX->Morpho Acc: 0.1659
  Morpho->GEX Acc: 0.1488
  Average Acc: 0.1574
--------------------------------------------------


Training:  85%|████████▌ | 170/200 [13:48<02:16,  4.54s/it]

Epoch 170/200
  G Loss: 12.1149, D Loss: 0.7979
  GEX->Morpho Acc: 0.1876
  Morpho->GEX Acc: 0.1426
  Average Acc: 0.1651
--------------------------------------------------


Training:  90%|█████████ | 180/200 [14:23<01:10,  3.51s/it]

Epoch 180/200
  G Loss: 12.1058, D Loss: 0.8732
  GEX->Morpho Acc: 0.1736
  Morpho->GEX Acc: 0.2047
  Average Acc: 0.1891
--------------------------------------------------


Training:  95%|█████████▌| 190/200 [15:14<00:48,  4.86s/it]

Epoch 190/200
  G Loss: 13.6506, D Loss: 1.0646
  GEX->Morpho Acc: 0.1597
  Morpho->GEX Acc: 0.0837
  Average Acc: 0.1217
--------------------------------------------------


Training: 100%|██████████| 200/200 [15:59<00:00,  4.80s/it]

Epoch 200/200
  G Loss: 11.9456, D Loss: 0.8363
  GEX->Morpho Acc: 0.1845
  Morpho->GEX Acc: 0.1473
  Average Acc: 0.1659
--------------------------------------------------

Getting latent space embeddings...






FINAL RESULTS:
GEX -> Morpho Accuracy: 0.1845
Morpho -> GEX Accuracy: 0.1473
Average Accuracy: 0.1659

Performing UMAP dimensionality reduction...
  - Computing GEX UMAP...


  warn(


  - GEX UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/CycleGAN/gex_umap.csv
  - Computing Morpho UMAP...
  - Morpho UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/CycleGAN/morpho_umap.csv

Done!


## DCCA

In [3]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
import umap
import os
from scipy.linalg import sqrtm, inv
import warnings
warnings.filterwarnings('ignore')


class DCCAEncoder(nn.Module):
    """DCCA编码器网络"""
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.1):
        super(DCCAEncoder, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            if i < len(hidden_dims) - 1:
                layers.append(nn.Dropout(dropout_rate))
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        
        self.encoder = nn.Sequential(*layers)
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight, gain=0.1)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.BatchNorm1d):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        if torch.isnan(x).any():
            x = torch.nan_to_num(x, nan=0.0)
        
        output = self.encoder(x)
        
        if torch.isnan(output).any():
            output = torch.nan_to_num(output, nan=0.0)
        
        return output


class DCCALoss(nn.Module):
    """DCCA损失函数"""
    def __init__(self, outdim_size, use_all_singular_values=False, regularization=1e-3):
        super(DCCALoss, self).__init__()
        self.outdim_size = outdim_size
        self.use_all_singular_values = use_all_singular_values
        self.regularization = regularization

    def forward(self, H1, H2):
        batch_size = H1.size(0)
        
        if torch.isnan(H1).any() or torch.isnan(H2).any():
            return torch.tensor(0.01, device=H1.device, requires_grad=True)
        
        if torch.isinf(H1).any() or torch.isinf(H2).any():
            return torch.tensor(0.01, device=H1.device, requires_grad=True)
        
        H1_centered = H1 - H1.mean(dim=0, keepdim=True)
        H2_centered = H2 - H2.mean(dim=0, keepdim=True)
        
        if H1_centered.std() < 1e-8 or H2_centered.std() < 1e-8:
            return torch.tensor(0.01, device=H1.device, requires_grad=True)
        
        try:
            SigmaHat12 = torch.matmul(H1_centered.t(), H2_centered) / (batch_size - 1)
            SigmaHat11 = torch.matmul(H1_centered.t(), H1_centered) / (batch_size - 1)
            SigmaHat22 = torch.matmul(H2_centered.t(), H2_centered) / (batch_size - 1)
            
            SigmaHat11 += self.regularization * torch.eye(H1.size(1), device=H1.device)
            SigmaHat22 += self.regularization * torch.eye(H2.size(1), device=H2.device)
            
            if torch.isnan(SigmaHat11).any() or torch.isnan(SigmaHat22).any() or torch.isnan(SigmaHat12).any():
                return torch.tensor(0.01, device=H1.device, requires_grad=True)
            
            try:
                L11 = torch.linalg.cholesky(SigmaHat11)
                SigmaHat11_inv_sqrt = torch.cholesky_inverse(L11)
                
                L22 = torch.linalg.cholesky(SigmaHat22)
                SigmaHat22_inv_sqrt = torch.cholesky_inverse(L22)
                
            except RuntimeError:
                eigenvalues1, eigenvectors1 = torch.linalg.eigh(SigmaHat11)
                eigenvalues1 = torch.clamp(eigenvalues1, min=1e-6)
                SigmaHat11_inv_sqrt = torch.matmul(
                    torch.matmul(eigenvectors1, torch.diag(1.0 / torch.sqrt(eigenvalues1))),
                    eigenvectors1.t()
                )
                
                eigenvalues2, eigenvectors2 = torch.linalg.eigh(SigmaHat22)
                eigenvalues2 = torch.clamp(eigenvalues2, min=1e-6)
                SigmaHat22_inv_sqrt = torch.matmul(
                    torch.matmul(eigenvectors2, torch.diag(1.0 / torch.sqrt(eigenvalues2))),
                    eigenvectors2.t()
                )
            
            T = torch.matmul(torch.matmul(SigmaHat11_inv_sqrt, SigmaHat12), SigmaHat22_inv_sqrt)
            
            if torch.isnan(T).any() or torch.isinf(T).any():
                return torch.tensor(0.01, device=H1.device, requires_grad=True)
            
            U, S, V = torch.linalg.svd(T)
            
            if torch.isnan(S).any() or torch.isinf(S).any():
                return torch.tensor(0.01, device=H1.device, requires_grad=True)
            
            if self.use_all_singular_values:
                corr = torch.sum(S)
            else:
                corr = torch.sum(S[:min(self.outdim_size, len(S))])
            
            return -corr
            
        except Exception as e:
            return torch.tensor(0.01, device=H1.device, requires_grad=True)


class DCCAIntegrator:
    """DCCA多模态数据整合器"""
    
    def __init__(self, gex_dim, morpho_dim, latent_dim=64, hidden_dims=None, 
                 use_all_singular_values=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.latent_dim = latent_dim
        self.use_all_singular_values = use_all_singular_values
        
        if hidden_dims is None:
            hidden_dims = [256, 128]
        
        self.gex_encoder = DCCAEncoder(gex_dim, hidden_dims, latent_dim).to(device)
        self.morpho_encoder = DCCAEncoder(morpho_dim, hidden_dims, latent_dim).to(device)
        
        self.dcca_loss = DCCALoss(latent_dim, use_all_singular_values).to(device)
        
        self.optimizer = optim.Adam(
            list(self.gex_encoder.parameters()) + list(self.morpho_encoder.parameters()),
            lr=0.001, weight_decay=1e-4
        )
        
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', patience=20, factor=0.5, verbose=True
        )
        
        self.gex_scaler = StandardScaler()
        self.morpho_scaler = StandardScaler()
        
        self.losses = []
        self.correlations = []
        
    def compute_correlation(self, H1, H2):
        """计算两个表示之间的典型相关性"""
        H1_centered = H1 - H1.mean(dim=0)
        H2_centered = H2 - H2.mean(dim=0)
        
        correlation_matrix = torch.corrcoef(torch.cat([H1_centered.t(), H2_centered.t()], dim=0))
        cross_corr = correlation_matrix[:H1.size(1), H1.size(1):]
        
        U, S, V = torch.linalg.svd(cross_corr)
        
        if self.use_all_singular_values:
            return torch.mean(S)
        else:
            return torch.mean(S[:min(self.latent_dim, len(S))])
    
    def train(self, gex_data, morpho_data, epochs=200, batch_size=128, verbose=True):
        """训练DCCA模型"""
        
        print("Preprocessing data...")
        
        gex_data = np.nan_to_num(gex_data, nan=0.0, posinf=1e6, neginf=-1e6)
        morpho_data = np.nan_to_num(morpho_data, nan=0.0, posinf=1e6, neginf=-1e6)
        
        gex_data_norm = self.gex_scaler.fit_transform(gex_data)
        morpho_data_norm = self.morpho_scaler.fit_transform(morpho_data)
        
        gex_data_norm = np.nan_to_num(gex_data_norm, nan=0.0)
        morpho_data_norm = np.nan_to_num(morpho_data_norm, nan=0.0)
        
        gex_tensor = torch.FloatTensor(gex_data_norm).to(self.device)
        morpho_tensor = torch.FloatTensor(morpho_data_norm).to(self.device)
        
        dataset = TensorDataset(gex_tensor, morpho_tensor)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        
        print(f"Starting DCCA training for {epochs} epochs...")
        
        best_correlation = -1
        patience_counter = 0
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            epoch_corr = 0.0
            num_batches = 0
            
            self.gex_encoder.train()
            self.morpho_encoder.train()
            
            for batch_idx, (gex_batch, morpho_batch) in enumerate(dataloader):
                if gex_batch.size(0) < 8:
                    continue
                
                try:
                    gex_encoded = self.gex_encoder(gex_batch)
                    morpho_encoded = self.morpho_encoder(morpho_batch)
                    
                    if torch.isnan(gex_encoded).any() or torch.isnan(morpho_encoded).any():
                        continue
                    
                    loss = self.dcca_loss(gex_encoded, morpho_encoded)
                    
                    if torch.isnan(loss) or torch.isinf(loss):
                        continue
                    
                    self.optimizer.zero_grad()
                    loss.backward()
                    
                    torch.nn.utils.clip_grad_norm_(
                        list(self.gex_encoder.parameters()) + list(self.morpho_encoder.parameters()), 
                        max_norm=1.0
                    )
                    
                    self.optimizer.step()
                    
                    with torch.no_grad():
                        correlation = self.compute_correlation(gex_encoded, morpho_encoded)
                        if not torch.isnan(correlation):
                            epoch_corr += correlation.item()
                    
                    epoch_loss += loss.item()
                    num_batches += 1
                    
                except Exception as e:
                    continue
            
            if num_batches == 0:
                continue
            
            avg_loss = epoch_loss / num_batches
            avg_corr = epoch_corr / num_batches
            
            self.losses.append(avg_loss)
            self.correlations.append(avg_corr)
            
            self.scheduler.step(avg_corr)
            
            if avg_corr > best_correlation:
                best_correlation = avg_corr
                patience_counter = 0
                self.best_gex_encoder_state = self.gex_encoder.state_dict().copy()
                self.best_morpho_encoder_state = self.morpho_encoder.state_dict().copy()
            else:
                patience_counter += 1
            
            if verbose and (epoch + 1) % 20 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.6f}, '
                      f'Correlation: {avg_corr:.6f}, Best: {best_correlation:.6f}')
            
            if patience_counter >= 50:
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        if hasattr(self, 'best_gex_encoder_state'):
            self.gex_encoder.load_state_dict(self.best_gex_encoder_state)
            self.morpho_encoder.load_state_dict(self.best_morpho_encoder_state)
            print(f"Loaded best model with correlation: {best_correlation:.6f}")
        
        return self.losses, self.correlations
    
    def transform(self, gex_data, morpho_data):
        """将数据转换到DCCA学习的表示空间"""
        self.gex_encoder.eval()
        self.morpho_encoder.eval()
        
        with torch.no_grad():
            gex_data_norm = self.gex_scaler.transform(gex_data)
            morpho_data_norm = self.morpho_scaler.transform(morpho_data)
            
            gex_tensor = torch.FloatTensor(gex_data_norm).to(self.device)
            morpho_tensor = torch.FloatTensor(morpho_data_norm).to(self.device)
            
            gex_encoded = self.gex_encoder(gex_tensor).cpu().numpy()
            morpho_encoded = self.morpho_encoder(morpho_tensor).cpu().numpy()
            
        return gex_encoded, morpho_encoded


def calculate_celltype_accuracy(gex_encoded, morpho_encoded, rna_family_labels, k=1):
    """计算细胞类型匹配准确率"""
    
    nbrs_gex = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(gex_encoded)
    distances_m2g, indices_m2g = nbrs_gex.kneighbors(morpho_encoded)
    
    nearest_gex_indices = indices_m2g[:, 0]
    nearest_gex_labels = rna_family_labels[nearest_gex_indices]
    morpho_labels = rna_family_labels
    
    matches_m2g = np.sum(morpho_labels == nearest_gex_labels)
    morpho_to_gex_accuracy = matches_m2g / len(morpho_labels)
    
    nbrs_morpho = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(morpho_encoded)
    distances_g2m, indices_g2m = nbrs_morpho.kneighbors(gex_encoded)
    
    nearest_morpho_indices = indices_g2m[:, 0]
    nearest_morpho_labels = rna_family_labels[nearest_morpho_indices]
    gex_labels = rna_family_labels
    
    matches_g2m = np.sum(gex_labels == nearest_morpho_labels)
    gex_to_morpho_accuracy = matches_g2m / len(gex_labels)
    
    average_accuracy = (morpho_to_gex_accuracy + gex_to_morpho_accuracy) / 2
    
    return morpho_to_gex_accuracy, gex_to_morpho_accuracy, average_accuracy


def main():
    """主函数：完整的DCCA整合流程"""
    
    # 数据路径
    gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
    morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
    rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
    
    # 输出路径
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/DCCA/"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading data...")
    
    # 加载基因表达数据
    gex_df = pd.read_csv(gene_expression_path, header=None)
    gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
    
    # 加载形态学数据
    morpho_df = pd.read_csv(morphology_path, header=0)
    morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
    
    # 加载RNA family标签
    try:
        rna_df = pd.read_csv(rna_family_path, header=0)
        if rna_df.shape[1] == 1:
            rna_family_labels = rna_df.iloc[:, 0].values
        else:
            rna_family_labels = rna_df.iloc[:, 1].values
        
        min_samples = min(len(gex_data), len(morpho_data), len(rna_family_labels))
        gex_data = gex_data[:min_samples]
        morpho_data = morpho_data[:min_samples]
        rna_family_labels = rna_family_labels[:min_samples]
        
        print(f"Data loaded successfully:")
        print(f"  - GEX data shape: {gex_data.shape}")
        print(f"  - Morpho data shape: {morpho_data.shape}")
        print(f"  - RNA family labels: {len(np.unique(rna_family_labels))} unique types")
        
    except Exception as e:
        print(f"Error loading RNA family labels: {e}")
        return
    
    # 创建并训练DCCA模型
    print("\nInitializing DCCA model...")
    integrator = DCCAIntegrator(
        gex_dim=gex_data.shape[1],
        morpho_dim=morpho_data.shape[1],
        latent_dim=64,
        hidden_dims=[256, 128],
        use_all_singular_values=False
    )
    
    print("Training DCCA model...")
    losses, correlations = integrator.train(
        gex_data, morpho_data,
        epochs=200,
        batch_size=32,
        verbose=True
    )
    
    # 转换数据到DCCA表示空间
    print("\nTransforming data to DCCA representation space...")
    gex_encoded, morpho_encoded = integrator.transform(gex_data, morpho_data)
    
    print(f"Encoded data shapes:")
    print(f"  - GEX encoded: {gex_encoded.shape}")
    print(f"  - Morpho encoded: {morpho_encoded.shape}")
    
    # 计算细胞类型匹配准确率
    print("\nCalculating celltype accuracy...")
    morpho_to_gex_acc, gex_to_morpho_acc, avg_acc = calculate_celltype_accuracy(
        gex_encoded, morpho_encoded, rna_family_labels
    )
    
    # 输出结果
    print(f"\n" + "="*60)
    print(f"DCCA INTEGRATION RESULTS:")
    print(f"="*60)
    print(f"  Morpho to GEX accuracy: {morpho_to_gex_acc:.4f}")
    print(f"  GEX to Morpho accuracy: {gex_to_morpho_acc:.4f}")
    print(f"  Average accuracy: {avg_acc:.4f}")
    print(f"="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(gex_encoded)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(morpho_encoded)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")


if __name__ == "__main__":
    main()

Loading data...
Data loaded successfully:
  - GEX data shape: (645, 2000)
  - Morpho data shape: (645, 645)
  - RNA family labels: 10 unique types

Initializing DCCA model...
Training DCCA model...
Preprocessing data...
Starting DCCA training for 200 epochs...
Epoch [20/200], Loss: -6086.182886, Correlation: 0.787640, Best: 0.791076
Epoch [40/200], Loss: -6238.226270, Correlation: 0.832393, Best: 0.832393
Epoch [60/200], Loss: -6286.326758, Correlation: 0.839704, Best: 0.840527
Epoch [80/200], Loss: -6304.940503, Correlation: 0.843580, Best: 0.850997
Epoch [100/200], Loss: -6337.102148, Correlation: 0.853867, Best: 0.860918
Epoch [120/200], Loss: -6373.642456, Correlation: 0.858300, Best: 0.861411
Epoch [140/200], Loss: -6463.071997, Correlation: 0.867312, Best: 0.867312
Epoch [160/200], Loss: -6494.428760, Correlation: 0.869706, Best: 0.871038
Epoch [180/200], Loss: -6484.048901, Correlation: 0.870034, Best: 0.871038
Epoch [200/200], Loss: -6593.796240, Correlation: 0.879479, Best: 0.

## UnionCom

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import umap
from tqdm import tqdm
import os

class UnionComIntegrator:
    """
    UnionCom implementation for multi-modal data integration
    Based on the UnionCom algorithm for single-cell multi-omics integration
    """
    
    def __init__(self, latent_dim=50, lambda_s=1.0, lambda_f=1.0, lambda_d=1.0, 
                 knn_k=10, max_iter=200, lr=0.01, device='cuda'):
        """
        Initialize UnionCom integrator
        
        Args:
            latent_dim: Dimension of the integrated latent space
            lambda_s: Weight for structure preservation loss
            lambda_f: Weight for feature matching loss  
            lambda_d: Weight for domain alignment loss
            knn_k: Number of neighbors for KNN graph construction
            max_iter: Maximum number of iterations
            lr: Learning rate
            device: Computing device
        """
        self.latent_dim = latent_dim
        self.lambda_s = lambda_s
        self.lambda_f = lambda_f
        self.lambda_d = lambda_d
        self.knn_k = knn_k
        self.max_iter = max_iter
        self.lr = lr
        self.device = device if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Will be initialized after loading data
        self.encoder_gex = None
        self.encoder_morpho = None
        
    def load_data(self):
        """Load multi-modal data"""
        # Gene expression data
        gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
        gex_df = pd.read_csv(gene_expression_path, header=None)
        self.gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # Morphology data
        morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
        morpho_df = pd.read_csv(morphology_path, header=0)
        self.morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # RNA family labels
        rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
        try:
            rna_df = pd.read_csv(rna_family_path, header=0)
            if rna_df.shape[1] == 1:
                self.rna_family_labels = rna_df.iloc[:, 0].values
            else:
                self.rna_family_labels = rna_df.iloc[:, 1].values
            
            min_samples = min(len(self.gex_data), len(self.morpho_data))
            self.rna_family_labels = self.rna_family_labels[:min_samples]
            self.gex_data = self.gex_data[:min_samples]
            self.morpho_data = self.morpho_data[:min_samples]
            
            print(f"RNA family labels loaded: {len(np.unique(self.rna_family_labels))} unique types")
            print(f"Data shapes - GEX: {self.gex_data.shape}, Morpho: {self.morpho_data.shape}")
        except Exception as e:
            print(f"Warning: Could not load RNA family labels: {e}")
            self.rna_family_labels = None
        
        # Data normalization
        self.gex_mean = np.mean(self.gex_data, axis=0)
        self.gex_std = np.std(self.gex_data, axis=0) + 1e-8
        self.morpho_mean = np.mean(self.morpho_data, axis=0)
        self.morpho_std = np.std(self.morpho_data, axis=0) + 1e-8
        
        self.gex_data = (self.gex_data - self.gex_mean) / self.gex_std
        self.morpho_data = (self.morpho_data - self.morpho_mean) / self.morpho_std
        
        self.n_samples = len(self.gex_data)
        self.gex_dim = self.gex_data.shape[1]
        self.morpho_dim = self.morpho_data.shape[1]
        
        # Initialize encoders
        self._initialize_encoders()
        
        # Convert to tensors
        self.gex_tensor = torch.FloatTensor(self.gex_data).to(self.device)
        self.morpho_tensor = torch.FloatTensor(self.morpho_data).to(self.device)
        
        # Build KNN graphs for structure preservation
        self._build_knn_graphs()
        
    def _initialize_encoders(self):
        """Initialize encoder networks"""
        
        class Encoder(nn.Module):
            def __init__(self, input_dim, output_dim):
                super(Encoder, self).__init__()
                hidden_dim = max(256, min(512, input_dim // 2))
                self.network = nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(0.2),
                    nn.Linear(hidden_dim, hidden_dim // 2),
                    nn.BatchNorm1d(hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(0.2),
                    nn.Linear(hidden_dim // 2, output_dim)
                )
            
            def forward(self, x):
                return self.network(x)
        
        self.encoder_gex = Encoder(self.gex_dim, self.latent_dim).to(self.device)
        self.encoder_morpho = Encoder(self.morpho_dim, self.latent_dim).to(self.device)
        
        # Initialize optimizers
        self.optimizer = optim.Adam(
            list(self.encoder_gex.parameters()) + list(self.encoder_morpho.parameters()),
            lr=self.lr
        )
        
    def _build_knn_graphs(self):
        """Build KNN graphs for structure preservation"""
        print("Building KNN graphs...")
        
        # Build KNN graph for GEX data
        nbrs_gex = NearestNeighbors(n_neighbors=self.knn_k, metric='euclidean').fit(self.gex_data)
        _, indices_gex = nbrs_gex.kneighbors(self.gex_data)
        self.knn_indices_gex = indices_gex
        
        # Build KNN graph for Morpho data  
        nbrs_morpho = NearestNeighbors(n_neighbors=self.knn_k, metric='euclidean').fit(self.morpho_data)
        _, indices_morpho = nbrs_morpho.kneighbors(self.morpho_data)
        self.knn_indices_morpho = indices_morpho
        
        print(f"KNN graphs built with k={self.knn_k}")
        
    def _structure_preservation_loss(self, embeddings_gex, embeddings_morpho):
        """Calculate structure preservation loss"""
        loss = 0.0
        
        # Structure preservation for GEX
        for i in range(len(embeddings_gex)):
            anchor_emb = embeddings_gex[i]
            neighbor_indices = self.knn_indices_gex[i]
            
            # Calculate distances in original space
            anchor_orig = self.gex_tensor[i]
            neighbors_orig = self.gex_tensor[neighbor_indices]
            orig_distances = torch.norm(neighbors_orig - anchor_orig.unsqueeze(0), dim=1)
            
            # Calculate distances in embedding space
            neighbors_emb = embeddings_gex[neighbor_indices]
            emb_distances = torch.norm(neighbors_emb - anchor_emb.unsqueeze(0), dim=1)
            
            # Structure preservation loss (preserve relative distances)
            loss += torch.mean((orig_distances - emb_distances) ** 2)
            
        # Structure preservation for Morpho
        for i in range(len(embeddings_morpho)):
            anchor_emb = embeddings_morpho[i]
            neighbor_indices = self.knn_indices_morpho[i]
            
            # Calculate distances in original space
            anchor_orig = self.morpho_tensor[i]
            neighbors_orig = self.morpho_tensor[neighbor_indices]
            orig_distances = torch.norm(neighbors_orig - anchor_orig.unsqueeze(0), dim=1)
            
            # Calculate distances in embedding space
            neighbors_emb = embeddings_morpho[neighbor_indices]
            emb_distances = torch.norm(neighbors_emb - anchor_emb.unsqueeze(0), dim=1)
            
            # Structure preservation loss
            loss += torch.mean((orig_distances - emb_distances) ** 2)
            
        return loss / (2 * self.n_samples)
    
    def _feature_matching_loss(self, embeddings_gex, embeddings_morpho):
        """Calculate feature matching loss between modalities"""
        # Calculate mean embeddings for each modality
        mean_gex = torch.mean(embeddings_gex, dim=0)
        mean_morpho = torch.mean(embeddings_morpho, dim=0)
        
        # Feature matching loss (align distributions)
        feature_loss = torch.mean((mean_gex - mean_morpho) ** 2)
        
        # Add covariance alignment
        cov_gex = torch.cov(embeddings_gex.T)
        cov_morpho = torch.cov(embeddings_morpho.T)
        cov_loss = torch.mean((cov_gex - cov_morpho) ** 2)
        
        return feature_loss + 0.1 * cov_loss
    
    def _domain_alignment_loss(self, embeddings_gex, embeddings_morpho):
        """Calculate domain alignment loss using Maximum Mean Discrepancy (MMD)"""
        
        def compute_kernel(x, y, sigma=1.0):
            """Compute RBF kernel matrix"""
            x_size = x.size(0)
            y_size = y.size(0)
            dim = x.size(1)
            
            x = x.unsqueeze(1)  # (x_size, 1, dim)
            y = y.unsqueeze(0)  # (1, y_size, dim)
            
            tiled_x = x.expand(x_size, y_size, dim)
            tiled_y = y.expand(x_size, y_size, dim)
            
            kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim)
            return torch.exp(-kernel_input / sigma)
        
        # Compute MMD loss
        x_kernel = compute_kernel(embeddings_gex, embeddings_gex)
        y_kernel = compute_kernel(embeddings_morpho, embeddings_morpho) 
        xy_kernel = compute_kernel(embeddings_gex, embeddings_morpho)
        
        mmd_loss = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
        return mmd_loss
    
    def train_step(self):
        """Perform one training step"""
        self.optimizer.zero_grad()
        
        # Forward pass
        embeddings_gex = self.encoder_gex(self.gex_tensor)
        embeddings_morpho = self.encoder_morpho(self.morpho_tensor)
        
        # Calculate losses
        structure_loss = self._structure_preservation_loss(embeddings_gex, embeddings_morpho)
        feature_loss = self._feature_matching_loss(embeddings_gex, embeddings_morpho)
        domain_loss = self._domain_alignment_loss(embeddings_gex, embeddings_morpho)
        
        # Total loss
        total_loss = (self.lambda_s * structure_loss + 
                     self.lambda_f * feature_loss + 
                     self.lambda_d * domain_loss)
        
        # Backward pass
        total_loss.backward()
        self.optimizer.step()
        
        return total_loss.item(), structure_loss.item(), feature_loss.item(), domain_loss.item()
    
    def get_integrated_embeddings(self):
        """Get integrated embeddings"""
        self.encoder_gex.eval()
        self.encoder_morpho.eval()
        
        with torch.no_grad():
            gex_embeddings = self.encoder_gex(self.gex_tensor).cpu().numpy()
            morpho_embeddings = self.encoder_morpho(self.morpho_tensor).cpu().numpy()
        
        return gex_embeddings, morpho_embeddings
    
    def calculate_celltype_accuracy(self, gex_embeddings, morpho_embeddings):
        """Calculate cell type matching accuracy"""
        if self.rna_family_labels is None:
            print("No RNA family labels available for accuracy calculation")
            return None, None, None
        
        # Use nearest neighbor search
        nbrs_morpho = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(morpho_embeddings)
        nbrs_gex = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(gex_embeddings)
        
        # GEX to Morpho matching accuracy
        _, indices_gex2morpho = nbrs_morpho.kneighbors(gex_embeddings)
        gex2morpho_matches = 0
        for i, nearest_idx in enumerate(indices_gex2morpho.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                gex2morpho_matches += 1
        gex2morpho_accuracy = gex2morpho_matches / len(self.rna_family_labels)
        
        # Morpho to GEX matching accuracy
        _, indices_morpho2gex = nbrs_gex.kneighbors(morpho_embeddings)
        morpho2gex_matches = 0
        for i, nearest_idx in enumerate(indices_morpho2gex.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                morpho2gex_matches += 1
        morpho2gex_accuracy = morpho2gex_matches / len(self.rna_family_labels)
        
        # Average matching accuracy
        average_accuracy = (gex2morpho_accuracy + morpho2gex_accuracy) / 2
        
        return gex2morpho_accuracy, morpho2gex_accuracy, average_accuracy
    
    def train(self, verbose_interval=20):
        """Train UnionCom model"""
        print("Starting UnionCom training...")
        
        for iteration in tqdm(range(self.max_iter), desc="Training UnionCom"):
            # Training step
            total_loss, struct_loss, feat_loss, dom_loss = self.train_step()
            
            # Calculate accuracy periodically
            if (iteration + 1) % verbose_interval == 0:
                gex_emb, morpho_emb = self.get_integrated_embeddings()
                gex2morpho_acc, morpho2gex_acc, avg_acc = self.calculate_celltype_accuracy(gex_emb, morpho_emb)
                
                print(f"\nIteration {iteration+1}/{self.max_iter}")
                print(f"  Total Loss: {total_loss:.4f}")
                print(f"  Structure Loss: {struct_loss:.4f}")
                print(f"  Feature Loss: {feat_loss:.4f}")
                print(f"  Domain Loss: {dom_loss:.4f}")
                if avg_acc is not None:
                    print(f"  GEX->Morpho Acc: {gex2morpho_acc:.4f}")
                    print(f"  Morpho->GEX Acc: {morpho2gex_acc:.4f}")
                    print(f"  Average Acc: {avg_acc:.4f}")
                print("-" * 50)

# Usage example
if __name__ == "__main__":
    # 创建输出目录
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/UnionCom/"
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize UnionCom integrator
    integrator = UnionComIntegrator(
        latent_dim=50,
        lambda_s=1.0,
        lambda_f=1.0,
        lambda_d=1.0,
        knn_k=10,
        max_iter=500,
        lr=0.001
    )
    
    # Load data
    integrator.load_data()
    
    # Train model
    integrator.train(verbose_interval=25)
    
    # Get final integrated embeddings
    print("\nGetting latent space embeddings...")
    final_gex_embeddings, final_morpho_embeddings = integrator.get_integrated_embeddings()
    
    # Calculate final accuracy
    final_gex2morpho_acc, final_morpho2gex_acc, final_avg_acc = integrator.calculate_celltype_accuracy(
        final_gex_embeddings, final_morpho_embeddings
    )
    
    print("\n" + "="*60)
    print("FINAL UNIONCOM RESULTS:")
    print("="*60)
    print(f"GEX -> Morpho Accuracy: {final_gex2morpho_acc:.4f}")
    print(f"Morpho -> GEX Accuracy: {final_morpho2gex_acc:.4f}")
    print(f"Average Accuracy: {final_avg_acc:.4f}")
    print("="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(final_gex_embeddings)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(final_morpho_embeddings)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")

Using device: cpu
RNA family labels loaded: 10 unique types
Data shapes - GEX: (645, 2000), Morpho: (645, 645)
Building KNN graphs...
KNN graphs built with k=10
Starting UnionCom training...


Training UnionCom:   5%|▌         | 25/500 [00:38<15:13,  1.92s/it]


Iteration 25/500
  Total Loss: 153.3509
  Structure Loss: 151.6015
  Feature Loss: 1.7415
  Domain Loss: 0.0079
  GEX->Morpho Acc: 0.2078
  Morpho->GEX Acc: 0.1845
  Average Acc: 0.1961
--------------------------------------------------


Training UnionCom:  10%|█         | 50/500 [01:09<08:57,  1.19s/it]


Iteration 50/500
  Total Loss: 37.8613
  Structure Loss: 34.8849
  Feature Loss: 2.9522
  Domain Loss: 0.0242
  GEX->Morpho Acc: 0.1705
  Morpho->GEX Acc: 0.2171
  Average Acc: 0.1938
--------------------------------------------------


Training UnionCom:  15%|█▌        | 75/500 [01:38<07:35,  1.07s/it]


Iteration 75/500
  Total Loss: 15.1456
  Structure Loss: 14.5406
  Feature Loss: 0.5881
  Domain Loss: 0.0169
  GEX->Morpho Acc: 0.1798
  Morpho->GEX Acc: 0.1659
  Average Acc: 0.1729
--------------------------------------------------


Training UnionCom:  20%|██        | 100/500 [02:07<08:42,  1.31s/it]


Iteration 100/500
  Total Loss: 8.1320
  Structure Loss: 7.7341
  Feature Loss: 0.3800
  Domain Loss: 0.0178
  GEX->Morpho Acc: 0.1473
  Morpho->GEX Acc: 0.1628
  Average Acc: 0.1550
--------------------------------------------------


Training UnionCom:  25%|██▌       | 125/500 [02:39<08:16,  1.32s/it]


Iteration 125/500
  Total Loss: 5.1507
  Structure Loss: 4.8344
  Feature Loss: 0.2968
  Domain Loss: 0.0195
  GEX->Morpho Acc: 0.1721
  Morpho->GEX Acc: 0.1643
  Average Acc: 0.1682
--------------------------------------------------


Training UnionCom:  30%|███       | 150/500 [03:06<06:53,  1.18s/it]


Iteration 150/500
  Total Loss: 3.6101
  Structure Loss: 3.3535
  Feature Loss: 0.2376
  Domain Loss: 0.0190
  GEX->Morpho Acc: 0.1845
  Morpho->GEX Acc: 0.1767
  Average Acc: 0.1806
--------------------------------------------------


Training UnionCom:  35%|███▌      | 175/500 [03:36<06:19,  1.17s/it]


Iteration 175/500
  Total Loss: 2.6644
  Structure Loss: 2.4539
  Feature Loss: 0.1923
  Domain Loss: 0.0182
  GEX->Morpho Acc: 0.1953
  Morpho->GEX Acc: 0.1829
  Average Acc: 0.1891
--------------------------------------------------


Training UnionCom:  40%|████      | 200/500 [04:06<06:03,  1.21s/it]


Iteration 200/500
  Total Loss: 2.0448
  Structure Loss: 1.8556
  Feature Loss: 0.1714
  Domain Loss: 0.0179
  GEX->Morpho Acc: 0.2016
  Morpho->GEX Acc: 0.1798
  Average Acc: 0.1907
--------------------------------------------------


Training UnionCom:  45%|████▌     | 225/500 [04:38<05:06,  1.12s/it]


Iteration 225/500
  Total Loss: 1.6197
  Structure Loss: 1.4416
  Feature Loss: 0.1603
  Domain Loss: 0.0177
  GEX->Morpho Acc: 0.1953
  Morpho->GEX Acc: 0.1798
  Average Acc: 0.1876
--------------------------------------------------


Training UnionCom:  50%|█████     | 250/500 [05:09<04:42,  1.13s/it]


Iteration 250/500
  Total Loss: 1.3161
  Structure Loss: 1.1466
  Feature Loss: 0.1519
  Domain Loss: 0.0177
  GEX->Morpho Acc: 0.1969
  Morpho->GEX Acc: 0.1783
  Average Acc: 0.1876
--------------------------------------------------


Training UnionCom:  55%|█████▌    | 275/500 [05:41<04:22,  1.17s/it]


Iteration 275/500
  Total Loss: 1.0908
  Structure Loss: 0.9296
  Feature Loss: 0.1436
  Domain Loss: 0.0176
  GEX->Morpho Acc: 0.1907
  Morpho->GEX Acc: 0.1736
  Average Acc: 0.1822
--------------------------------------------------


Training UnionCom:  60%|██████    | 300/500 [06:13<04:09,  1.25s/it]


Iteration 300/500
  Total Loss: 0.9203
  Structure Loss: 0.7669
  Feature Loss: 0.1359
  Domain Loss: 0.0175
  GEX->Morpho Acc: 0.1876
  Morpho->GEX Acc: 0.1643
  Average Acc: 0.1760
--------------------------------------------------


Training UnionCom:  65%|██████▌   | 325/500 [06:45<04:06,  1.41s/it]


Iteration 325/500
  Total Loss: 0.7887
  Structure Loss: 0.6426
  Feature Loss: 0.1287
  Domain Loss: 0.0174
  GEX->Morpho Acc: 0.1798
  Morpho->GEX Acc: 0.1628
  Average Acc: 0.1713
--------------------------------------------------


Training UnionCom:  70%|███████   | 350/500 [07:18<03:41,  1.48s/it]


Iteration 350/500
  Total Loss: 0.6864
  Structure Loss: 0.5457
  Feature Loss: 0.1233
  Domain Loss: 0.0175
  GEX->Morpho Acc: 0.1752
  Morpho->GEX Acc: 0.1566
  Average Acc: 0.1659
--------------------------------------------------


Training UnionCom:  75%|███████▌  | 375/500 [07:48<02:31,  1.21s/it]


Iteration 375/500
  Total Loss: 0.6013
  Structure Loss: 0.4692
  Feature Loss: 0.1149
  Domain Loss: 0.0172
  GEX->Morpho Acc: 0.1783
  Morpho->GEX Acc: 0.1550
  Average Acc: 0.1667
--------------------------------------------------


Training UnionCom:  80%|████████  | 400/500 [08:17<02:06,  1.27s/it]


Iteration 400/500
  Total Loss: 0.5355
  Structure Loss: 0.4081
  Feature Loss: 0.1105
  Domain Loss: 0.0168
  GEX->Morpho Acc: 0.1783
  Morpho->GEX Acc: 0.1597
  Average Acc: 0.1690
--------------------------------------------------


Training UnionCom:  85%|████████▌ | 425/500 [08:48<01:25,  1.14s/it]


Iteration 425/500
  Total Loss: 0.4773
  Structure Loss: 0.3573
  Feature Loss: 0.1031
  Domain Loss: 0.0170
  GEX->Morpho Acc: 0.1798
  Morpho->GEX Acc: 0.1581
  Average Acc: 0.1690
--------------------------------------------------


Training UnionCom:  90%|█████████ | 450/500 [09:26<01:29,  1.80s/it]


Iteration 450/500
  Total Loss: 0.4302
  Structure Loss: 0.3148
  Feature Loss: 0.0985
  Domain Loss: 0.0169
  GEX->Morpho Acc: 0.1876
  Morpho->GEX Acc: 0.1581
  Average Acc: 0.1729
--------------------------------------------------


Training UnionCom:  95%|█████████▌| 475/500 [10:02<00:33,  1.34s/it]


Iteration 475/500
  Total Loss: 0.3923
  Structure Loss: 0.2794
  Feature Loss: 0.0965
  Domain Loss: 0.0164
  GEX->Morpho Acc: 0.1953
  Morpho->GEX Acc: 0.1566
  Average Acc: 0.1760
--------------------------------------------------


Training UnionCom: 100%|██████████| 500/500 [10:41<00:00,  1.28s/it]


Iteration 500/500
  Total Loss: 0.3565
  Structure Loss: 0.2481
  Feature Loss: 0.0918
  Domain Loss: 0.0166
  GEX->Morpho Acc: 0.1907
  Morpho->GEX Acc: 0.1550
  Average Acc: 0.1729
--------------------------------------------------

Getting latent space embeddings...

FINAL UNIONCOM RESULTS:
GEX -> Morpho Accuracy: 0.1907
Morpho -> GEX Accuracy: 0.1550
Average Accuracy: 0.1729

Performing UMAP dimensionality reduction...
  - Computing GEX UMAP...





  - GEX UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/UnionCom/gex_umap.csv
  - Computing Morpho UMAP...
  - Morpho UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/UnionCom/morpho_umap.csv

Done!


## SCIM

In [5]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
import umap
import os
import warnings
warnings.filterwarnings('ignore')


class MutualInformationLoss(nn.Module):
    """互信息损失函数"""
    def __init__(self, sigma=1.0, num_bins=50):
        super(MutualInformationLoss, self).__init__()
        self.sigma = sigma
        self.num_bins = num_bins
    
    def gaussian_kernel(self, x, y, sigma):
        """高斯核函数"""
        dist = torch.cdist(x, y, p=2)
        return torch.exp(-dist**2 / (2 * sigma**2))
    
    def entropy_estimate_kde(self, x):
        """使用核密度估计计算熵"""
        n = x.size(0)
        K = self.gaussian_kernel(x, x, self.sigma)
        K = K - torch.eye(n, device=x.device)
        density = torch.sum(K, dim=1) / (n - 1)
        density = torch.clamp(density, min=1e-10)
        entropy = -torch.mean(torch.log(density))
        return entropy
    
    def mutual_information_kde(self, x, y):
        """使用核密度估计计算互信息"""
        h_x = self.entropy_estimate_kde(x)
        h_y = self.entropy_estimate_kde(y)
        
        xy = torch.cat([x, y], dim=1)
        h_xy = self.entropy_estimate_kde(xy)
        
        mi = h_x + h_y - h_xy
        return mi
    
    def forward(self, x, y):
        """计算负互信息作为损失"""
        mi = self.mutual_information_kde(x, y)
        return -mi


class SCIMEncoder(nn.Module):
    """SCIM编码器网络"""
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.1):
        super(SCIMEncoder, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            if i < len(hidden_dims) - 1:
                layers.append(nn.Dropout(dropout_rate))
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        
        self.encoder = nn.Sequential(*layers)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight, gain=0.1)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.BatchNorm1d):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        return self.encoder(x)


class ContrastiveLoss(nn.Module):
    """对比学习损失函数"""
    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, z1, z2):
        batch_size = z1.size(0)
        
        z1 = nn.functional.normalize(z1, dim=1)
        z2 = nn.functional.normalize(z2, dim=1)
        
        similarity_matrix = torch.matmul(z1, z2.t()) / self.temperature
        
        positive_samples = torch.diag(similarity_matrix)
        
        mask = torch.eye(batch_size, device=z1.device).bool()
        negative_samples = similarity_matrix.masked_select(~mask).view(batch_size, -1)
        
        logits = torch.cat([positive_samples.unsqueeze(1), negative_samples], dim=1)
        labels = torch.zeros(batch_size, dtype=torch.long, device=z1.device)
        
        loss = nn.functional.cross_entropy(logits, labels)
        return loss


class SCIMIntegrator:
    """SCIM多模态数据整合器"""
    
    def __init__(self, gex_dim, morpho_dim, latent_dim=128, hidden_dims=None, 
                 mi_weight=1.0, contrastive_weight=0.5,
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.latent_dim = latent_dim
        self.mi_weight = mi_weight
        self.contrastive_weight = contrastive_weight
        
        if hidden_dims is None:
            hidden_dims = [256, 128]
        
        print(f"Initializing SCIM with dimensions: GEX={gex_dim}, Morpho={morpho_dim}, Latent={latent_dim}")
        
        self.gex_encoder = SCIMEncoder(gex_dim, hidden_dims, latent_dim).to(device)
        self.morpho_encoder = SCIMEncoder(morpho_dim, hidden_dims, latent_dim).to(device)
        
        self.mi_loss = MutualInformationLoss(sigma=1.0).to(device)
        self.contrastive_loss = ContrastiveLoss(temperature=0.1).to(device)
        self.reconstruction_loss = nn.MSELoss()
        
        self.gex_decoder = self._create_decoder(latent_dim, hidden_dims, gex_dim).to(device)
        self.morpho_decoder = self._create_decoder(latent_dim, hidden_dims, morpho_dim).to(device)
        
        all_params = (list(self.gex_encoder.parameters()) + 
                     list(self.morpho_encoder.parameters()) +
                     list(self.gex_decoder.parameters()) +
                     list(self.morpho_decoder.parameters()))
        
        self.optimizer = optim.Adam(all_params, lr=0.001, weight_decay=1e-4)
        
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=15, factor=0.5, verbose=True
        )
        
        self.gex_scaler = StandardScaler()
        self.morpho_scaler = StandardScaler()
        
        print("SCIM initialization completed!")
    
    def _create_decoder(self, input_dim, hidden_dims, output_dim):
        """创建重构解码器"""
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in reversed(hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.1))
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        
        return nn.Sequential(*layers)
    
    def train(self, gex_data, morpho_data, epochs=200, batch_size=64, 
              reconstruction_weight=0.1, verbose=True):
        """训练SCIM模型"""
        
        print("Preprocessing data...")
        
        gex_data = np.nan_to_num(gex_data, nan=0.0, posinf=1e6, neginf=-1e6)
        morpho_data = np.nan_to_num(morpho_data, nan=0.0, posinf=1e6, neginf=-1e6)
        
        gex_data_norm = self.gex_scaler.fit_transform(gex_data)
        morpho_data_norm = self.morpho_scaler.fit_transform(morpho_data)
        
        gex_tensor = torch.FloatTensor(gex_data_norm).to(self.device)
        morpho_tensor = torch.FloatTensor(morpho_data_norm).to(self.device)
        
        dataset = TensorDataset(gex_tensor, morpho_tensor)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        
        print(f"Starting SCIM training for {epochs} epochs...")
        
        best_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(epochs):
            epoch_total_loss = 0.0
            num_batches = 0
            
            self.gex_encoder.train()
            self.morpho_encoder.train()
            self.gex_decoder.train()
            self.morpho_decoder.train()
            
            for batch_idx, (gex_batch, morpho_batch) in enumerate(dataloader):
                if gex_batch.size(0) < 8:
                    continue
                
                try:
                    gex_encoded = self.gex_encoder(gex_batch)
                    morpho_encoded = self.morpho_encoder(morpho_batch)
                    
                    if torch.isnan(gex_encoded).any() or torch.isnan(morpho_encoded).any():
                        continue
                    
                    mi_loss = self.mi_loss(gex_encoded, morpho_encoded)
                    contrastive_loss = self.contrastive_loss(gex_encoded, morpho_encoded)
                    
                    gex_reconstructed = self.gex_decoder(gex_encoded)
                    morpho_reconstructed = self.morpho_decoder(morpho_encoded)
                    
                    recon_loss_gex = self.reconstruction_loss(gex_reconstructed, gex_batch)
                    recon_loss_morpho = self.reconstruction_loss(morpho_reconstructed, morpho_batch)
                    reconstruction_loss = (recon_loss_gex + recon_loss_morpho) / 2
                    
                    total_loss = (self.mi_weight * mi_loss + 
                                 self.contrastive_weight * contrastive_loss +
                                 reconstruction_weight * reconstruction_loss)
                    
                    if torch.isnan(total_loss) or torch.isinf(total_loss):
                        continue
                    
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    
                    torch.nn.utils.clip_grad_norm_(
                        list(self.gex_encoder.parameters()) + 
                        list(self.morpho_encoder.parameters()) +
                        list(self.gex_decoder.parameters()) +
                        list(self.morpho_decoder.parameters()), 
                        max_norm=1.0
                    )
                    
                    self.optimizer.step()
                    
                    epoch_total_loss += total_loss.item()
                    num_batches += 1
                    
                except Exception as e:
                    continue
            
            if num_batches == 0:
                continue
            
            avg_loss = epoch_total_loss / num_batches
            
            self.scheduler.step(avg_loss)
            
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
                self.best_gex_encoder_state = self.gex_encoder.state_dict().copy()
                self.best_morpho_encoder_state = self.morpho_encoder.state_dict().copy()
            else:
                patience_counter += 1
            
            if verbose and (epoch + 1) % 20 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.6f}')
            
            if patience_counter >= 40:
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        if hasattr(self, 'best_gex_encoder_state'):
            self.gex_encoder.load_state_dict(self.best_gex_encoder_state)
            self.morpho_encoder.load_state_dict(self.best_morpho_encoder_state)
            print(f"Loaded best model with loss: {best_loss:.6f}")
    
    def transform(self, gex_data, morpho_data):
        """将数据转换到SCIM学习的表示空间"""
        self.gex_encoder.eval()
        self.morpho_encoder.eval()
        
        with torch.no_grad():
            gex_data_norm = self.gex_scaler.transform(gex_data)
            morpho_data_norm = self.morpho_scaler.transform(morpho_data)
            
            gex_tensor = torch.FloatTensor(gex_data_norm).to(self.device)
            morpho_tensor = torch.FloatTensor(morpho_data_norm).to(self.device)
            
            gex_encoded = self.gex_encoder(gex_tensor).cpu().numpy()
            morpho_encoded = self.morpho_encoder(morpho_tensor).cpu().numpy()
            
        return gex_encoded, morpho_encoded


def calculate_celltype_accuracy(gex_encoded, morpho_encoded, rna_family_labels, k=1):
    """计算细胞类型匹配准确率"""
    
    nbrs_gex = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(gex_encoded)
    distances_m2g, indices_m2g = nbrs_gex.kneighbors(morpho_encoded)
    
    nearest_gex_indices = indices_m2g[:, 0]
    nearest_gex_labels = rna_family_labels[nearest_gex_indices]
    morpho_labels = rna_family_labels
    
    matches_m2g = np.sum(morpho_labels == nearest_gex_labels)
    morpho_to_gex_accuracy = matches_m2g / len(morpho_labels)
    
    nbrs_morpho = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(morpho_encoded)
    distances_g2m, indices_g2m = nbrs_morpho.kneighbors(gex_encoded)
    
    nearest_morpho_indices = indices_g2m[:, 0]
    nearest_morpho_labels = rna_family_labels[nearest_morpho_indices]
    gex_labels = rna_family_labels
    
    matches_g2m = np.sum(gex_labels == nearest_morpho_labels)
    gex_to_morpho_accuracy = matches_g2m / len(gex_labels)
    
    average_accuracy = (morpho_to_gex_accuracy + gex_to_morpho_accuracy) / 2
    
    return morpho_to_gex_accuracy, gex_to_morpho_accuracy, average_accuracy


def main():
    """主函数：完整的SCIM整合流程"""
    
    # 数据路径
    gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
    morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
    rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
    
    # 输出路径
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/SCIM/"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading data...")
    
    # 加载基因表达数据
    gex_df = pd.read_csv(gene_expression_path, header=None)
    gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
    
    # 加载形态学数据
    morpho_df = pd.read_csv(morphology_path, header=0)
    morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
    
    # 加载RNA family标签
    try:
        rna_df = pd.read_csv(rna_family_path, header=0)
        if rna_df.shape[1] == 1:
            rna_family_labels = rna_df.iloc[:, 0].values
        else:
            rna_family_labels = rna_df.iloc[:, 1].values
        
        min_samples = min(len(gex_data), len(morpho_data), len(rna_family_labels))
        gex_data = gex_data[:min_samples]
        morpho_data = morpho_data[:min_samples]
        rna_family_labels = rna_family_labels[:min_samples]
        
        print(f"Data loaded successfully:")
        print(f"  - GEX data shape: {gex_data.shape}")
        print(f"  - Morpho data shape: {morpho_data.shape}")
        print(f"  - RNA family labels: {len(np.unique(rna_family_labels))} unique types")
        
    except Exception as e:
        print(f"Error loading RNA family labels: {e}")
        return
    
    # 创建并训练SCIM模型
    print("\nInitializing SCIM model...")
    integrator = SCIMIntegrator(
        gex_dim=gex_data.shape[1],
        morpho_dim=morpho_data.shape[1],
        latent_dim=128,
        hidden_dims=[256, 128],
        mi_weight=1.0,
        contrastive_weight=0.5
    )
    
    print("Training SCIM model...")
    integrator.train(
        gex_data, morpho_data,
        epochs=200,
        batch_size=32,
        reconstruction_weight=0.1,
        verbose=True
    )
    
    # 转换数据到SCIM表示空间
    print("\nTransforming data to SCIM representation space...")
    gex_encoded, morpho_encoded = integrator.transform(gex_data, morpho_data)
    
    print(f"Encoded data shapes:")
    print(f"  - GEX encoded: {gex_encoded.shape}")
    print(f"  - Morpho encoded: {morpho_encoded.shape}")
    
    # 计算细胞类型匹配准确率
    print("\nCalculating celltype accuracy...")
    morpho_to_gex_acc, gex_to_morpho_acc, avg_acc = calculate_celltype_accuracy(
        gex_encoded, morpho_encoded, rna_family_labels
    )
    
    # 输出结果
    print(f"\n" + "="*60)
    print(f"SCIM INTEGRATION RESULTS:")
    print(f"="*60)
    print(f"  Morpho to GEX accuracy: {morpho_to_gex_acc:.4f}")
    print(f"  GEX to Morpho accuracy: {gex_to_morpho_acc:.4f}")
    print(f"  Average accuracy: {avg_acc:.4f}")
    print(f"="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(gex_encoded)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(morpho_encoded)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")


if __name__ == "__main__":
    main()

Loading data...
Data loaded successfully:
  - GEX data shape: (645, 2000)
  - Morpho data shape: (645, 645)
  - RNA family labels: 10 unique types

Initializing SCIM model...
Initializing SCIM with dimensions: GEX=2000, Morpho=645, Latent=128
SCIM initialization completed!
Training SCIM model...
Preprocessing data...
Starting SCIM training for 200 epochs...
Epoch [20/200], Loss: -3.788324
Epoch [40/200], Loss: -21.455040
Epoch [60/200], Loss: -21.932619
Epoch [80/200], Loss: -22.007526
Epoch [100/200], Loss: -22.477897
Epoch [120/200], Loss: -22.579822
Epoch [140/200], Loss: -22.593367
Epoch [160/200], Loss: -22.637714
Epoch [180/200], Loss: -22.654634
Epoch [200/200], Loss: -22.637306
Loaded best model with loss: -22.742101

Transforming data to SCIM representation space...
Encoded data shapes:
  - GEX encoded: (645, 128)
  - Morpho encoded: (645, 128)

Calculating celltype accuracy...

SCIM INTEGRATION RESULTS:
  Morpho to GEX accuracy: 0.6558
  GEX to Morpho accuracy: 0.5814
  Avera

## SCJoint

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import umap
from tqdm import tqdm
import os

class scJointIntegrator:
    """
    scJoint implementation for multi-modal data integration
    Based on the scJoint algorithm using transfer learning and adversarial training
    """
    
    def __init__(self, latent_dim=64, hidden_dims=[512, 256], lambda_adv=1.0, lambda_recon=10.0,
                 lambda_kl=1.0, max_epochs=200, batch_size=64, lr=0.001, device='cuda'):
        """
        Initialize scJoint integrator
        
        Args:
            latent_dim: Dimension of the shared latent space
            hidden_dims: Hidden layer dimensions
            lambda_adv: Weight for adversarial loss
            lambda_recon: Weight for reconstruction loss
            lambda_kl: Weight for KL divergence loss
            max_epochs: Maximum training epochs
            batch_size: Training batch size
            lr: Learning rate
            device: Computing device
        """
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.lambda_adv = lambda_adv
        self.lambda_recon = lambda_recon
        self.lambda_kl = lambda_kl
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.device = device if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Will be initialized after loading data
        self.encoder_gex = None
        self.encoder_morpho = None
        self.decoder_gex = None
        self.decoder_morpho = None
        self.discriminator = None
        
    def load_data(self):
        """Load multi-modal data"""
        # Gene expression data
        gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
        gex_df = pd.read_csv(gene_expression_path, header=None)
        self.gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # Morphology data
        morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
        morpho_df = pd.read_csv(morphology_path, header=0)
        self.morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # RNA family labels
        rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
        try:
            rna_df = pd.read_csv(rna_family_path, header=0)
            if rna_df.shape[1] == 1:
                self.rna_family_labels = rna_df.iloc[:, 0].values
            else:
                self.rna_family_labels = rna_df.iloc[:, 1].values
            
            min_samples = min(len(self.gex_data), len(self.morpho_data))
            self.rna_family_labels = self.rna_family_labels[:min_samples]
            self.gex_data = self.gex_data[:min_samples]
            self.morpho_data = self.morpho_data[:min_samples]
            
            print(f"RNA family labels loaded: {len(np.unique(self.rna_family_labels))} unique types")
            print(f"Data shapes - GEX: {self.gex_data.shape}, Morpho: {self.morpho_data.shape}")
        except Exception as e:
            print(f"Warning: Could not load RNA family labels: {e}")
            self.rna_family_labels = None
        
        # Data normalization
        self.gex_mean = np.mean(self.gex_data, axis=0)
        self.gex_std = np.std(self.gex_data, axis=0) + 1e-8
        self.morpho_mean = np.mean(self.morpho_data, axis=0)
        self.morpho_std = np.std(self.morpho_data, axis=0) + 1e-8
        
        self.gex_data_norm = (self.gex_data - self.gex_mean) / self.gex_std
        self.morpho_data_norm = (self.morpho_data - self.morpho_mean) / self.morpho_std
        
        self.n_samples = len(self.gex_data)
        self.gex_dim = self.gex_data.shape[1]
        self.morpho_dim = self.morpho_data.shape[1]
        
        # Initialize networks
        self._initialize_networks()
        
        # Convert to tensors
        self.gex_tensor = torch.FloatTensor(self.gex_data_norm).to(self.device)
        self.morpho_tensor = torch.FloatTensor(self.morpho_data_norm).to(self.device)
        
        print(f"Data loaded and normalized. Shape: GEX {self.gex_tensor.shape}, Morpho {self.morpho_tensor.shape}")
        
    def _initialize_networks(self):
        """Initialize encoder, decoder, and discriminator networks"""
        
        class Encoder(nn.Module):
            def __init__(self, input_dim, latent_dim, hidden_dims):
                super(Encoder, self).__init__()
                layers = []
                prev_dim = input_dim
                
                for hidden_dim in hidden_dims:
                    layers.extend([
                        nn.Linear(prev_dim, hidden_dim),
                        nn.BatchNorm1d(hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(0.2)
                    ])
                    prev_dim = hidden_dim
                
                self.encoder = nn.Sequential(*layers)
                
                # Variational components
                self.mu_layer = nn.Linear(prev_dim, latent_dim)
                self.logvar_layer = nn.Linear(prev_dim, latent_dim)
                
            def forward(self, x):
                h = self.encoder(x)
                mu = self.mu_layer(h)
                logvar = self.logvar_layer(h)
                return mu, logvar
            
            def reparameterize(self, mu, logvar):
                if self.training:
                    std = torch.exp(0.5 * logvar)
                    eps = torch.randn_like(std)
                    return mu + eps * std
                else:
                    return mu
        
        class Decoder(nn.Module):
            def __init__(self, latent_dim, output_dim, hidden_dims):
                super(Decoder, self).__init__()
                layers = []
                prev_dim = latent_dim
                
                # Reverse the hidden dimensions for decoder
                reversed_hidden_dims = hidden_dims[::-1]
                
                for hidden_dim in reversed_hidden_dims:
                    layers.extend([
                        nn.Linear(prev_dim, hidden_dim),
                        nn.BatchNorm1d(hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(0.2)
                    ])
                    prev_dim = hidden_dim
                
                layers.append(nn.Linear(prev_dim, output_dim))
                self.decoder = nn.Sequential(*layers)
                
            def forward(self, z):
                return self.decoder(z)
        
        class Discriminator(nn.Module):
            def __init__(self, latent_dim):
                super(Discriminator, self).__init__()
                self.discriminator = nn.Sequential(
                    nn.Linear(latent_dim, 128),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(128, 64),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(64, 32),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(32, 1),
                    nn.Sigmoid()
                )
                
            def forward(self, z):
                return self.discriminator(z)
        
        # Initialize networks
        self.encoder_gex = Encoder(self.gex_dim, self.latent_dim, self.hidden_dims).to(self.device)
        self.encoder_morpho = Encoder(self.morpho_dim, self.latent_dim, self.hidden_dims).to(self.device)
        self.decoder_gex = Decoder(self.latent_dim, self.gex_dim, self.hidden_dims).to(self.device)
        self.decoder_morpho = Decoder(self.latent_dim, self.morpho_dim, self.hidden_dims).to(self.device)
        self.discriminator = Discriminator(self.latent_dim).to(self.device)
        
        # Initialize optimizers
        self.optimizer_ae = optim.Adam(
            list(self.encoder_gex.parameters()) + 
            list(self.encoder_morpho.parameters()) +
            list(self.decoder_gex.parameters()) + 
            list(self.decoder_morpho.parameters()),
            lr=self.lr, betas=(0.5, 0.999)
        )
        
        self.optimizer_disc = optim.Adam(
            self.discriminator.parameters(),
            lr=self.lr, betas=(0.5, 0.999)
        )
        
        print("Networks initialized successfully")
        
    def _kl_divergence(self, mu, logvar):
        """Calculate KL divergence loss"""
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    def _reconstruction_loss(self, recon, original):
        """Calculate reconstruction loss"""
        return F.mse_loss(recon, original, reduction='sum')
    
    def train_epoch(self, epoch):
        """Train one epoch"""
        self.encoder_gex.train()
        self.encoder_morpho.train()
        self.decoder_gex.train()
        self.decoder_morpho.train()
        self.discriminator.train()
        
        # Create data indices
        indices = torch.randperm(self.n_samples)
        n_batches = (self.n_samples + self.batch_size - 1) // self.batch_size
        
        total_ae_loss = 0
        total_disc_loss = 0
        
        for batch_idx in range(n_batches):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, self.n_samples)
            batch_indices = indices[start_idx:end_idx]
            
            gex_batch = self.gex_tensor[batch_indices]
            morpho_batch = self.morpho_tensor[batch_indices]
            batch_size_actual = len(batch_indices)
            
            # ==========================================
            # Train Autoencoder (Generators)
            # ==========================================
            self.optimizer_ae.zero_grad()
            
            # Encode
            gex_mu, gex_logvar = self.encoder_gex(gex_batch)
            morpho_mu, morpho_logvar = self.encoder_morpho(morpho_batch)
            
            # Reparameterize
            gex_z = self.encoder_gex.reparameterize(gex_mu, gex_logvar)
            morpho_z = self.encoder_morpho.reparameterize(morpho_mu, morpho_logvar)
            
            # Decode (reconstruction)
            gex_recon = self.decoder_gex(gex_z)
            morpho_recon = self.decoder_morpho(morpho_z)
            
            # Cross-modal reconstruction (transfer learning)
            gex_cross_recon = self.decoder_gex(morpho_z)
            morpho_cross_recon = self.decoder_morpho(gex_z)
            
            # Reconstruction losses
            gex_recon_loss = self._reconstruction_loss(gex_recon, gex_batch)
            morpho_recon_loss = self._reconstruction_loss(morpho_recon, morpho_batch)
            
            # Cross-modal reconstruction losses
            gex_cross_loss = self._reconstruction_loss(gex_cross_recon, gex_batch)
            morpho_cross_loss = self._reconstruction_loss(morpho_cross_recon, morpho_batch)
            
            # KL divergence losses
            gex_kl_loss = self._kl_divergence(gex_mu, gex_logvar)
            morpho_kl_loss = self._kl_divergence(morpho_mu, morpho_logvar)
            
            # Adversarial losses (fool the discriminator)
            gex_disc_fake = self.discriminator(gex_z)
            morpho_disc_fake = self.discriminator(morpho_z)
            
            # Labels for adversarial training
            valid_labels = torch.ones(batch_size_actual, 1).to(self.device)
            
            gex_adv_loss = F.binary_cross_entropy(gex_disc_fake, valid_labels)
            morpho_adv_loss = F.binary_cross_entropy(morpho_disc_fake, valid_labels)
            
            # Total autoencoder loss
            ae_loss = (self.lambda_recon * (gex_recon_loss + morpho_recon_loss + 
                                          gex_cross_loss + morpho_cross_loss) +
                      self.lambda_kl * (gex_kl_loss + morpho_kl_loss) +
                      self.lambda_adv * (gex_adv_loss + morpho_adv_loss))
            
            ae_loss.backward()
            self.optimizer_ae.step()
            
            # ==========================================
            # Train Discriminator
            # ==========================================
            self.optimizer_disc.zero_grad()
            
            # Real samples (from prior distribution)
            real_z = torch.randn(batch_size_actual, self.latent_dim).to(self.device)
            real_labels = torch.ones(batch_size_actual, 1).to(self.device)
            fake_labels = torch.zeros(batch_size_actual, 1).to(self.device)
            
            # Discriminator on real samples
            real_disc = self.discriminator(real_z)
            real_loss = F.binary_cross_entropy(real_disc, real_labels)
            
            # Discriminator on fake samples (detached to avoid gradients to encoder)
            fake_gex_z = self.encoder_gex.reparameterize(gex_mu.detach(), gex_logvar.detach())
            fake_morpho_z = self.encoder_morpho.reparameterize(morpho_mu.detach(), morpho_logvar.detach())
            
            fake_gex_disc = self.discriminator(fake_gex_z)
            fake_morpho_disc = self.discriminator(fake_morpho_z)
            
            fake_gex_loss = F.binary_cross_entropy(fake_gex_disc, fake_labels)
            fake_morpho_loss = F.binary_cross_entropy(fake_morpho_disc, fake_labels)
            
            # Total discriminator loss
            disc_loss = real_loss + (fake_gex_loss + fake_morpho_loss) / 2
            
            disc_loss.backward()
            self.optimizer_disc.step()
            
            total_ae_loss += ae_loss.item()
            total_disc_loss += disc_loss.item()
        
        return total_ae_loss / n_batches, total_disc_loss / n_batches
    
    def get_integrated_embeddings(self):
        """Get integrated embeddings in the shared latent space"""
        self.encoder_gex.eval()
        self.encoder_morpho.eval()
        
        with torch.no_grad():
            gex_mu, _ = self.encoder_gex(self.gex_tensor)
            morpho_mu, _ = self.encoder_morpho(self.morpho_tensor)
            
            # Use mean (mu) as the embedding for deterministic results
            gex_embeddings = gex_mu.cpu().numpy()
            morpho_embeddings = morpho_mu.cpu().numpy()
        
        return gex_embeddings, morpho_embeddings
    
    def calculate_celltype_accuracy(self, gex_embeddings, morpho_embeddings):
        """Calculate cell type matching accuracy"""
        if self.rna_family_labels is None:
            print("No RNA family labels available for accuracy calculation")
            return None, None, None
        
        # Use nearest neighbor search
        nbrs_morpho = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(morpho_embeddings)
        nbrs_gex = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(gex_embeddings)
        
        # GEX to Morpho matching accuracy
        _, indices_gex2morpho = nbrs_morpho.kneighbors(gex_embeddings)
        gex2morpho_matches = 0
        for i, nearest_idx in enumerate(indices_gex2morpho.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                gex2morpho_matches += 1
        gex2morpho_accuracy = gex2morpho_matches / len(self.rna_family_labels)
        
        # Morpho to GEX matching accuracy
        _, indices_morpho2gex = nbrs_gex.kneighbors(morpho_embeddings)
        morpho2gex_matches = 0
        for i, nearest_idx in enumerate(indices_morpho2gex.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                morpho2gex_matches += 1
        morpho2gex_accuracy = morpho2gex_matches / len(self.rna_family_labels)
        
        # Average matching accuracy
        average_accuracy = (gex2morpho_accuracy + morpho2gex_accuracy) / 2
        
        return gex2morpho_accuracy, morpho2gex_accuracy, average_accuracy
    
    def train(self, verbose_interval=20):
        """Train scJoint model"""
        print("Starting scJoint training...")
        
        for epoch in tqdm(range(self.max_epochs), desc="Training scJoint"):
            ae_loss, disc_loss = self.train_epoch(epoch)
            
            # Calculate accuracy periodically
            if (epoch + 1) % verbose_interval == 0:
                gex_emb, morpho_emb = self.get_integrated_embeddings()
                gex2morpho_acc, morpho2gex_acc, avg_acc = self.calculate_celltype_accuracy(gex_emb, morpho_emb)
                
                print(f"\nEpoch {epoch+1}/{self.max_epochs}")
                print(f"  Autoencoder Loss: {ae_loss:.4f}")
                print(f"  Discriminator Loss: {disc_loss:.4f}")
                if avg_acc is not None:
                    print(f"  GEX->Morpho Acc: {gex2morpho_acc:.4f}")
                    print(f"  Morpho->GEX Acc: {morpho2gex_acc:.4f}")
                    print(f"  Average Acc: {avg_acc:.4f}")
                print("-" * 50)

# Usage example
if __name__ == "__main__":
    # 创建输出目录
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/scJoint/"
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize scJoint integrator
    integrator = scJointIntegrator(
        latent_dim=64,
        hidden_dims=[512, 256],
        lambda_adv=1.0,
        lambda_recon=10.0,
        lambda_kl=1.0,
        max_epochs=200,
        batch_size=64,
        lr=0.001
    )
    
    # Load data
    integrator.load_data()
    
    # Train model
    integrator.train(verbose_interval=20)
    
    # Get final integrated embeddings
    print("\nGetting latent space embeddings...")
    final_gex_embeddings, final_morpho_embeddings = integrator.get_integrated_embeddings()
    
    # Calculate final accuracy
    final_gex2morpho_acc, final_morpho2gex_acc, final_avg_acc = integrator.calculate_celltype_accuracy(
        final_gex_embeddings, final_morpho_embeddings
    )
    
    print("\n" + "="*60)
    print("FINAL scJOINT RESULTS:")
    print("="*60)
    print(f"GEX -> Morpho Accuracy: {final_gex2morpho_acc:.4f}")
    print(f"Morpho -> GEX Accuracy: {final_morpho2gex_acc:.4f}")
    print(f"Average Accuracy: {final_avg_acc:.4f}")
    print("="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(final_gex_embeddings)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(final_morpho_embeddings)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")

Using device: cpu
RNA family labels loaded: 10 unique types
Data shapes - GEX: (645, 2000), Morpho: (645, 645)
Networks initialized successfully
Data loaded and normalized. Shape: GEX torch.Size([645, 2000]), Morpho torch.Size([645, 645])
Starting scJoint training...


Training scJoint:  10%|█         | 20/200 [00:28<04:04,  1.36s/it]


Epoch 20/200
  Autoencoder Loss: 2271037.8551
  Discriminator Loss: 0.2504
  GEX->Morpho Acc: 0.3333
  Morpho->GEX Acc: 0.3364
  Average Acc: 0.3349
--------------------------------------------------


Training scJoint:  20%|██        | 40/200 [00:52<03:18,  1.24s/it]


Epoch 40/200
  Autoencoder Loss: 2196729.4290
  Discriminator Loss: 0.1589
  GEX->Morpho Acc: 0.3969
  Morpho->GEX Acc: 0.4202
  Average Acc: 0.4085
--------------------------------------------------


Training scJoint:  30%|███       | 60/200 [01:20<03:34,  1.53s/it]


Epoch 60/200
  Autoencoder Loss: 2157874.6080
  Discriminator Loss: 0.0517
  GEX->Morpho Acc: 0.4295
  Morpho->GEX Acc: 0.4946
  Average Acc: 0.4620
--------------------------------------------------


Training scJoint:  40%|████      | 80/200 [01:53<03:17,  1.65s/it]


Epoch 80/200
  Autoencoder Loss: 2119367.9517
  Discriminator Loss: 0.0912
  GEX->Morpho Acc: 0.3829
  Morpho->GEX Acc: 0.5023
  Average Acc: 0.4426
--------------------------------------------------


Training scJoint:  50%|█████     | 100/200 [02:20<02:24,  1.44s/it]


Epoch 100/200
  Autoencoder Loss: 2077106.3750
  Discriminator Loss: 0.0685
  GEX->Morpho Acc: 0.4062
  Morpho->GEX Acc: 0.4574
  Average Acc: 0.4318
--------------------------------------------------


Training scJoint:  60%|██████    | 120/200 [02:43<01:46,  1.33s/it]


Epoch 120/200
  Autoencoder Loss: 2026913.9219
  Discriminator Loss: 0.0414
  GEX->Morpho Acc: 0.4248
  Morpho->GEX Acc: 0.4636
  Average Acc: 0.4442
--------------------------------------------------


Training scJoint:  70%|███████   | 140/200 [03:09<01:27,  1.46s/it]


Epoch 140/200
  Autoencoder Loss: 1997084.4105
  Discriminator Loss: 0.0722
  GEX->Morpho Acc: 0.4248
  Morpho->GEX Acc: 0.4589
  Average Acc: 0.4419
--------------------------------------------------


Training scJoint:  80%|████████  | 160/200 [03:34<00:55,  1.39s/it]


Epoch 160/200
  Autoencoder Loss: 1975421.4432
  Discriminator Loss: 0.0476
  GEX->Morpho Acc: 0.4171
  Morpho->GEX Acc: 0.4837
  Average Acc: 0.4504
--------------------------------------------------


Training scJoint:  90%|█████████ | 180/200 [03:58<00:21,  1.08s/it]


Epoch 180/200
  Autoencoder Loss: 1929905.3295
  Discriminator Loss: 0.0930
  GEX->Morpho Acc: 0.4016
  Morpho->GEX Acc: 0.4884
  Average Acc: 0.4450
--------------------------------------------------


Training scJoint: 100%|██████████| 200/200 [04:24<00:00,  1.32s/it]


Epoch 200/200
  Autoencoder Loss: 1897902.4403
  Discriminator Loss: 0.0491
  GEX->Morpho Acc: 0.4062
  Morpho->GEX Acc: 0.4930
  Average Acc: 0.4496
--------------------------------------------------

Getting latent space embeddings...

FINAL scJOINT RESULTS:
GEX -> Morpho Accuracy: 0.4062
Morpho -> GEX Accuracy: 0.4930
Average Accuracy: 0.4496

Performing UMAP dimensionality reduction...
  - Computing GEX UMAP...





  - GEX UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/scJoint/gex_umap.csv
  - Computing Morpho UMAP...
  - Morpho UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/scJoint/morpho_umap.csv

Done!


## SciCAN

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import umap
from tqdm import tqdm
import os

class sciCANIntegrator:
    """
    sciCAN implementation for multi-modal data integration
    Based on the sciCAN algorithm using Canonical Correlation Analysis and 
    Adversarial Networks for single-cell multi-omics integration
    """
    
    def __init__(self, cca_dim=50, adversarial_dim=128, hidden_dims=[512, 256], 
                 lambda_cca=1.0, lambda_adv=1.0, lambda_recon=10.0, lambda_cycle=5.0,
                 max_epochs=200, batch_size=64, lr=0.001, device='cuda'):
        """
        Initialize sciCAN integrator
        
        Args:
            cca_dim: Dimension of CCA space
            adversarial_dim: Dimension of adversarial feature space
            hidden_dims: Hidden layer dimensions
            lambda_cca: Weight for CCA loss
            lambda_adv: Weight for adversarial loss
            lambda_recon: Weight for reconstruction loss
            lambda_cycle: Weight for cycle consistency loss
            max_epochs: Maximum training epochs
            batch_size: Training batch size
            lr: Learning rate
            device: Computing device
        """
        self.cca_dim = cca_dim
        self.adversarial_dim = adversarial_dim
        self.hidden_dims = hidden_dims
        self.lambda_cca = lambda_cca
        self.lambda_adv = lambda_adv
        self.lambda_recon = lambda_recon
        self.lambda_cycle = lambda_cycle
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.device = device if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Networks will be initialized after loading data
        self.encoder_gex = None
        self.encoder_morpho = None
        self.decoder_gex = None
        self.decoder_morpho = None
        self.cca_projector_gex = None
        self.cca_projector_morpho = None
        self.discriminator = None
        
    def load_data(self):
        """Load multi-modal data"""
        # Gene expression data
        gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
        gex_df = pd.read_csv(gene_expression_path, header=None)
        self.gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # Morphology data
        morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
        morpho_df = pd.read_csv(morphology_path, header=0)
        self.morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # RNA family labels
        rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
        try:
            rna_df = pd.read_csv(rna_family_path, header=0)
            if rna_df.shape[1] == 1:
                self.rna_family_labels = rna_df.iloc[:, 0].values
            else:
                self.rna_family_labels = rna_df.iloc[:, 1].values
            
            min_samples = min(len(self.gex_data), len(self.morpho_data))
            self.rna_family_labels = self.rna_family_labels[:min_samples]
            self.gex_data = self.gex_data[:min_samples]
            self.morpho_data = self.morpho_data[:min_samples]
            
            print(f"RNA family labels loaded: {len(np.unique(self.rna_family_labels))} unique types")
            print(f"Data shapes - GEX: {self.gex_data.shape}, Morpho: {self.morpho_data.shape}")
        except Exception as e:
            print(f"Warning: Could not load RNA family labels: {e}")
            self.rna_family_labels = None
        
        # Data normalization
        self.gex_mean = np.mean(self.gex_data, axis=0)
        self.gex_std = np.std(self.gex_data, axis=0) + 1e-8
        self.morpho_mean = np.mean(self.morpho_data, axis=0)
        self.morpho_std = np.std(self.morpho_data, axis=0) + 1e-8
        
        self.gex_data_norm = (self.gex_data - self.gex_mean) / self.gex_std
        self.morpho_data_norm = (self.morpho_data - self.morpho_mean) / self.morpho_std
        
        self.n_samples = len(self.gex_data)
        self.gex_dim = self.gex_data.shape[1]
        self.morpho_dim = self.morpho_data.shape[1]
        
        # Initialize networks
        self._initialize_networks()
        
        # Convert to tensors
        self.gex_tensor = torch.FloatTensor(self.gex_data_norm).to(self.device)
        self.morpho_tensor = torch.FloatTensor(self.morpho_data_norm).to(self.device)
        
        print(f"Data loaded and normalized. Shape: GEX {self.gex_tensor.shape}, Morpho {self.morpho_tensor.shape}")
        
    def _initialize_networks(self):
        """Initialize neural networks for sciCAN"""
        
        class Encoder(nn.Module):
            def __init__(self, input_dim, hidden_dims, output_dim):
                super(Encoder, self).__init__()
                layers = []
                prev_dim = input_dim
                
                for hidden_dim in hidden_dims:
                    layers.extend([
                        nn.Linear(prev_dim, hidden_dim),
                        nn.BatchNorm1d(hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(0.2)
                    ])
                    prev_dim = hidden_dim
                
                layers.append(nn.Linear(prev_dim, output_dim))
                layers.append(nn.Tanh())
                self.encoder = nn.Sequential(*layers)
                
            def forward(self, x):
                return self.encoder(x)
        
        class Decoder(nn.Module):
            def __init__(self, input_dim, hidden_dims, output_dim):
                super(Decoder, self).__init__()
                layers = []
                prev_dim = input_dim
                
                reversed_hidden_dims = hidden_dims[::-1]
                
                for hidden_dim in reversed_hidden_dims:
                    layers.extend([
                        nn.Linear(prev_dim, hidden_dim),
                        nn.BatchNorm1d(hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(0.2)
                    ])
                    prev_dim = hidden_dim
                
                layers.append(nn.Linear(prev_dim, output_dim))
                self.decoder = nn.Sequential(*layers)
                
            def forward(self, x):
                return self.decoder(x)
        
        class CCAProjector(nn.Module):
            def __init__(self, input_dim, cca_dim):
                super(CCAProjector, self).__init__()
                self.projector = nn.Sequential(
                    nn.Linear(input_dim, cca_dim),
                    nn.BatchNorm1d(cca_dim)
                )
                
            def forward(self, x):
                return self.projector(x)
        
        class Discriminator(nn.Module):
            def __init__(self, input_dim):
                super(Discriminator, self).__init__()
                self.discriminator = nn.Sequential(
                    nn.Linear(input_dim, 256),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(256, 128),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(128, 64),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(64, 1),
                    nn.Sigmoid()
                )
                
            def forward(self, x):
                return self.discriminator(x)
        
        # Initialize networks
        self.encoder_gex = Encoder(self.gex_dim, self.hidden_dims, self.adversarial_dim).to(self.device)
        self.encoder_morpho = Encoder(self.morpho_dim, self.hidden_dims, self.adversarial_dim).to(self.device)
        
        self.decoder_gex = Decoder(self.adversarial_dim, self.hidden_dims, self.gex_dim).to(self.device)
        self.decoder_morpho = Decoder(self.adversarial_dim, self.hidden_dims, self.morpho_dim).to(self.device)
        
        self.cca_projector_gex = CCAProjector(self.adversarial_dim, self.cca_dim).to(self.device)
        self.cca_projector_morpho = CCAProjector(self.adversarial_dim, self.cca_dim).to(self.device)
        
        self.discriminator = Discriminator(self.adversarial_dim).to(self.device)
        
        # Initialize optimizers
        self.optimizer_generators = optim.Adam(
            list(self.encoder_gex.parameters()) + 
            list(self.encoder_morpho.parameters()) +
            list(self.decoder_gex.parameters()) + 
            list(self.decoder_morpho.parameters()) +
            list(self.cca_projector_gex.parameters()) +
            list(self.cca_projector_morpho.parameters()),
            lr=self.lr, betas=(0.5, 0.999)
        )
        
        self.optimizer_discriminator = optim.Adam(
            self.discriminator.parameters(),
            lr=self.lr, betas=(0.5, 0.999)
        )
        
        print("sciCAN networks initialized successfully")
        
    def _canonical_correlation_loss(self, h1, h2):
        """Calculate Canonical Correlation Analysis loss"""
        h1_centered = h1 - torch.mean(h1, dim=0, keepdim=True)
        h2_centered = h2 - torch.mean(h2, dim=0, keepdim=True)
        
        n = h1.size(0)
        c11 = torch.mm(h1_centered.t(), h1_centered) / (n - 1)
        c22 = torch.mm(h2_centered.t(), h2_centered) / (n - 1)
        c12 = torch.mm(h1_centered.t(), h2_centered) / (n - 1)
        
        eps = 1e-4
        c11 = c11 + eps * torch.eye(c11.size(0)).to(self.device)
        c22 = c22 + eps * torch.eye(c22.size(0)).to(self.device)
        
        correlation = torch.trace(c12) / (torch.sqrt(torch.trace(c11)) * torch.sqrt(torch.trace(c22)) + 1e-8)
        
        return -correlation
    
    def _reconstruction_loss(self, recon, original):
        """Calculate reconstruction loss"""
        return F.mse_loss(recon, original)
    
    def _cycle_consistency_loss(self, cycle_recon, original):
        """Calculate cycle consistency loss"""
        return F.l1_loss(cycle_recon, original)
    
    def train_epoch(self, epoch):
        """Train one epoch"""
        self.encoder_gex.train()
        self.encoder_morpho.train()
        self.decoder_gex.train()
        self.decoder_morpho.train()
        self.cca_projector_gex.train()
        self.cca_projector_morpho.train()
        self.discriminator.train()
        
        indices = torch.randperm(self.n_samples)
        n_batches = (self.n_samples + self.batch_size - 1) // self.batch_size
        
        total_gen_loss = 0
        total_disc_loss = 0
        
        for batch_idx in range(n_batches):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, self.n_samples)
            batch_indices = indices[start_idx:end_idx]
            
            gex_batch = self.gex_tensor[batch_indices]
            morpho_batch = self.morpho_tensor[batch_indices]
            batch_size_actual = len(batch_indices)
            
            valid = torch.ones(batch_size_actual, 1).to(self.device)
            fake = torch.zeros(batch_size_actual, 1).to(self.device)
            
            # Train Generators
            self.optimizer_generators.zero_grad()
            
            gex_features = self.encoder_gex(gex_batch)
            morpho_features = self.encoder_morpho(morpho_batch)
            
            gex_cca = self.cca_projector_gex(gex_features)
            morpho_cca = self.cca_projector_morpho(morpho_features)
            
            gex_recon = self.decoder_gex(gex_features)
            morpho_recon = self.decoder_morpho(morpho_features)
            
            gex_recon_loss = self._reconstruction_loss(gex_recon, gex_batch)
            morpho_recon_loss = self._reconstruction_loss(morpho_recon, morpho_batch)
            
            gex_from_morpho = self.decoder_gex(morpho_features)
            morpho_from_gex = self.decoder_morpho(gex_features)
            
            morpho_features_cycle = self.encoder_morpho(morpho_from_gex)
            gex_features_cycle = self.encoder_gex(gex_from_morpho)
            
            gex_cycle_recon = self.decoder_gex(gex_features_cycle)
            morpho_cycle_recon = self.decoder_morpho(morpho_features_cycle)
            
            gex_cycle_loss = self._cycle_consistency_loss(gex_cycle_recon, gex_batch)
            morpho_cycle_loss = self._cycle_consistency_loss(morpho_cycle_recon, morpho_batch)
            
            cca_loss = self._canonical_correlation_loss(gex_cca, morpho_cca)
            
            gex_disc_fake = self.discriminator(gex_features)
            morpho_disc_fake = self.discriminator(morpho_features)
            
            gex_adv_loss = F.binary_cross_entropy(gex_disc_fake, valid)
            morpho_adv_loss = F.binary_cross_entropy(morpho_disc_fake, valid)
            
            gen_loss = (self.lambda_recon * (gex_recon_loss + morpho_recon_loss) +
                       self.lambda_cycle * (gex_cycle_loss + morpho_cycle_loss) +
                       self.lambda_cca * cca_loss +
                       self.lambda_adv * (gex_adv_loss + morpho_adv_loss))
            
            gen_loss.backward()
            self.optimizer_generators.step()
            
            # Train Discriminator
            self.optimizer_discriminator.zero_grad()
            
            real_features = torch.randn(batch_size_actual, self.adversarial_dim).to(self.device)
            
            real_pred = self.discriminator(real_features)
            real_loss = F.binary_cross_entropy(real_pred, valid)
            
            fake_gex_features = self.encoder_gex(gex_batch).detach()
            fake_morpho_features = self.encoder_morpho(morpho_batch).detach()
            
            fake_gex_pred = self.discriminator(fake_gex_features)
            fake_morpho_pred = self.discriminator(fake_morpho_features)
            
            fake_gex_loss = F.binary_cross_entropy(fake_gex_pred, fake)
            fake_morpho_loss = F.binary_cross_entropy(fake_morpho_pred, fake)
            
            disc_loss = real_loss + (fake_gex_loss + fake_morpho_loss) / 2
            
            disc_loss.backward()
            self.optimizer_discriminator.step()
            
            total_gen_loss += gen_loss.item()
            total_disc_loss += disc_loss.item()
        
        return total_gen_loss / n_batches, total_disc_loss / n_batches
    
    def get_integrated_embeddings(self):
        """Get integrated embeddings in the CCA space"""
        self.encoder_gex.eval()
        self.encoder_morpho.eval()
        self.cca_projector_gex.eval()
        self.cca_projector_morpho.eval()
        
        with torch.no_grad():
            gex_features = self.encoder_gex(self.gex_tensor)
            morpho_features = self.encoder_morpho(self.morpho_tensor)
            
            gex_embeddings = self.cca_projector_gex(gex_features).cpu().numpy()
            morpho_embeddings = self.cca_projector_morpho(morpho_features).cpu().numpy()
        
        return gex_embeddings, morpho_embeddings
    
    def calculate_celltype_accuracy(self, gex_embeddings, morpho_embeddings):
        """Calculate cell type matching accuracy"""
        if self.rna_family_labels is None:
            print("No RNA family labels available for accuracy calculation")
            return None, None, None
        
        nbrs_morpho = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(morpho_embeddings)
        nbrs_gex = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(gex_embeddings)
        
        _, indices_gex2morpho = nbrs_morpho.kneighbors(gex_embeddings)
        gex2morpho_matches = 0
        for i, nearest_idx in enumerate(indices_gex2morpho.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                gex2morpho_matches += 1
        gex2morpho_accuracy = gex2morpho_matches / len(self.rna_family_labels)
        
        _, indices_morpho2gex = nbrs_gex.kneighbors(morpho_embeddings)
        morpho2gex_matches = 0
        for i, nearest_idx in enumerate(indices_morpho2gex.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                morpho2gex_matches += 1
        morpho2gex_accuracy = morpho2gex_matches / len(self.rna_family_labels)
        
        average_accuracy = (gex2morpho_accuracy + morpho2gex_accuracy) / 2
        
        return gex2morpho_accuracy, morpho2gex_accuracy, average_accuracy
    
    def train(self, verbose_interval=20):
        """Train sciCAN model"""
        print("Starting sciCAN training...")
        
        for epoch in tqdm(range(self.max_epochs), desc="Training sciCAN"):
            gen_loss, disc_loss = self.train_epoch(epoch)
            
            if (epoch + 1) % verbose_interval == 0:
                gex_emb_cca, morpho_emb_cca = self.get_integrated_embeddings()
                gex2morpho_acc_cca, morpho2gex_acc_cca, avg_acc_cca = self.calculate_celltype_accuracy(
                    gex_emb_cca, morpho_emb_cca)
                
                print(f"\nEpoch {epoch+1}/{self.max_epochs}")
                print(f"  Generator Loss: {gen_loss:.4f}")
                print(f"  Discriminator Loss: {disc_loss:.4f}")
                if avg_acc_cca is not None:
                    print(f"  CCA Space - GEX->Morpho: {gex2morpho_acc_cca:.4f}, Morpho->GEX: {morpho2gex_acc_cca:.4f}, Avg: {avg_acc_cca:.4f}")
                print("-" * 50)

# Usage example
if __name__ == "__main__":
    # 创建输出目录
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/sciCAN/"
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize sciCAN integrator
    integrator = sciCANIntegrator(
        cca_dim=50,
        adversarial_dim=128,
        hidden_dims=[512, 256],
        lambda_cca=1.0,
        lambda_adv=1.0,
        lambda_recon=10.0,
        lambda_cycle=5.0,
        max_epochs=200,
        batch_size=64,
        lr=0.001
    )
    
    # Load data
    integrator.load_data()
    
    # Train model
    integrator.train(verbose_interval=20)
    
    # Get final integrated embeddings (CCA space)
    print("\nGetting latent space embeddings (CCA space)...")
    final_gex_embeddings, final_morpho_embeddings = integrator.get_integrated_embeddings()
    
    # Calculate final accuracy
    final_gex2morpho_acc, final_morpho2gex_acc, final_avg_acc = integrator.calculate_celltype_accuracy(
        final_gex_embeddings, final_morpho_embeddings
    )
    
    print("\n" + "="*60)
    print("FINAL sciCAN RESULTS (CCA Space):")
    print("="*60)
    print(f"GEX -> Morpho Accuracy: {final_gex2morpho_acc:.4f}")
    print(f"Morpho -> GEX Accuracy: {final_morpho2gex_acc:.4f}")
    print(f"Average Accuracy: {final_avg_acc:.4f}")
    print("="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(final_gex_embeddings)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(final_morpho_embeddings)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")

Using device: cpu
RNA family labels loaded: 10 unique types
Data shapes - GEX: (645, 2000), Morpho: (645, 645)
sciCAN networks initialized successfully
Data loaded and normalized. Shape: GEX torch.Size([645, 2000]), Morpho torch.Size([645, 645])
Starting sciCAN training...


Training sciCAN:   0%|          | 1/200 [00:02<06:53,  2.08s/it]

Training sciCAN:  10%|█         | 20/200 [00:46<06:56,  2.32s/it]


Epoch 20/200
  Generator Loss: 23.8070
  Discriminator Loss: 0.3552
  CCA Space - GEX->Morpho: 0.2822, Morpho->GEX: 0.3085, Avg: 0.2953
--------------------------------------------------


Training sciCAN:  20%|██        | 40/200 [01:30<05:31,  2.07s/it]


Epoch 40/200
  Generator Loss: 22.9078
  Discriminator Loss: 0.2704
  CCA Space - GEX->Morpho: 0.3256, Morpho->GEX: 0.3349, Avg: 0.3302
--------------------------------------------------


Training sciCAN:  30%|███       | 60/200 [02:17<05:26,  2.33s/it]


Epoch 60/200
  Generator Loss: 21.6729
  Discriminator Loss: 0.2382
  CCA Space - GEX->Morpho: 0.2496, Morpho->GEX: 0.2388, Avg: 0.2442
--------------------------------------------------


Training sciCAN:  40%|████      | 80/200 [03:01<04:37,  2.31s/it]


Epoch 80/200
  Generator Loss: 22.2730
  Discriminator Loss: 0.2998
  CCA Space - GEX->Morpho: 0.2403, Morpho->GEX: 0.2682, Avg: 0.2543
--------------------------------------------------


Training sciCAN:  50%|█████     | 100/200 [03:45<03:26,  2.06s/it]


Epoch 100/200
  Generator Loss: 22.0953
  Discriminator Loss: 0.2922
  CCA Space - GEX->Morpho: 0.2605, Morpho->GEX: 0.2667, Avg: 0.2636
--------------------------------------------------


Training sciCAN:  60%|██████    | 120/200 [04:29<03:01,  2.27s/it]


Epoch 120/200
  Generator Loss: 24.2103
  Discriminator Loss: 0.2308
  CCA Space - GEX->Morpho: 0.2140, Morpho->GEX: 0.3457, Avg: 0.2798
--------------------------------------------------


Training sciCAN:  70%|███████   | 140/200 [05:14<02:15,  2.25s/it]


Epoch 140/200
  Generator Loss: 21.4291
  Discriminator Loss: 0.3290
  CCA Space - GEX->Morpho: 0.2388, Morpho->GEX: 0.2481, Avg: 0.2434
--------------------------------------------------


Training sciCAN:  80%|████████  | 160/200 [05:58<01:38,  2.47s/it]


Epoch 160/200
  Generator Loss: 22.7924
  Discriminator Loss: 0.3033
  CCA Space - GEX->Morpho: 0.2698, Morpho->GEX: 0.2620, Avg: 0.2659
--------------------------------------------------


Training sciCAN:  90%|█████████ | 180/200 [06:38<00:43,  2.16s/it]


Epoch 180/200
  Generator Loss: 23.6537
  Discriminator Loss: 0.3259
  CCA Space - GEX->Morpho: 0.2837, Morpho->GEX: 0.2527, Avg: 0.2682
--------------------------------------------------


Training sciCAN: 100%|██████████| 200/200 [07:21<00:00,  2.21s/it]


Epoch 200/200
  Generator Loss: 23.0546
  Discriminator Loss: 0.2644
  CCA Space - GEX->Morpho: 0.2264, Morpho->GEX: 0.2884, Avg: 0.2574
--------------------------------------------------

Getting latent space embeddings (CCA space)...






FINAL sciCAN RESULTS (CCA Space):
GEX -> Morpho Accuracy: 0.2264
Morpho -> GEX Accuracy: 0.2884
Average Accuracy: 0.2574

Performing UMAP dimensionality reduction...
  - Computing GEX UMAP...
  - GEX UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/sciCAN/gex_umap.csv
  - Computing Morpho UMAP...
  - Morpho UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/sciCAN/morpho_umap.csv

Done!


## ScDART

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import umap
from tqdm import tqdm
import os

class scDARTIntegrator:
    """
    scDART implementation for multi-modal data integration
    Based on the scDART algorithm using domain adversarial neural networks
    for single-cell multi-omics data integration
    """
    
    def __init__(self, latent_dim=128, feature_dims=[512, 256], domain_dims=[256, 128], 
                 lambda_domain=1.0, lambda_recon=1.0, lambda_cluster=0.1, lambda_entropy=0.01,
                 max_epochs=200, batch_size=64, lr=0.001, gradient_reversal_alpha=1.0, device='cuda'):
        """
        Initialize scDART integrator
        
        Args:
            latent_dim: Dimension of the shared feature space
            feature_dims: Hidden dimensions for feature extractor
            domain_dims: Hidden dimensions for domain classifier
            lambda_domain: Weight for domain adversarial loss
            lambda_recon: Weight for reconstruction loss
            lambda_cluster: Weight for clustering loss
            lambda_entropy: Weight for entropy loss
            max_epochs: Maximum training epochs
            batch_size: Training batch size
            lr: Learning rate
            gradient_reversal_alpha: Strength of gradient reversal
            device: Computing device
        """
        self.latent_dim = latent_dim
        self.feature_dims = feature_dims
        self.domain_dims = domain_dims
        self.lambda_domain = lambda_domain
        self.lambda_recon = lambda_recon
        self.lambda_cluster = lambda_cluster
        self.lambda_entropy = lambda_entropy
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.gradient_reversal_alpha = gradient_reversal_alpha
        self.device = device if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Networks will be initialized after loading data
        self.feature_extractor = None
        self.domain_classifier = None
        self.decoder_gex = None
        self.decoder_morpho = None
        
    def load_data(self):
        """Load multi-modal data"""
        # Gene expression data
        gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
        gex_df = pd.read_csv(gene_expression_path, header=None)
        self.gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # Morphology data
        morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
        morpho_df = pd.read_csv(morphology_path, header=0)
        self.morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # RNA family labels
        rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
        try:
            rna_df = pd.read_csv(rna_family_path, header=0)
            if rna_df.shape[1] == 1:
                self.rna_family_labels = rna_df.iloc[:, 0].values
            else:
                self.rna_family_labels = rna_df.iloc[:, 1].values
            
            min_samples = min(len(self.gex_data), len(self.morpho_data))
            self.rna_family_labels = self.rna_family_labels[:min_samples]
            self.gex_data = self.gex_data[:min_samples]
            self.morpho_data = self.morpho_data[:min_samples]
            
            # Convert labels to numeric for clustering loss
            unique_labels = np.unique(self.rna_family_labels)
            self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
            self.numeric_labels = np.array([self.label_to_idx[label] for label in self.rna_family_labels])
            self.n_classes = len(unique_labels)
            
            print(f"RNA family labels loaded: {self.n_classes} unique types")
            print(f"Data shapes - GEX: {self.gex_data.shape}, Morpho: {self.morpho_data.shape}")
        except Exception as e:
            print(f"Warning: Could not load RNA family labels: {e}")
            self.rna_family_labels = None
            self.n_classes = 10  # Default number of clusters
        
        # Data normalization
        self.gex_mean = np.mean(self.gex_data, axis=0)
        self.gex_std = np.std(self.gex_data, axis=0) + 1e-8
        self.morpho_mean = np.mean(self.morpho_data, axis=0)
        self.morpho_std = np.std(self.morpho_data, axis=0) + 1e-8
        
        self.gex_data_norm = (self.gex_data - self.gex_mean) / self.gex_std
        self.morpho_data_norm = (self.morpho_data - self.morpho_mean) / self.morpho_std
        
        self.n_samples = len(self.gex_data)
        self.gex_dim = self.gex_data.shape[1]
        self.morpho_dim = self.morpho_data.shape[1]
        
        # Initialize networks
        self._initialize_networks()
        
        # Convert to tensors
        self.gex_tensor = torch.FloatTensor(self.gex_data_norm).to(self.device)
        self.morpho_tensor = torch.FloatTensor(self.morpho_data_norm).to(self.device)
        
        # Domain labels (0 for GEX, 1 for Morpho)
        self.domain_labels_gex = torch.zeros(self.n_samples, dtype=torch.long).to(self.device)
        self.domain_labels_morpho = torch.ones(self.n_samples, dtype=torch.long).to(self.device)
        
        print(f"Data loaded and normalized. Shape: GEX {self.gex_tensor.shape}, Morpho {self.morpho_tensor.shape}")
        
    def _initialize_networks(self):
        """Initialize neural networks"""
        
        class GradientReversalLayer(torch.autograd.Function):
            """Gradient Reversal Layer for domain adversarial training"""
            @staticmethod
            def forward(ctx, x, alpha):
                ctx.alpha = alpha
                return x.view_as(x)
            
            @staticmethod
            def backward(ctx, grad_output):
                return grad_output.neg() * ctx.alpha, None
        
        def grad_reverse(x, alpha):
            return GradientReversalLayer.apply(x, alpha)
        
        class FeatureExtractor(nn.Module):
            def __init__(self, gex_dim, morpho_dim, latent_dim, hidden_dims):
                super(FeatureExtractor, self).__init__()
                
                # Separate encoders for each modality
                self.gex_encoder = self._build_encoder(gex_dim, hidden_dims, latent_dim)
                self.morpho_encoder = self._build_encoder(morpho_dim, hidden_dims, latent_dim)
                
            def _build_encoder(self, input_dim, hidden_dims, output_dim):
                layers = []
                prev_dim = input_dim
                
                for hidden_dim in hidden_dims:
                    layers.extend([
                        nn.Linear(prev_dim, hidden_dim),
                        nn.BatchNorm1d(hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(0.2)
                    ])
                    prev_dim = hidden_dim
                
                layers.append(nn.Linear(prev_dim, output_dim))
                return nn.Sequential(*layers)
            
            def forward(self, x, modality):
                if modality == 'gex':
                    return self.gex_encoder(x)
                elif modality == 'morpho':
                    return self.morpho_encoder(x)
                else:
                    raise ValueError("Modality must be 'gex' or 'morpho'")
        
        class DomainClassifier(nn.Module):
            def __init__(self, latent_dim, hidden_dims):
                super(DomainClassifier, self).__init__()
                layers = []
                prev_dim = latent_dim
                
                for hidden_dim in hidden_dims:
                    layers.extend([
                        nn.Linear(prev_dim, hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(0.3)
                    ])
                    prev_dim = hidden_dim
                
                layers.append(nn.Linear(prev_dim, 2))  # Binary classification: GEX vs Morpho
                self.classifier = nn.Sequential(*layers)
                
            def forward(self, x, alpha):
                x = grad_reverse(x, alpha)
                return self.classifier(x)
        
        class Decoder(nn.Module):
            def __init__(self, latent_dim, output_dim, hidden_dims):
                super(Decoder, self).__init__()
                layers = []
                prev_dim = latent_dim
                
                # Reverse hidden dimensions for decoder
                reversed_hidden_dims = hidden_dims[::-1]
                
                for hidden_dim in reversed_hidden_dims:
                    layers.extend([
                        nn.Linear(prev_dim, hidden_dim),
                        nn.BatchNorm1d(hidden_dim),
                        nn.ReLU(),
                        nn.Dropout(0.2)
                    ])
                    prev_dim = hidden_dim
                
                layers.append(nn.Linear(prev_dim, output_dim))
                self.decoder = nn.Sequential(*layers)
                
            def forward(self, z):
                return self.decoder(z)
        
        class ClusteringHead(nn.Module):
            def __init__(self, latent_dim, n_clusters):
                super(ClusteringHead, self).__init__()
                self.cluster_head = nn.Sequential(
                    nn.Linear(latent_dim, 128),
                    nn.ReLU(),
                    nn.Dropout(0.2),
                    nn.Linear(128, n_clusters),
                    nn.Softmax(dim=1)
                )
                
            def forward(self, x):
                return self.cluster_head(x)
        
        # Initialize networks
        self.feature_extractor = FeatureExtractor(
            self.gex_dim, self.morpho_dim, self.latent_dim, self.feature_dims
        ).to(self.device)
        
        self.domain_classifier = DomainClassifier(
            self.latent_dim, self.domain_dims
        ).to(self.device)
        
        self.decoder_gex = Decoder(
            self.latent_dim, self.gex_dim, self.feature_dims
        ).to(self.device)
        
        self.decoder_morpho = Decoder(
            self.latent_dim, self.morpho_dim, self.feature_dims
        ).to(self.device)
        
        self.clustering_head = ClusteringHead(
            self.latent_dim, self.n_classes
        ).to(self.device)
        
        # Store gradient reversal function
        self.grad_reverse = grad_reverse
        
        # Initialize optimizers
        self.optimizer_main = optim.Adam(
            list(self.feature_extractor.parameters()) + 
            list(self.decoder_gex.parameters()) + 
            list(self.decoder_morpho.parameters()) +
            list(self.clustering_head.parameters()),
            lr=self.lr, betas=(0.5, 0.999)
        )
        
        self.optimizer_domain = optim.Adam(
            self.domain_classifier.parameters(),
            lr=self.lr, betas=(0.5, 0.999)
        )
        
        print("scDART networks initialized successfully")
        
    def _clustering_loss(self, features, target_distribution=None):
        """Calculate clustering loss using target distribution"""
        cluster_probs = self.clustering_head(features)
        
        if target_distribution is None:
            # Use uniform distribution as target
            target_distribution = torch.ones_like(cluster_probs) / self.n_classes
        
        # KL divergence loss
        kl_loss = F.kl_div(
            torch.log(cluster_probs + 1e-8), 
            target_distribution, 
            reduction='batchmean'
        )
        
        return kl_loss
    
    def _entropy_loss(self, features):
        """Calculate entropy loss to encourage confident predictions"""
        cluster_probs = self.clustering_head(features)
        entropy = -torch.sum(cluster_probs * torch.log(cluster_probs + 1e-8), dim=1)
        return torch.mean(entropy)
    
    def train_epoch(self, epoch):
        """Train one epoch"""
        self.feature_extractor.train()
        self.domain_classifier.train()
        self.decoder_gex.train()
        self.decoder_morpho.train()
        self.clustering_head.train()
        
        # Create data indices
        indices = torch.randperm(self.n_samples)
        n_batches = (self.n_samples + self.batch_size - 1) // self.batch_size
        
        total_main_loss = 0
        total_domain_loss = 0
        
        # Adaptive alpha for gradient reversal
        p = float(epoch) / self.max_epochs
        alpha = 2. / (1. + np.exp(-10 * p)) - 1
        alpha *= self.gradient_reversal_alpha
        
        for batch_idx in range(n_batches):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, self.n_samples)
            batch_indices = indices[start_idx:end_idx]
            
            gex_batch = self.gex_tensor[batch_indices]
            morpho_batch = self.morpho_tensor[batch_indices]
            domain_labels_gex_batch = self.domain_labels_gex[batch_indices]
            domain_labels_morpho_batch = self.domain_labels_morpho[batch_indices]
            
            batch_size_actual = len(batch_indices)
            
            # Train Main Networks
            self.optimizer_main.zero_grad()
            
            # Extract features
            gex_features = self.feature_extractor(gex_batch, 'gex')
            morpho_features = self.feature_extractor(morpho_batch, 'morpho')
            
            # Reconstruction
            gex_recon = self.decoder_gex(gex_features)
            morpho_recon = self.decoder_morpho(morpho_features)
            
            # Cross-modal reconstruction
            gex_cross_recon = self.decoder_gex(morpho_features)
            morpho_cross_recon = self.decoder_morpho(gex_features)
            
            # Reconstruction losses
            gex_recon_loss = F.mse_loss(gex_recon, gex_batch)
            morpho_recon_loss = F.mse_loss(morpho_recon, morpho_batch)
            gex_cross_loss = F.mse_loss(gex_cross_recon, gex_batch)
            morpho_cross_loss = F.mse_loss(morpho_cross_recon, morpho_batch)
            
            total_recon_loss = (gex_recon_loss + morpho_recon_loss + 
                              gex_cross_loss + morpho_cross_loss)
            
            # Domain adversarial loss
            all_features = torch.cat([gex_features, morpho_features], dim=0)
            all_domain_labels = torch.cat([domain_labels_gex_batch, domain_labels_morpho_batch], dim=0)
            
            domain_pred = self.domain_classifier(all_features, alpha)
            domain_adv_loss = F.cross_entropy(domain_pred, all_domain_labels)
            
            # Clustering losses
            gex_cluster_loss = self._clustering_loss(gex_features)
            morpho_cluster_loss = self._clustering_loss(morpho_features)
            total_cluster_loss = gex_cluster_loss + morpho_cluster_loss
            
            # Entropy losses
            gex_entropy_loss = self._entropy_loss(gex_features)
            morpho_entropy_loss = self._entropy_loss(morpho_features)
            total_entropy_loss = gex_entropy_loss + morpho_entropy_loss
            
            # Total main loss
            main_loss = (self.lambda_recon * total_recon_loss +
                        self.lambda_domain * domain_adv_loss +
                        self.lambda_cluster * total_cluster_loss +
                        self.lambda_entropy * total_entropy_loss)
            
            main_loss.backward()
            self.optimizer_main.step()
            
            # Train Domain Classifier
            self.optimizer_domain.zero_grad()
            
            gex_features_detached = self.feature_extractor(gex_batch, 'gex').detach()
            morpho_features_detached = self.feature_extractor(morpho_batch, 'morpho').detach()
            
            all_features_detached = torch.cat([gex_features_detached, morpho_features_detached], dim=0)
            
            domain_pred_real = self.domain_classifier(all_features_detached, 0.0)
            domain_loss = F.cross_entropy(domain_pred_real, all_domain_labels)
            
            domain_loss.backward()
            self.optimizer_domain.step()
            
            total_main_loss += main_loss.item()
            total_domain_loss += domain_loss.item()
        
        return total_main_loss / n_batches, total_domain_loss / n_batches
    
    def get_integrated_embeddings(self):
        """Get integrated embeddings in the shared feature space"""
        self.feature_extractor.eval()
        
        with torch.no_grad():
            gex_embeddings = self.feature_extractor(self.gex_tensor, 'gex').cpu().numpy()
            morpho_embeddings = self.feature_extractor(self.morpho_tensor, 'morpho').cpu().numpy()
        
        return gex_embeddings, morpho_embeddings
    
    def calculate_celltype_accuracy(self, gex_embeddings, morpho_embeddings):
        """Calculate cell type matching accuracy"""
        if self.rna_family_labels is None:
            print("No RNA family labels available for accuracy calculation")
            return None, None, None
        
        nbrs_morpho = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(morpho_embeddings)
        nbrs_gex = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(gex_embeddings)
        
        _, indices_gex2morpho = nbrs_morpho.kneighbors(gex_embeddings)
        gex2morpho_matches = 0
        for i, nearest_idx in enumerate(indices_gex2morpho.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                gex2morpho_matches += 1
        gex2morpho_accuracy = gex2morpho_matches / len(self.rna_family_labels)
        
        _, indices_morpho2gex = nbrs_gex.kneighbors(morpho_embeddings)
        morpho2gex_matches = 0
        for i, nearest_idx in enumerate(indices_morpho2gex.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                morpho2gex_matches += 1
        morpho2gex_accuracy = morpho2gex_matches / len(self.rna_family_labels)
        
        average_accuracy = (gex2morpho_accuracy + morpho2gex_accuracy) / 2
        
        return gex2morpho_accuracy, morpho2gex_accuracy, average_accuracy
    
    def train(self, verbose_interval=20):
        """Train scDART model"""
        print("Starting scDART training...")
        
        for epoch in tqdm(range(self.max_epochs), desc="Training scDART"):
            main_loss, domain_loss = self.train_epoch(epoch)
            
            if (epoch + 1) % verbose_interval == 0:
                gex_emb, morpho_emb = self.get_integrated_embeddings()
                gex2morpho_acc, morpho2gex_acc, avg_acc = self.calculate_celltype_accuracy(gex_emb, morpho_emb)
                
                print(f"\nEpoch {epoch+1}/{self.max_epochs}")
                print(f"  Main Loss: {main_loss:.4f}")
                print(f"  Domain Loss: {domain_loss:.4f}")
                if avg_acc is not None:
                    print(f"  GEX->Morpho Acc: {gex2morpho_acc:.4f}")
                    print(f"  Morpho->GEX Acc: {morpho2gex_acc:.4f}")
                    print(f"  Average Acc: {avg_acc:.4f}")
                print("-" * 50)

# Usage example
if __name__ == "__main__":
    # 创建输出目录
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/scDART/"
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize scDART integrator
    integrator = scDARTIntegrator(
        latent_dim=128,
        feature_dims=[512, 256],
        domain_dims=[256, 128],
        lambda_domain=1.0,
        lambda_recon=1.0,
        lambda_cluster=0.1,
        lambda_entropy=0.01,
        max_epochs=200,
        batch_size=64,
        lr=0.001,
        gradient_reversal_alpha=1.0
    )
    
    # Load data
    integrator.load_data()
    
    # Train model
    integrator.train(verbose_interval=20)
    
    # Get final integrated embeddings
    print("\nGetting latent space embeddings...")
    final_gex_embeddings, final_morpho_embeddings = integrator.get_integrated_embeddings()
    
    # Calculate final accuracy
    final_gex2morpho_acc, final_morpho2gex_acc, final_avg_acc = integrator.calculate_celltype_accuracy(
        final_gex_embeddings, final_morpho_embeddings
    )
    
    print("\n" + "="*60)
    print("FINAL scDART RESULTS:")
    print("="*60)
    print(f"GEX -> Morpho Accuracy: {final_gex2morpho_acc:.4f}")
    print(f"Morpho -> GEX Accuracy: {final_morpho2gex_acc:.4f}")
    print(f"Average Accuracy: {final_avg_acc:.4f}")
    print("="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(final_gex_embeddings)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(final_morpho_embeddings)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")

Using device: cpu
RNA family labels loaded: 10 unique types
Data shapes - GEX: (645, 2000), Morpho: (645, 645)
scDART networks initialized successfully
Data loaded and normalized. Shape: GEX torch.Size([645, 2000]), Morpho torch.Size([645, 645])
Starting scDART training...


Training scDART:  10%|█         | 20/200 [00:26<04:02,  1.35s/it]


Epoch 20/200
  Main Loss: 2.8273
  Domain Loss: 0.7019
  GEX->Morpho Acc: 0.2202
  Morpho->GEX Acc: 0.3318
  Average Acc: 0.2760
--------------------------------------------------


Training scDART:  14%|█▍        | 28/200 [00:38<04:06,  1.43s/it]

Training scDART:  20%|██        | 40/200 [00:55<03:47,  1.42s/it]


Epoch 40/200
  Main Loss: 2.6777
  Domain Loss: 0.6097
  GEX->Morpho Acc: 0.2791
  Morpho->GEX Acc: 0.3054
  Average Acc: 0.2922
--------------------------------------------------


Training scDART:  30%|███       | 60/200 [01:21<03:18,  1.42s/it]


Epoch 60/200
  Main Loss: 2.7183
  Domain Loss: 0.6246
  GEX->Morpho Acc: 0.2698
  Morpho->GEX Acc: 0.2946
  Average Acc: 0.2822
--------------------------------------------------


Training scDART:  40%|████      | 80/200 [01:49<02:36,  1.30s/it]


Epoch 80/200
  Main Loss: 2.5507
  Domain Loss: 0.5971
  GEX->Morpho Acc: 0.3132
  Morpho->GEX Acc: 0.2512
  Average Acc: 0.2822
--------------------------------------------------


Training scDART:  50%|█████     | 100/200 [02:17<02:16,  1.37s/it]


Epoch 100/200
  Main Loss: 2.6574
  Domain Loss: 0.7248
  GEX->Morpho Acc: 0.2775
  Morpho->GEX Acc: 0.3008
  Average Acc: 0.2891
--------------------------------------------------


Training scDART:  60%|██████    | 120/200 [02:50<01:54,  1.43s/it]


Epoch 120/200
  Main Loss: 2.6673
  Domain Loss: 0.6804
  GEX->Morpho Acc: 0.3194
  Morpho->GEX Acc: 0.3163
  Average Acc: 0.3178
--------------------------------------------------


Training scDART:  70%|███████   | 140/200 [03:17<01:23,  1.40s/it]


Epoch 140/200
  Main Loss: 2.7772
  Domain Loss: 0.7212
  GEX->Morpho Acc: 0.3411
  Morpho->GEX Acc: 0.2202
  Average Acc: 0.2806
--------------------------------------------------


Training scDART:  80%|████████  | 160/200 [03:45<00:53,  1.34s/it]


Epoch 160/200
  Main Loss: 2.8382
  Domain Loss: 0.7134
  GEX->Morpho Acc: 0.3411
  Morpho->GEX Acc: 0.3318
  Average Acc: 0.3364
--------------------------------------------------


Training scDART:  90%|█████████ | 180/200 [04:13<00:27,  1.39s/it]


Epoch 180/200
  Main Loss: 2.6912
  Domain Loss: 0.6995
  GEX->Morpho Acc: 0.3581
  Morpho->GEX Acc: 0.3969
  Average Acc: 0.3775
--------------------------------------------------


Training scDART: 100%|██████████| 200/200 [04:42<00:00,  1.41s/it]


Epoch 200/200
  Main Loss: 2.5664
  Domain Loss: 0.6722
  GEX->Morpho Acc: 0.3690
  Morpho->GEX Acc: 0.3953
  Average Acc: 0.3822
--------------------------------------------------

Getting latent space embeddings...






FINAL scDART RESULTS:
GEX -> Morpho Accuracy: 0.3690
Morpho -> GEX Accuracy: 0.3953
Average Accuracy: 0.3822

Performing UMAP dimensionality reduction...
  - Computing GEX UMAP...
  - GEX UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/scDART/gex_umap.csv
  - Computing Morpho UMAP...
  - Morpho UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/scDART/morpho_umap.csv

Done!


## STACI

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import umap
from tqdm import tqdm
import os

class STACIIntegrator:
    """
    STACI implementation for multi-modal data integration
    Based on the STACI (Spatial-Temporal Attention for Cross-modal Integration) algorithm
    using attention mechanisms and contrastive learning for single-cell multi-omics integration
    """
    
    def __init__(self, embedding_dim=128, attention_heads=8, attention_layers=3, 
                 temperature=0.1, lambda_contrast=1.0, lambda_recon=1.0, lambda_attention=0.5,
                 max_epochs=200, batch_size=64, lr=0.001, device='cuda'):
        """
        Initialize STACI integrator
        
        Args:
            embedding_dim: Dimension of the shared embedding space
            attention_heads: Number of attention heads
            attention_layers: Number of attention layers
            temperature: Temperature parameter for contrastive loss
            lambda_contrast: Weight for contrastive loss
            lambda_recon: Weight for reconstruction loss
            lambda_attention: Weight for attention alignment loss
            max_epochs: Maximum training epochs
            batch_size: Training batch size
            lr: Learning rate
            device: Computing device
        """
        self.embedding_dim = embedding_dim
        self.attention_heads = attention_heads
        self.attention_layers = attention_layers
        self.temperature = temperature
        self.lambda_contrast = lambda_contrast
        self.lambda_recon = lambda_recon
        self.lambda_attention = lambda_attention
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.device = device if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Networks will be initialized after loading data
        self.encoder_gex = None
        self.encoder_morpho = None
        self.decoder_gex = None
        self.decoder_morpho = None
        self.cross_attention = None
        self.projector_gex = None
        self.projector_morpho = None
        
    def load_data(self):
        """Load multi-modal data"""
        # Gene expression data
        gene_expression_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/exon_data_top2000.csv"
        gex_df = pd.read_csv(gene_expression_path, header=None)
        self.gex_data = gex_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # Morphology data
        morphology_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/gw_dist.csv"
        morpho_df = pd.read_csv(morphology_path, header=0)
        self.morpho_data = morpho_df.iloc[:, 1:].to_numpy().astype(np.float32)
        
        # RNA family labels
        rna_family_path = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/scala/rna_family_matched.csv"
        try:
            rna_df = pd.read_csv(rna_family_path, header=0)
            if rna_df.shape[1] == 1:
                self.rna_family_labels = rna_df.iloc[:, 0].values
            else:
                self.rna_family_labels = rna_df.iloc[:, 1].values
            
            min_samples = min(len(self.gex_data), len(self.morpho_data))
            self.rna_family_labels = self.rna_family_labels[:min_samples]
            self.gex_data = self.gex_data[:min_samples]
            self.morpho_data = self.morpho_data[:min_samples]
            
            print(f"RNA family labels loaded: {len(np.unique(self.rna_family_labels))} unique types")
            print(f"Data shapes - GEX: {self.gex_data.shape}, Morpho: {self.morpho_data.shape}")
        except Exception as e:
            print(f"Warning: Could not load RNA family labels: {e}")
            self.rna_family_labels = None
        
        # Data normalization
        self.gex_mean = np.mean(self.gex_data, axis=0)
        self.gex_std = np.std(self.gex_data, axis=0) + 1e-8
        self.morpho_mean = np.mean(self.morpho_data, axis=0)
        self.morpho_std = np.std(self.morpho_data, axis=0) + 1e-8
        
        self.gex_data_norm = (self.gex_data - self.gex_mean) / self.gex_std
        self.morpho_data_norm = (self.morpho_data - self.morpho_mean) / self.morpho_std
        
        self.n_samples = len(self.gex_data)
        self.gex_dim = self.gex_data.shape[1]
        self.morpho_dim = self.morpho_data.shape[1]
        
        # Initialize networks
        self._initialize_networks()
        
        # Convert to tensors
        self.gex_tensor = torch.FloatTensor(self.gex_data_norm).to(self.device)
        self.morpho_tensor = torch.FloatTensor(self.morpho_data_norm).to(self.device)
        
        print(f"Data loaded and normalized. Shape: GEX {self.gex_tensor.shape}, Morpho {self.morpho_tensor.shape}")
        
    def _initialize_networks(self):
        """Initialize neural networks for STACI"""
        
        class Encoder(nn.Module):
            def __init__(self, input_dim, embedding_dim):
                super(Encoder, self).__init__()
                hidden_dim = max(256, min(512, input_dim // 2))
                self.encoder = nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.LayerNorm(hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim, hidden_dim // 2),
                    nn.LayerNorm(hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim // 2, embedding_dim),
                    nn.LayerNorm(embedding_dim)
                )
                
            def forward(self, x):
                return self.encoder(x)
        
        class Decoder(nn.Module):
            def __init__(self, embedding_dim, output_dim):
                super(Decoder, self).__init__()
                hidden_dim = max(256, min(512, output_dim // 2))
                self.decoder = nn.Sequential(
                    nn.Linear(embedding_dim, hidden_dim // 2),
                    nn.LayerNorm(hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim // 2, hidden_dim),
                    nn.LayerNorm(hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim, output_dim)
                )
                
            def forward(self, x):
                return self.decoder(x)
        
        class CrossModalAttention(nn.Module):
            def __init__(self, embedding_dim, num_heads, num_layers):
                super(CrossModalAttention, self).__init__()
                self.embedding_dim = embedding_dim
                self.num_heads = num_heads
                
                encoder_layer = nn.TransformerEncoderLayer(
                    d_model=embedding_dim,
                    nhead=num_heads,
                    dim_feedforward=embedding_dim * 2,
                    dropout=0.1,
                    activation='relu',
                    batch_first=True
                )
                self.attention_layers = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
                
                self.query_proj = nn.Linear(embedding_dim, embedding_dim)
                self.key_proj = nn.Linear(embedding_dim, embedding_dim)
                self.value_proj = nn.Linear(embedding_dim, embedding_dim)
                
                self.output_proj = nn.Linear(embedding_dim, embedding_dim)
                
            def forward(self, gex_embeddings, morpho_embeddings):
                batch_size = gex_embeddings.size(0)
                
                gex_seq = gex_embeddings.unsqueeze(1)
                morpho_seq = morpho_embeddings.unsqueeze(1)
                
                joint_seq = torch.cat([gex_seq, morpho_seq], dim=1)
                
                attended = self.attention_layers(joint_seq)
                
                gex_attended = attended[:, 0, :]
                morpho_attended = attended[:, 1, :]
                
                gex_query = self.query_proj(gex_attended)
                morpho_key = self.key_proj(morpho_attended)
                morpho_value = self.value_proj(morpho_attended)
                
                gex_cross_attended = self._cross_attention(gex_query, morpho_key, morpho_value)
                
                morpho_query = self.query_proj(morpho_attended)
                gex_key = self.key_proj(gex_attended)
                gex_value = self.value_proj(gex_attended)
                
                morpho_cross_attended = self._cross_attention(morpho_query, gex_key, gex_value)
                
                gex_final = self.output_proj(gex_cross_attended + gex_attended)
                morpho_final = self.output_proj(morpho_cross_attended + morpho_attended)
                
                return gex_final, morpho_final
            
            def _cross_attention(self, query, key, value):
                attention_weights = torch.softmax(
                    torch.sum(query * key, dim=-1, keepdim=True) / np.sqrt(self.embedding_dim), 
                    dim=-1
                )
                return attention_weights * value
        
        class ContrastiveProjector(nn.Module):
            def __init__(self, embedding_dim, projection_dim=128):
                super(ContrastiveProjector, self).__init__()
                self.projector = nn.Sequential(
                    nn.Linear(embedding_dim, projection_dim),
                    nn.ReLU(),
                    nn.Linear(projection_dim, projection_dim),
                    nn.LayerNorm(projection_dim)
                )
                
            def forward(self, x):
                return F.normalize(self.projector(x), dim=-1)
        
        # Initialize networks
        self.encoder_gex = Encoder(self.gex_dim, self.embedding_dim).to(self.device)
        self.encoder_morpho = Encoder(self.morpho_dim, self.embedding_dim).to(self.device)
        
        self.decoder_gex = Decoder(self.embedding_dim, self.gex_dim).to(self.device)
        self.decoder_morpho = Decoder(self.embedding_dim, self.morpho_dim).to(self.device)
        
        self.cross_attention = CrossModalAttention(
            self.embedding_dim, self.attention_heads, self.attention_layers
        ).to(self.device)
        
        self.projector_gex = ContrastiveProjector(self.embedding_dim).to(self.device)
        self.projector_morpho = ContrastiveProjector(self.embedding_dim).to(self.device)
        
        # Initialize optimizer
        self.optimizer = optim.AdamW(
            list(self.encoder_gex.parameters()) + 
            list(self.encoder_morpho.parameters()) +
            list(self.decoder_gex.parameters()) + 
            list(self.decoder_morpho.parameters()) +
            list(self.cross_attention.parameters()) +
            list(self.projector_gex.parameters()) +
            list(self.projector_morpho.parameters()),
            lr=self.lr, weight_decay=1e-4
        )
        
        print("STACI networks initialized successfully")
        
    def _contrastive_loss(self, gex_proj, morpho_proj):
        """Calculate contrastive loss between paired samples"""
        batch_size = gex_proj.size(0)
        
        similarity_matrix = torch.mm(gex_proj, morpho_proj.t()) / self.temperature
        
        labels = torch.arange(batch_size).to(self.device)
        
        loss_gex_to_morpho = F.cross_entropy(similarity_matrix, labels)
        loss_morpho_to_gex = F.cross_entropy(similarity_matrix.t(), labels)
        
        return (loss_gex_to_morpho + loss_morpho_to_gex) / 2
    
    def _attention_alignment_loss(self, gex_embeddings, morpho_embeddings, 
                                 gex_attended, morpho_attended):
        """Calculate attention alignment loss to encourage meaningful cross-modal attention"""
        gex_attention_change = F.mse_loss(gex_attended, gex_embeddings)
        morpho_attention_change = F.mse_loss(morpho_attended, morpho_embeddings)
        
        target_change = 0.1
        gex_alignment_loss = (gex_attention_change - target_change) ** 2
        morpho_alignment_loss = (morpho_attention_change - target_change) ** 2
        
        return (gex_alignment_loss + morpho_alignment_loss) / 2
    
    def _reconstruction_loss(self, recon, original):
        """Calculate reconstruction loss"""
        return F.mse_loss(recon, original)
    
    def train_epoch(self, epoch):
        """Train one epoch"""
        self.encoder_gex.train()
        self.encoder_morpho.train()
        self.decoder_gex.train()
        self.decoder_morpho.train()
        self.cross_attention.train()
        self.projector_gex.train()
        self.projector_morpho.train()
        
        indices = torch.randperm(self.n_samples)
        n_batches = (self.n_samples + self.batch_size - 1) // self.batch_size
        
        total_loss = 0
        
        for batch_idx in range(n_batches):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, self.n_samples)
            batch_indices = indices[start_idx:end_idx]
            
            gex_batch = self.gex_tensor[batch_indices]
            morpho_batch = self.morpho_tensor[batch_indices]
            
            self.optimizer.zero_grad()
            
            gex_embeddings = self.encoder_gex(gex_batch)
            morpho_embeddings = self.encoder_morpho(morpho_batch)
            
            gex_attended, morpho_attended = self.cross_attention(gex_embeddings, morpho_embeddings)
            
            gex_recon = self.decoder_gex(gex_attended)
            morpho_recon = self.decoder_morpho(morpho_attended)
            
            gex_cross_recon = self.decoder_gex(morpho_attended)
            morpho_cross_recon = self.decoder_morpho(gex_attended)
            
            gex_proj = self.projector_gex(gex_attended)
            morpho_proj = self.projector_morpho(morpho_attended)
            
            gex_recon_loss = self._reconstruction_loss(gex_recon, gex_batch)
            morpho_recon_loss = self._reconstruction_loss(morpho_recon, morpho_batch)
            gex_cross_loss = self._reconstruction_loss(gex_cross_recon, gex_batch)
            morpho_cross_loss = self._reconstruction_loss(morpho_cross_recon, morpho_batch)
            
            recon_loss = (gex_recon_loss + morpho_recon_loss + 
                         gex_cross_loss + morpho_cross_loss) / 4
            
            contrast_loss = self._contrastive_loss(gex_proj, morpho_proj)
            
            attention_loss = self._attention_alignment_loss(
                gex_embeddings, morpho_embeddings, gex_attended, morpho_attended
            )
            
            total_batch_loss = (self.lambda_recon * recon_loss + 
                              self.lambda_contrast * contrast_loss +
                              self.lambda_attention * attention_loss)
            
            total_batch_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(
                list(self.encoder_gex.parameters()) + 
                list(self.encoder_morpho.parameters()) +
                list(self.decoder_gex.parameters()) + 
                list(self.decoder_morpho.parameters()) +
                list(self.cross_attention.parameters()) +
                list(self.projector_gex.parameters()) +
                list(self.projector_morpho.parameters()),
                max_norm=1.0
            )
            
            self.optimizer.step()
            
            total_loss += total_batch_loss.item()
        
        return total_loss / n_batches
    
    def get_integrated_embeddings(self):
        """Get integrated embeddings after cross-modal attention"""
        self.encoder_gex.eval()
        self.encoder_morpho.eval()
        self.cross_attention.eval()
        
        with torch.no_grad():
            gex_embeddings = self.encoder_gex(self.gex_tensor)
            morpho_embeddings = self.encoder_morpho(self.morpho_tensor)
            
            gex_attended, morpho_attended = self.cross_attention(gex_embeddings, morpho_embeddings)
            
            gex_final = gex_attended.cpu().numpy()
            morpho_final = morpho_attended.cpu().numpy()
        
        return gex_final, morpho_final
    
    def calculate_celltype_accuracy(self, gex_embeddings, morpho_embeddings):
        """Calculate cell type matching accuracy"""
        if self.rna_family_labels is None:
            print("No RNA family labels available for accuracy calculation")
            return None, None, None
        
        nbrs_morpho = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(morpho_embeddings)
        nbrs_gex = NearestNeighbors(n_neighbors=1, metric='euclidean').fit(gex_embeddings)
        
        _, indices_gex2morpho = nbrs_morpho.kneighbors(gex_embeddings)
        gex2morpho_matches = 0
        for i, nearest_idx in enumerate(indices_gex2morpho.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                gex2morpho_matches += 1
        gex2morpho_accuracy = gex2morpho_matches / len(self.rna_family_labels)
        
        _, indices_morpho2gex = nbrs_gex.kneighbors(morpho_embeddings)
        morpho2gex_matches = 0
        for i, nearest_idx in enumerate(indices_morpho2gex.flatten()):
            if self.rna_family_labels[i] == self.rna_family_labels[nearest_idx]:
                morpho2gex_matches += 1
        morpho2gex_accuracy = morpho2gex_matches / len(self.rna_family_labels)
        
        average_accuracy = (gex2morpho_accuracy + morpho2gex_accuracy) / 2
        
        return gex2morpho_accuracy, morpho2gex_accuracy, average_accuracy
    
    def train(self, verbose_interval=20):
        """Train STACI model"""
        print("Starting STACI training...")
        
        for epoch in tqdm(range(self.max_epochs), desc="Training STACI"):
            total_loss = self.train_epoch(epoch)
            
            if (epoch + 1) % verbose_interval == 0:
                gex_emb, morpho_emb = self.get_integrated_embeddings()
                gex2morpho_acc, morpho2gex_acc, avg_acc = self.calculate_celltype_accuracy(gex_emb, morpho_emb)
                
                print(f"\nEpoch {epoch+1}/{self.max_epochs}")
                print(f"  Total Loss: {total_loss:.4f}")
                if avg_acc is not None:
                    print(f"  GEX->Morpho Acc: {gex2morpho_acc:.4f}")
                    print(f"  Morpho->GEX Acc: {morpho2gex_acc:.4f}")
                    print(f"  Average Acc: {avg_acc:.4f}")
                print("-" * 50)

# Usage example
if __name__ == "__main__":
    # 创建输出目录
    output_dir = "/home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/STACI/"
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize STACI integrator
    integrator = STACIIntegrator(
        embedding_dim=128,
        attention_heads=8,
        attention_layers=3,
        temperature=0.1,
        lambda_contrast=1.0,
        lambda_recon=1.0,
        lambda_attention=0.5,
        max_epochs=200,
        batch_size=64,
        lr=0.001
    )
    
    # Load data
    integrator.load_data()
    
    # Train model
    integrator.train(verbose_interval=20)
    
    # Get final integrated embeddings
    print("\nGetting latent space embeddings...")
    final_gex_embeddings, final_morpho_embeddings = integrator.get_integrated_embeddings()
    
    # Calculate final accuracy
    final_gex2morpho_acc, final_morpho2gex_acc, final_avg_acc = integrator.calculate_celltype_accuracy(
        final_gex_embeddings, final_morpho_embeddings
    )
    
    print("\n" + "="*60)
    print("FINAL STACI RESULTS:")
    print("="*60)
    print(f"GEX -> Morpho Accuracy: {final_gex2morpho_acc:.4f}")
    print(f"Morpho -> GEX Accuracy: {final_morpho2gex_acc:.4f}")
    print(f"Average Accuracy: {final_avg_acc:.4f}")
    print("="*60)
    
    # UMAP降维
    print("\nPerforming UMAP dimensionality reduction...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    
    # GEX UMAP
    print("  - Computing GEX UMAP...")
    gex_umap = reducer.fit_transform(final_gex_embeddings)
    gex_umap_df = pd.DataFrame(gex_umap, columns=['UMAP1', 'UMAP2'])
    gex_umap_path = os.path.join(output_dir, "gex_umap.csv")
    gex_umap_df.to_csv(gex_umap_path, index=False)
    print(f"  - GEX UMAP saved to: {gex_umap_path}")
    
    # Morpho UMAP
    print("  - Computing Morpho UMAP...")
    morpho_umap = reducer.fit_transform(final_morpho_embeddings)
    morpho_umap_df = pd.DataFrame(morpho_umap, columns=['UMAP1', 'UMAP2'])
    morpho_umap_path = os.path.join(output_dir, "morpho_umap.csv")
    morpho_umap_df.to_csv(morpho_umap_path, index=False)
    print(f"  - Morpho UMAP saved to: {morpho_umap_path}")
    
    print("\nDone!")

Using device: cpu
RNA family labels loaded: 10 unique types
Data shapes - GEX: (645, 2000), Morpho: (645, 645)
STACI networks initialized successfully
Data loaded and normalized. Shape: GEX torch.Size([645, 2000]), Morpho torch.Size([645, 645])
Starting STACI training...


Training STACI:  10%|█         | 20/200 [00:30<04:43,  1.58s/it]


Epoch 20/200
  Total Loss: 0.7844
  GEX->Morpho Acc: 0.5209
  Morpho->GEX Acc: 0.5209
  Average Acc: 0.5209
--------------------------------------------------


Training STACI:  17%|█▋        | 34/200 [00:53<04:16,  1.55s/it]

Training STACI:  20%|██        | 40/200 [01:03<04:42,  1.76s/it]


Epoch 40/200
  Total Loss: 0.6066
  GEX->Morpho Acc: 0.5008
  Morpho->GEX Acc: 0.4651
  Average Acc: 0.4829
--------------------------------------------------


Training STACI:  30%|███       | 60/200 [01:35<04:10,  1.79s/it]


Epoch 60/200
  Total Loss: 0.5140
  GEX->Morpho Acc: 0.5333
  Morpho->GEX Acc: 0.4822
  Average Acc: 0.5078
--------------------------------------------------


Training STACI:  40%|████      | 80/200 [02:11<03:08,  1.57s/it]


Epoch 80/200
  Total Loss: 0.4555
  GEX->Morpho Acc: 0.5845
  Morpho->GEX Acc: 0.4884
  Average Acc: 0.5364
--------------------------------------------------


Training STACI:  50%|█████     | 100/200 [02:42<02:30,  1.50s/it]


Epoch 100/200
  Total Loss: 0.4360
  GEX->Morpho Acc: 0.5628
  Morpho->GEX Acc: 0.5178
  Average Acc: 0.5403
--------------------------------------------------


Training STACI:  60%|██████    | 120/200 [03:13<02:07,  1.60s/it]


Epoch 120/200
  Total Loss: 0.3928
  GEX->Morpho Acc: 0.6047
  Morpho->GEX Acc: 0.5349
  Average Acc: 0.5698
--------------------------------------------------


Training STACI:  70%|███████   | 140/200 [03:45<01:32,  1.55s/it]


Epoch 140/200
  Total Loss: 0.3689
  GEX->Morpho Acc: 0.6047
  Morpho->GEX Acc: 0.5411
  Average Acc: 0.5729
--------------------------------------------------


Training STACI:  80%|████████  | 160/200 [04:17<01:01,  1.53s/it]


Epoch 160/200
  Total Loss: 0.3593
  GEX->Morpho Acc: 0.6202
  Morpho->GEX Acc: 0.5240
  Average Acc: 0.5721
--------------------------------------------------


Training STACI:  90%|█████████ | 180/200 [04:51<00:31,  1.56s/it]


Epoch 180/200
  Total Loss: 0.3330
  GEX->Morpho Acc: 0.6512
  Morpho->GEX Acc: 0.5333
  Average Acc: 0.5922
--------------------------------------------------


Training STACI: 100%|██████████| 200/200 [05:24<00:00,  1.62s/it]


Epoch 200/200
  Total Loss: 0.3162
  GEX->Morpho Acc: 0.6589
  Morpho->GEX Acc: 0.5767
  Average Acc: 0.6178
--------------------------------------------------

Getting latent space embeddings...

FINAL STACI RESULTS:
GEX -> Morpho Accuracy: 0.6589
Morpho -> GEX Accuracy: 0.5767
Average Accuracy: 0.6178

Performing UMAP dimensionality reduction...
  - Computing GEX UMAP...





  - GEX UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/STACI/gex_umap.csv
  - Computing Morpho UMAP...
  - Morpho UMAP saved to: /home/users/turbodu/kzlinlab/projects/morpho_integration/out/turbo/writeup20/STACI/morpho_umap.csv

Done!


## SCOT

In [7]:
import numpy as np
import torch
import ot
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler, normalize
import pandas as pd
import umap
from pathlib import Path
import warnings

class SCOTv2(object):
    """Simplified SCOTv2 for cross-modal alignment."""
    def __init__(self, data):
        assert isinstance(data, list) and len(data) >= 2, "Input must be list of ≥2 numpy arrays"
        self.data = data
        self.marginals = []
        self.graphs = []
        self.graphDists = []
        self.couplings = []
        self.integrated_data = []

    def _init_marginals(self):
        for i in range(len(self.data)):
            n = self.data[i].shape[0]
            self.marginals.append(torch.ones(n) / n)
        return self.marginals

    def _normalize(self, norm="l2", bySample=True):
        for i in range(len(self.data)):
            if norm == "zscore":
                scaler = StandardScaler()
                self.data[i] = scaler.fit_transform(self.data[i])
            else:
                axis = 1 if bySample else 0
                self.data[i] = normalize(self.data[i], norm=norm, axis=axis)
        return self.data

    def construct_graph(self, k=20, mode="connectivity", metric="correlation"):
        include_self = True if mode == "connectivity" else False
        for i in range(len(self.data)):
            self.graphs.append(
                kneighbors_graph(self.data[i], n_neighbors=k, mode=mode,
                                 metric=metric, include_self=include_self)
            )
        return self.graphs

    def init_graph_distances(self):
        for i in range(len(self.data)):
            sp = dijkstra(csgraph=csr_matrix(self.graphs[i]),
                          directed=False, return_predecessors=False)
            max_dist = np.nanmax(sp[sp != np.inf])
            sp[sp > max_dist] = max_dist
            self.graphDists.append(sp / sp.max())
        return self.graphDists

    def direct_balanced_ot(self, a, dx, b, dy, eps=0.1):
        a_np, b_np, dx_np, dy_np = a.numpy(), b.numpy(), dx.numpy(), dy.numpy()
        try:
            coupling, log = ot.gromov.entropic_gromov_wasserstein(
                dx_np, dy_np, a_np, b_np, epsilon=eps,
                max_iter=10000, tol=1e-9, log=True, verbose=False
            )
            if np.any(np.isnan(coupling)):
                coupling = np.outer(a_np, b_np)
            coupling /= coupling.sum(axis=1, keepdims=True)
            return torch.tensor(coupling), True
        except Exception as e:
            print(f"GW OT failed: {e}")
            coupling = np.outer(a_np, b_np)
            coupling /= coupling.sum(axis=1, keepdims=True)
            return torch.tensor(coupling), False

    def find_correspondences(self, normalize=True, norm="l2", bySample=True, k=20,
                             mode="connectivity", metric="correlation", eps=0.1):
        if normalize:
            self._normalize(norm=norm, bySample=bySample)
        self._init_marginals()
        self.construct_graph(k=k, mode=mode, metric=metric)
        self.init_graph_distances()
        a, b = torch.Tensor(self.marginals[0]), torch.Tensor(self.marginals[1])
        dx, dy = torch.Tensor(self.graphDists[0]), torch.Tensor(self.graphDists[1])
        coupling, _ = self.direct_balanced_ot(a, dx, b, dy, eps=eps)
        self.couplings.append(coupling)
        return self.couplings

    def barycentric_projection(self):
        aligned = [self.data[0]]
        coupling = self.couplings[0].numpy()
        projected = np.matmul(coupling, self.data[1])
        aligned.append(projected)
        self.integrated_data = aligned
        return aligned

    def align(self, normalize=True, norm="l2", bySample=True, k=20,
              mode="connectivity", metric="correlation", eps=0.1):
        self.find_correspondences(normalize=normalize, norm=norm, bySample=bySample,
                                  k=k, mode=mode, metric=metric, eps=eps)
        return self.barycentric_projection()


# ========== MAIN ==========
if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    # 输入路径
    X_path = "/Users/apple/Desktop/KLin_Group/Project_2024/data/Morpho_data/dataset/Scala/exon_data_top2000.csv"
    Y_path = "/Users/apple/Desktop/KLin_Group/Project_2024/data/Morpho_data/dataset/Scala/gw_dist.csv"

    # 输出路径
    output_dir = Path("/Users/apple/Desktop/SCOT")
    output_dir.mkdir(parents=True, exist_ok=True)

    # ===== 正确读取数据 =====
    # 让 pandas 自动识别第一行 header
    # 第一列为 cell ID，当 index 用，数据部分是剩下的 2000 列
    X_df = pd.read_csv(X_path, index_col=0)   # header=0 默认
    Y_df = pd.read_csv(Y_path, index_col=0)   # 距离矩阵，同样第一列是 index

    X = X_df.to_numpy()   # 期望 (645, 2000)
    Y = Y_df.to_numpy()   # 期望 (645, 645)

    min_n = min(X.shape[0], Y.shape[0])
    X, Y = X[:min_n, :], Y[:min_n, :]

    print(f"Dataset shapes - X: {X.shape}, Y: {Y.shape}")
    assert X.shape[0] == Y.shape[0], "X 和 Y 的行数不一致，请检查输入文件是否匹配"

    # 运行 SCOT 对齐
    SCOT = SCOTv2([Y, X])   # 我们把 X 对齐到 Y
    aligned = SCOT.align(normalize=True, k=50, eps=0.1)

    # 生成 UMAP
    print("Generating 2D UMAP embeddings...")
    reducer = umap.UMAP(n_components=2, random_state=42)
    umap_Y = reducer.fit_transform(aligned[0])
    umap_X = reducer.fit_transform(aligned[1])

    # 保存结果
    pd.DataFrame(umap_Y, columns=["UMAP1", "UMAP2"]).to_csv(output_dir / "Y_umap.csv", index=False)
    pd.DataFrame(umap_X, columns=["UMAP1", "UMAP2"]).to_csv(output_dir / "X_umap.csv", index=False)

    print("✅ Saved:")
    print(f"  {output_dir}/Y_umap.csv")
    print(f"  {output_dir}/X_umap.csv")


Dataset shapes - X: (644, 2000), Y: (644, 645)
Generating 2D UMAP embeddings...
✅ Saved:
  /Users/apple/Desktop/SCOT/Y_umap.csv
  /Users/apple/Desktop/SCOT/X_umap.csv
