In [1]:
# In[]
import sys, os
sys.path.append('../')
sys.path.append('../src/')
import numpy as np
import pandas as pd
from scipy import sparse
import networkx as nx

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import time

from sklearn.decomposition import PCA
from sklearn.manifold import MDS

import scDART.diffusion_dist as diff
import scDART.dataset as dataset
import scDART.model as model
import scDART.loss as loss
import scDART.train
import scDART.TI as ti
import scDART.benchmark as bmk
import scDART.de_analy as de

from umap import UMAP

import utils as utils

import scDART.post_align as palign
from scipy.sparse import load_npz

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

plt.rcParams["font.size"] = 20

In [5]:
# In[] scan and find the one with the highest neighborhood overlap score
seeds = [0, 1, 2]
latent_dims = [4, 8, 32]
reg_ds = [1, 10]
reg_gs = [0.01, 1, 10]
reg_mmds = [1, 10, 20, 30]

latent_dim = latent_dims[0]
reg_d = reg_ds[0]
reg_g = reg_gs[1]
# harder to merge, need to make mmd loss larger
reg_mmd = reg_mmds[1]
seed = seeds[0]

learning_rate = 3e-4
n_epochs = 500
use_anchor = False
ts = [30, 50, 70]
use_potential = True
norm = "l1"

print("Random seed: " + str(seed))
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

counts_rna = pd.read_csv("/home/xcx/MYBenchmark-datas/1469/counts_rna.csv", index_col = 0).values
counts_atac = pd.read_csv("/home/xcx/MYBenchmark-datas/1469/counts_atac.csv", index_col = 0).values
label_rna = pd.read_csv("/home/xcx/MYBenchmark-datas/1469/anno.txt", header = None)
label_atac = pd.read_csv("/home/xcx/MYBenchmark-datas/1469/anno.txt", header = None)
rna_dataset = dataset.dataset(counts = counts_rna, anchor = None)
atac_dataset = dataset.dataset(counts = counts_atac, anchor = None)
coarse_reg = torch.FloatTensor(pd.read_csv("/home/xcx/MYBenchmark-datas/1469/region2gene.csv", sep = ",", index_col = 0).values).to(device)

batch_size = int(max([len(rna_dataset),len(atac_dataset)])/4)

train_rna_loader = DataLoader(rna_dataset, batch_size = batch_size, shuffle = True)
train_atac_loader = DataLoader(atac_dataset, batch_size = batch_size, shuffle = True)

EMBED_CONFIG = {
    'gact_layers': [atac_dataset.counts.shape[1], 1024, 512, rna_dataset.counts.shape[1]], 
    'proj_layers': [rna_dataset.counts.shape[1], 512, 128, latent_dim], # number of nodes in each 
    'learning_rate': learning_rate,
    'n_epochs': n_epochs + 1,
    'use_anchor': use_anchor,
    'reg_d': reg_d,
    'reg_g': reg_g,
    'reg_mmd': reg_mmd,
    'l_dist_type': 'kl',
    'device': device
}

Random seed: 0


In [6]:
# calculate the diffusion distance
dist_rna = diff.diffu_distance(rna_dataset.counts.numpy(), ts = ts,
                                use_potential = use_potential, dr = "pca", n_components = 30)

dist_atac = diff.diffu_distance(atac_dataset.counts.numpy(), ts = ts,
                                use_potential = use_potential, dr = "lsi", n_components = 30)

dist_rna = dist_rna/np.linalg.norm(dist_rna)
dist_atac = dist_atac/np.linalg.norm(dist_atac)
dist_rna = torch.FloatTensor(dist_rna).to(device)
dist_atac = torch.FloatTensor(dist_atac).to(device)
# initialize the model
gene_act = model.gene_act(features = EMBED_CONFIG["gact_layers"], dropout_rate = 0.0, negative_slope = 0.2).to(device)
encoder = model.Encoder(features = EMBED_CONFIG["proj_layers"], dropout_rate = 0.0, negative_slope = 0.2).to(device)
model_dict = {"gene_act": gene_act, "encoder": encoder}

opt_genact = torch.optim.Adam(gene_act.parameters(), lr = learning_rate)
opt_encoder = torch.optim.Adam(encoder.parameters(), lr = learning_rate)
opt_dict = {"gene_act": opt_genact, "encoder": opt_encoder}

running time(sec): 1.3912913799285889
running time(sec): 1.2345635890960693
running time(sec): 1.3471260070800781
running time(sec): 1.317655324935913
running time(sec): 1.5719985961914062
running time(sec): 1.32198166847229


In [20]:
import scDART.train as train
import datetime
starttime = datetime.datetime.now()

# training models
train.match_latent(model = model_dict, opts = opt_dict, dist_atac = dist_atac, dist_rna = dist_rna, 
                data_loader_rna = train_rna_loader, data_loader_atac = train_atac_loader, n_epochs = EMBED_CONFIG["n_epochs"], 
                reg_mtx = coarse_reg, reg_d = EMBED_CONFIG["reg_d"], reg_g = EMBED_CONFIG["reg_g"], reg_mmd = EMBED_CONFIG["reg_mmd"], use_anchor = EMBED_CONFIG["use_anchor"], norm = norm, 
                mode = EMBED_CONFIG["l_dist_type"])

with torch.no_grad():
    z_rna = model_dict["encoder"](rna_dataset.counts.to(device)).cpu().detach()
    z_atac = model_dict["encoder"](model_dict["gene_act"](atac_dataset.counts.to(device))).cpu().detach()

    
# np.save(file = "/home/xcx/results/1469/1-scDART/z_rna_" + str(latent_dim) + "_" + str(reg_d) + "_" + str(reg_g) + "_" + str(reg_mmd) + "_" + str(seed) + "_l1.npy", arr = z_rna.numpy())
# np.save(file = "/home/xcx/results/1469/1-scDART/z_atac_" + str(latent_dim) + "_" + str(reg_d) + "_" + str(reg_g) + "_" + str(reg_mmd) + "_" + str(seed) + "_l1.npy", arr = z_atac.numpy())
# torch.save(model_dict, "/home/xcx/results/1469/1-scDART/model_" + str(latent_dim) + "_" + str(reg_d) + "_" + str(reg_g) + "_" + str(reg_mmd) + "_" + str(seed) + "_l1.pth")

#long running
endtime = datetime.datetime.now()
print((endtime-starttime).seconds)

epoch:  0
	 mmd loss: 0.125
	 ATAC dist loss: 0.052
	 RNA dist loss: 0.045
	 gene activity loss: 0.380
	 anchor matching loss: 0.000
epoch:  100
	 mmd loss: 0.104
	 ATAC dist loss: 0.037
	 RNA dist loss: 0.046
	 gene activity loss: 0.334
	 anchor matching loss: 0.000
epoch:  200
	 mmd loss: 0.104
	 ATAC dist loss: 0.041
	 RNA dist loss: 0.032
	 gene activity loss: 0.321
	 anchor matching loss: 0.000
epoch:  300
	 mmd loss: 0.099
	 ATAC dist loss: 0.044
	 RNA dist loss: 0.043
	 gene activity loss: 0.291
	 anchor matching loss: 0.000
epoch:  400
	 mmd loss: 0.097
	 ATAC dist loss: 0.045
	 RNA dist loss: 0.036
	 gene activity loss: 1.238
	 anchor matching loss: 0.000
epoch:  500
	 mmd loss: 0.108
	 ATAC dist loss: 0.049
	 RNA dist loss: 0.040
	 gene activity loss: 0.986
	 anchor matching loss: 0.000
196


In [9]:
latent_rna = z_rna.numpy()
latent_atac = z_atac.numpy()

df = pd.DataFrame(data=latent_rna)
df.to_csv(os.path.join("/home/xcx/results/1469/1-scDART/z_rna.csv"), index=False)
df = pd.DataFrame(data=latent_atac)
df.to_csv(os.path.join("/home/xcx/results/1469/1-scDART/z_atac.csv"), index=False)