In [4]:
"""
Inspect an .h5ad file to understand its structure for QC ingestion.

Usage:
  conda run -n venv python scripts/inspect_h5ad.py \
      --h5ad data/raw/weinreb/stateFate_inVitro/stateFate_inVitro_normed_counts.h5ad
"""

import argparse
import numpy as np
import scanpy as sc
import pandas as pd
from pandas.api.types import is_numeric_dtype, is_categorical_dtype


def inspect_h5ad(path: str) -> None:
    print(f"Loading {path}")
    ad = sc.read_h5ad(path, backed=None)

    print("\n=== AnnData overview ===")
    print(ad)
    print("shape (n_cells, n_genes):", ad.shape)
    print("X class:", type(ad.X))

    # ---------- OBS (cell metadata) ----------
    print("\n=== OBS (cell metadata) ===")
    print("obs columns:", list(ad.obs.columns))
    print("\nobs.head():")
    print(ad.obs.head())

    print("\n[obs summary by column]")
    for col in ad.obs.columns:
        s = ad.obs[col]
        nunique = s.nunique()
        print(f"\n---- {col} ----")
        print("dtype:", s.dtype)
        print("n_unique:", nunique)

        if nunique <= 20:
            # categorical-ish: show value counts
            print("value_counts():")
            print(s.value_counts().head(20))
        else:
            # high-cardinality: branch by dtype
            if is_numeric_dtype(s):
                print(
                    "min/mean/max:",
                    float(s.min()),
                    float(s.mean()),
                    float(s.max()),
                )
            elif is_categorical_dtype(s):
                cats = s.cat.categories
                print(f"categorical with {len(cats)} categories")
                print("categories (first 20):", list(cats[:20]))
            else:
                print("example values:", s.iloc[:10].tolist())

    # Highlight likely-important columns for QC / modeling if present
    interesting_obs = [
        "timepoint",
        "day",
        "treatment",
        "condition",
        "sample",
        "batch",
        "clone",
        "clone_id",
        "lineage",
        "cell_type",
        "state",
    ]
    print("\n=== Selected interesting obs columns (if present) ===")
    for col in interesting_obs:
        if col in ad.obs:
            print(f"\n---- {col} ----")
            s = ad.obs[col]
            print("dtype:", s.dtype)
            print("n_unique:", s.nunique())
            print(s.value_counts().head(20))

    # ---------- VAR (gene metadata) ----------
    print("\n=== VAR (gene metadata) ===")
    print("var columns:", list(ad.var.columns))
    print("\nvar.head():")
    print(ad.var.head())
    print("\nvar_names (first 10):")
    print(ad.var_names[:10].tolist())


In [5]:
inspect_h5ad("./data/prep/qc.h5ad")

Loading ./data/prep/qc.h5ad

=== AnnData overview ===
AnnData object with n_obs × n_vars = 130881 × 2000
    obs: 'Library', 'Cell barcode', 'Time point', 'Starting population', 'Cell type annotation', 'Well', 'SPRING-x', 'SPRING-y', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'n_genes', '_scvi_batch', '_scvi_labels'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'diffmap_evals', 'eggfm_meta', 'hvg', 'log1p', 'neighbors', 'pca'
    obsm: 'X_clone_membership', 'X_dcolpca', 'X_diff_dcol', 'X_diff_eggfm', 'X_diff_pca', 'X_diff_pca_double', 'X_diff_pca_x2', 'X_diffmap', 'X_eggfm', 'X_pca', 'X_phate', 'X_scvi'
  

  utils.warn_names_duplicates("obs")
  elif is_categorical_dtype(s):


In [None]:
import numpy as np

# X_geo: (n, d)
# eigvals: (n, d)
# eigvecs: (n, d, d)  # eigvecs[i] is a (d, d) matrix, columns = eigenvectors
# knn_indices: (n, k)  # neighbors for each i

def get_tangent_normal_bases(eigvecs: np.ndarray,
                             eigvals: np.ndarray,
                             tangent_dim: int):
    """
    Return per-point tangent and normal bases based on smallest / largest eigenvalues.
    """
    n, d, _ = eigvecs.shape
    # Sort eigenvalues and vectors from small -> large if not already
    sort_idx = np.argsort(eigvals, axis=1)  # (n, d)
    U_sorted = np.zeros_like(eigvecs)
    lam_sorted = np.zeros_like(eigvals)
    for i in range(n):
        U_sorted[i] = eigvecs[i, :, sort_idx[i]]   # (d, d)
        lam_sorted[i] = eigvals[i, sort_idx[i]]    # (d,)

    # Tangent = first r eigenvectors, Normal = remaining
    T = U_sorted[:, :, :tangent_dim]        # (n, d, r)
    N = U_sorted[:, :, tangent_dim:]        # (n, d, d-r)
    return T, N, lam_sorted, U_sorted


In [None]:
def compute_edge_tangent_normal_ratios(X_geo, knn_indices,
                                       T, N, eps=1e-8):
    """
    X_geo: (n, d)
    knn_indices: (n, k)
    T: (n, d, r)
    N: (n, d, d-r)
    Returns:
        t_norms: (n, k) tangent component norms
        n_norms: (n, k) normal component norms
        ratios: (n, k) = n_norm / (t_norm + eps)
    """
    n, d = X_geo.shape
    k = knn_indices.shape[1]
    r = T.shape[-1]

    t_norms = np.zeros((n, k))
    n_norms = np.zeros((n, k))

    for i in range(n):
        xi = X_geo[i]
        Ti = T[i]   # (d, r)
        Ni = N[i]   # (d, d-r)
        for l in range(k):
            j = knn_indices[i, l]
            v = X_geo[j] - xi  # (d,)

            # Tangent projection: T_i T_i^T v
            v_tan = Ti @ (Ti.T @ v)   # (d,)
            # Normal projection: N_i N_i^T v
            v_nor = Ni @ (Ni.T @ v)   # (d,)

            t_norms[i, l] = np.linalg.norm(v_tan)
            n_norms[i, l] = np.linalg.norm(v_nor)

    ratios = n_norms / (t_norms + eps)
    return t_norms, n_norms, ratios


In [None]:
def estimate_local_gradients_scalar(X_geo, values, knn_indices, ridge=1e-4):
    """
    Approximate gradients of a scalar function f(x) at each point using local
    linear regression on neighbors.

    X_geo: (n, d)
    values: (n,) scalar f(x_i)
    knn_indices: (n, k)
    Returns:
        grads: (n, d)
    """
    n, d = X_geo.shape
    k = knn_indices.shape[1]
    grads = np.zeros((n, d))

    for i in range(n):
        xi = X_geo[i]
        nbrs = knn_indices[i]           # (k,)
        X_diff = X_geo[nbrs] - xi       # (k, d)
        y_diff = values[nbrs] - values[i]  # (k,)

        # Solve min ||X_diff @ beta - y_diff||^2 + ridge*||beta||^2
        A = X_diff.T @ X_diff + ridge * np.eye(d)
        b = X_diff.T @ y_diff
        beta = np.linalg.solve(A, b)    # (d,)
        grads[i] = beta

    return grads


In [None]:
def decompose_gradients_into_tangent_normal(grads, T, N, eps=1e-8):
    """
    grads: (n, d)
    T: (n, d, r)
    N: (n, d, d-r)
    Returns:
        tan_norms: (n,)
        nor_norms: (n,)
        frac_normal: (n,) = ||g_normal|| / (||g|| + eps)
    """
    n, d = grads.shape
    tan_norms = np.zeros(n)
    nor_norms = np.zeros(n)
    frac_normal = np.zeros(n)

    for i in range(n):
        g = grads[i]         # (d,)
        Ti = T[i]            # (d, r)
        Ni = N[i]            # (d, d-r)

        g_tan = Ti @ (Ti.T @ g)
        g_nor = Ni @ (Ni.T @ g)

        tan_norms[i] = np.linalg.norm(g_tan)
        nor_norms[i] = np.linalg.norm(g_nor)
        frac_normal[i] = nor_norms[i] / (np.linalg.norm(g) + eps)

    return tan_norms, nor_norms, frac_normal


In [None]:
def randomize_eigenbases(eigvals, eigvecs, rng=None):
    """
    Shuffle eigenbases across points: for each i, pick a random index j
    and assign eigvecs[j], but keep eigvals[i].
    """
    if rng is None:
        rng = np.random.default_rng()

    n, d, _ = eigvecs.shape
    rand_idx = rng.integers(low=0, high=n, size=n)

    eigvecs_rand = np.zeros_like(eigvecs)
    for i in range(n):
        eigvecs_rand[i] = eigvecs[rand_idx[i]]

    # eigenvalues remain the same
    return eigvals.copy(), eigvecs_rand


In [None]:
def reconstruct_hessians_from_eigs(eigvals, eigvecs):
    """
    eigvals: (n, d)
    eigvecs: (n, d, d)
    Returns:
        H: (n, d, d)
    """
    n, d = eigvals.shape
    H = np.zeros((n, d, d))
    for i in range(n):
        U = eigvecs[i]                    # (d, d)
        L = np.diag(eigvals[i])           # (d, d)
        H[i] = U @ L @ U.T
    return H


In [None]:
def flatten_eigenvalues_to_scalar(eigvals):
    """
    Replace per-point eigenvalues with their mean: lambda_i -> bar_lambda_i.
    """
    bar_lambda = eigvals.mean(axis=1, keepdims=True)  # (n, 1)
    eigvals_flat = np.repeat(bar_lambda, eigvals.shape[1], axis=1)
    return eigvals_flat
