In [1]:
import anndata as ad
import cupy as cp
import harmonypy as hm
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scanpy as sc
import yaml

from functools import partial
from scipy.io import mmread
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

In [2]:
config_path = "/home/romainlhardy/code/hyperbolic-cancer/configs/lung/lung_umap25.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)

In [3]:
def load_data(config):
    mtx_path = config["data"]["mtx_path"]
    batch_paths = config["data"]["batch_paths"]

    mtx = mmread(mtx_path).tocsr().T
    adata = ad.AnnData(X=mtx)

    n_cells, n_genes = mtx.shape
    adata.var_names = [f"gene_{i}" for i in range(n_genes)]
    adata.obs_names = [f"cell_{i}" for i in range(n_cells)]
    
    batch_indices = []
    for path in batch_paths:
        df = pd.read_csv(path, sep="\t", header=None)
        batch_indices.append(df.values)
    batch_array = np.concatenate(batch_indices, axis=1)

    batch_names = [f"batch_{i}" for i in range(batch_array.shape[1])]
    for i, name in enumerate(batch_names):
        adata.obs[name] = batch_array[:, i]
        adata.obs[name] = adata.obs[name].astype("category")

    return adata, batch_names, n_cells, n_genes


def preprocess(adata, config):
    preprocess_config = config["model"]["preprocess"]
    if preprocess_config["normalize"]:
        sc.pp.normalize_total(adata)
    if preprocess_config["log1p"]:
        sc.pp.log1p(adata)
    if preprocess_config["top_genes"]:
        sc.pp.highly_variable_genes(adata, n_top_genes=preprocess_config["top_genes"])
        adata = adata[:, adata.var.highly_variable].copy()
    if preprocess_config["scale"]:
        sc.pp.scale(adata)
    return adata


def train(adata, batch_names, config):
    sc.tl.pca(adata, **config["model"]["pca"])

    ho = hm.run_harmony(
        adata.obsm["X_pca"],
        adata.obs,
        vars_use=batch_names,
        **config["model"]["harmony"]
    )
    adata.obsm["X_pca_harmony"] = ho.Z_corr.T
        
    sc.pp.neighbors(adata, use_rep="X_pca_harmony", **config["model"]["neighbors"])

    sc.tl.umap(adata, **config["model"]["umap"])

    save_path = os.path.join(config["output_dir"], f"{config['experiment']}.h5ad")
    adata.write_h5ad(save_path)
    
    return adata

In [4]:
adata, batch_names, n_cells, n_genes = load_data(config)

adata = preprocess(adata, config)

adata = train(adata, batch_names, config)

2025-04-21 19:59:39,461 - harmonypy - INFO - Computing initial centroids with sklearn.KMeans...
2025-04-21 20:03:08,051 - harmonypy - INFO - sklearn.KMeans initialization complete.
2025-04-21 20:03:08,353 - harmonypy - INFO - Iteration 1 of 15
2025-04-21 20:03:48,025 - harmonypy - INFO - Iteration 2 of 15
2025-04-21 20:04:26,833 - harmonypy - INFO - Iteration 3 of 15
2025-04-21 20:05:07,981 - harmonypy - INFO - Iteration 4 of 15
2025-04-21 20:05:47,993 - harmonypy - INFO - Iteration 5 of 15
2025-04-21 20:06:27,452 - harmonypy - INFO - Iteration 6 of 15
2025-04-21 20:07:07,684 - harmonypy - INFO - Iteration 7 of 15
2025-04-21 20:07:46,558 - harmonypy - INFO - Iteration 8 of 15
2025-04-21 20:08:13,673 - harmonypy - INFO - Iteration 9 of 15
2025-04-21 20:08:33,203 - harmonypy - INFO - Iteration 10 of 15
2025-04-21 20:08:54,101 - harmonypy - INFO - Iteration 11 of 15
2025-04-21 20:09:13,768 - harmonypy - INFO - Iteration 12 of 15
2025-04-21 20:09:34,332 - harmonypy - INFO - Iteration 13 of