## Demo Notebook for training a DGAT model.

Here, we trained a DGAT model using only one sample, which is Tonsil dataset. You can replace this with your own dataset. The dataset should be in the format of AnnData object with RNA and protein data. After that we predicted the protein expression from the RNA data of Lymph Node in the **Demo_Predict_2.ipynb**.

Please replace the **sample_list**, **model_save_dir**, **pyg_save_dir** and **dataset_save_dir** with your actual paths

In [1]:
import numpy as np

import scanpy as sc
import torch
import warnings
warnings.filterwarnings('ignore')

from utils.Preprocessing import qc_control_cytassist, normalize, clean_protein_names, preprocess_train_list
from Model.Train_and_Predict import train

import random
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

sample_list = ['Tonsil']

dataset_save_dir = './DGAT_training_datasets'

pyg_save_dir = './pyg_data' # Any dir that you want. Building a graph might take 3 mins or more (depends on the scale of the dataset).This directory is used to save the graph data in PyG format, which can be reused for training.

model_save_dir = './DGAT_pretrained_models'# Directory to save the trained DGAT model.


os.makedirs(model_save_dir, exist_ok=True)

### Loading dataset
You can replace this with your own dataset. The dataset should be in the format of AnnData object with RNA and protein data.

In [2]:
adata_list = []
pdata_list = []
for sample in sample_list:
    print(sample)
    adata = sc.read_h5ad(f'{dataset_save_dir}/{sample}_RNA.h5ad')
    pdata = sc.read_h5ad(f'{dataset_save_dir}/{sample}_ADT.h5ad')
    adata_list.append(adata)
    pdata_list.append(pdata)


Tonsil


### Preprocessing
In this step, we will perform quality control, normalization, and finding the common genes and proteins across the datasets.


In [3]:
# Performs quality control and normalization on the training datasets
common_gene, common_protein = preprocess_train_list(adata_list, pdata_list)

Common genes before QC: 18085
Common proteins before QC: 35
Num of encoding genes: 30
Common genes after QC: 17434
Common proteins after QC: 31
Common gene names saved to common_gene_17434.txt
Common protein names saved to common_protein_31.txt


In [4]:
adata_list = [adata[:,common_gene] for adata in adata_list]
pdata_list = [pdata[:,common_protein] for pdata in pdata_list]

### Training



In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Trains the DGAT model using the preprocessed training datasets. It takes in the list of processed ST datasets and protein datasets, and the directory to save the processed graph data (for more efficient reproduction).
model_components = train(adata_list, pdata_list, pyg_save_dir)

model_save_path = f"{model_save_dir}/{len(common_gene)}_gene_{len(common_protein)}_protein"
print(model_save_path)
os.makedirs(model_save_path, exist_ok=True)
torch.save(model_components['mRNA_encoder'].state_dict(), model_save_path+f'/encoder_mRNA.pth')
torch.save(model_components['mRNA_decoder'].state_dict(), model_save_path+f'/decoder_mRNA.pth')
torch.save(model_components['protein_encoder'].state_dict(), model_save_path+f'/encoder_protein.pth')
torch.save(model_components['protein_decoder'].state_dict(), model_save_path+f'/decoder_protein.pth')



Common genes: 17434
Common proteins: 31
Creating or loading dataset
Creating and saving preprocessed data for sample 'CytAssist_FFPE_Protein_Expression_Human_Tonsil'
Dataset ready
Using single device for both models


                                                          

Epoch [1/100] Total Loss: 12.9259 | mRNA Recon: 1.3340 | Protein Recon: 0.9171 | Alignment: 1.2841 | Protein Pred: 0.9071 | mRNA Pred: 1.3334
Epoch 1: EB evidence = -2.4350


                                                          

Epoch [2/100] Total Loss: 10.9490 | mRNA Recon: 1.3014 | Protein Recon: 0.5432 | Alignment: 1.0218 | Protein Pred: 0.5248 | mRNA Pred: 1.3022
Epoch 2: EB evidence = -1.4304


                                                          

Epoch [3/100] Total Loss: 9.9475 | mRNA Recon: 1.2659 | Protein Recon: 0.4257 | Alignment: 0.7465 | Protein Pred: 0.3922 | mRNA Pred: 1.2692
Epoch 3: EB evidence = -0.2269


                                                          

Epoch [4/100] Total Loss: 9.3565 | mRNA Recon: 1.2305 | Protein Recon: 0.3731 | Alignment: 0.6035 | Protein Pred: 0.3312 | mRNA Pred: 1.2340
Epoch 4: EB evidence = 4.7506


                                                          

Epoch [5/100] Total Loss: 8.9712 | mRNA Recon: 1.1945 | Protein Recon: 0.3443 | Alignment: 0.5374 | Protein Pred: 0.3061 | mRNA Pred: 1.1990
Epoch 5: EB evidence = 0.3355


                                                          

Epoch [6/100] Total Loss: 8.6864 | mRNA Recon: 1.1595 | Protein Recon: 0.3345 | Alignment: 0.5004 | Protein Pred: 0.2967 | mRNA Pred: 1.1639
Epoch 6: EB evidence = 0.5407


                                                          

Epoch [7/100] Total Loss: 8.4378 | mRNA Recon: 1.1261 | Protein Recon: 0.3299 | Alignment: 0.4750 | Protein Pred: 0.2908 | mRNA Pred: 1.1301
Epoch 7: EB evidence = 0.9744


                                                          

Epoch [8/100] Total Loss: 8.2083 | mRNA Recon: 1.0956 | Protein Recon: 0.3235 | Alignment: 0.4559 | Protein Pred: 0.2839 | mRNA Pred: 1.0996
Epoch 8: EB evidence = 1.0267


                                                          

Epoch [9/100] Total Loss: 8.0106 | mRNA Recon: 1.0695 | Protein Recon: 0.3167 | Alignment: 0.4400 | Protein Pred: 0.2778 | mRNA Pred: 1.0730
Epoch 9: EB evidence = 0.1994


                                                           

Epoch [10/100] Total Loss: 7.8445 | mRNA Recon: 1.0489 | Protein Recon: 0.3077 | Alignment: 0.4238 | Protein Pred: 0.2722 | mRNA Pred: 1.0517
Epoch 10: Learning rate reduced.
Epoch 10: EB evidence = 0.9638


                                                           

Epoch [11/100] Total Loss: 7.7180 | mRNA Recon: 1.0352 | Protein Recon: 0.2977 | Alignment: 0.4086 | Protein Pred: 0.2661 | mRNA Pred: 1.0372
Epoch 11: EB evidence = 0.8752


                                                           

Epoch [12/100] Total Loss: 7.6501 | mRNA Recon: 1.0293 | Protein Recon: 0.2895 | Alignment: 0.3963 | Protein Pred: 0.2624 | mRNA Pred: 1.0306
Epoch 12: EB evidence = 0.7097


                                                           

Epoch [13/100] Total Loss: 7.5999 | mRNA Recon: 1.0264 | Protein Recon: 0.2819 | Alignment: 0.3831 | Protein Pred: 0.2585 | mRNA Pred: 1.0273
Epoch 13: EB evidence = 0.9485


                                                           

Epoch [14/100] Total Loss: 7.5657 | mRNA Recon: 1.0255 | Protein Recon: 0.2753 | Alignment: 0.3709 | Protein Pred: 0.2553 | mRNA Pred: 1.0261
Epoch 14: EB evidence = 0.9427


                                                           

Epoch [15/100] Total Loss: 7.5359 | mRNA Recon: 1.0252 | Protein Recon: 0.2699 | Alignment: 0.3597 | Protein Pred: 0.2515 | mRNA Pred: 1.0256
Epoch 15: EB evidence = 0.8094


                                                           

Epoch [16/100] Total Loss: 7.5045 | mRNA Recon: 1.0245 | Protein Recon: 0.2652 | Alignment: 0.3495 | Protein Pred: 0.2476 | mRNA Pred: 1.0249
Epoch 16: EB evidence = 0.7354


                                                           

Epoch [17/100] Total Loss: 7.4654 | mRNA Recon: 1.0227 | Protein Recon: 0.2592 | Alignment: 0.3392 | Protein Pred: 0.2434 | mRNA Pred: 1.0232
Epoch 17: EB evidence = 0.9145


                                                           

Epoch [18/100] Total Loss: 7.4232 | mRNA Recon: 1.0201 | Protein Recon: 0.2543 | Alignment: 0.3302 | Protein Pred: 0.2391 | mRNA Pred: 1.0209
Epoch 18: EB evidence = 0.9142


                                                           

Epoch [19/100] Total Loss: 7.3755 | mRNA Recon: 1.0171 | Protein Recon: 0.2482 | Alignment: 0.3217 | Protein Pred: 0.2339 | mRNA Pred: 1.0182
Epoch 19: EB evidence = 0.8805


                                                           

Epoch [20/100] Total Loss: 7.3311 | mRNA Recon: 1.0144 | Protein Recon: 0.2413 | Alignment: 0.3143 | Protein Pred: 0.2293 | mRNA Pred: 1.0156
Epoch 20: Learning rate reduced.
Epoch 20: EB evidence = 1.0575


                                                           

Epoch [21/100] Total Loss: 7.2907 | mRNA Recon: 1.0122 | Protein Recon: 0.2357 | Alignment: 0.3075 | Protein Pred: 0.2244 | mRNA Pred: 1.0134
Epoch 21: EB evidence = 1.0429


                                                           

Epoch [22/100] Total Loss: 7.2605 | mRNA Recon: 1.0107 | Protein Recon: 0.2315 | Alignment: 0.3010 | Protein Pred: 0.2208 | mRNA Pred: 1.0119
Epoch 22: EB evidence = -inf


                                                           

Epoch [23/100] Total Loss: 7.2323 | mRNA Recon: 1.0097 | Protein Recon: 0.2269 | Alignment: 0.2962 | Protein Pred: 0.2167 | mRNA Pred: 1.0108
Epoch 23: EB evidence = 0.8810


                                                           

Epoch [24/100] Total Loss: 7.2078 | mRNA Recon: 1.0088 | Protein Recon: 0.2225 | Alignment: 0.2908 | Protein Pred: 0.2135 | mRNA Pred: 1.0098
Epoch 24: EB evidence = 1.0563


                                                           

Epoch [25/100] Total Loss: 7.1843 | mRNA Recon: 1.0082 | Protein Recon: 0.2181 | Alignment: 0.2859 | Protein Pred: 0.2101 | mRNA Pred: 1.0092
Epoch 25: EB evidence = 0.9710


                                                           

Epoch [26/100] Total Loss: 7.1607 | mRNA Recon: 1.0076 | Protein Recon: 0.2131 | Alignment: 0.2808 | Protein Pred: 0.2066 | mRNA Pred: 1.0086
Epoch 26: EB evidence = 0.9592


                                                           

Epoch [27/100] Total Loss: 7.1373 | mRNA Recon: 1.0070 | Protein Recon: 0.2084 | Alignment: 0.2756 | Protein Pred: 0.2033 | mRNA Pred: 1.0081
Epoch 27: EB evidence = 0.9625


                                                           

Epoch [28/100] Total Loss: 7.1140 | mRNA Recon: 1.0064 | Protein Recon: 0.2028 | Alignment: 0.2709 | Protein Pred: 0.2002 | mRNA Pred: 1.0076
Epoch 28: EB evidence = 0.9776


                                                           

Epoch [29/100] Total Loss: 7.0919 | mRNA Recon: 1.0057 | Protein Recon: 0.1983 | Alignment: 0.2663 | Protein Pred: 0.1972 | mRNA Pred: 1.0071
Epoch 29: EB evidence = 0.9467


                                                           

Epoch [30/100] Total Loss: 7.0685 | mRNA Recon: 1.0050 | Protein Recon: 0.1937 | Alignment: 0.2624 | Protein Pred: 0.1936 | mRNA Pred: 1.0065
Epoch 30: Learning rate reduced.
Epoch 30: EB evidence = 1.0484


                                                           

Epoch [31/100] Total Loss: 7.0515 | mRNA Recon: 1.0042 | Protein Recon: 0.1905 | Alignment: 0.2585 | Protein Pred: 0.1918 | mRNA Pred: 1.0059
Epoch 31: EB evidence = 0.9746


                                                           

Epoch [32/100] Total Loss: 7.0351 | mRNA Recon: 1.0036 | Protein Recon: 0.1875 | Alignment: 0.2553 | Protein Pred: 0.1896 | mRNA Pred: 1.0054
Epoch 32: EB evidence = 0.9897


                                                           

Epoch [33/100] Total Loss: 7.0202 | mRNA Recon: 1.0030 | Protein Recon: 0.1853 | Alignment: 0.2525 | Protein Pred: 0.1876 | mRNA Pred: 1.0048
Epoch 33: EB evidence = 0.9900


                                                           

Epoch [34/100] Total Loss: 7.0077 | mRNA Recon: 1.0024 | Protein Recon: 0.1838 | Alignment: 0.2497 | Protein Pred: 0.1860 | mRNA Pred: 1.0042
Epoch 34: EB evidence = 0.9926


                                                           

Epoch [35/100] Total Loss: 6.9932 | mRNA Recon: 1.0018 | Protein Recon: 0.1821 | Alignment: 0.2466 | Protein Pred: 0.1840 | mRNA Pred: 1.0036
Epoch 35: EB evidence = -inf


                                                           

Epoch [36/100] Total Loss: 6.9807 | mRNA Recon: 1.0012 | Protein Recon: 0.1813 | Alignment: 0.2437 | Protein Pred: 0.1822 | mRNA Pred: 1.0031
Epoch 36: EB evidence = 0.9706


                                                           

Epoch [37/100] Total Loss: 6.9653 | mRNA Recon: 1.0007 | Protein Recon: 0.1797 | Alignment: 0.2404 | Protein Pred: 0.1798 | mRNA Pred: 1.0024
Epoch 37: EB evidence = 0.9878


                                                           

Epoch [38/100] Total Loss: 6.9556 | mRNA Recon: 1.0003 | Protein Recon: 0.1789 | Alignment: 0.2380 | Protein Pred: 0.1785 | mRNA Pred: 1.0020
Epoch 38: EB evidence = 0.9852


                                                           

Epoch [39/100] Total Loss: 6.9427 | mRNA Recon: 0.9998 | Protein Recon: 0.1771 | Alignment: 0.2353 | Protein Pred: 0.1767 | mRNA Pred: 1.0015
Epoch 39: EB evidence = 0.9905


                                                           

Epoch [40/100] Total Loss: 6.9309 | mRNA Recon: 0.9994 | Protein Recon: 0.1757 | Alignment: 0.2325 | Protein Pred: 0.1750 | mRNA Pred: 1.0010
Epoch 40: Learning rate reduced.
Epoch 40: EB evidence = 0.9823


                                                           

Epoch [41/100] Total Loss: 6.9195 | mRNA Recon: 0.9989 | Protein Recon: 0.1739 | Alignment: 0.2302 | Protein Pred: 0.1734 | mRNA Pred: 1.0005
Epoch 41: EB evidence = 0.9902


                                                           

Epoch [42/100] Total Loss: 6.9097 | mRNA Recon: 0.9986 | Protein Recon: 0.1733 | Alignment: 0.2282 | Protein Pred: 0.1717 | mRNA Pred: 1.0002
Epoch 42: EB evidence = 0.9880


                                                           

Epoch [43/100] Total Loss: 6.9015 | mRNA Recon: 0.9981 | Protein Recon: 0.1719 | Alignment: 0.2265 | Protein Pred: 0.1709 | mRNA Pred: 0.9998
Epoch 43: EB evidence = 0.9920


                                                           

Epoch [44/100] Total Loss: 6.8921 | mRNA Recon: 0.9977 | Protein Recon: 0.1708 | Alignment: 0.2244 | Protein Pred: 0.1697 | mRNA Pred: 0.9994
Epoch 44: EB evidence = 0.9760


                                                           

Epoch [45/100] Total Loss: 6.8818 | mRNA Recon: 0.9973 | Protein Recon: 0.1700 | Alignment: 0.2228 | Protein Pred: 0.1679 | mRNA Pred: 0.9990
Epoch 45: EB evidence = 0.9757


                                                           

Epoch [46/100] Total Loss: 6.8733 | mRNA Recon: 0.9968 | Protein Recon: 0.1687 | Alignment: 0.2207 | Protein Pred: 0.1671 | mRNA Pred: 0.9986
Epoch 46: EB evidence = 0.9917


                                                           

Epoch [47/100] Total Loss: 6.8642 | mRNA Recon: 0.9963 | Protein Recon: 0.1676 | Alignment: 0.2188 | Protein Pred: 0.1660 | mRNA Pred: 0.9981
Epoch 47: EB evidence = 1.0051


                                                           

Epoch [48/100] Total Loss: 6.8558 | mRNA Recon: 0.9959 | Protein Recon: 0.1672 | Alignment: 0.2173 | Protein Pred: 0.1647 | mRNA Pred: 0.9978
Epoch 48: EB evidence = 1.0267


                                                           

Epoch [49/100] Total Loss: 6.8481 | mRNA Recon: 0.9954 | Protein Recon: 0.1660 | Alignment: 0.2158 | Protein Pred: 0.1640 | mRNA Pred: 0.9973
Epoch 49: EB evidence = 0.9778


                                                           

Epoch [50/100] Total Loss: 6.8377 | mRNA Recon: 0.9949 | Protein Recon: 0.1652 | Alignment: 0.2141 | Protein Pred: 0.1624 | mRNA Pred: 0.9969
Epoch 50: Learning rate reduced.
Epoch 50: EB evidence = 1.0086


                                                           

Epoch [51/100] Total Loss: 6.8277 | mRNA Recon: 0.9943 | Protein Recon: 0.1639 | Alignment: 0.2123 | Protein Pred: 0.1611 | mRNA Pred: 0.9965
Epoch 51: EB evidence = 0.9553
--> EB early stopping at epoch 51 (mean_evidence=0.9897, threshold=0.96)
./DGAT_models/17434_gene_31_protein
