# stPlus

stPlus is a reference-based method for the enhancement of spatial transcriptomics. Leveraging the holistic information in reference scRNA-seq data but not limited to the genes shared with spatial data, stPlus performs non-linear embedding for cells in both datasets and effectively predicts unmeasured spatial gene expression.

The **required inputs** of stPlus including

* **spatial_df**:       normalized and logarithmized original spatial data
* **scrna_df**:         normalized and logarithmized reference scRNA-seq data
* **genes_to_predict**: spatial genes to be predicted

The **output** of stPlus is

* **stPlus_res**:       predicted spatial transcriptomics data

Import the package of stPlus

The frequently used packages such as pandas (pd), numpy (np) and torch will be automatically imported along with stPlus.

In [1]:
from stPlus import *

Check the fundamentals of stPlus via

In [2]:
help(stPlus)

Help on function stPlus in module stPlus.model:

stPlus(spatial_df, scrna_df, genes_to_predict, save_path_prefix='./stPlus', top_k=2000, t_min=5, data_quality=None, random_seed=None, verbose=True, n_neighbors=50, converge_ratio=0.004, max_epoch_num=10000, batch_size=512, learning_rate=None, weight_decay=0.0002)
    spatial_df:       [pandas dataframe] normalized and logarithmized original spatial data (cell by gene)
    scrna_df:         [pandas dataframe] normalized and logarithmized reference scRNA-seq data (cell by gene)
    genes_to_predict: [1D numpy array] spatial genes to be predicted
    save_path_prefix: [str] prefix of path of trained t models with minimal loss
    top_k:            [int] number of highly variable genes to use
    t_min:            [int] number of epochs with minimal loss using to ensemble learning
    data_quality:     [float] user-specified or 1 minus the sparsity of scRNA-seq data (default)
    random_seed:      [int] random seed in torch
    verbose:     

# Getting Started stPlus Application

Load the normalized and logarithmized spatial and scRNA-seq data, and the genes to predict

The data can be accessed via

```
git clone git://github.com/xy-chen16/stPlus.git
cd stPlus
tar -zxvf data.tar.gz
```

In [3]:
spatial_df_file = './data/osmFISH_df.csv'
scrna_df_file   = './data/Zeisel_df.csv'
genes_file      = './data/genes_to_predict.txt'

In [4]:
spatial_df = pd.read_csv(spatial_df_file)
scrna_df   = pd.read_csv(scrna_df_file)
genes_to_predict = pd.read_csv(genes_file, header=None).iloc[:,0].values

In [5]:
spatial_df.head()

Unnamed: 0,Gad2,Slc32a1,Crhbp,Cnr1,Vip,Cpne5,Pthlh,Crh,Tbr1,Lamp5,...,Ctps,Anln,Mrc1,Hexb,Ttr,Foxj1,Vtn,Flt1,Apln,Acta2
0,2.777385,2.926465,0.0,1.390133,0.829559,2.119032,0.0,0.621403,2.536136,1.492039,...,2.009762,1.276651,1.001762,2.065889,1.001762,0.829559,0.621403,0.0,1.001762,0.829559
1,3.699013,3.059924,1.123004,0.346999,0.604114,0.346999,0.346999,0.808458,1.462912,0.0,...,1.123004,0.808458,0.346999,0.808458,0.346999,0.346999,0.604114,0.0,0.978048,0.346999
2,3.663039,3.480582,2.220487,0.521361,0.862314,0.0,1.222337,0.521361,2.143296,0.0,...,0.997242,0.862314,0.706299,2.256957,0.294279,0.862314,0.997242,0.0,0.521361,0.706299
3,3.428742,2.682501,0.831733,1.279431,0.499956,0.0,1.184873,0.0,3.079924,0.963793,...,1.279431,0.0,0.963793,1.587579,0.963793,0.280902,0.280902,0.0,0.280902,0.963793
4,2.433613,3.250374,0.0,2.282382,0.955511,0.0,0.0,0.0,1.223775,4.506454,...,1.435085,0.587787,0.955511,1.223775,0.587787,0.587787,0.587787,0.587787,1.223775,0.0


In [6]:
scrna_df.head()

Unnamed: 0,Tspan12,Tshz1,Fnbp1l,Adamts15,Cldn12,Rxfp1,2310042E22Rik,Sema3c,Jam2,Apbb1ip,...,Rab9,Tceanc,Msl3,Arhgap6,Mid1,Vamp7,Tmlhe,Zf12,Kdm5d,Uty
0,0.0,4.942366,4.942366,0.0,3.857929,0.0,0.0,6.236445,3.857929,0.0,...,5.785577,0.0,0.0,0.0,0.0,5.450333,0.0,0.0,0.0,0.0
1,0.0,3.850649,3.850649,0.0,3.850649,0.0,0.0,0.0,0.0,0.0,...,3.850649,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,5.251013,0.0,3.485126,0.0,4.16283,6.674137,3.485126,0.0,...,3.485126,0.0,4.16283,0.0,0.0,4.563094,0.0,0.0,0.0,5.404414
3,4.523832,4.123776,4.808798,0.0,0.0,0.0,4.523832,3.446682,0.0,0.0,...,4.523832,0.0,3.446682,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,4.542944,3.860383,0.0,0.0,0.0,0.0,6.143832,0.0,0.0,...,0.0,0.0,5.230756,0.0,0.0,4.944856,0.0,0.0,0.0,0.0


In [7]:
genes_to_predict

array(['Tesc', 'Pvrl3', 'Grm2'], dtype=object)

Run stPlus

In [8]:
save_path_prefix = './model/stPlus-demo'
stPlus_res = stPlus(spatial_df, scrna_df, genes_to_predict, save_path_prefix, random_seed=10)

Models will be saved in: ./model/stPlus-demo-5min*.pt

Spatial transcriptomics data: 3405 cells * 33 genes
Reference scRNA-seq data:     1691 cells * 15075 genes
3 genes to be predicted

Start initialization
Start embedding
	[1] recon_loss: 17203.548, pred_loss: 90388.025, total_loss: 107591.573
	[2] recon_loss: 10213.595, pred_loss: 54336.520, total_loss: 64550.114
	[3] recon_loss: 8357.641, pred_loss: 49231.158, total_loss: 57588.799
	[4] recon_loss: 7236.344, pred_loss: 46241.893, total_loss: 53478.237
	[5] recon_loss: 6078.648, pred_loss: 44857.473, total_loss: 50936.121
	[6] recon_loss: 5135.076, pred_loss: 43864.600, total_loss: 48999.676
	[7] recon_loss: 4502.876, pred_loss: 43209.645, total_loss: 47712.521
	[8] recon_loss: 4111.535, pred_loss: 42792.704, total_loss: 46904.238
	[9] recon_loss: 3823.475, pred_loss: 42349.922, total_loss: 46173.397
	[10] recon_loss: 3590.203, pred_loss: 41935.582, total_loss: 45525.785
	[11] recon_loss: 3453.445, pred_loss: 41591.686, total_loss: 

Obtain the following predicted spatial transcriptomics data

In [9]:
stPlus_res.head()

Unnamed: 0,Tesc,Pvrl3,Grm2
0,1.314772,1.774769,1.164707
1,0.695193,2.171672,0.468714
2,0.544767,2.187255,0.377925
3,0.925574,2.159383,0.873194
4,0.091284,2.185003,0.089993


# Reproduction of 5-fold cross-validation

Load data

In [10]:
spatial_df_file = './data/osmFISH_df.csv'
scrna_df_file   = './data/Zeisel_df.csv'
raw_spatial_df  = pd.read_csv(spatial_df_file)
raw_scrna_df    = pd.read_csv(scrna_df_file)
print(raw_spatial_df.shape, raw_scrna_df.shape) # cell by gene
raw_shared_gene = np.intersect1d(raw_spatial_df.columns, raw_scrna_df.columns)
print(raw_shared_gene.shape)

(3405, 33) (1691, 15075)
(33,)


In [11]:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=0)
kf.get_n_splits(raw_shared_gene)
torch.manual_seed(10)
idx = 1
for train_ind, test_ind in kf.split(raw_shared_gene):    
    print("\n===== Fold %d =====\nNumber of train genes: %d, Number of test genes: %d"%(idx, len(train_ind), len(test_ind)))
    train_gene = raw_shared_gene[train_ind]
    test_gene  = raw_shared_gene[test_ind]
    test_spatial_df = raw_spatial_df[test_gene]
    spatial_df = raw_spatial_df[train_gene]
    scrna_df   = raw_scrna_df
    
    if idx == 1:
        all_pred_res = pd.DataFrame(np.zeros((spatial_df.shape[0],raw_shared_gene.shape[0])), columns=raw_shared_gene) 
    save_path_prefix = './model/stPlus-demo-fold%d'%(idx)
    stPlus_res = stPlus(spatial_df, scrna_df, test_gene, save_path_prefix)
    all_pred_res[stPlus_res.columns.values] = stPlus_res
    idx += 1


===== Fold 1 =====
Number of train genes: 26, Number of test genes: 7
Models will be saved in: ./model/stPlus-demo-fold1-5min*.pt

Spatial transcriptomics data: 3405 cells * 26 genes
Reference scRNA-seq data:     1691 cells * 15075 genes
7 genes to be predicted

Start initialization
Start embedding
	[1] recon_loss: 11739.449, pred_loss: 74647.239, total_loss: 86386.688
	[2] recon_loss: 7977.249, pred_loss: 45591.339, total_loss: 53568.588
	[3] recon_loss: 6315.242, pred_loss: 40697.307, total_loss: 47012.549
	[4] recon_loss: 5338.903, pred_loss: 38085.516, total_loss: 43424.420
	[5] recon_loss: 4521.835, pred_loss: 36605.408, total_loss: 41127.243
	[6] recon_loss: 3885.873, pred_loss: 35562.838, total_loss: 39448.712
	[7] recon_loss: 3434.955, pred_loss: 34985.452, total_loss: 38420.407
	[8] recon_loss: 3090.548, pred_loss: 34581.929, total_loss: 37672.477
	[9] recon_loss: 2873.076, pred_loss: 34217.339, total_loss: 37090.415
	[10] recon_loss: 2734.121, pred_loss: 33895.281, total_los

	[6] recon_loss: 3144.046, pred_loss: 36754.673, total_loss: 39898.719
	[7] recon_loss: 2642.884, pred_loss: 36131.823, total_loss: 38774.706
	[8] recon_loss: 2291.286, pred_loss: 35632.580, total_loss: 37923.866
	[9] recon_loss: 2064.744, pred_loss: 35276.954, total_loss: 37341.699
	[10] recon_loss: 1926.914, pred_loss: 35039.834, total_loss: 36966.749
	[11] recon_loss: 1858.452, pred_loss: 34791.571, total_loss: 36650.023
	[12] recon_loss: 1827.175, pred_loss: 34564.738, total_loss: 36391.914
	[13] recon_loss: 1769.312, pred_loss: 34372.338, total_loss: 36141.650
	[14] recon_loss: 1712.696, pred_loss: 34178.038, total_loss: 35890.734
	[15] recon_loss: 1664.979, pred_loss: 34033.348, total_loss: 35698.326
	[16] recon_loss: 1641.416, pred_loss: 33875.078, total_loss: 35516.495
	[17] recon_loss: 1632.032, pred_loss: 33717.411, total_loss: 35349.443
	[18] recon_loss: 1637.649, pred_loss: 33601.531, total_loss: 35239.181
Start prediction
	Using model 1 to predict
	Using model 2 to predict

In [12]:
corr_res = calc_corr(raw_spatial_df, all_pred_res, raw_shared_gene)
print(np.median(corr_res))

0.2191236775005986


Note that there might be a slight difference in the results using different GPUs, even with the same version of all Python packages.