# stPlus

stPlus is a reference-based method for the enhancement of spatial transcriptomics. Leveraging the holistic information in reference scRNA-seq data but not limited to the genes shared with spatial data, stPlus performs non-linear embedding for cells in both datasets and effectively predicts unmeasured spatial gene expression.

The **required inputs** of stPlus including

* **spatial_df**:       normalized and logarithmized original spatial data
* **scrna_df**:         normalized and logarithmized reference scRNA-seq data
* **genes_to_predict**: spatial genes to be predicted

The **output** of stPlus is

* **stPlus_res**:       predicted spatial transcriptomics data

We first import the package of stPlus. 

The frequently used packages such as pandas (pd), numpy (np) and torch will be automatically imported along with stPlus.

In [1]:
from stPlus import *

We can check the fundamentals of stPlus via

In [2]:
help(stPlus)

Help on function stPlus in module stPlus.model:

stPlus(spatial_df, scrna_df, genes_to_predict, save_path_prefix='./stPlus', top_k=3000, t_min=5, data_quality=None, random_seed=None, verbose=True, converge_ratio=0.004, max_epoch_num=10000, batch_size=512, learning_rate=None, weight_decay=0.0002)
    spatial_df:       [pandas dataframe] normalized and logarithmized original spatial data (cell by gene)
    scrna_df:         [pandas dataframe] normalized and logarithmized reference scRNA-seq data (cell by gene)
    genes_to_predict: [1D numpy array] spatial genes to be predicted
    save_path_prefix: [str] prefix of path of trained t models with minimal loss
    top_k:            [int] number of highly variable genes to use
    t_min:            [int] number of epochs with minimal loss using to ensemble learning
    data_quality:     [float] user-specified or 1 minus the sparsity of scRNA-seq data (default)
    random_seed:      [int] random seed in torch
    verbose:          [bool] disp

# Getting Started stPlus Application

Load the normalized and logarithmized spatial and scRNA-seq data, and the genes to predict

The data can be accessed via: git clone git://github.com/xy-chen16/stPlus.git

In [6]:
spatial_df_file = './data/osmFISH_df.csv'
scrna_df_file   = './data/Zeisel_df.csv'
genes_file      = './data/genes_to_predict.txt'

In [7]:
print('Loading data')
spatial_df = pd.read_csv(spatial_df_file)
scrna_df   = pd.read_csv(scrna_df_file)
genes_to_predict = pd.read_csv(genes_file, header=None).iloc[:,0].values

Loading data


In [8]:
spatial_df.head()

Unnamed: 0,Gad2,Slc32a1,Crhbp,Cnr1,Vip,Cpne5,Pthlh,Crh,Tbr1,Lamp5,...,Ctps,Anln,Mrc1,Hexb,Ttr,Foxj1,Vtn,Flt1,Apln,Acta2
0,2.777385,2.926465,0.0,1.390133,0.829559,2.119032,0.0,0.621403,2.536136,1.492039,...,2.009762,1.276651,1.001762,2.065889,1.001762,0.829559,0.621403,0.0,1.001762,0.829559
1,3.699013,3.059924,1.123004,0.346999,0.604114,0.346999,0.346999,0.808458,1.462912,0.0,...,1.123004,0.808458,0.346999,0.808458,0.346999,0.346999,0.604114,0.0,0.978048,0.346999
2,3.663039,3.480582,2.220487,0.521361,0.862314,0.0,1.222337,0.521361,2.143296,0.0,...,0.997242,0.862314,0.706299,2.256957,0.294279,0.862314,0.997242,0.0,0.521361,0.706299
3,3.428742,2.682501,0.831733,1.279431,0.499956,0.0,1.184873,0.0,3.079924,0.963793,...,1.279431,0.0,0.963793,1.587579,0.963793,0.280902,0.280902,0.0,0.280902,0.963793
4,2.433613,3.250374,0.0,2.282382,0.955511,0.0,0.0,0.0,1.223775,4.506454,...,1.435085,0.587787,0.955511,1.223775,0.587787,0.587787,0.587787,0.587787,1.223775,0.0


In [9]:
scrna_df.head()

Unnamed: 0,Tspan12,Tshz1,Fnbp1l,Adamts15,Cldn12,Rxfp1,2310042E22Rik,Sema3c,Jam2,Apbb1ip,...,Rab9,Tceanc,Msl3,Arhgap6,Mid1,Vamp7,Tmlhe,Zf12,Kdm5d,Uty
0,0.0,4.942366,4.942366,0.0,3.857929,0.0,0.0,6.236445,3.857929,0.0,...,5.785577,0.0,0.0,0.0,0.0,5.450333,0.0,0.0,0.0,0.0
1,0.0,3.850649,3.850649,0.0,3.850649,0.0,0.0,0.0,0.0,0.0,...,3.850649,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,5.251013,0.0,3.485126,0.0,4.16283,6.674137,3.485126,0.0,...,3.485126,0.0,4.16283,0.0,0.0,4.563094,0.0,0.0,0.0,5.404414
3,4.523832,4.123776,4.808798,0.0,0.0,0.0,4.523832,3.446682,0.0,0.0,...,4.523832,0.0,3.446682,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,4.542944,3.860383,0.0,0.0,0.0,0.0,6.143832,0.0,0.0,...,0.0,0.0,5.230756,0.0,0.0,4.944856,0.0,0.0,0.0,0.0


In [10]:
genes_to_predict

['Tesc', 'Pvrl3', 'Grm2']

Run stPlus

In [None]:
save_path_prefix = './model/stPlus-demo'
stPlus_res = stPlus(spatial_df, scrna_df, genes_to_predict, save_path_prefix, random_seed=10)

Obtain the following predicted spatial transcriptomics data

In [9]:
stPlus_res.head()

Unnamed: 0,Tesc,Pvrl3,Grm2
0,1.947522,1.568999,0.918464
1,0.832874,1.935234,0.472614
2,0.790593,2.180426,0.344876
3,0.96256,1.850475,0.85958
4,0.091255,2.221686,0.089853


In [None]:
# the results should be
#       Tesc	Pvrl3   	Grm2
# 0	1.668627	2.048264	0.627106
# 1	0.720087	2.617475	0.284386
# 2	0.932298	2.547904	0.466870
# 3	0.990825	1.857638	0.963752
# 4	0.091311	2.091799	0.089835

# Reproduction of 5-fold cross validation

In [3]:
spatial_df_file = './data/osmFISH_df.csv'
scrna_df_file   = './data/Zeisel_df.csv'
raw_spatial_df  = pd.read_csv(spatial_df_file)
raw_scrna_df    = pd.read_csv(scrna_df_file)
print(raw_spatial_df.shape, raw_scrna_df.shape) # cell by gene
raw_shared_gene = np.intersect1d(raw_spatial_df.columns, raw_scrna_df.columns)
print(raw_shared_gene.shape)

(3405, 33) (1691, 15075)
(33,)


In [4]:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=0)
kf.get_n_splits(raw_shared_gene)
torch.manual_seed(10)
idx = 1
for train_ind, test_ind in kf.split(raw_shared_gene):    
    print("\n===== Fold %d =====\nNumber of train genes: %d, Number of test genes: %d"%(idx, len(train_ind), len(test_ind)))
    train_gene = raw_shared_gene[train_ind]
    test_gene  = raw_shared_gene[test_ind]
    test_spatial_df = raw_spatial_df[test_gene]
    spatial_df = raw_spatial_df[train_gene]
    scrna_df   = raw_scrna_df
    
    if idx == 1:
        all_pred_res = pd.DataFrame(np.zeros((spatial_df.shape[0],raw_shared_gene.shape[0])), columns=raw_shared_gene) 
    save_path_prefix = './model/stPlus-demo-fold%d'%(idx)
    stPlus_res = stPlus(spatial_df, scrna_df, test_gene, save_path_prefix)
    all_pred_res[stPlus_res.columns.values] = stPlus_res
    idx += 1


===== Fold 1 =====
Number of train genes: 26, Number of test genes: 7
Models will be saved in: /home/chenshengquan/data/SpaGE/model/stPlus-fold1-5min*.pt

Spatial transcriptomics data: 3405 cells * 26 genes
Reference scRNA-seq data:     1691 cells * 15075 genes
7 genes to be predicted

Start initialization
Start embedding
	[1] recon_loss: 10993.100, pred_loss: 69690.604, total_loss: 80683.704
	[2] recon_loss: 6743.004, pred_loss: 43190.807, total_loss: 49933.811
	[3] recon_loss: 5153.566, pred_loss: 38914.559, total_loss: 44068.125
	[4] recon_loss: 4075.303, pred_loss: 36493.186, total_loss: 40568.489
	[5] recon_loss: 3156.481, pred_loss: 35117.788, total_loss: 38274.270
	[6] recon_loss: 2388.118, pred_loss: 34161.463, total_loss: 36549.580
	[7] recon_loss: 1819.737, pred_loss: 33498.450, total_loss: 35318.187
	[8] recon_loss: 1403.374, pred_loss: 33091.804, total_loss: 34495.179
	[9] recon_loss: 1141.133, pred_loss: 32747.483, total_loss: 33888.615
	[10] recon_loss: 985.588, pred_los

	[14] recon_loss: 1229.982, pred_loss: 32368.523, total_loss: 33598.505
	[15] recon_loss: 1223.410, pred_loss: 32235.319, total_loss: 33458.729
	[16] recon_loss: 1233.465, pred_loss: 32045.407, total_loss: 33278.872
	[17] recon_loss: 1235.767, pred_loss: 31918.920, total_loss: 33154.687
Start prediction
	Using model 1 to predict
	Using model 2 to predict
	Using model 3 to predict
	Using model 4 to predict
	Using model 5 to predict


In [5]:
corr_res = calc_corr(raw_spatial_df, all_pred_res, raw_shared_gene)
print(np.mean(corr_res))
print(np.median(corr_res))

0.2275347768890717
0.21000466117482736


This demo notebook is conducted on a GeForce GTX 1080 GPU.

We note that the results may be slightly different on other GPUs even with the same version of all Python packages.

For example, the results obtained on a GeForce GTX 1080 Ti GPU are:

mean: 0.2275347768890717, median: 0.21000466117482736

In [None]:
# the results should be
# mean: 0.2269573212116094
# median: 0.20835767062356317