In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import scipy
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_distances
from sklearn.neighbors import NearestNeighbors
from collections import Counter

import warnings
warnings.filterwarnings("ignore")

In [None]:
def diagonal_merge(df1, df2):
    """   
    Diagonally merging z-matrices    
    """
    
    # Check if df1 is empty
    if df1.empty:
        return df2
    
    # Get the size of two dfs
    rows1, cols1 = df1.shape
    rows2, cols2 = df2.shape

    # Create a combination of new and column names
    new_index = list(df1.index) + list(df2.index)
    new_columns = list(df1.columns) + list(df2.columns)

    # Create a new DataFrame with rows and columns equal to the sum of the rows and columns of the two DataFrames, and fill in NaN
    result = pd.DataFrame(0, index=new_index, columns=new_columns)

    # Fill df1 into the upper left corner
    result.iloc[:rows1, :cols1] = df1.values

    # Fill df2 to the bottom right corner
    result.iloc[rows1:, cols1:] = df2.values

    return result

In [3]:
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
from itertools import product

def FindingNN(adata, k = 25):
    """   
    Find the nearest neighbors within the same batch and the nearest neighbors between different batches    
    args:
        adata: Data
        k: Number of neighbors, the default parameter is 25
    """    
    
    data = adata.X  
    B = adata.obs["batch"]  

    # PCA dimensionality reduction
    pca = PCA(n_components=100)
    data_pca = pca.fit_transform(data)

    # Calculate cosine similarity
    cosine_sim = cosine_similarity(data_pca)

    # Create a DataFrame to save cosine similarity
    cosine_df = pd.DataFrame(cosine_sim, index=adata.obs.index, columns=adata.obs.index)

    # Batch Information
    batches = B.unique()

    # Store the nearest neighbors within the same batch and across batches
    nearest_neighbors_same_batch = {}
    nearest_neighbors_diff_batch = defaultdict(list)

    # Calculate the nearest neighbors within the same batch
    for batch in batches:
        same_batch_cells = B[B == batch].index.tolist()
        for cell in same_batch_cells:
            same_batch_distances = cosine_df.loc[cell, same_batch_cells].drop(cell).nlargest(k)
            nearest_neighbors_same_batch[cell] = same_batch_distances.index.tolist()

    # Calculate the nearest neighbor across batches
    for batch, other_batch in product(batches, batches):
        if batch == other_batch:
            continue
        same_batch_cells = B[B == batch].index
        other_batch_cells = B[B == other_batch].index
        for cell in same_batch_cells:
            neighbors = cosine_df.loc[cell, other_batch_cells].nlargest(k).index.tolist()
            for neighbor in neighbors:
                neighbor_neighbors = cosine_df.loc[neighbor, same_batch_cells].nlargest(k).index.tolist()
                if cell in neighbor_neighbors:
                    nearest_neighbors_diff_batch[cell].append(neighbor)

    # Create adjacency matrix M
    n = len(adata)
    M = np.zeros((n, n), dtype=int)

    for i, cell_i in enumerate(adata.obs_names):
        same_batch_indices = [adata.obs_names.get_loc(neighbor) for neighbor in nearest_neighbors_same_batch.get(cell_i, [])]
        M[i, same_batch_indices] = 1
        M[same_batch_indices, i] = 1  

        diff_batch_indices = [adata.obs_names.get_loc(neighbor) for neighbor in nearest_neighbors_diff_batch.get(cell_i, [])]
        M[i, diff_batch_indices] = 1
        M[diff_batch_indices, i] = 1 

    return M


In [4]:
from scipy.sparse import coo_matrix

def compute_S_matrix(Z_matrix_total, M):
    """
    Construct similarity matrices between different clusters based on the z-matrix and adjacency matrix
    args:    
        Z_matrix_total: z-matrix
        M: adjacency matrix
    """    

    M_sparse = coo_matrix(M)
    K = Z_matrix_total.shape[1]

    # Initialize similarity matrix S
    S = np.zeros((K, K))

    for i, j in zip(M_sparse.row, M_sparse.col):
        Z_i = Z_matrix_total.iloc[i].values 
        Z_j = Z_matrix_total.iloc[j].values  

        S += np.outer(Z_i, Z_j)
        
    type_sums = Z_matrix_total.sum(axis=0).values  # 每列求和
    for i in range(K):
        for j in range(K):
            if type_sums[i] != 0 and type_sums[j] != 0:
                S[i, j] /= np.sqrt(type_sums[i] * type_sums[j])

    np.fill_diagonal(S, 0)
    S_df = pd.DataFrame(S, index=Z_matrix_total.columns, columns=Z_matrix_total.columns)

    return S_df



In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors
def nn_search(X, Y=None, k=25):
    """
    Computes nearest neighbors in Y for points in X
    args:
        X: nxd tensor of query points
        Y: mxd tensor of data points (optional)
        k: number of neighbors
    """
    if Y is None:
        Y = X
    X = X
    Y = Y
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(Y)
    Dis, Ids = nbrs.kneighbors(X)
    return Dis, Ids
def adjust_weights0(model,K = 3,number_of_items= 2000):
    """
    Adjust weights to ensure monotonicity
    """
    with torch.no_grad():
        # 例如，确保layer1和layer2的第一个权重同号
        for i in range(number_of_items):
            if model.layer1.weight[K*i, i].item() * model.layer2.weight[i, K*i].item() < 0:
                model.layer2.weight[i, K*i].neg_()
            if model.layer1.weight[K*i+1, i].item() * model.layer2.weight[i, K*i+1].item() < 0:
                model.layer2.weight[i, K*i+1].neg_()
            if model.layer1.weight[K*i+2, i].item() * model.layer2.weight[i, K*i+2].item() < 0:
                model.layer2.weight[i, K*i+2].neg_()
            if model.layer3.weight[K*i, i].item() * model.layer4.weight[i, K*i].item() < 0:
                model.layer4.weight[i, K*i].neg_()
            if model.layer3.weight[K*i+1, i].item() * model.layer4.weight[i, K*i+1].item() < 0:
                model.layer4.weight[i, K*i+1].neg_()
            if model.layer3.weight[K*i+2, i].item() * model.layer4.weight[i, K*i+2].item() < 0:
                model.layer4.weight[i, K*i+2].neg_()
            if model.layer5.weight[K*i, i].item() * model.layer6.weight[i, K*i].item() < 0:
                model.layer6.weight[i, K*i].neg_()
            if model.layer5.weight[K*i+1, i].item() * model.layer6.weight[i, K*i+1].item() < 0:
                model.layer6.weight[i, K*i+1].neg_()
            if model.layer5.weight[K*i+2, i].item() * model.layer6.weight[i, K*i+2].item() < 0:
                model.layer6.weight[i, K*i+2].neg_()
def adjust_weights(model,K = 3,number_of_items= 2000):
    """
    Adjust weights to ensure monotonicity
    """
    with torch.no_grad():
        # 例如，确保layer1和layer2的第一个权重同号
        for i in range(number_of_items):
            if model.layer1.weight[K*i, i].item() * model.layer2.weight[i, K*i].item() < 0:
                model.layer2.weight[i, K*i]=0
            if model.layer1.weight[K*i+1, i].item() * model.layer2.weight[i, K*i+1].item() < 0:
                model.layer2.weight[i, K*i+1]=0
            if model.layer1.weight[K*i+2, i].item() * model.layer2.weight[i, K*i+2].item() < 0:
                model.layer2.weight[i, K*i+2]=0
            if model.layer3.weight[K*i, i].item() * model.layer4.weight[i, K*i].item() < 0:
                model.layer4.weight[i, K*i]=0
            if model.layer3.weight[K*i+1, i].item() * model.layer4.weight[i, K*i+1].item() < 0:
                model.layer4.weight[i, K*i+1]=0
            if model.layer3.weight[K*i+2, i].item() * model.layer4.weight[i, K*i+2].item() < 0:
                model.layer4.weight[i, K*i+2]=0
            if model.layer5.weight[K*i, i].item() * model.layer6.weight[i, K*i].item() < 0:
                model.layer6.weight[i, K*i]=0
            if model.layer5.weight[K*i+1, i].item() * model.layer6.weight[i, K*i+1].item() < 0:
                model.layer6.weight[i, K*i+1]=0
            if model.layer5.weight[K*i+2, i].item() * model.layer6.weight[i, K*i+2].item() < 0:
                model.layer6.weight[i, K*i+2]=0


def gaussian_kernel(x, y, sigma):
    """
    gaussian_kernel
    args：
        sigma: standard deviation
    """
    beta = 1. / (0.5*sigma ** 2)
    dist = torch.cdist(x, y)**2
    return torch.exp(-beta * dist)
    

def compute_mmd(x, y, sigma1,sigma2,sigma3,Z1,Z2):
    """
    Calculate the maximum mean divergence two sets of samples
    args：
        x: the first group of samples
        y: the second group of samples
        sigma1: the first standard deviation of a mixture Gaussian kernel
        sigma2: the second standard deviation of a mixture Gaussian kernel
        sigma3: the third standard deviation of a mixture Gaussian kernel
        Z1: the weight of each sample in the first group of samples
        Z2: the weight of each sample in the second group of samples
    """
    x_kernel = gaussian_kernel(x, x, sigma1)+gaussian_kernel(x, x, sigma2)+gaussian_kernel(x, x, sigma3)
    y_kernel = gaussian_kernel(y, y, sigma1)+gaussian_kernel(y, y, sigma2)+gaussian_kernel(y, y, sigma3)
    xy_kernel = gaussian_kernel(x, y, sigma1)+gaussian_kernel(x, y, sigma2)+gaussian_kernel(x, y, sigma3)
    Mask1 = torch.mm(Z1.unsqueeze(1), Z1.unsqueeze(1).t())
    Mask2 = torch.mm(Z1.unsqueeze(1), Z2.unsqueeze(1).t())
    Mask3 = torch.mm(Z2.unsqueeze(1), Z2.unsqueeze(1).t())
    mmd = torch.sum(Mask1*x_kernel)/torch.sum(Mask1) + torch.sum(Mask3*y_kernel)/torch.sum(Mask3) - 2 * torch.sum(Mask2*xy_kernel)/torch.sum(Mask2)
    return mmd
    
# Class: Calculate the maximum mean 
class MMDLoss(nn.Module):
    def __init__(self, sigma1,sigma2,sigma3):
        super(MMDLoss, self).__init__()
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.sigma3 = sigma3

    def forward(self, x, y,Z1,Z2):
        return compute_mmd(x, y,self.sigma1,self.sigma2,self.sigma3,Z1,Z2)

# Initialize weights
def init_weights_positive(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.uniform_(m.weight, a=-1/2, b=1/2) 

In [6]:
def find_columns(df, start_row, s, visited=None):
    """
    Search for clusters that may be similar based on the similarity matrix and similarity threshold
    args：
        df: similarity matrix
        start_row: a starting cluster
        s: similarity threshold
    """
    if visited is None:
        visited = set()

    results = []
    columns1 = df.columns[df.loc[start_row] > s]
    columns1 = [result for result in set(columns1) if result.startswith(start_row[0:6])]
    columns2 = df.columns[df.loc[start_row] > s]
    columns2 = [result for result in set(columns2) if not result.startswith(start_row[0:6])]

    columns = columns1+columns2
    for column in columns:
        if column not in visited:
            visited.add(column)
            results.append(column)
            additional_results, visited = find_columns(df, column, s,visited)
            results.extend(additional_results)
    return results, visited

In [7]:
def find_threshold(df, batch_name):
    """
    Select similarity threshold for merging and matching clusters
    args：
        df: similarity matrix
        batch_name: the name of query batch
    """
    
    for s in np.arange(0.5, 20, 0.05):
        visited_all = []
        visited_all_cur = 0
        for start_cluster in df.columns[df.columns.str.contains(batch_name)]:
            if start_cluster not in visited_all:
                start_row = start_cluster
                results, visited = find_columns(df, start_row, s)

                # 将结果根据前缀分为两部分
                batch1_results = [result for result in set(results) if result.startswith(batch_name[0:6])]
                batch2_results = [result for result in set(results) if result.startswith('batch0_')]
                # 提取数字并转换为整数
                batch1_indices = [int(x.split('_')[1]) for x in batch1_results]
                batch2_indices = [int(x.split('_')[1]) for x in batch2_results]
                visited_all =  list(set(visited_all+batch1_results))
                if len(batch2_indices)!=0:
                    visited_all_cur = visited_all_cur+len(batch1_indices)+len(batch2_indices)
        if s == 0.5:
            visited_all_pre = visited_all_cur
        if visited_all_cur < visited_all_pre:
            break
        else: threshold = s
    return threshold
    

        

In [514]:
import random
def order_preserving_correction(adata, option = 'Global', preprocessing = False, methods = 'Louvain', ASW = True, resolution=1, epochs = 250):   
    """
    Batch effect correction with the order-preserving feature
    
    args:
        adata: The input data file in AnnData format, which can contain either single-cell RNA sequencing or bulk RNA sequencing data.
        option: A string parameter that specifies the type of monotonic model to use. Set to 'Global' by default, which utilizes a global monotonic model. Alternatively, a 'Partial' model can be chosen.
        preprocessing: A boolean parameter (default is False) that indicates whether to preprocess the data. It is recommended to preprocess the raw count data.
        methods: A string that specifies the clustering algorithm for initialization. The default method is 'Louvain', with alternative options including 'Leiden' and 'GMM'.
        ASW: A boolean parameter (default is True). When using the Louvain or Leiden algorithms, it determines the resolution parameter based on the Average Silhouette Width (ASW).
        resolution: A numeric parameter (default is 1) that sets the custom resolution parameter when ASW is set to False.
        epochs: An integer that defines the number of training iterations, with a default value of 250.    
    """   

    random.seed(123)
    batch_counts = adata.obs['batch'].value_counts()
    # Create a new batch label
    new_labels = {old: f'batch{i}' for i, old in enumerate(batch_counts.sort_values(ascending=False).index)}
    # Replace the original batch label
    adata.obs['batch'] = adata.obs['batch'].map(new_labels)
    batches = Counter(adata.obs["batch"])
    most_common_batch = batches.most_common(1)[0][0]
    batch_numbers = len(batches) 
    Z_matrix_total = pd.DataFrame()   
    
    if preprocessing == True:
        sc.pp.filter_cells(adata, min_genes=10)
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
        #sc.pp.log1p(adata)
        adata.obs["batch"] = adata.obs["batch"].astype('category')
        sc.pp.highly_variable_genes(adata, n_top_genes=2000, batch_key='batch')
        adata = adata[:,adata.var["highly_variable"]==True]
    


    if methods == 'Louvain':
        for batch in adata.obs['batch'].unique():
            adata_batch = adata[adata.obs['batch'] == batch]
            sc.pp.pca(adata_batch)
            sc.pp.neighbors(adata_batch)
            sc.tl.umap(adata_batch)
            if ASW == True:
                asw_score_pre = 0
                asw_score_pre2 = 0
                for resolution in np.arange(0.05, 3, 0.01):
                    sc.tl.louvain(adata_batch, resolution=resolution)
                    X = adata_batch.obsm['X_pca']
                    Y = adata_batch.obsm['X_umap']
                    labels = adata_batch.obs['louvain']
                    if len(labels.value_counts()) == 1:
                        break
                    asw_score = silhouette_score(X, labels)
                    asw_score2 = silhouette_score(Y, labels)
                    if (asw_score < asw_score_pre)&(asw_score2 < asw_score_pre2):
                        resolution = resolution - 0.01
                        print(resolution)
                        break
                    else:
                        asw_score_pre = asw_score
                        asw_score_pre2 = asw_score2
            sc.tl.louvain(adata_batch, resolution=resolution)
            adata_batch.zmatrix = pd.get_dummies(adata_batch.obs['louvain'])
            num_cols = adata_batch.zmatrix.shape[1]
            adata_batch.zmatrix.columns = ['%s_%d' % (batch, i) for i in range(num_cols)]
            adata_batch.zmatrix.index = np.where(adata.obs['batch'] ==  batch)[0]
            Z_matrix_total = diagonal_merge(Z_matrix_total, adata_batch.zmatrix)

        
    
    if methods == 'Leiden':
        for batch in adata.obs['batch'].unique():
            adata_batch = adata[adata.obs['batch'] == batch]
            sc.pp.pca(adata_batch)
            sc.pp.neighbors(adata_batch)
            sc.tl.umap(adata_batch)
            if ASW == True:
                asw_score_pre = 0
                asw_score_pre2 = 0
                for resolution in np.arange(0.05, 3, 0.01):
                    sc.tl.leiden(adata_batch, resolution=resolution)
                    X = adata_batch.obsm['X_pca']
                    Y = adata_batch.obsm['X_umap']
                    labels = adata_batch.obs['leiden']
                    if len(labels.value_counts()) == 1:
                        break
                    asw_score = silhouette_score(X, labels)
                    asw_score2 = silhouette_score(Y, labels)
                    if (asw_score < asw_score_pre)&(asw_score2 < asw_score_pre2):
                        resolution = resolution - 0.01
                        break
                    else:
                        asw_score_pre = asw_score
                        asw_score_pre2 = asw_score2
            sc.tl.leiden(adata_batch, resolution=resolution)
            adata_batch.zmatrix = pd.get_dummies(adata_batch.obs['leiden'])
            num_cols = adata_batch.zmatrix.shape[1]
            adata_batch.zmatrix.columns = ['%s_%d' % (batch, i) for i in range(num_cols)]
            adata_batch.zmatrix.index = np.where(adata.obs['batch'] ==  batch)[0]
            Z_matrix_total = diagonal_merge(Z_matrix_total, adata_batch.zmatrix)
         
    if methods == 'GMM':
        import numpy as np
        from sklearn.mixture import GaussianMixture
        def gmm_with_bic(X, max_components=20):
            best_gmm = None
            best_bic = np.inf  
            best_n_components = 1

            for n_components in range(2, max_components + 1):
                gmm = GaussianMixture(n_components=n_components, random_state=123)
                gmm.fit(X)

                bic = gmm.bic(X)  # 计算BIC

                # If the current BIC is smaller, update the best model
                if bic < best_bic:
                    best_bic = bic
                    best_gmm = gmm
                    best_n_components = n_components
                    
            # Return the posterior probability matrix Z of the best model
            Z_matrix = best_gmm.predict_proba(X)
            
            return Z_matrix
            
        for batch in adata.obs['batch'].unique():
            adata_batch = adata[adata.obs['batch'] == batch]
            sc.pp.pca(adata)
            sc.pp.neighbors(adata_batch) 
            sc.tl.umap(adata_batch)
            adata_batch.zmatrix = pd.DataFrame(gmm_with_bic(adata_batch.obsm["X_umap"]))
            num_cols = adata_batch.zmatrix.shape[1]
            adata_batch.zmatrix.columns = ['%s_%d' % (batch, i) for i in range(num_cols)]
            adata_batch.zmatrix.index = np.where(adata.obs['batch'] ==  batch)[0]
            Z_matrix_total = diagonal_merge(Z_matrix_total, adata_batch.zmatrix)

            
    M = FindingNN(adata)
    S = compute_S_matrix(Z_matrix_total, M)
    
    # set it to True to use GPU and False to use CPU
    use_gpu = True  
    if use_gpu:
        torch.cuda.set_device(0)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # the number of genes
    number_of_items = adata.X.shape[1]
    
    # the number of nodes
    K = 3  
    
    if option == 'Partial':
        class Generator(nn.Module):
            def __init__(self,G):
                super(Generator, self).__init__()
                self.layer1 = nn.Linear(in_features=number_of_items+G, out_features=K*number_of_items)
                self.layer2 = nn.Linear(in_features=K*number_of_items, out_features=number_of_items)
                nn.init.zeros_(self.layer1.bias)
                nn.init.zeros_(self.layer2.bias)
                self.leaky_relu = nn.LeakyReLU(negative_slope=0.1)
                self.layer3 = nn.Linear(in_features=number_of_items+G, out_features=K*number_of_items)
                self.layer4 = nn.Linear(in_features=K*number_of_items, out_features=number_of_items)
                nn.init.zeros_(self.layer3.bias)
                nn.init.zeros_(self.layer4.bias)
                self.layer5 = nn.Linear(in_features=number_of_items+G, out_features=K*number_of_items)
                self.layer6 = nn.Linear(in_features=K*number_of_items, out_features=number_of_items)
                nn.init.zeros_(self.layer5.bias)
                nn.init.zeros_(self.layer6.bias)
                self.bias1 = nn.Parameter(torch.zeros(number_of_items))
                self.bias2 = nn.Parameter(torch.zeros(number_of_items))
                self.bias3 = nn.Parameter(torch.zeros(number_of_items))

            def forward(self, x,label):
                inputs =  torch.cat(dim = 1, tensors = [x,label])
                h1 = F.linear(inputs, self.layer1.weight*mask1, self.layer1.bias)
                outputs1 = F.linear(h1, self.layer2.weight*mask2, self.layer2.bias)
                outputs1 = self.leaky_relu(outputs1)
                outputs2 = F.linear(h1, self.layer2.weight*mask3/number_of_items, self.bias1)
                output = outputs1+MASK2*outputs2/number_of_items

                output = torch.cat(dim = 1, tensors = [output,label])
                output = F.linear(output, self.layer3.weight*mask1, self.layer3.bias)
                outputs1 = F.linear(output, self.layer4.weight*mask2, self.layer4.bias)
                outputs1 = self.leaky_relu(outputs1)
                outputs2 = F.linear(output, self.layer4.weight*mask3/number_of_items, self.bias2)        
                output = outputs1+MASK2*outputs2/number_of_items
                output = x/2+output

                output = torch.cat(dim = 1, tensors = [output,label])
                h1 = F.linear(output, self.layer5.weight*mask1, self.layer5.bias)
                outputs1 = F.linear(h1, self.layer6.weight*mask2, self.layer6.bias)
                outputs1 = self.leaky_relu(outputs1)
                outputs2 = F.linear(h1, self.layer6.weight*mask3/number_of_items, self.bias3)
                return outputs1+MASK2*outputs2/number_of_items

    if option == 'Global':
        class Generator(nn.Module):
            def __init__(self):
                super(Generator, self).__init__()
                self.layer1 = nn.Linear(in_features=number_of_items, out_features=K*number_of_items)
                self.layer2 = nn.Linear(in_features=K*number_of_items, out_features=number_of_items)
                nn.init.zeros_(self.layer1.bias)
                nn.init.zeros_(self.layer2.bias)
                self.leaky_relu = nn.LeakyReLU(negative_slope=0.1)
                self.layer3 = nn.Linear(in_features=number_of_items, out_features=K*number_of_items)
                self.layer4 = nn.Linear(in_features=K*number_of_items, out_features=number_of_items)
                nn.init.zeros_(self.layer3.bias)
                nn.init.zeros_(self.layer4.bias)
                self.layer5 = nn.Linear(in_features=number_of_items, out_features=K*number_of_items)
                self.layer6 = nn.Linear(in_features=K*number_of_items, out_features=number_of_items)
                nn.init.zeros_(self.layer5.bias)
                nn.init.zeros_(self.layer6.bias)
                self.bias1 = nn.Parameter(torch.zeros(number_of_items))
                self.bias2 = nn.Parameter(torch.zeros(number_of_items))
                self.bias3 = nn.Parameter(torch.zeros(number_of_items))

            def forward(self, x):
                inputs =  x
                h1 = F.linear(inputs, self.layer1.weight*mask1, self.layer1.bias)
                outputs1 = F.linear(h1, self.layer2.weight*mask2, self.layer2.bias)
                outputs1 = self.leaky_relu(outputs1)
                outputs2 = F.linear(h1, self.layer2.weight*mask3,self.bias1)
                output = outputs1+MASK2*outputs2/number_of_items

                output = output
                output = F.linear(output, self.layer3.weight*mask1, self.layer3.bias)
                outputs1 = F.linear(output, self.layer4.weight*mask2, self.layer4.bias)
                outputs1 = self.leaky_relu(outputs1)
                outputs2 = F.linear(output, self.layer4.weight*mask3,self.bias2)
                output = outputs1+MASK2*outputs2/number_of_items

                output = x/2+output
                output = F.linear(output, self.layer5.weight*mask1, self.layer5.bias)
                outputs1 = F.linear(output, self.layer6.weight*mask2, self.layer6.bias)
                outputs1 = self.leaky_relu(outputs1)
                outputs2 = F.linear(output, self.layer6.weight*mask3, self.bias3)

                return outputs1+MASK2*outputs2/number_of_items

    if option == 'Partial':
        import numpy as np 
        reference_batch = pd.DataFrame(adata.X[adata.obs['batch'] == 'batch0'])
        batch_name =  adata.obs['batch'].unique()
        query_batch_name = batch_name[batch_name != 'batch0']
        resultall = reference_batch
        train_2 = reference_batch
        target_data = torch.tensor(train_2.values).to(device)
        Dis, Ids = nn_search(train_2,k=25)
        sigma1 = np.percentile(Dis.mean(axis=1), 50)/2
        sigma2 = np.percentile(Dis.mean(axis=1), 50)
        sigma3 = np.percentile(Dis.mean(axis=1), 50)*2
        mmdloss = MMDLoss(sigma1=sigma1,sigma2=sigma2,sigma3=sigma3)
        for z in query_batch_name:
            train_1 = pd.DataFrame(adata.X[adata.obs['batch'] == z])
            source_data = torch.tensor(train_1.values).to(device)
            MASK1 = torch.ones((train_1.shape[0],train_1.shape[1])).to(device)
            MASK2 = torch.ones((train_1.shape[0],train_1.shape[1])).to(device)
            MASK1[source_data==0]=0
            MASK2[source_data!=0]=0
            g = Z_matrix_total.columns.str.contains(z).sum()
            mask1 =  torch.ones((K*number_of_items,number_of_items+g)).to(device)
            for i in range(number_of_items):
                mask1[0:(K*i),i]=0
                mask1[(K*i+K):(K*number_of_items),i]=0
            mask2 =  torch.ones((number_of_items,K*number_of_items)).to(device)
            for i in range(number_of_items):
                mask2[i,0:(K*i)]=0
                mask2[i,(K*i+K):(K*number_of_items)]=0    

            mask3 =  torch.zeros((number_of_items,K*number_of_items)).to(device)
            for i in range(number_of_items):
                mask3[i,0:(K*i)]=1
                mask3[i,(K*i+K):(K*number_of_items)]=1        

            seed = 123  
            np.random.seed(seed)
            torch.manual_seed(seed)
            
            generator = Generator(G=g).to(device)
            generator.apply(init_weights_positive)
            adjust_weights0(generator,K = 3,number_of_items=number_of_items)
            
            lr = 0.01
            g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
            
            g_loss_values = []
            threshold = find_threshold(S,z)

            for epoch in range(epochs):
                if Z_matrix_total.empty:
                    raise ValueError("Z_matrix_total is not defined or is empty.")
                if use_gpu is True:
                    source_data = torch.tensor(train_1.values).to(device)
                    target_data = torch.tensor(train_2.values).to(device)
                    Z1_all = torch.tensor(Z_matrix_total.loc[(adata.obs == z).values,Z_matrix_total.columns.str.contains(z)].values).to(device)
                    Z2_all = torch.tensor(Z_matrix_total.loc[(adata.obs == 'batch0').values,Z_matrix_total.columns.str.contains('batch0')].values).to(device)
                else:
                    source_data = torch.tensor(train_1.values)
                    target_data = torch.tensor(train_2.values)
                    Z1_all = torch.tensor(Z_matrix_total.loc[(adata.obs == z).values,Z_matrix_total.columns.str.contains(z)].values)
                    Z2_all = torch.tensor(Z_matrix_total.loc[(adata.obs == 'batch0').values,Z_matrix_total.columns.str.contains('batch0')].values)


                Z1_all =Z1_all.float()
                Z2_all =Z2_all.float()
                source_data = source_data.float()
                target_data = target_data.float()
                Label_source = Z1_all.float()


                # Training Generator
                g_optimizer.zero_grad()
                g_output = generator(source_data,Label_source)


                g_loss = 0
                visited_all = []
                for start_cluster in Z_matrix_total.columns[Z_matrix_total.columns.str.contains(z)]:
                    if start_cluster not in visited_all:
                        start_row = start_cluster
                        results, _ = find_columns(S, start_row,s = threshold)

                        # Divide the results into two parts based on batch information
                        batch1_results = [result for result in set(results) if result.startswith(z)]
                        batch2_results = [result for result in set(results) if result.startswith('batch0_')]
                        batch1_indices = [int(x.split('_')[1]) for x in batch1_results]
                        batch2_indices = [int(x.split('_')[1]) for x in batch2_results]
                        visited_all = visited_all+batch1_results
                        g_loss = g_loss + mmdloss(g_output,target_data,Z1_all[:,batch1_indices].sum(dim=1),Z2_all[:,batch2_indices].sum(dim=1))                
                g_loss.backward()
                g_optimizer.step()

                if (epoch+1) % 5 == 0:
                    adjust_weights(generator,K = 3,number_of_items=number_of_items)
                    print(f'Epoch {epoch+1}/{epochs},  Generator Loss2: {g_loss.item()}')


            g_output_numpy = generator(source_data,Label_source).cpu().detach().numpy()
            result = pd.DataFrame(g_output_numpy)
            result.columns = resultall.columns
            resultall = pd.concat([resultall,result], axis=0)
    if option == 'Global':  
        import numpy as np 
        reference_batch = pd.DataFrame(adata.X[adata.obs['batch'] == 'batch0'])
        batch_name =  adata.obs['batch'].unique()
        query_batch_name = batch_name[batch_name != 'batch0']
        resultall = reference_batch
        train_2 = reference_batch
        target_data = torch.tensor(train_2.values).to(device)
        Dis, Ids = nn_search(train_2,k=25)
        sigma1 = np.percentile(Dis.mean(axis=1), 50)/2
        sigma2 = np.percentile(Dis.mean(axis=1), 50)
        sigma3 = np.percentile(Dis.mean(axis=1), 50)*2
        mmdloss = MMDLoss(sigma1=sigma1,sigma2=sigma2,sigma3=sigma3)
        for z in query_batch_name:
            train_1 = pd.DataFrame(adata.X[adata.obs['batch'] == z])
            source_data = torch.tensor(train_1.values).to(device)
            MASK1 = torch.ones((train_1.shape[0],train_1.shape[1])).to(device)
            MASK2 = torch.ones((train_1.shape[0],train_1.shape[1])).to(device)
            MASK1[source_data==0]=0
            MASK2[source_data!=0]=0
            mask1 =  torch.ones((K*number_of_items,number_of_items)).to(device)
            for i in range(number_of_items):
                mask1[0:(K*i),i]=0
                mask1[(K*i+K):(K*number_of_items),i]=0
            mask2 =  torch.ones((number_of_items,K*number_of_items)).to(device)
            for i in range(number_of_items):
                mask2[i,0:(K*i)]=0
                mask2[i,(K*i+K):(K*number_of_items)]=0    

            mask3 =  torch.zeros((number_of_items,K*number_of_items)).to(device)
            for i in range(number_of_items):
                mask3[i,0:(K*i)]=1
                mask3[i,(K*i+K):(K*number_of_items)]=1        

            seed = 123
            np.random.seed(seed)
            torch.manual_seed(seed)
            
            generator = Generator().to(device)
            generator.apply(init_weights_positive)
            adjust_weights0(generator,K = 3,number_of_items=number_of_items)
            
            lr = 0.01
            g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
            
            g_loss_values = []
            threshold = find_threshold(S,z)


            for epoch in range(epochs):
                if Z_matrix_total.empty:
                    raise ValueError("Z_matrix_total is not defined or is empty.")
                if use_gpu is True:
                    source_data = torch.tensor(train_1.values).to(device)
                    target_data = torch.tensor(train_2.values).to(device)
                    Z1_all = torch.tensor(Z_matrix_total.loc[(adata.obs == z).values,Z_matrix_total.columns.str.contains(z)].values).to(device)
                    Z2_all = torch.tensor(Z_matrix_total.loc[(adata.obs == 'batch0').values,Z_matrix_total.columns.str.contains('batch0')].values).to(device)
                else:
                    source_data = torch.tensor(train_1.values)
                    target_data = torch.tensor(train_2.values)
                    Z1_all = torch.tensor(Z_matrix_total.loc[(adata.obs == z).values,Z_matrix_total.columns.str.contains(z)].values)
                    Z2_all = torch.tensor(Z_matrix_total.loc[(adata.obs == 'batch0').values,Z_matrix_total.columns.str.contains('batch0')].values)


                Z1_all =Z1_all.float()
                Z2_all =Z2_all.float()
                source_data = source_data.float()
                target_data = target_data.float()


                # 训练生成器
                g_optimizer.zero_grad()
                g_output = generator(source_data)


                g_loss = 0
                visited_all = []
                for start_cluster in Z_matrix_total.columns[Z_matrix_total.columns.str.contains(z)]:
                    if start_cluster not in visited_all:
                        start_row = start_cluster
                        results, _ = find_columns(S, start_row,s = threshold)

                        # Divide the results into two parts based on batch information
                        batch1_results = [result for result in set(results) if result.startswith(z)]
                        batch2_results = [result for result in set(results) if result.startswith('batch0_')]
                        batch1_indices = [int(x.split('_')[1]) for x in batch1_results]
                        batch2_indices = [int(x.split('_')[1]) for x in batch2_results]
                        visited_all = visited_all+batch1_results
                        g_loss = g_loss + mmdloss(g_output,target_data,Z1_all[:,batch1_indices].sum(dim=1),Z2_all[:,batch2_indices].sum(dim=1))                
                g_loss.backward()
                g_optimizer.step()

                if (epoch+1) % 5 == 0:
                    adjust_weights(generator,K = 3,number_of_items=number_of_items)
                    print(f'Epoch {epoch+1}/{epochs},  Generator Loss2: {g_loss.item()}')


            g_output_numpy = generator(source_data).cpu().detach().numpy()
            result = pd.DataFrame(g_output_numpy)
            result.columns = resultall.columns
            resultall = pd.concat([resultall,result], axis=0)
    return resultall
        