In [1]:
import os
os.chdir(path='../../')
import numpy as np
import scanpy as sc
from STForte import STGraph
from STForte import STForteModel
from STForte.helper import save_gdata
from STForte.helper import annotation_propagate
trial_name = "trial-DLPFC/multi_slides"
data_name = "adata_673_676_paste2"

  rank_zero_deprecation(
Global seed set to 0


In [2]:
adata = sc.read_h5ad(filename=f"{trial_name}/data/{data_name}.h5ad")
adata.obs['section_id'] = adata.obs['section_id'].astype("str").astype("category")
adata

  utils.warn_names_duplicates("obs")


AnnData object with n_obs × n_vars = 14364 × 33538
    obs: 'in_tissue', 'array_row', 'array_col', 'spatialLIBD', 'section_id'
    uns: '151673', '151674', '151675', '151676'
    obsm: 'spatial'

In [3]:
adata.layers["log1p"] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4, layer="log1p")
sc.pp.log1p(adata, layer="log1p")
adata = adata[:,adata.X.sum(axis=0)!=0]
# sc.pp.highly_variable_genes(adata, n_top_genes=3000, batch_key="section_id", layer="log1p", inplace=True, subset=True)
adata

View of AnnData object with n_obs × n_vars = 14364 × 24155
    obs: 'in_tissue', 'array_row', 'array_col', 'spatialLIBD', 'section_id'
    uns: '151673', '151674', '151675', '151676', 'log1p'
    obsm: 'spatial'
    layers: 'log1p'

In [4]:
# sc.pp.pca(adata, n_comps=300)
# sc.external.pp.harmony_integrate(adata, "section_id", adjusted_basis='X_pca_harmony',)
# sc.external.pp.scanorama_integrate(adata, "section_id", adjusted_basis='X_pca_scanorama')

In [5]:
d = adata.uns[list(adata.uns.keys())[0]]['spatial']['stomic']['scalefactors']['fiducial_diameter_fullres']
stgraph = STGraph.graphFromAnndata(
    adata=adata,
    # attr_loc=["obsm","X_pca_scanorama"],
    d = d 
    # knn=True,
    # k=18*2,
)
stgraph.pca(n_components=300,svd_solver="full")
# stgraph.add_additional_notes("section_id", adata.obs['section_id'].to_numpy())
# stgraph.scvi(batch_id="section_id", scvi_kwargs=dict(n_hidden=300, n_latent=256, n_layers=2))
gdata = stgraph.topyg()



d-based initialize:   0%|          | 0/14364 [00:00<?, ?it/s]

PCA pre-compression for data, from 24155 onto 300-dim.
Scaling data: None; SVD solver: full; random_state=42.
Start compression...	Done! Elapsed time: 518.33s.
FP


In [6]:
# d = adata.uns[list(adata.uns.keys())[0]]['spatial']['stomic']['scalefactors']['fiducial_diameter_fullres']
# stgraph = STGraph.graphFrom3DAnndata(
#     adata=adata,
#     ordered_section_name=np.arange(4),
#     attr_loc=["obsm", "X_pca_harmony"],
#     # attr_loc=["obsm",'X_pca_scanorama'],
#     section_id=['obs', 'section_id'],
#     d=d,
#     knn=True,
#     k=18,
#     between_section_k=18,
# )
# # stgraph.pca(n_components=300)
# # stgraph.padding3D()
# # stgraph.remove_duplicates(r=0.3)
# gdata = stgraph.topyg()

In [7]:
model = STForteModel(adata=adata, gdata=gdata, epochs=550,
                     output_dir='./{:s}/pl_ckpts/'.format(trial_name),
                     module_kwargs=dict(
                        partial_adjacent=True, lmbd_cross=10, lmbd_gan=4,
                        ),
                     )
model

<STForte._model.STForteModel at 0x7f493b9e3ca0>

In [8]:
model.fit()

Global seed set to 42
  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | attr_encoder  | Sequential        | 71.9 K
1 | strc_encoder  | Sequential_2ba8ea | 71.9 K
2 | attr_decoder  | Sequential        | 72.2 K
3 | strc_decoder  | Sequential        | 2.1 K 
4 | discriminator | Sequential        | 1.1 K 
----------------------------------------------------
219 K     Trainable params
0         Non-trainable params
219 K     Total params
0.877     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=550` reached.


In [9]:
if not os.path.exists(f"{trial_name}/data"):
    os.makedirs(f"{trial_name}/data")   
save_gdata(gdata, path=f"{trial_name}/data/gdata_multi.pkl")
adata.write_h5ad(f"{trial_name}/data/trial_multi.h5ad")

In [10]:
import torch
z_attr, z_strc, _, _, _, _ = model._get_module_output()
adata.obsm['STForte_ATTR'] = z_attr.detach().numpy()
adata.obsm['STForte_TOPO'] = z_strc.detach().numpy()
adata.obsm['STForte_COMB'] = torch.cat([z_attr,z_strc],dim=1).detach().numpy()
if not os.path.exists(f"./{trial_name}/outputs"):
    os.makedirs(f"./{trial_name}/outputs")
adata.write(f"./{trial_name}/outputs/stforte_multi.h5ad")

  adata.obsm['STForte_ATTR'] = z_attr.detach().numpy()
  utils.warn_names_duplicates("obs")


In [11]:
# For analysis in padding resolution
adata_sp = model.get_result_anndata(adj_mat=False)
adata_sp.obs['section_id'] = adata.obs['section_id']
adata_sp.obs['section_id'].astype("str").astype("category")
adata_sp.write_h5ad("./{:s}/outputs/sp_multi.h5ad".format(trial_name))

  utils.warn_names_duplicates("obs")
