# stPlus

stPlus is a reference-based method for the enhancement of spatial transcriptomics. Leveraging the holistic information in reference scRNA-seq data but not limited to the genes shared with spatial data, stPlus performs non-linear embedding for cells in both datasets and effectively predicts unmeasured spatial gene expression.

The **required inputs** of stPlus including

* **spatial_df**:       normalized and logarithmized original spatial data
* **scrna_df**:         normalized and logarithmized reference scRNA-seq data
* **genes_to_predict**: spatial genes to be predicted

The **output** of stPlus is

* **stPlus_res**:       predicted spatial transcriptomics data

Import the package of stPlus

The frequently used packages such as pandas (pd), numpy (np) and torch will be automatically imported along with stPlus.

In [1]:
from stPlus import *

Check the fundamentals of stPlus via

In [2]:
help(stPlus)

Help on function stPlus in module stPlus.model:

stPlus(spatial_df, scrna_df, genes_to_predict, save_path_prefix='./stPlus', top_k=3000, t_min=5, data_quality=None, random_seed=None, verbose=True, converge_ratio=0.004, max_epoch_num=10000, batch_size=512, learning_rate=None, weight_decay=0.0002)
    spatial_df:       [pandas dataframe] normalized and logarithmized original spatial data (cell by gene)
    scrna_df:         [pandas dataframe] normalized and logarithmized reference scRNA-seq data (cell by gene)
    genes_to_predict: [1D numpy array] spatial genes to be predicted
    save_path_prefix: [str] prefix of path of trained t models with minimal loss
    top_k:            [int] number of highly variable genes to use
    t_min:            [int] number of epochs with minimal loss using to ensemble learning
    data_quality:     [float] user-specified or 1 minus the sparsity of scRNA-seq data (default)
    random_seed:      [int] random seed in torch
    verbose:          [bool] disp

# Getting Started stPlus Application

Load the normalized and logarithmized spatial and scRNA-seq data, and the genes to predict

The data can be accessed via

```
git clone git://github.com/xy-chen16/stPlus.git
cd stPlus
tar -zxvf data.tar.gz
```

In [3]:
spatial_df_file = './data/osmFISH_df.csv'
scrna_df_file   = './data/Zeisel_df.csv'
genes_file      = './data/genes_to_predict.txt'

In [4]:
spatial_df = pd.read_csv(spatial_df_file)
scrna_df   = pd.read_csv(scrna_df_file)
genes_to_predict = pd.read_csv(genes_file, header=None).iloc[:,0].values

In [5]:
spatial_df.head()

Unnamed: 0,Gad2,Slc32a1,Crhbp,Cnr1,Vip,Cpne5,Pthlh,Crh,Tbr1,Lamp5,...,Ctps,Anln,Mrc1,Hexb,Ttr,Foxj1,Vtn,Flt1,Apln,Acta2
0,2.777385,2.926465,0.0,1.390133,0.829559,2.119032,0.0,0.621403,2.536136,1.492039,...,2.009762,1.276651,1.001762,2.065889,1.001762,0.829559,0.621403,0.0,1.001762,0.829559
1,3.699013,3.059924,1.123004,0.346999,0.604114,0.346999,0.346999,0.808458,1.462912,0.0,...,1.123004,0.808458,0.346999,0.808458,0.346999,0.346999,0.604114,0.0,0.978048,0.346999
2,3.663039,3.480582,2.220487,0.521361,0.862314,0.0,1.222337,0.521361,2.143296,0.0,...,0.997242,0.862314,0.706299,2.256957,0.294279,0.862314,0.997242,0.0,0.521361,0.706299
3,3.428742,2.682501,0.831733,1.279431,0.499956,0.0,1.184873,0.0,3.079924,0.963793,...,1.279431,0.0,0.963793,1.587579,0.963793,0.280902,0.280902,0.0,0.280902,0.963793
4,2.433613,3.250374,0.0,2.282382,0.955511,0.0,0.0,0.0,1.223775,4.506454,...,1.435085,0.587787,0.955511,1.223775,0.587787,0.587787,0.587787,0.587787,1.223775,0.0


In [6]:
scrna_df.head()

Unnamed: 0,Tspan12,Tshz1,Fnbp1l,Adamts15,Cldn12,Rxfp1,2310042E22Rik,Sema3c,Jam2,Apbb1ip,...,Rab9,Tceanc,Msl3,Arhgap6,Mid1,Vamp7,Tmlhe,Zf12,Kdm5d,Uty
0,0.0,4.942366,4.942366,0.0,3.857929,0.0,0.0,6.236445,3.857929,0.0,...,5.785577,0.0,0.0,0.0,0.0,5.450333,0.0,0.0,0.0,0.0
1,0.0,3.850649,3.850649,0.0,3.850649,0.0,0.0,0.0,0.0,0.0,...,3.850649,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,5.251013,0.0,3.485126,0.0,4.16283,6.674137,3.485126,0.0,...,3.485126,0.0,4.16283,0.0,0.0,4.563094,0.0,0.0,0.0,5.404414
3,4.523832,4.123776,4.808798,0.0,0.0,0.0,4.523832,3.446682,0.0,0.0,...,4.523832,0.0,3.446682,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,4.542944,3.860383,0.0,0.0,0.0,0.0,6.143832,0.0,0.0,...,0.0,0.0,5.230756,0.0,0.0,4.944856,0.0,0.0,0.0,0.0


In [7]:
genes_to_predict

array(['Tesc', 'Pvrl3', 'Grm2'], dtype=object)

Run stPlus

In [8]:
save_path_prefix = './model/stPlus-demo'
stPlus_res = stPlus(spatial_df, scrna_df, genes_to_predict, save_path_prefix, random_seed=10)

Models will be saved in: ./model/stPlus-demo-5min*.pt

Spatial transcriptomics data: 3405 cells * 33 genes
Reference scRNA-seq data:     1691 cells * 15075 genes
3 genes to be predicted

Start initialization
Start embedding
	[1] recon_loss: 15260.905, pred_loss: 84649.188, total_loss: 99910.093
	[2] recon_loss: 9475.398, pred_loss: 51617.205, total_loss: 61092.603
	[3] recon_loss: 7647.690, pred_loss: 46626.822, total_loss: 54274.512
	[4] recon_loss: 6398.403, pred_loss: 44095.514, total_loss: 50493.917
	[5] recon_loss: 5177.927, pred_loss: 42897.739, total_loss: 48075.666
	[6] recon_loss: 4231.157, pred_loss: 41949.444, total_loss: 46180.600
	[7] recon_loss: 3502.348, pred_loss: 41292.738, total_loss: 44795.086
	[8] recon_loss: 3011.403, pred_loss: 40858.082, total_loss: 43869.486
	[9] recon_loss: 2633.043, pred_loss: 40477.712, total_loss: 43110.755
	[10] recon_loss: 2335.652, pred_loss: 40156.559, total_loss: 42492.210
	[11] recon_loss: 2149.345, pred_loss: 39855.450, total_loss: 42

Obtain the following predicted spatial transcriptomics data

In [9]:
stPlus_res.head()

Unnamed: 0,Tesc,Pvrl3,Grm2
0,1.668627,2.048264,0.627106
1,0.720087,2.617475,0.284386
2,0.932298,2.547904,0.46687
3,0.990825,1.857638,0.963752
4,0.091311,2.091799,0.089835


# Reproduction of 5-fold cross validation

Load data

In [10]:
spatial_df_file = './data/osmFISH_df.csv'
scrna_df_file   = './data/Zeisel_df.csv'
raw_spatial_df  = pd.read_csv(spatial_df_file)
raw_scrna_df    = pd.read_csv(scrna_df_file)
print(raw_spatial_df.shape, raw_scrna_df.shape) # cell by gene
raw_shared_gene = np.intersect1d(raw_spatial_df.columns, raw_scrna_df.columns)
print(raw_shared_gene.shape)

(3405, 33) (1691, 15075)
(33,)


In [11]:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=0)
kf.get_n_splits(raw_shared_gene)
torch.manual_seed(10)
idx = 1
for train_ind, test_ind in kf.split(raw_shared_gene):    
    print("\n===== Fold %d =====\nNumber of train genes: %d, Number of test genes: %d"%(idx, len(train_ind), len(test_ind)))
    train_gene = raw_shared_gene[train_ind]
    test_gene  = raw_shared_gene[test_ind]
    test_spatial_df = raw_spatial_df[test_gene]
    spatial_df = raw_spatial_df[train_gene]
    scrna_df   = raw_scrna_df
    
    if idx == 1:
        all_pred_res = pd.DataFrame(np.zeros((spatial_df.shape[0],raw_shared_gene.shape[0])), columns=raw_shared_gene) 
    save_path_prefix = './model/stPlus-demo-fold%d'%(idx)
    stPlus_res = stPlus(spatial_df, scrna_df, test_gene, save_path_prefix)
    all_pred_res[stPlus_res.columns.values] = stPlus_res
    idx += 1


===== Fold 1 =====
Number of train genes: 26, Number of test genes: 7
Models will be saved in: ./model/stPlus-demo-fold1-5min*.pt

Spatial transcriptomics data: 3405 cells * 26 genes
Reference scRNA-seq data:     1691 cells * 15075 genes
7 genes to be predicted

Start initialization
Start embedding
	[1] recon_loss: 11178.695, pred_loss: 69912.709, total_loss: 81091.403
	[2] recon_loss: 6938.868, pred_loss: 43341.697, total_loss: 50280.564
	[3] recon_loss: 5382.792, pred_loss: 38929.662, total_loss: 44312.454
	[4] recon_loss: 4354.111, pred_loss: 36393.928, total_loss: 40748.039
	[5] recon_loss: 3473.185, pred_loss: 35038.585, total_loss: 38511.770
	[6] recon_loss: 2835.064, pred_loss: 34179.996, total_loss: 37015.060
	[7] recon_loss: 2350.828, pred_loss: 33587.845, total_loss: 35938.673
	[8] recon_loss: 1937.088, pred_loss: 33177.183, total_loss: 35114.271
	[9] recon_loss: 1554.804, pred_loss: 32796.420, total_loss: 34351.225
	[10] recon_loss: 1332.759, pred_loss: 32492.870, total_los

	[6] recon_loss: 4994.750, pred_loss: 34898.472, total_loss: 39893.222
	[7] recon_loss: 4360.762, pred_loss: 34344.013, total_loss: 38704.775
	[8] recon_loss: 3907.691, pred_loss: 33956.611, total_loss: 37864.302
	[9] recon_loss: 3613.018, pred_loss: 33646.026, total_loss: 37259.043
	[10] recon_loss: 3375.909, pred_loss: 33437.704, total_loss: 36813.613
	[11] recon_loss: 3133.867, pred_loss: 33228.272, total_loss: 36362.138
	[12] recon_loss: 3018.531, pred_loss: 33024.705, total_loss: 36043.236
	[13] recon_loss: 2949.428, pred_loss: 32852.134, total_loss: 35801.562
	[14] recon_loss: 2913.348, pred_loss: 32657.254, total_loss: 35570.601
	[15] recon_loss: 2878.601, pred_loss: 32508.848, total_loss: 35387.448
	[16] recon_loss: 2868.747, pred_loss: 32349.233, total_loss: 35217.979
	[17] recon_loss: 2857.406, pred_loss: 32187.900, total_loss: 35045.307
	[18] recon_loss: 2861.625, pred_loss: 32069.794, total_loss: 34931.420
Start prediction
	Using model 1 to predict
	Using model 2 to predict

In [12]:
corr_res = calc_corr(raw_spatial_df, all_pred_res, raw_shared_gene)
print(np.mean(corr_res))
print(np.median(corr_res))

0.2269573212116094
0.20835767062356317


This demo notebook is conducted on a GeForce GTX 1080 GPU.

We note that the results may be slightly different on other GPUs even with the same version of all Python packages.

For example, the results obtained on a GeForce GTX 1080 Ti GPU are:

```
mean: 0.2275347768890717
median: 0.21000466117482736
```