In [1]:
import anndata as ad
import cupy as cp
import harmonypy as hm
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scanpy as sc
import scanpy.external as sce
import yaml

from functools import partial
from scipy.io import mmread, mmwrite
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from tqdm.auto import tqdm

In [2]:
config_path = "/home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_tsne2.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

In [3]:
def load_data(config):
    mtx_path = config["data"]["mtx_path"]
    batch_paths = config["data"]["batch_paths"]

    X = mmread(mtx_path).tocsr().T
    adata = ad.AnnData(X=X)

    n_cells, n_genes = X.shape
    adata.var_names = [f"gene_{i}" for i in range(n_genes)]
    adata.obs_names = [f"cell_{i}" for i in range(n_cells)]
    
    batch_indices = []
    for path in batch_paths:
        df = pd.read_csv(path, sep="\t", header=None)
        batch_indices.append(df.values.astype(int))
    batch_array = np.concatenate(batch_indices, axis=1)
    
    # Filter columns with too many unique values (Harmony is slow with many high-cardinality categories)
    valid_columns = []
    for i in range(batch_array.shape[1]):
        unique_values = np.unique(batch_array[:, i])
        if len(unique_values) <= 1000:
            value_to_idx = {v: idx for idx, v in enumerate(unique_values)}
            batch_array[:, i] = np.array([value_to_idx[v] for v in batch_array[:, i]])
            valid_columns.append(i)
    batch_array = batch_array[:, valid_columns]

    batch_names = [f"batch_{i}" for i in range(batch_array.shape[1])]
    for i, name in enumerate(batch_names):
        adata.obs[name] = batch_array[:, i]
        adata.obs[name] = adata.obs[name].astype("category")

    return adata, batch_names, n_cells, n_genes


def preprocess(adata, config):
    preprocess_config = config["model"]["preprocess"]
    if preprocess_config.get("normalize", False):
        sc.pp.normalize_total(adata, target_sum=1e4, inplace=True)
    if preprocess_config.get("log1p", False):
        sc.pp.log1p(adata)
    if preprocess_config.get("top_genes", None):
        sc.pp.highly_variable_genes(adata, n_top_genes=preprocess_config["top_genes"])
        adata = adata[:, adata.var.highly_variable].copy()
    if preprocess_config.get("scale", False):
        sc.pp.scale(adata, zero_center=False)
    return adata


def train(adata, batch_names, config):
    sc.tl.pca(adata, **config["model"]["pca"])

    ho = hm.run_harmony(
        adata.obsm["X_pca"],
        adata.obs,
        vars_use=batch_names,
        **config["model"]["harmony"]
    )
    adata.obsm["X_pca_harmony"] = ho.Z_corr.T
    
    sc.tl.tsne(adata, use_rep="X_pca_harmony", **config["model"]["tsne"])

    save_path = os.path.join(config["output_dir"], f"{config['experiment']}.h5ad")
    adata.write_h5ad(save_path)

    return adata

In [None]:
adata, batch_names, n_cells, n_genes = load_data(config)
print(batch_names, n_cells, n_genes, [len(np.unique(adata.obs[b].values)) for b in batch_names])

adata = preprocess(adata, config)

adata = train(adata, batch_names, config)