In [1]:
import math, os
from time import time

import sys
sys.path.append('./spaMultiVAE/') 

os.environ['R_HOME'] = '/home/ws6tg/anaconda3/envs/EnvR43/lib/R/'
import torch
from spaMultiVAE import SPAMULTIVAE
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors
import h5py
import scanpy as sc
from preprocess import normalize, geneSelection

In [34]:
path="../datasets/10x_human_lymph_node_D1/"
result_path=path.replace("datasets","results")

'''
Parameter setting
'''

class Args(object):
    def __init__(self):
        self.data_file = 'humantonsil_SVG.h5'
        self.select_genes = 0
        self.select_proteins = 0
        self.batch_size = "auto"
        self.maxiter = 100
        self.train_size = 0.95
        self.patience = 20
        self.lr = 5e-3
        self.weight_decay = 1e-6
        self.gene_noise = 0
        self.protein_noise = 0
        self.dropoutE = 0
        self.dropoutD = 0
        self.encoder_layers = [128, 64]
        self.GP_dim = 2
        self.Normal_dim = 18
        self.gene_decoder_layers = [128]
        self.protein_decoder_layers = [128]
        self.init_beta = 10
        self.min_beta = 4
        self.max_beta = 25
        self.KL_loss = 0.025  
        self.num_samples = 1
        self.fix_inducing_points = True
        self.inducing_point_steps = 19
        self.fixed_gp_params = False
        self.loc_range = 20.
        self.kernel_scale = 20.
        self.model_file = result_path+"model.pt"
        self.final_latent_file = "final_latent.txt"
        self.gene_denoised_counts_file = "gene_denoised_counts.txt"
        self.protein_denoised_counts_file = "protein_denoised_counts.txt"
        self.protein_sigmoid_file = "protein_sigmoid.txt"
        self.gene_enhanced_denoised_counts_file = "gene_enhanced_denoised_counts.txt"
        self.protein_enhanced_denoised_counts_file = "protein_enhanced_denoised_counts.txt"
        self.enhanced_loc_file = "enhanced_loc.txt"
        self.device = "cuda:3"

args = Args()

# data_mat = h5py.File(args.data_file, 'r')
adata2=sc.read_h5ad(path+"adata_ADT.h5ad")
adata1=sc.read_h5ad(path+"adata_RNA.h5ad")


x1 = np.array(adata1.X.toarray()).astype('float64')     # gene count matrix
x2 = np.array(adata2.X.toarray()).astype('float64')  # protein count matrix
loc = np.array(adata1.obsm["spatial"]).astype('float64')       # location information

if args.batch_size == "auto":
    if x1.shape[0] <= 1024:
        args.batch_size = 128
    elif x1.shape[0] <= 2048:
        args.batch_size = 256
    else:
        args.batch_size = 512
else:
    args.batch_size = int(args.batch_size)
    
print(args)

if args.select_genes > 0:
    importantGenes = geneSelection(x1, n=args.select_genes, plot=False)
    x1 = x1[:, importantGenes]
    np.savetxt("selected_genes.txt", importantGenes, delimiter=",", fmt="%i")

if args.select_proteins > 0:
    importantProteins = geneSelection(x2, n=args.select_proteins, plot=False)
    x2 = x2[:, importantProteins]
    np.savetxt("selected_proteins.txt", importantProteins, delimiter=",", fmt="%i")

scaler = MinMaxScaler()
loc = scaler.fit_transform(loc) * args.loc_range

print(x1.shape)
print(x2.shape)
print(loc.shape)

eps = 1e-5
initial_inducing_points = np.mgrid[0:(1+eps):(1./args.inducing_point_steps), 0:(1+eps):(1./args.inducing_point_steps)].reshape(2, -1).T * args.loc_range
print(initial_inducing_points.shape)

adata1 = sc.AnnData(x1, dtype="float64")
adata1 = normalize(adata1,
                  size_factors=True,
                  filter_min_counts=False,
                  normalize_input=True,
                  logtrans_input=True)

adata2 = sc.AnnData(x2, dtype="float64")
adata2 = normalize(adata2,
                  size_factors=False,
                  filter_min_counts=False,
                  normalize_input=True,
                  logtrans_input=True)
adata2.X = np.nan_to_num(adata2.X, nan=0.0)

adata2_no_scale = sc.AnnData(x2, dtype="float64")
adata2_no_scale = normalize(adata2_no_scale,
                  size_factors=False,
                  filter_min_counts=False,
                  normalize_input=False,
                  logtrans_input=True)


adata2_no_scale.X = np.nan_to_num(adata2_no_scale.X, nan=0.0)  # 将NaN替换为0

# Fit GMM model to the protein counts and use the smaller component as the initial values as protein background prior
gm = GaussianMixture(n_components=2, covariance_type="diag", n_init=20).fit(adata2_no_scale.X)
back_idx = np.argmin(gm.means_, axis=0)
protein_log_back_mean = np.log(np.expm1(gm.means_[back_idx, np.arange(adata2_no_scale.n_vars)]))
protein_log_back_mean=np.nan_to_num(protein_log_back_mean, nan=0.0)
protein_log_back_scale = np.sqrt(gm.covariances_[back_idx, np.arange(adata2_no_scale.n_vars)])
print("protein_back_mean shape", protein_log_back_mean.shape)

model = SPAMULTIVAE(gene_dim=adata1.n_vars, protein_dim=adata2.n_vars, GP_dim=args.GP_dim, Normal_dim=args.Normal_dim, 
    encoder_layers=args.encoder_layers, gene_decoder_layers=args.gene_decoder_layers, protein_decoder_layers=args.protein_decoder_layers,
    gene_noise=args.gene_noise, protein_noise=args.protein_noise, encoder_dropout=args.dropoutE, decoder_dropout=args.dropoutD,
    fixed_inducing_points=args.fix_inducing_points, initial_inducing_points=initial_inducing_points, 
    fixed_gp_params=args.fixed_gp_params, kernel_scale=args.kernel_scale, N_train=adata1.n_obs, KL_loss=args.KL_loss, init_beta=args.init_beta, min_beta=args.min_beta, 
    max_beta=args.max_beta, protein_back_mean=protein_log_back_mean, protein_back_scale=protein_log_back_scale, dtype=torch.float64, 
    device=args.device,dynamicVAE=False)

print(str(model))

if not os.path.isfile(args.model_file):
    t0 = time()
    model.train_model(pos=loc, gene_ncounts=adata1.X, gene_raw_counts=adata1.raw.to_adata().X, gene_size_factors=adata1.obs["size_factors"].values, 
                protein_ncounts=adata2.X, protein_raw_counts=adata2.raw.to_adata().X,
                lr=args.lr, weight_decay=args.weight_decay, batch_size=args.batch_size, num_samples=args.num_samples,
                train_size=1, maxiter=args.maxiter, patience=args.patience, save_model=True, model_weights=args.model_file)
    print('Training time: %d seconds.' % int(time() - t0))
else:
    model.load_model(args.model_file)

final_latent = model.batching_latent_samples(X=loc, gene_Y=adata1.X, protein_Y=adata2.X, batch_size=args.batch_size)

adata=sc.read_h5ad(path+"adata_RNA.h5ad")
adata.obsm["X_SpaMultiVAE"]=final_latent

adata.write_h5ad(path.replace("datasets","results")+"adataSpaMultiVAE.h5ad")