
# Notebook for AnnData Exploration, Processing, and Stacking

This notebook:
1. Imports required libraries.
2. Reads and inspects ATSE and gene expression AnnData files.
3. Fixes NaNs in the splicing data.
4. Merges AnnData objects (ATSE + GE) into one combined AnnData.
5. Creates or updates relevant `.obs`, `.var`, and `.layers`.
6. Writes out the final combined AnnData object.


## 0. Set Paths and Configuration Here

Change these paths to your own file locations or directories

In [3]:
ATSE_DATA_PATH = "/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_Anndata_ATSE_counts_with_waypoints_20250211_171237.h5ad"
GE_DATA_PATH = "/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_Anndata_GeneExpression_20250211_171237.h5ad"
OUTPUT_COMBINED_PATH = "/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_BRAINONLY_Combined_GE_ATSE.h5ad"
REDO_JUNC_RATIO = True

print("ATSE data path:", ATSE_DATA_PATH)
print("GE data path:  ", GE_DATA_PATH)
print("Output path:   ", OUTPUT_COMBINED_PATH)

ATSE data path: /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_Anndata_ATSE_counts_with_waypoints_20250211_171237.h5ad
GE data path:   /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_Anndata_GeneExpression_20250211_171237.h5ad
Output path:    /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_BRAINONLY_Combined_GE_ATSE.h5ad


## 1. Imports

In [21]:
import jax
import jaxlib
print("jax version:", jax.__version__)
print("jaxlib version:", jaxlib.__version__)

import scvi
import h5py
import anndata as ad
import pandas as pd
import scipy.sparse as sp
import numpy as np
from scipy.sparse import csr_matrix, hstack, vstack

jax version: 0.5.0
jaxlib version: 0.5.0


## 2. Load ATSE and Gene Expression AnnData

In [5]:
atse_anndata = ad.read_h5ad(ATSE_DATA_PATH)
print("ATSE AnnData:", atse_anndata)

ATSE AnnData: AnnData object with n_obs × n_vars = 19942 × 76811
    obs: 'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id', 'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped'
    var: 'junction_id', 'event_id', 'splice_motif', 'label_5_prime', 'label_3_prime', 'annotation_status', 'gene_name', 'gene_id', 'num_junctions', 'position_off_5_prime', 'position_off_3_prime', 'CountJuncs', 'non_zero_count_cells', 'non_zero_cell_prop', 'annotation_status_score', 'non_zero_cell_prop_score', 'splice_motif_score', 'junction_id_index'
    uns: 'age_colors', 'neighbors', 'pca_explained_variance_ratio', 'tissue_colors', 'umap'
    obsm: 'X_pca', 'X_umap', 'phi_init_100_waypoints', 'phi_init_30_waypoints'
    varm: 'psi_init_100_waypoints', 'psi_init_30_waypoints'
    layers: 'cell_by_cluster_matrix', 'cell_by_junction_matrix'
    obsp: 'connectivities', 'distances'


In [6]:
ge_anndata = ad.read_h5ad(GE_DATA_PATH)
print("GE AnnData:", ge_anndata)

GE AnnData: AnnData object with n_obs × n_vars = 19942 × 7150
    obs: 'FACS.selection', 'age', 'cell', 'cell_ontology_class', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex', 'subtissue', 'tissue', 'n_genes', 'n_counts', 'cell_clean', 'cell_id', 'batch', 'old_cell_id_index', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped', 'leiden'
    var: 'n_cells', 'mouse_gene_name'
    uns: 'age_colors', 'cell_type_grouped_colors', 'leiden', 'log1p', 'neighbors', 'umap'
    obsm: 'X_pca', 'X_umap'
    layers: 'raw_counts'
    obsp: 'connectivities', 'distances'


## 3. Create `.var` DataFrames for Each Modality

We align the `.var` structure (e.g., gene IDs, junction IDs, and "modality").

In [7]:
gene_expr_var = pd.DataFrame(
    {
        "ID": ge_anndata.var["mouse_gene_name"],  # from the GE AnnData
        "modality": "Gene_Expression",
    },
    index=ge_anndata.var.index
)

splicing_var = pd.DataFrame(
    {
        "ID": atse_anndata.var["junction_id"],  # from the ATSE AnnData
        "modality": "Splicing",
    },
    index=atse_anndata.var.index
)

combined_var = pd.concat([gene_expr_var, splicing_var])
print("combined_var shape:", combined_var.shape)

combined_var shape: (83961, 2)


## 4. Create `.obs` DataFrame

We assume we're pairing by row index. The `.obs` of the combined AnnData is taken from ATSE for example. Adjust if needed.

In [8]:
combined_obs = pd.DataFrame(
    {
        "batch_id": atse_anndata.obs["batch"],
        "age": atse_anndata.obs["age"],
        "cell_ontology_class": atse_anndata.obs["cell_ontology_class"],
        "cell_type_grouped": atse_anndata.obs["cell_type_grouped"],
        "mouse.id": atse_anndata.obs["mouse.id"],
        "sex": atse_anndata.obs["sex"],
        "modality": "paired",
    },
    index=atse_anndata.obs.index
)
print("combined_obs shape:", combined_obs.shape)

combined_obs shape: (19942, 7)


## 5. Create and/or Fix Splicing "junc_ratio" Layer

If `"junc_ratio"` does not exist or we set `REDO_JUNC_RATIO` to True, we compute it as `junc_counts / atse_counts` (potentially yielding NaNs or infinities).  
Then we replace all NaN and infinite values with `0.0`.  
If the layer is sparse, we convert it to dense, fix values, then convert back to sparse.

In [10]:
# 1) Check if "junc_ratio" is missing
if "junc_ratio" not in atse_anndata.layers or REDO_JUNC_RATIO:
    # a) Retrieve junc_counts and atse_counts
    junc_counts = atse_anndata.layers["cell_by_junction_matrix"]
    atse_counts = atse_anndata.layers["cell_by_cluster_matrix"]

    # b) Convert both to dense if either is sparse
    if sp.issparse(junc_counts):
        junc_counts = junc_counts.toarray()
    if sp.issparse(atse_counts):
        atse_counts = atse_counts.toarray()

    # c) Ensure float dtype, then compute ratio
    junc_counts = junc_counts.astype(float)
    atse_counts = atse_counts.astype(float)
    ratio = junc_counts / atse_counts  # may create Inf/NaN if atse_counts == 0

    # d) Replace NaNs, +Inf, -Inf with 0
    np.nan_to_num(ratio, copy=False, nan=0.0, posinf=0.0, neginf=0.0)

    # e) Convert back to sparse if desired
    ratio_sparse = csr_matrix(ratio)

    # f) Store the ratio as a new layer
    atse_anndata.layers["junc_ratio"] = ratio_sparse

# 2) Now fix any NaNs or infs in "junc_ratio" (in case it existed already)
splicing_counts = atse_anndata.layers["junc_ratio"]
if sp.issparse(splicing_counts):
    dense_counts = splicing_counts.astype(float).toarray()
    np.nan_to_num(dense_counts, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
    splicing_counts = csr_matrix(dense_counts)
else:
    splicing_counts = splicing_counts.astype(float)
    np.nan_to_num(splicing_counts, copy=False, nan=0.0, posinf=0.0, neginf=0.0)

# 3) Store the cleaned-up ratio back in layers
atse_anndata.layers["junc_ratio"] = splicing_counts

  ratio = junc_counts / atse_counts  # may create Inf/NaN if atse_counts == 0


## 6. Combine Count Data (Gene Expression + Splicing)

- We assume the gene expression is in `.layers["raw_counts"]` in `ge_anndata`.
- We assume the splicing data is in `.layers["junc_ratio"]` in `atse_anndata`.
- We concatenate (hstack) these along columns.

In [12]:
gene_expr_counts = ge_anndata.layers["raw_counts"]
splicing_counts = atse_anndata.layers["junc_ratio"]

combined_counts = sp.hstack([gene_expr_counts, splicing_counts])
print("Combined counts shape:", combined_counts.shape)

Combined counts shape: (19942, 83961)


## 7. Create Combined AnnData

We place the merged count matrix into `.X`, the combined `.obs`, and the combined `.var`.


In [13]:
combined_adata = ad.AnnData(
    X=combined_counts,
    obs=combined_obs,
    var=combined_var
)
print("combined_adata shape:", combined_adata.shape)


combined_adata shape: (19942, 83961)


## 8. Add "Padded" Layers for cell-by-junction/cluster Matrices

For specialized layers aligned to the shape of the combined matrix (`atse_counts` and `junc_counts`)

In [14]:
padding_amount = gene_expr_counts.shape[1]  # number of gene expression columns
print("Padding amount:", padding_amount)

cell_by_junction_matrix = atse_anndata.layers["cell_by_junction_matrix"]
cell_by_cluster_matrix = atse_anndata.layers["cell_by_cluster_matrix"]

num_rows_junction, num_cols_junction = cell_by_junction_matrix.shape
num_rows_cluster, num_cols_cluster = cell_by_cluster_matrix.shape

# Create padding of shape (num_rows, padding_cols)
padding_junction = csr_matrix((num_rows_junction, padding_amount))
padding_cluster = csr_matrix((num_rows_cluster, padding_amount))

# Horizontally stack the padding with the original matrices
cell_by_junction_matrix_padded = hstack([padding_junction, cell_by_junction_matrix], format="csr")
cell_by_cluster_matrix_padded = hstack([padding_cluster, cell_by_cluster_matrix], format="csr")

# Store in combined_adata
combined_adata.layers["cell_by_junction_matrix"] = cell_by_junction_matrix_padded
combined_adata.layers["cell_by_cluster_matrix"] = cell_by_cluster_matrix_padded

Padding amount: 7150


## 9. Write Out the Final Combined AnnData

This final file should be ready for use with `MULTIVISPLICE`.

In [15]:
combined_adata.write(OUTPUT_COMBINED_PATH)
print(f"Final combined AnnData written to: {OUTPUT_COMBINED_PATH}")

Final combined AnnData written to: /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_BRAINONLY_Combined_GE_ATSE.h5ad


Double checking that writing it out worked:

In [18]:
combined_adata = ad.read_h5ad(OUTPUT_COMBINED_PATH)
print(f"Combined AnnData Read from {OUTPUT_COMBINED_PATH}")
print(combined_adata)
print(combined_adata.X)

Combined AnnData Read from /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/BRAIN_ONLY/02112025/TMS_BRAINONLY_Combined_GE_ATSE.h5ad
AnnData object with n_obs × n_vars = 19942 × 83961
    obs: 'batch_id', 'age', 'cell_ontology_class', 'cell_type_grouped', 'mouse.id', 'sex', 'modality'
    var: 'ID', 'modality'
    layers: 'cell_by_cluster_matrix', 'cell_by_junction_matrix'
<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 75955761 stored elements and shape (19942, 83961)>
  Coords	Values
  (0, 5854)	60.0
  (0, 575)	53.0
  (0, 3664)	1.0
  (0, 2222)	26.0
  (0, 6781)	119.0
  (0, 6718)	1.0
  (0, 305)	32.0
  (0, 1865)	54.0
  (0, 6978)	76.0
  (0, 4243)	36.0
  (0, 2359)	27.0
  (0, 3726)	50.0
  (0, 16)	143.0
  (0, 7052)	87.0
  (0, 2708)	1.0
  (0, 1135)	33.0
  (0, 1795)	111.0
  (0, 3277)	1.0
  (0, 5421)	19.0
  (0, 737)	12.0
  (0, 3047)	167.0
  (0, 531)	48.0
  (0, 1485)	123.0
  (0, 3604)	162.0
  (0, 5283)	118.0
  :	:
  (19941, 1157)	4.0
  (19941, 4582)	2.0
  (19941, 6806