In [1]:
import preprocessing

import scanpy as sc
import pandas as pd
import pathlib
import pickle
import warnings
import numpy as np

In [2]:
np.random.seed(0)
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

In [3]:
ROOT = pathlib.Path(".")

In [4]:
adata_sa1 = sc.read(ROOT / "adata_sa1.h5ad")
adata_sa3 = sc.read(ROOT / "adata_sa3.h5ad")

In [5]:
adata_sa1, adata_sa3

(AnnData object with n_obs × n_vars = 91246 × 1022
     obs: 'X', 'Y', 'Z', 'Tissue_Symbol', 'Maintype_Symbol', 'Subtype_Symbol'
     uns: 'Tissue_Symbol_colors'
     obsm: 'spatial', 'spatial_full',
 AnnData object with n_obs × n_vars = 207684 × 1022
     obs: 'X', 'Y', 'Z', 'Tissue_Symbol', 'Maintype_Symbol', 'Subtype_Symbol'
     uns: 'Tissue_Symbol_colors'
     obsm: 'spatial', 'spatial_full')

In [6]:
keep_sa1_maintype = (
    adata_sa1.obs["Maintype_Symbol"]
    .value_counts()
    .index[adata_sa1.obs["Maintype_Symbol"].value_counts() > 100]
)
keep_sa3_maintype = (
    adata_sa3.obs["Maintype_Symbol"]
    .value_counts()
    .index[adata_sa3.obs["Maintype_Symbol"].value_counts() > 100]
)
keep_sa1_tissue = (
    adata_sa1.obs["Tissue_Symbol"]
    .value_counts()
    .index[adata_sa1.obs["Tissue_Symbol"].value_counts() > 100]
)
keep_sa3_tissue = (
    adata_sa3.obs["Tissue_Symbol"]
    .value_counts()
    .index[adata_sa3.obs["Tissue_Symbol"].value_counts() > 100]
)
adata_sa1 = adata_sa1[
    adata_sa1.obs["Maintype_Symbol"].isin(keep_sa1_maintype)
    & adata_sa1.obs["Tissue_Symbol"].isin(keep_sa1_tissue)
].copy()
adata_sa3 = adata_sa3[
    adata_sa3.obs["Maintype_Symbol"].isin(keep_sa3_maintype)
    & adata_sa3.obs["Tissue_Symbol"].isin(keep_sa3_tissue)
].copy()
print(adata_sa1.shape, adata_sa3.shape)

(91156, 1022) (207383, 1022)


In [7]:
for adata in [adata_sa1, adata_sa3]:
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    adata.layers["log_normalized"] = adata.X.copy()
    spatial = adata.obsm["spatial"]
    adata.obsm["spatial"] = (spatial - spatial.mean()) / spatial.std()

In [8]:
sc.tl.rank_genes_groups(adata_sa1, "Tissue_Symbol", method="wilcoxon")
sc.tl.rank_genes_groups(adata_sa3, "Tissue_Symbol", method="wilcoxon")

In [9]:
trn_genes, val_genes, tst_genes = \
    preprocessing.get_genes(adata_sa1, adata_sa3, cluster="CBXgr", val_tst=(10, 10))

In [10]:
pd.Series(trn_genes).to_csv(ROOT / "trn_genes.csv")
pd.Series(val_genes).to_csv(ROOT / "val_genes.csv")
pd.Series(tst_genes).to_csv(ROOT / "tst_genes.csv")

In [11]:
len(trn_genes), len(val_genes), len(tst_genes)

(1002, 10, 10)

In [12]:
adata_sa1.write(ROOT / "adata_sa1_norm.h5ad")
adata_sa3.write(ROOT / "adata_sa3_norm.h5ad")

In [13]:
data_1, data_2 = preprocessing.get_spatial_data(
    ROOT / "adata_sa1_norm.h5ad",
    ROOT / "adata_sa3_norm.h5ad",
    trn_genes_path=ROOT / "trn_genes.csv",
    val_genes_path=ROOT / "val_genes.csv",
    tst_genes_path=ROOT / "tst_genes.csv"
)

In [14]:
key = 'Maintype_Symbol'
gt_celltypes = adata_sa1.obs[key]
celltypes_to_pull = adata_sa3.obs[key]

In [15]:
assert data_1[0].shape[0] == gt_celltypes.shape[0]
assert data_2[0].shape[0] == celltypes_to_pull.shape[0]

In [16]:
with open(ROOT / "data.pkl", "wb") as fout:
    pickle.dump((data_1, data_2, gt_celltypes, celltypes_to_pull), fout)