In [8]:
# Loading reference single-cell and target ST adata

import torch

from models.utils import check_anndata, set_seed

# set_seed(0) # better set
# loading sc reference adata
sc_reference_path = "../datasets/seqFISH/single/seqFISH_sc.h5ad"
sc_rna_origin_adata = check_anndata(sc_reference_path, True)
# loading st target adata
st_target_path = f"../datasets/seqFISH//spatial/seqFISH_st3000.h5ad"
st_rna_origin_adata = check_anndata(
        st_target_path, True)

Data matrix:
(1691, 19972)
  (0, 1)	3
  (0, 2)	3
  (0, 4)	1
  (0, 7)	11
  (0, 8)	1
  (0, 12)	1
  (0, 14)	3
  (0, 17)	9
  (0, 19)	5
  (0, 20)	13
  (0, 22)	3
  (0, 23)	7
  (0, 24)	28
  (0, 27)	2
  (0, 29)	22
  (0, 33)	1
  (0, 35)	1
  (0, 36)	2
  (0, 37)	116
  (0, 41)	7
  (0, 43)	3
  (0, 45)	18
  (0, 46)	1
  (0, 48)	1
  (0, 50)	4
  :	:
  (1690, 19326)	1
  (1690, 19335)	2
  (1690, 19348)	1
  (1690, 19357)	1
  (1690, 19362)	1
  (1690, 19379)	1
  (1690, 19388)	1
  (1690, 19396)	1
  (1690, 19500)	3
  (1690, 19517)	1
  (1690, 19547)	3
  (1690, 19554)	1
  (1690, 19582)	1
  (1690, 19611)	1
  (1690, 19614)	1
  (1690, 19623)	1
  (1690, 19648)	1
  (1690, 19657)	1
  (1690, 19732)	1
  (1690, 19795)	1
  (1690, 19810)	1
  (1690, 19832)	1
  (1690, 19876)	1
  (1690, 19912)	1
  (1690, 19920)	2
Data obs:
       cell_type
0        iNeuron
1        iNeuron
2        iNeuron
3        iNeuron
4        iNeuron
...          ...
1686  endo.mural
1687  endo.mural
1688  endo.mural
1689  endo.mural
1690  endo.mural



In [9]:
# Preprocessing sc and st adata
from models.utils import filter_genes
import scanpy as sc
# cell type column in anndata
cell_type_key = 'cell_type'
sc_rna_adata, st_rna_adata = \
        filter_genes(sc_rna_adata=sc_rna_origin_adata, st_rna_adata=st_rna_origin_adata, cell_type_key=cell_type_key,
                     n_genes=None, use_deg=True)
sc.pp.normalize_total(st_rna_adata, target_sum=1e4)
sc.pp.normalize_total(sc_rna_adata, target_sum=1e4)
sc.pp.normalize_total(sc_rna_origin_adata, target_sum=1e4)
sc.pp.log1p(sc_rna_origin_adata)

common gene number:2781
deg gene number:2221


In [10]:
# Generate proto tree for HIDF training
from models.BranchBound import generate_proto_generator
resolution = 0.3
save_path = 'seqFISH3000'
proto_generator = generate_proto_generator(sc_rna_adata=sc_rna_adata,
                                           sc_origin_rna_adata=sc_rna_origin_adata,
                                           sc_omics_adata=None,
                                           resolution=resolution,
                                           save_path=save_path, save_bool=False)
depth = proto_generator.calculate_depth()
print(f'depth: {depth}')
# depth of tree, which can be the iterator times
iterator_times = depth

         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
depth: 5


In [11]:
# Start Iterator Training Process

from models.utils import softmax_to_logits_matrix
from models.train import Context
from torch import optim
from models.deconv_trainer import HIDF_Trainer
from torch.utils.data import DataLoader
from models.deconv_dataset import Spatial_Exp_Dataset
from models.deconv_model import HIDF
from sklearn.neighbors import NearestNeighbors
import numpy as np
from models.utils import conver_adata_X_to_numpy
reg_lambda = 1e-2
k = 9
for i in range(iterator_times):
    if i == 0:
        proto_exp_matrix, proto_cell_type_matrix = proto_generator.generate_rna_proto_matrix()
        sc_proto_adata = sc.AnnData(X=proto_exp_matrix)

        cell_type_set_list = proto_generator.current_cell_type_set_list

        device_name = 'cuda:0'
        lr = 1e-3
        epoch = 300
        batch_size = 1024
         # spatial regulization parameters
         # due to platform type

        st_rna_matrix = conver_adata_X_to_numpy(st_rna_adata.X)
        sc_rna_proto_matrix = np.array(sc_proto_adata.X)

        proto_number = sc_rna_proto_matrix.shape[0]
        st_number = st_rna_matrix.shape[0]

        spatial = st_rna_origin_adata.obsm['spatial']
        neighbors = k
        neigh = NearestNeighbors(n_neighbors=neighbors)
        mtx = np.array(spatial)
        neigh.fit(mtx)
        A = neigh.kneighbors_graph(mtx)

        neighbor_index = np.zeros(shape=(st_number, neighbors))
        for i in range(st_number):
            indices = A[i, :].indices
            neighbor_index[i,:] = indices


        deconv_model = HIDF(sc_rna_proto_matrix,
                                          st_rna_matrix,
                                          proto_number,
                                          st_number,
                                          proto_cell_type_matrix)

        st_rna_dataset = Spatial_Exp_Dataset(data=st_rna_matrix, neighbor_index=neighbor_index)
        st_train_loader = DataLoader(st_rna_dataset, batch_size=batch_size, shuffle=True)

        deconv_trainer = HIDF_Trainer(model=deconv_model, train_dataset=st_rna_dataset,
                                      test_dataset=None, continue_train=False,
                                      trained_model_path=None, device_name=device_name, lr=lr,
                                      save_path=save_path)
        deconv_trainer.opt = optim.AdamW([
                                          {'params': deconv_model.gene_offset_parameter, 'lr':0.1},
                                          {'params': deconv_model.st_offset_parameter, 'lr':0.1},
                                          {'params':deconv_model.mapping_matrix, 'lr':0.1}],
                                          lr=lr)
        ctx = Context(epoch=epoch, batch_size=batch_size, save_model_path=None, random_seed=None)
        ctx.st_train_loader = st_train_loader
        ctx.pre_cell_type_matrix = None
        ctx.reg_lambda = reg_lambda
        ctx.constrain_loss_list = []
        ctx.regular_loss_list = []
        ctx.rec_gene_loss_list = []
        deconv_trainer.train(ctx)

        target_dataloader = DataLoader(st_rna_dataset, batch_size=batch_size, shuffle=False)
        ctx = Context(epoch=epoch, batch_size=batch_size, save_model_path=None, random_seed=None)
        ctx.st_train_loader = target_dataloader
        deconv_trainer.deconv(ctx)
        st_cell_type_matrix = ctx.st_cell_type_matrix
        mapping_matrix = ctx.mapping_matrix
        print(f'st_cell_type_matrix shape: {st_cell_type_matrix.shape}')
        print(f'mapping_matrix shape: {mapping_matrix.shape}')

        weight_list = np.max(np.array(mapping_matrix), axis=0)

    else:
        _, new_sim_matrix = proto_generator.update_current_proto_matrix_with_sim(
            weight_list=weight_list,
            threshold=0,
            proto_latent_matrix=None,
            sim_matrix=mapping_matrix)

        print(f'new_sim_matrix:{new_sim_matrix.shape}')

        proto_exp_matrix, proto_cell_type_matrix = proto_generator.generate_rna_proto_matrix()
        print(f'proto shape:{proto_exp_matrix.shape}')

        cell_type_set_list = proto_generator.current_cell_type_set_list
        # new sim matrix相当于经过softmax后的输出，不能直接使用，需要反推为softmax的输入
        new_sim_matrix = softmax_to_logits_matrix(new_sim_matrix)
        # shape: (proto_number, st_number) <- shape:(st_number, proto_number)
        new_sim_matrix = new_sim_matrix.transpose()

        deconv_model.update_after_train(new_proto_gene_matrix=proto_exp_matrix,
                                              new_proto_mapping_matrix=new_sim_matrix,
                                              new_proto_cell_type_matrix=proto_cell_type_matrix)

        deconv_trainer.opt = optim.AdamW([
                                          {'params': deconv_model.gene_offset_parameter, 'lr':0.1},
                                          {'params': deconv_model.st_offset_parameter, 'lr':0.1},
                                          {'params':deconv_model.mapping_matrix, 'lr':0.1}],
                                          lr=lr)

        ctx = Context(epoch=epoch, batch_size=batch_size, save_model_path=None, random_seed=None)
        ctx.st_train_loader = DataLoader(st_rna_dataset, batch_size=batch_size, shuffle=True)
        ctx.pre_cell_type_matrix = torch.tensor(st_cell_type_matrix, dtype=torch.float32, device=device_name)
        ctx.reg_lambda = reg_lambda
        ctx.constrain_loss_list = []
        ctx.regular_loss_list = []
        ctx.rec_gene_loss_list = []
        deconv_trainer.train(ctx)

        ctx = Context(epoch=epoch, batch_size=batch_size, save_model_path=None, random_seed=None)
        ctx.st_train_loader = DataLoader(st_rna_dataset, batch_size=batch_size, shuffle=False)
        deconv_trainer.deconv(ctx)

        st_cell_type_matrix = ctx.st_cell_type_matrix
        mapping_matrix = ctx.mapping_matrix
        weight_list = np.max(np.array(mapping_matrix), axis=0)

device :cuda:0


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 300/300 [00:01<00:00, 278.35it/s]


[[0.14366985857486725, 0.13910244405269623, 0.13498824834823608, 0.1312718391418457, 0.12790162861347198, 0.12483004480600357, 0.1220184713602066, 0.11943834275007248, 0.11706967651844025, 0.11489787697792053, 0.11291106045246124, 0.11109822243452072, 0.10944855958223343, 0.10795080661773682, 0.1065930724143982, 0.10536330938339233, 0.10424992442131042, 0.1032419428229332, 0.10232909023761749, 0.10150168836116791, 0.10075082629919052, 0.10006853193044662, 0.09944802522659302, 0.09888357669115067, 0.09837035089731216, 0.09790419787168503, 0.09748136252164841, 0.09709815680980682, 0.0967506393790245, 0.0964348316192627, 0.09614681452512741, 0.09588325768709183, 0.09564155340194702, 0.09541971236467361, 0.09521632641553879, 0.09503012895584106, 0.09486006200313568, 0.09470473229885101, 0.09456272423267365, 0.09443259239196777, 0.09431301057338715, 0.09420286118984222, 0.0941014364361763, 0.09400807321071625, 0.09392235428094864, 0.09384380280971527, 0.09377191960811615, 0.0937061160802841

100%|██████████| 300/300 [00:01<00:00, 278.58it/s]


[[0.09243980050086975, 0.09238068759441376, 0.0924665778875351, 0.09214786440134048, 0.09212099015712738, 0.09205654263496399, 0.09188927710056305, 0.0918109193444252, 0.09179641306400299, 0.09172260761260986, 0.09161429852247238, 0.09155397117137909, 0.09152700752019882, 0.09147275984287262, 0.09139804542064667, 0.09134781360626221, 0.09132199734449387, 0.0912829041481018, 0.09122685343027115, 0.09118176996707916, 0.09115171432495117, 0.09111639857292175, 0.09107208997011185, 0.09103556722402573, 0.09100916236639023, 0.0909772738814354, 0.09093933552503586, 0.0909072756767273, 0.09088002890348434, 0.0908503383398056, 0.09081992506980896, 0.09079288691282272, 0.09076809138059616, 0.09074194729328156, 0.09071460366249084, 0.09068963676691055, 0.09066750109195709, 0.09064565598964691, 0.09062406420707703, 0.0906042754650116, 0.09058532118797302, 0.09056666493415833, 0.09054888784885406, 0.09053291380405426, 0.0905180275440216, 0.09050407260656357, 0.09049065411090851, 0.09047750383615494

100%|██████████| 300/300 [00:01<00:00, 259.53it/s]


[[0.0900648906826973, 0.09026525169610977, 0.09007295966148376, 0.08997008204460144, 0.08996538072824478, 0.0898292139172554, 0.0896873027086258, 0.08965186774730682, 0.0896301344037056, 0.08953993767499924, 0.08944868296384811, 0.08940169215202332, 0.08936262130737305, 0.08930492401123047, 0.08923985809087753, 0.08917982131242752, 0.08912645280361176, 0.08907750993967056, 0.089027538895607, 0.08897601068019867, 0.08892697840929031, 0.08887840062379837, 0.08882764726877213, 0.088776595890522, 0.08872632682323456, 0.08867760747671127, 0.08862946182489395, 0.08858097344636917, 0.08853437006473541, 0.0884898230433464, 0.08844351768493652, 0.08839469403028488, 0.0883466899394989, 0.08830071985721588, 0.08825548738241196, 0.0882108137011528, 0.08816767483949661, 0.08812557905912399, 0.08808279782533646, 0.08803998678922653, 0.0879983976483345, 0.08795814216136932, 0.08791957050561905, 0.08788163959980011, 0.08784343302249908, 0.08780445158481598, 0.0877663642168045, 0.08772943168878555, 0.0

100%|██████████| 300/300 [00:01<00:00, 268.22it/s]


[[0.08632900565862656, 0.08737636357545853, 0.08652199804782867, 0.08667372912168503, 0.08691500127315521, 0.0867445319890976, 0.0864548608660698, 0.08637433499097824, 0.08649852126836777, 0.08658083528280258, 0.08650164306163788, 0.08637037873268127, 0.08631221204996109, 0.08633828908205032, 0.08638393878936768, 0.08638911694288254, 0.08634759485721588, 0.08629503846168518, 0.08626948297023773, 0.08627961575984955, 0.08629988133907318, 0.08629877120256424, 0.08627203106880188, 0.08624327927827835, 0.08623585104942322, 0.08624789863824844, 0.0862576961517334, 0.08625023066997528, 0.08623142540454865, 0.08621768653392792, 0.08621705323457718, 0.08622261881828308, 0.08622290939092636, 0.08621490746736526, 0.08620551973581314, 0.08620157837867737, 0.08620248734951019, 0.0862026959657669, 0.08619850128889084, 0.08619160950183868, 0.08618691563606262, 0.08618635684251785, 0.08618646115064621, 0.08618345856666565, 0.08617827296257019, 0.08617441356182098, 0.08617305755615234, 0.0861721560359

100%|██████████| 300/300 [00:01<00:00, 245.93it/s]

[[0.0858364999294281, 0.08699490875005722, 0.08599647134542465, 0.0862259641289711, 0.08651575446128845, 0.08633290976285934, 0.08601353317499161, 0.08591122925281525, 0.08603547513484955, 0.08613678812980652, 0.08607133477926254, 0.08594068884849548, 0.08587765693664551, 0.08590513467788696, 0.08596087992191315, 0.08597716689109802, 0.08593878149986267, 0.0858796238899231, 0.0858469009399414, 0.08585993945598602, 0.08588963747024536, 0.08589303493499756, 0.08586502820253372, 0.08583668619394302, 0.08583343029022217, 0.08584882318973541, 0.08585850894451141, 0.08585003763437271, 0.08583179861307144, 0.0858205035328865, 0.0858231708407402, 0.085830919444561, 0.08583122491836548, 0.08582251518964767, 0.08581404387950897, 0.08581320196390152, 0.08581697195768356, 0.08581819385290146, 0.0858147144317627, 0.08581028133630753, 0.08580860495567322, 0.08580915629863739, 0.0858091339468956, 0.08580712229013443, 0.08580426871776581, 0.08580286800861359, 0.0858033150434494, 0.08580392599105835, 0




In [12]:
# Read out deconvolution results: cell type proportion matrix
import pandas as pd
save_path = 'seqFISH3000'
new_df = pd.DataFrame(st_cell_type_matrix)
new_df.columns = cell_type_set_list
new_df.index = [f'X{i}' for i in range(st_rna_adata.shape[0])]
new_df.to_csv(f'{save_path}/cell_type_results.csv')
print(new_df)
# new df is cell type proportion matrix


         Olig  astrocytes   eNeuron  endo.mural   iNeuron  microglia
X0   0.005885    0.396474  0.000386    0.564672  0.000180   0.032403
X1   0.045221    0.349456  0.005111    0.379621  0.161414   0.059178
X2   0.005685    0.523440  0.014369    0.269996  0.022654   0.163856
X3   0.004223    0.351257  0.185759    0.072694  0.326329   0.059738
X4   0.001883    0.021683  0.482022    0.175549  0.232579   0.086284
..        ...         ...       ...         ...       ...        ...
X66  0.343186    0.013594  0.496034    0.026313  0.093683   0.027189
X67  0.383048    0.027250  0.401302    0.032297  0.098585   0.057518
X68  0.251808    0.486537  0.010102    0.214432  0.002573   0.034548
X69  0.662470    0.268907  0.000930    0.013932  0.005447   0.048314
X70  0.883231    0.013802  0.003130    0.005171  0.000140   0.094526

[71 rows x 6 columns]


In [13]:
# Read out deconvolution results: cell-spot mapping matrix
save_path = 'seqFISH3000'
mapping_df = pd.DataFrame(mapping_matrix)
mapping_df.index = st_rna_adata.obs_names.tolist()
mapping_df.columns = proto_generator.current_proto_type_set_list
mapping_df.to_csv(f'{save_path}/map_results.csv')
print(mapping_df)
# mapping df is the cell-spot mapping matrix, index is obs names in st adata, columns is proto_type name

# reading proto_type and obs names, obs_names is obs_names in sc reference adata.
obs_names_df = pd.DataFrame({'proto_type':proto_generator.current_proto_type_list,
                             'obs_names':proto_generator.current_cell_obs_name_list})
obs_names_df.to_csv(f'{save_path}/proto_type_obs_names_meta.csv')
print(obs_names_df)

new_st_gene_list = st_rna_adata.var_names.tolist()
var_names_df = pd.DataFrame({'var_names':new_st_gene_list})
var_names_df.to_csv(f'{save_path}/st_var_names.csv')
print(var_names_df)
# training finished, next step can be seen in visualize_tutorial, analysis_tutorial and Interpretability_tutorial

    Olig_0_0_0_0  Olig_0_0_0_1  Olig_0_0_0_10  Olig_0_0_0_100  Olig_0_0_0_101  \
0   3.501234e-07  3.600846e-07   4.274239e-07    4.004756e-07    4.027794e-07   
1   3.257382e-07  3.222299e-07   3.275607e-07    3.098855e-07    3.144546e-07   
2   2.893596e-07  2.780317e-07   2.800761e-07    2.937659e-07    3.322470e-07   
3   2.703758e-07  2.887691e-07   2.692723e-07    2.689736e-07    2.893762e-07   
4   2.186998e-07  2.111610e-07   2.190456e-07    2.281110e-07    2.222856e-07   
..           ...           ...            ...             ...             ...   
66  5.498379e-06  4.818993e-06   5.757078e-06    4.019264e-06    4.701515e-06   
67  4.434631e-06  4.448635e-06   5.582931e-06    4.125894e-06    4.242161e-06   
68  3.770513e-07  3.672226e-07   3.634147e-07    3.590075e-07    3.531640e-07   
69  6.634106e-07  6.292111e-07   6.976379e-07    6.096389e-07    4.959672e-07   
70  2.410595e-07  2.508112e-07   2.528461e-07    2.518433e-07    2.264307e-07   

    Olig_0_0_0_102  Olig_0_

In [14]:
# Saving trained Parameters, including mapping matrix, gene offset parameter and st offset parameter
torch.save(deconv_model.state_dict(), f"{save_path}/trained_hbc.pt")
print(f'{deconv_model.mapping_matrix.shape}')
print(f'proto_gene_matrix shape :{deconv_model.proto_gene_matrix.shape}')
print(f'st_gene_matrix shape:{deconv_model.st_gene_matrix.shape}')
print(f'{deconv_model.proto_cell_type_matrix.shape}')
print(f'{deconv_model.gene_offset_parameter.shape}')
print(f'{deconv_model.st_offset_parameter.shape}')
# these parameter shape will be used in Interpretability tutorial

torch.Size([1691, 71])
proto_gene_matrix shape :torch.Size([1691, 2221])
st_gene_matrix shape:torch.Size([71, 2221])
torch.Size([1691, 6])
torch.Size([1, 2221])
torch.Size([71, 1])
