In [None]:
import scanpy as sc
import os
import pandas as pd
import numpy as np

import os, sys
from pathlib import Path

script_dir = Path().resolve()
repo_dir = script_dir
src_dir = repo_dir / 'src'
data_dir = repo_dir / 'data'
sys.path.append(str(src_dir))
sys.path.append(str(data_dir))

from dataset.dataloader import AnnDataDataset

PARAMETERS = {
    'hvgs': 20116,
    'num_genes': 20116,
    # 'hvgs': 5000,
    # 'num_genes': 5000,
    'latent_dimension': 50,
    'target_sum': 10000,
    'batch_size': 128,
    'num_epochs': 1,
}

tm_droplet_data = sc.read(
    r'./data/raw/tabula_muris/TM_droplet.h5ad',
    # backup_url="https://figshare.com/ndownloader/files/23938934",
)
tm_facs_data = sc.read(
    r'./data/raw/tabula_muris/TM_facs.h5ad',
    # backup_url="https://figshare.com/ndownloader/files/23939711",
)

In [2]:
tm_droplet_data_tissues = tm_droplet_data.obs.tissue.tolist()
tm_droplet_data_tissues = {t for t in tm_droplet_data_tissues}
tm_droplet_data_tissues
print(f'{tm_droplet_data_tissues=}')
print(f'{len(tm_droplet_data_tissues)=}')

tm_facs_data_tissues = tm_facs_data.obs.tissue.tolist()
tm_facs_data_tissues = {t for t in tm_facs_data_tissues}
tm_facs_data_tissues
print(f'{tm_facs_data_tissues=}')
print(f'{len(tm_facs_data_tissues)=}')

tm_all_tissues = tm_droplet_data_tissues.union(tm_facs_data_tissues)
# tm_all_tissues
print(f'{len(tm_all_tissues)=}')

# train_tissues = tm_shared_tissues[:-4]
# test_tissues = tm_shared_tissues[-4:]

# print(f'{train_tissues=}')
# print(f'{test_tissues=}')

# train_tissues=['Large_Intestine', 'Spleen', 'Mammary_Gland', 'Lung', 'Kidney', 'Thymus', 'Bladder', 'Tongue', 'Marrow', 'Trachea']
test_tissues={'Skin', 'Liver', 'Limb_Muscle', 'Pancreas'}
train_tissues = tm_all_tissues.difference(test_tissues) # v3,5
print(train_tissues)
print(test_tissues)

tm_droplet_data_tissues={'Limb_Muscle', 'Heart_and_Aorta', 'Fat', 'Kidney', 'Bladder', 'Tongue', 'Liver', 'Mammary_Gland', 'Skin', 'Pancreas', 'Thymus', 'Trachea', 'Marrow', 'Large_Intestine', 'Spleen', 'Lung'}
len(tm_droplet_data_tissues)=16
tm_facs_data_tissues={'Limb_Muscle', 'Kidney', 'SCAT', 'Skin', 'Large_Intestine', 'Diaphragm', 'Brain_Non-Myeloid', 'Brain_Myeloid', 'Marrow', 'Lung', 'Liver', 'Pancreas', 'Thymus', 'Heart', 'Trachea', 'Spleen', 'Aorta', 'BAT', 'Bladder', 'Tongue', 'Mammary_Gland', 'GAT', 'MAT'}
len(tm_facs_data_tissues)=23
len(tm_all_tissues)=25
{'Kidney', 'SCAT', 'Thymus', 'Heart', 'Trachea', 'Large_Intestine', 'Spleen', 'Aorta', 'Heart_and_Aorta', 'BAT', 'Fat', 'Bladder', 'Diaphragm', 'Tongue', 'Brain_Non-Myeloid', 'Mammary_Gland', 'Brain_Myeloid', 'Marrow', 'MAT', 'GAT', 'Lung'}
{'Limb_Muscle', 'Pancreas', 'Liver', 'Skin'}


In [3]:
tm_droplet_data = tm_droplet_data[
    (~tm_droplet_data.obs.cell_ontology_class.isna())
].copy()
tm_facs_data = tm_facs_data[
    (~tm_facs_data.obs.cell_ontology_class.isna())
].copy()

In [4]:
gene_len = pd.read_csv(
    "https://raw.githubusercontent.com/chenlingantelope/HarmonizationSCANVI/master/data/gene_len.txt",
    delimiter=" ",
    header=None,
    index_col=0,
)
gene_len.head()

Unnamed: 0_level_0,1
0,Unnamed: 1_level_1
0610007C21Rik,94.571429
0610007L01Rik,156.0
0610007P08Rik,202.272727
0610007P14Rik,104.0
0610007P22Rik,158.75


In [5]:
import numpy as np
from scipy import sparse

gene_len = gene_len.reindex(tm_facs_data.var.index).dropna()

tm_facs_data = tm_facs_data[:, gene_len.index].copy()   # break the view

gene_len_vec = gene_len[1].values.astype(np.float32)
median_len  = np.median(gene_len_vec)

# column‑wise scaling in CSC format
X = tm_facs_data.X.tocsc(copy=True)        # -> (n_cells × n_genes)
X = X.multiply(1.0 / gene_len_vec)         # divide each column by its length
X = X.multiply(median_len)                 # multiply by the median length
X.data = np.rint(X.data)                   # round only the non‑zero entries

tm_facs_data.X = X.tocsr()                 # store back as CSR (Scanpy’s default)


In [6]:
tm_droplet_train = tm_droplet_data[
    (tm_droplet_data.obs['tissue'].isin(train_tissues))  
    & (~tm_droplet_data.obs.cell_ontology_class.isna())
].copy()

tm_facs_train = tm_facs_data[
    (tm_facs_data.obs['tissue'].isin(train_tissues))  
    & (~tm_facs_data.obs.cell_ontology_class.isna())
].copy()

tm_droplet_train.obs["tech"] = "10x"
tm_facs_train.obs["tech"] = "SS2"
tm_adata_train = tm_droplet_train.concatenate(tm_facs_train)

  tm_adata_train = tm_droplet_train.concatenate(tm_facs_train)


In [7]:
tm_droplet_test = tm_droplet_data[
    (tm_droplet_data.obs['tissue'].isin(test_tissues))  
    & (~tm_droplet_data.obs.cell_ontology_class.isna())
].copy()

tm_facs_test = tm_facs_data[
    (tm_facs_data.obs['tissue'].isin(test_tissues))  
    & (~tm_facs_data.obs.cell_ontology_class.isna())
].copy()

tm_droplet_test.obs["tech"] = "10x"
tm_facs_test.obs["tech"] = "SS2"
tm_adata_test = tm_droplet_test.concatenate(tm_facs_test)

  tm_adata_test = tm_droplet_test.concatenate(tm_facs_test)


In [8]:
print(f'{len(tm_adata_train)=}')
print(f'{len(tm_adata_test)=}')

len(tm_adata_train)=294439
len(tm_adata_test)=61774


In [9]:
sc.pp.normalize_total(tm_adata_train, target_sum=1e4)
sc.pp.log1p(tm_adata_train)
sc.pp.highly_variable_genes(
    tm_adata_train,
    batch_key="tech",
)

tm_adata_train.X = np.nan_to_num(tm_adata_train.X, nan=0)

num_genes = len(tm_adata_train.var.index)
PARAMETERS['hvgs'] = num_genes

hvg_genes = tm_adata_train.var.index[tm_adata_train.var['highly_variable']].tolist()

# tm_adata_train = tm_adata_train[:, tm_adata_train.var.index.isin(hvg_genes)]

In [10]:
sc.pp.normalize_total(tm_adata_test, target_sum=1e4)
sc.pp.log1p(tm_adata_test)

tm_adata_test.X = np.nan_to_num(tm_adata_test.X, nan=0)

# tm_adata_test = tm_adata_test[:, tm_adata_test.var.index.isin(hvg_genes)]

In [11]:
tm_adata_train.obs.rename(columns={'cell_ontology_class': 'Celltype'}, inplace=True)
tm_adata_test.obs.rename(columns={'cell_ontology_class': 'Celltype'}, inplace=True)
tm_adata_test

AnnData object with n_obs × n_vars = 61774 × 18244
    obs: 'age', 'cell', 'Celltype', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'n_genes', 'sex', 'subtissue', 'tissue', 'tissue_free_annotation', 'tech', 'FACS.selection', 'n_counts', 'batch'
    var: 'n_cells-0', 'n_cells-1'
    uns: 'log1p'

In [12]:
celltype_techs = tm_adata_train.obs.groupby("Celltype")["tech"].unique()

# 2) Build a dictionary mapping each cell type to "only_10x", "only_SS2", or "both"
celltype_status = {}
for celltype, tech_list in celltype_techs.items():
    tech_set = set(tech_list)
    if len(tech_set) == 1:
        if "10x" in tech_set:
            celltype_status[celltype] = "only_10x"
        else:
            celltype_status[celltype] = "only_SS2"
    else:
        celltype_status[celltype] = "both"

# 3) Create a new column in .obs indicating whether a cell's type is only_10x, only_SS2, or both
tm_adata_train.obs["celltype_tech_availability"] = (
    tm_adata_train.obs["Celltype"].map(celltype_status)
)

  celltype_techs = tm_adata_train.obs.groupby("Celltype")["tech"].unique()


In [13]:
celltype_techs = tm_adata_test.obs.groupby("Celltype")["tech"].unique()

# 2) Build a dictionary mapping each cell type to "only_10x", "only_SS2", or "both"
celltype_status = {}
for celltype, tech_list in celltype_techs.items():
    tech_set = set(tech_list)
    if len(tech_set) == 1:
        if "10x" in tech_set:
            celltype_status[celltype] = "only_10x"
        else:
            celltype_status[celltype] = "only_SS2"
    else:
        celltype_status[celltype] = "both"

# 3) Create a new column in .obs indicating whether a cell's type is only_10x, only_SS2, or both
tm_adata_test.obs["celltype_tech_availability"] = (
    tm_adata_test.obs["Celltype"].map(celltype_status)
)

  celltype_techs = tm_adata_test.obs.groupby("Celltype")["tech"].unique()


In [14]:
tm_adata_test.obs['Celltype'].replace(
    to_replace='pancreatic ductal cel',
    value='pancreatic ductal cell',
    inplace=True
)

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  tm_adata_test.obs['Celltype'].replace(
  tm_adata_test.obs['Celltype'].replace(


In [15]:
tm_adata_test.obs['celltype_tech_availability']

index
AAACCTGAGATGTCGG-1-9-0-0-0           both
AAAGCAATCGGAAATA-1-9-0-0-0           both
AAAGTAGAGGCCCTTG-1-9-0-0-0           both
AACCGCGAGAAACCGC-1-9-0-0-0           both
AACTCCCAGTTGTCGT-1-9-0-0-0           both
                                   ...   
P9.MAA000907.3_11_M.1.1-1-1-1        both
P9.MAA000927.3_9_M.1.1-1-1-1     only_SS2
P9.MAA000938.3_8_M.1.1-1-1-1     only_SS2
P9.MAA001857.3_38_F.1.1-1-1-1        both
P9.MAA001861.3_39_F.1.1-1-1-1        both
Name: celltype_tech_availability, Length: 61774, dtype: object

In [None]:
import os, sys
from pathlib import Path

script_dir = Path().resolve()
repo_dir = script_dir
src_dir = repo_dir / 'src'
data_dir = repo_dir / 'data'
sys.path.append(str(src_dir))
sys.path.append(str(data_dir))

from dataset.dataloader import AnnDataDataset


In [None]:
tm_dataset = AnnDataDataset(tm_adata_train)
tm_dataloader = DataLoader(tm_dataset, batch_size=PARAMETERS['batch_size'], shuffle=True)

In [None]:
with open(r'./data/pickled/tabula_muris/tm_dataset_train_tissues_length_normalized_v3,5.pkl', 'wb') as f: # NOTE: 3,5 because apparently v3 already has both sexes
    pickle.dump(tm_dataset, f)

with open(r'./data/pickled/tabula_muris/tm_dataloader_train_tissues_length_normalized_v3,5.pkl', 'wb') as f:
    pickle.dump(tm_dataloader, f)

with open(r'./data/pickled/tabula_muris/tm_adata_train_length_normalized_v3,5.pkl', 'wb') as f:
    pickle.dump(tm_adata_train, f)

# with open(r'./data/pickled/tabula_muris/tm_adata_test_v3,5.pkl', 'wb') as f: # NOTE: v3 test already has both sex test tissues
#     pickle.dump(tm_adata_test, f)

In [None]:
# Define functions to precompute data-dependent variables
def precompute_gene_clusters(dataset):
    most_significant_genes_dict = dataset.most_significant_genes_dict
    least_significant_genes_dict = dataset.least_significant_genes_dict
    gene_networks = dataset.gene_networks
    cell_type_categories = dataset.cell_type_categories
    code_to_celltype = dataset.code_to_celltype
    celltype_to_code = dataset.celltype_to_code
    gene_names = dataset.gene_names
    gene_name_to_index = dataset.gene_name_to_index
    index_to_gene_name = dataset.index_to_gene_name
    gene_dispersions = dataset.gene_dispersions
    print('Precomputed gene clusters!')
    return (most_significant_genes_dict, least_significant_genes_dict,
            gene_networks, gene_names, code_to_celltype, celltype_to_code,
            gene_name_to_index, index_to_gene_name, gene_dispersions)

def precompute_mu_sigma(dataloader, most_significant_genes_dict, least_significant_genes_dict, gene_name_to_index):
    all_expression_matrix = []
    cell_types_data = {}
    cell_types_msg_data = {}
    cell_types_lsg_data = {}
    for batch in dataloader:
        expression_matrix, cell_types = batch
        all_expression_matrix.append(expression_matrix)
        
        for cell_type in torch.unique(cell_types):
            cell_type = int(cell_type)
            cell_type_mask = cell_types == cell_type
            cell_type_expression_matrix = expression_matrix[cell_type_mask]
            # All genes
            if cell_type not in cell_types_data:
                cell_types_data[cell_type] = []
            cell_types_data[cell_type].append(cell_type_expression_matrix)

            # Most significant genes
            msg_genes = most_significant_genes_dict[cell_type]
            msg_gene_indices = [gene_name_to_index[g] for g in msg_genes]
            msg_significant_gene_matrix = cell_type_expression_matrix[:, msg_gene_indices]
            if cell_type not in cell_types_msg_data:
                cell_types_msg_data[cell_type] = []
            cell_types_msg_data[cell_type].append(msg_significant_gene_matrix)
            
            # Least significant genes
            lsg_genes = least_significant_genes_dict[cell_type]
            lsg_gene_indices = [gene_name_to_index[g] for g in lsg_genes]
            lsg_significant_gene_matrix = cell_type_expression_matrix[:, lsg_gene_indices]
            if cell_type not in cell_types_lsg_data:
                cell_types_lsg_data[cell_type] = []
            cell_types_lsg_data[cell_type].append(lsg_significant_gene_matrix)

    cell_type_mu_sigma = {}
    cell_type_msg_mu_sigma = {}
    cell_type_lsg_mu_sigma = {}
    # All genes
    for cell_type, cell_type_expression_matrix in cell_types_data.items():
        data_tensor = torch.cat(cell_type_expression_matrix, dim=0)
        mu = torch.mean(data_tensor, dim=0)
        sigma = torch.std(data_tensor, dim=0, unbiased=False)
        sigma = torch.clamp(sigma, min=1e-8)
        cell_type_mu_sigma[int(cell_type)] = (mu, sigma)
    
    # Most significant genes
    for cell_type, matrices in cell_types_msg_data.items():
        data_tensor = torch.cat(matrices, dim=0)
        mu = torch.mean(data_tensor, dim=0)
        sigma = torch.std(data_tensor, dim=0, unbiased=False)
        sigma = torch.clamp(sigma, min=1e-8)
        dispersion = sigma**2 / mu
        cell_type_msg_mu_sigma[int(cell_type)] = (mu, sigma, dispersion)
    
    # Least significant genes
    for cell_type, matrices in cell_types_lsg_data.items():
        data_tensor = torch.cat(matrices, dim=0)
        mu = torch.mean(data_tensor, dim=0)
        sigma = torch.std(data_tensor, dim=0, unbiased=False)
        sigma = torch.clamp(sigma, min=1e-8)
        dispersion = sigma**2 / mu
        cell_type_lsg_mu_sigma[int(cell_type)] = (mu, sigma, dispersion)

    data_tensor = torch.cat(all_expression_matrix, dim=0)
    global_mu_sigma = (torch.mean(data_tensor, dim=0),
                       torch.std(data_tensor, dim=0, unbiased=False))

    return cell_type_mu_sigma, global_mu_sigma, cell_type_msg_mu_sigma, cell_type_lsg_mu_sigma

# Precompute data-dependent variables before model initialization
(most_significant_genes_dict, least_significant_genes_dict,
 gene_networks, gene_names, code_to_celltype, celltype_to_code,
 gene_name_to_index, index_to_gene_name, gene_dispersions) = precompute_gene_clusters(tm_dataset)

cell_type_mu_sigma, global_mu_sigma, cell_type_msg_mu_sigma, cell_type_lsg_mu_sigma = precompute_mu_sigma(
    tm_dataloader, most_significant_genes_dict, least_significant_genes_dict, gene_name_to_index)

In [None]:
precomputed_dir = data_dir / 'pickled' / 'tabula_muris' / 'precomputed'
precomputed_dir.mkdir(parents=True, exist_ok=True)

precomputed_gene_clusters_path =  precomputed_dir / 'tm_dataset_train_tissues_length_normalized_v3,5_precomputed_gene_clusters.pkl'
with open(precomputed_gene_clusters_path, 'wb') as f:
    pickle.dump(
        {
            "most_significant_genes_dict": most_significant_genes_dict,
            "least_significant_genes_dict": least_significant_genes_dict,
            "gene_networks": gene_networks,
            "gene_names": gene_names,
            "code_to_celltype": code_to_celltype,
            "celltype_to_code": celltype_to_code,
            "gene_name_to_index": gene_name_to_index,
            "index_to_gene_name": index_to_gene_name,
            "gene_dispersions": gene_dispersions,
        },
        f,
    )

precomputed_mu_sigma_path = precomputed_dir / 'tm_dataset_train_tissues_length_normalized_v3,5_precomputed_mu_sigma.pkl'
with open(precomputed_mu_sigma_path, "wb") as f:
    pickle.dump(
        {
            "cell_type_mu_sigma": cell_type_mu_sigma,
            "global_mu_sigma": global_mu_sigma,
            "cell_type_msg_mu_sigma": cell_type_msg_mu_sigma,
            "cell_type_lsg_mu_sigma": cell_type_lsg_mu_sigma,
        },
        f,
    )