In [23]:
import pandas as pd

from models.deconv_model import HIDF
import numpy as np
import torch

# load model parameter
save_path = 'seqFISH3000'
sc_rna_proto_matrix = np.zeros(shape=(1691, 2221)) # proto_gene_matrix shape
st_rna_matrix = np.zeros(shape=(71, 2221)) # st_gene_matrix shape
proto_number = sc_rna_proto_matrix.shape[0]
st_number = st_rna_matrix.shape[0]
proto_cell_type_matrix = np.zeros(shape=(1691, 6)) # proto_cell_type_matrix shape
deconv_model = HIDF(sc_rna_proto_matrix,
                      st_rna_matrix,
                      proto_number,
                      st_number,
                      proto_cell_type_matrix)
saved_model = torch.load(f'{save_path}/trained_hbc.pt')
deconv_model.load_state_dict(saved_model)


<All keys matched successfully>

In [24]:
# loading mapping matrix
mapping_matrix = torch.softmax(deconv_model.mapping_matrix, dim=0)
mapping_matrix = torch.transpose(mapping_matrix, dim0=0, dim1=1)
print(mapping_matrix.shape)

torch.Size([71, 1691])


In [25]:
# generated reconstructed gene expression matrix
st_rec_gene_matrix = torch.matmul(mapping_matrix, deconv_model.proto_gene_matrix)
st_rec_gene_matrix = torch.log(st_rec_gene_matrix + 1)
st_rec_gene_matrix =  torch.nn.functional.softplus(deconv_model.gene_offset_parameter) + torch.nn.functional.softplus(deconv_model.st_offset_parameter) + st_rec_gene_matrix

print(f'rec st gene matrix:{st_rec_gene_matrix.shape}')
proto_cell_type_matrix = deconv_model.proto_cell_type_matrix
base_rec = st_rec_gene_matrix

rec st gene matrix:torch.Size([71, 2221])


In [37]:
from models.utils import check_anndata
cell_type_key = 'cell_type'
sc_reference_path = "../datasets/seqFISH/single/seqFISH_sc.h5ad"
sc_rna_origin_adata = check_anndata(sc_reference_path, True)
cell_type_set_list = sc_rna_origin_adata.obs[cell_type_key].unique().tolist()
cell_type_set_list.sort()
print(cell_type_set_list)
mask_mapping_matrix = mapping_matrix.clone()
target_cell_type_index = 2
target_cell_type = cell_type_set_list[target_cell_type_index]
print(cell_type_set_list[target_cell_type_index])
mask_mapping_matrix[:, deconv_model.proto_cell_type_matrix[:, target_cell_type_index]==1] = 0
# mask target cell type weights
st_mask_rec_gene_matrix = torch.matmul(mask_mapping_matrix, deconv_model.proto_gene_matrix)
st_mask_rec_gene_matrix = torch.log(st_mask_rec_gene_matrix + 1)
st_mask_rec_gene_matrix =  torch.nn.functional.softplus(deconv_model.gene_offset_parameter) + torch.nn.functional.softplus(deconv_model.st_offset_parameter) + st_mask_rec_gene_matrix
sensitive_mse = (base_rec - st_mask_rec_gene_matrix) ** 2
# calculated the sensitive of target cell type


Data matrix:
(1691, 19972)
  (0, 1)	3
  (0, 2)	3
  (0, 4)	1
  (0, 7)	11
  (0, 8)	1
  (0, 12)	1
  (0, 14)	3
  (0, 17)	9
  (0, 19)	5
  (0, 20)	13
  (0, 22)	3
  (0, 23)	7
  (0, 24)	28
  (0, 27)	2
  (0, 29)	22
  (0, 33)	1
  (0, 35)	1
  (0, 36)	2
  (0, 37)	116
  (0, 41)	7
  (0, 43)	3
  (0, 45)	18
  (0, 46)	1
  (0, 48)	1
  (0, 50)	4
  :	:
  (1690, 19326)	1
  (1690, 19335)	2
  (1690, 19348)	1
  (1690, 19357)	1
  (1690, 19362)	1
  (1690, 19379)	1
  (1690, 19388)	1
  (1690, 19396)	1
  (1690, 19500)	3
  (1690, 19517)	1
  (1690, 19547)	3
  (1690, 19554)	1
  (1690, 19582)	1
  (1690, 19611)	1
  (1690, 19614)	1
  (1690, 19623)	1
  (1690, 19648)	1
  (1690, 19657)	1
  (1690, 19732)	1
  (1690, 19795)	1
  (1690, 19810)	1
  (1690, 19832)	1
  (1690, 19876)	1
  (1690, 19912)	1
  (1690, 19920)	2
Data obs:
       cell_type
0        iNeuron
1        iNeuron
2        iNeuron
3        iNeuron
4        iNeuron
...          ...
1686  endo.mural
1687  endo.mural
1688  endo.mural
1689  endo.mural
1690  endo.mural



In [38]:
# select key gene via sensitive analysis
sensitive_mse_np = sensitive_mse.detach().cpu().numpy()
mse_np = np.mean(sensitive_mse_np, axis=0)
index_arr = np.argsort(mse_np)
# sorted gene sensitive
var_names = pd.read_csv(f"{save_path}/st_var_names.csv", index_col=0)
gene_names = var_names['var_names'][index_arr].tolist()
print(gene_names)
# These gene were sorted based on sensitive value
# The higher the ranking of genes, the greater the impact on that cell type

['st8sia6', 'agtr2', 'sp9', 'kcnk5', 'prkag3', 'baiap3', 'clrn1', 'tnfsf18', 'casr', 'grxcr2', 'dmrta1', 'mfng', 'slc34a3', 'cyb5r2', 'chat', 'adam3', 'cd28', 'slc15a3', 'tlr5', 'wnt5b', 'npas1', 'timd4', 'phf11d', 'uba7', 'piezo2', 'bcl3', 'nos1', 'ucp2', 'lox', 'nfatc1', 'krt73', 'cybrd1', 'dnaic1', 'abca8a', 'sash3', 'tmem154', 'alox12', 'adam1b', 'pbx3', 'il21r', 'bmf', 'zfp488', 'abi3', 'retsat', 'pla2g3', 'rin3', 'lfng', 'emilin2', 'enpp1', 'pde3a', 'lrig3', 'prdm1', 'rem1', 'klhdc7a', 'slc37a2', 'maff', 'ace2', 'flt4', 'ccrl2', 'sox7', 'nlrc4', 'slc35f2', 'fzd10', 'vash2', 'osr1', 'erbb4', 'notum', 'dennd1c', 'smad6', 'sfrp2', 'themis2', 'cybb', 'rnf43', 'cysltr1', 'mpeg1', 'hs3st3b1', 'prss16', 'slc13a3', 'ddc', 'myo1f', 'kif19a', 'tfcp2l1', 'gna12', 'gjd2', 'madcam1', 'ppp1r3g', 'tor4a', 'shroom3', 'fam114a1', 'osgin1', 'slc35d3', 'cmklr1', 'cyp27a1', 'arhgef6', 'col4a3', 'npsr1', 'vstm4', 'col11a2', 'ehd2', 'pkd2l1', 'arhgef5', 'erbb3', 'bcam', 'p4ha3', 'rab37', 'tbc1d8b', 't

In [39]:
# Based on these sorted gene list, we can do lots of downstream task
# Such as Enrichment analysis
import gseapy as gp

try:
    enr = gp.enrichr(gene_list=gene_names[-100:],
                     organism='mouse',
                     # gene_sets=['KEGG_2019_Mouse'],
                     gene_sets=['GO_Cellular_Component_2023',],
                     # gene_sets=['KEGG_2019_Mouse',
                     #            'Reactome_Pathways_2024'],
                     outdir=f'{save_path}/{target_cell_type}', top_term=10)
    result = enr.results.head(10)
    print(result)
except Exception as e:
    print(e)
    print('error')

try:
    enr = gp.enrichr(gene_list=gene_names[-100:],
                     organism='mouse',
                     # gene_sets=['KEGG_2019_Mouse'],
                     gene_sets=[
                                'GO_Biological_Process_2023',
                                ],
                     # gene_sets=['KEGG_2019_Mouse',
                     #            'Reactome_Pathways_2024'],
                     outdir=f'{save_path}/{target_cell_type}', top_term=10)
    result = enr.results.head(10)
    print(result)
except Exception as e:
    print(e)
    print('error')

try:
    enr = gp.enrichr(gene_list=gene_names[-100:],
                     organism='mouse',
                     # gene_sets=['KEGG_2019_Mouse'],
                     gene_sets=[
                                'GO_Molecular_Function_2023'],
                     # gene_sets=['KEGG_2019_Mouse',
                     #            'Reactome_Pathways_2024'],
                     outdir=f'{save_path}/{target_cell_type}', top_term=10)
    result = enr.results.head(10)
    print(result)
except Exception as e:
    print(e)
    print('error')