## Data preprocessing

In [2]:
import warnings
warnings.filterwarnings("ignore")
import STAGE
import scanpy as sc
import matplotlib.pyplot as plt

In [2]:
input_dir = 'Data/T2'
adata = sc.read_visium(path=input_dir, count_file='filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()

In [3]:
# Coordinates (array_col, array_row)
adata.obsm["coord"]=adata.obs.loc[:, ['array_col', 'array_row']].to_numpy()

In [4]:
# Normalization
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=5000)
# sc.pp.normalize_total(adata, target_sum=1e4)
# sc.pp.log1p(adata)

In [5]:
adata

AnnData object with n_obs × n_vars = 2903 × 32285
    obs: 'in_tissue', 'array_row', 'array_col'
    var: 'gene_ids', 'feature_types', 'genome', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'spatial', 'hvg'
    obsm: 'spatial', 'coord'

In [3]:
import pandas as pd
import anndata as ad
counts = pd.read_csv("E:\\data\\HN\cnts.csv", index_col=0)
coords = pd.read_csv("E:\\data\\HN\locs.csv", index_col=0)
coords = coords // 16
adata = ad.AnnData(X=counts.values, obs=coords, var=pd.DataFrame(index=counts.columns.values))
adata.obsm["coord"]=adata.obs.loc[:, ['x', 'y']].to_numpy()
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=377)

## Running model

In [4]:
adata_stage=STAGE.STAGE(
    adata, 
    save_path='./T2_HBC_2',
    data_type='10x', 
    experiment='generation', 
    coord_sf=77, 
    train_epoch=10000,
    seed=1234,
    batch_size=512, 
    learning_rate=1e-3, 
    w_recon=0.1, 
    w_w=0.1, 
    w_l1=0.1,
    relu=True,
    device='cuda:0'
)

Epochs: 100%|██████████| 10000/10000 [38:00<00:00,  4.39it/s, latent_loss: 0.16230, recon_loss: 0.19898, total_loss: 0.01822]


In [5]:
adata_stage.X

array([[ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     ,  0.     ],
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     ,  0.     ],
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     ,  0.     ],
       ...,
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     , 80.6191 ],
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     , 80.76536],
       [ 0.     ,  0.     ,  0.     , ...,  0.     ,  0.     , 80.9117 ]],
      dtype=float32)

In [6]:
gene = adata_stage.X[:343800].reshape((900, 382, -1))
gene = gene // 37
import numpy as np
gene = np.transpose(gene, (2, 0, 1))
import pickle
with open('genes_3D.pkl', 'wb') as f:
    pickle.dump(gene, f)