# Notebook for MuData Creation from GE and ATSE AnnData

This notebook:
1. Reads and inspects ATSE and gene expression AnnData files.
2. Fixes NaNs in the splicing data.
3. Creates modality-specific `.obs`, `.var`, and `.layers` for each AnnData.
4. Creates a MuData object with modalities “rna”, “junc_counts”, “cell_by_junction_matrix”, 
    and “cell_by_cluster_matrix”.
5. Writes out the final MuData object for use with MULTIVISPLICE.

## 0. Set Paths and Configuration

In [15]:
ATSE_DATA_PATH = "/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/ALL_CELLS/022025/TMS_Anndata_ATSE_counts_with_waypoints_20250209_165655.h5ad"
GE_DATA_PATH = "/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/ALL_CELLS/022025/TMS_Anndata_GeneExpression_20250209_165655.h5ad"
OUTPUT_MUDATA_PATH = "/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/ALL_CELLS/022025/TMS_MUData_GE_ATSE_20250209_165655.h5mu"
REDO_JUNC_RATIO = True

print("ATSE data path:", ATSE_DATA_PATH)
print("GE data path:  ", GE_DATA_PATH)
print("Output MuData path:", OUTPUT_MUDATA_PATH)

ATSE data path: /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/ALL_CELLS/022025/TMS_Anndata_ATSE_counts_with_waypoints_20250209_165655.h5ad
GE data path:   /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/ALL_CELLS/022025/TMS_Anndata_GeneExpression_20250209_165655.h5ad
Output MuData path: /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/ALL_CELLS/022025/TMS_MUData_GE_ATSE_20250209_165655.h5mu


## 1. Imports

In [2]:
import anndata as ad
import mudata as mu
import pandas as pd
import scipy.sparse as sp
import numpy as np
from scipy.sparse import csr_matrix, hstack, vstack

# (Also your other imports, if needed)
import scvi

import jax
import jaxlib
print("jax version:", jax.__version__)
print("jaxlib version:", jaxlib.__version__)

import h5py
import anndata as ad

jax version: 0.4.35
jaxlib version: 0.4.35


## 2. Load ATSE and Gene Expression AnnData

In [3]:
atse_anndata = ad.read_h5ad(ATSE_DATA_PATH)
print("ATSE AnnData:", atse_anndata)

ATSE AnnData: AnnData object with n_obs × n_vars = 106199 × 72460
    obs: 'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id', 'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped'
    var: 'junction_id', 'event_id', 'splice_motif', 'label_5_prime', 'label_3_prime', 'annotation_status', 'gene_name', 'gene_id', 'num_junctions', 'position_off_5_prime', 'position_off_3_prime', 'CountJuncs', 'non_zero_count_cells', 'non_zero_cell_prop', 'annotation_status_score', 'non_zero_cell_prop_score', 'splice_motif_score', 'junction_id_index'
    uns: 'age_colors', 'neighbors', 'pca_explained_variance_ratio', 'tissue_colors', 'umap'
    obsm: 'X_pca', 'X_umap', 'phi_init_100_waypoints', 'phi_init_30_waypoints'
    varm: 'psi_init_100_waypoints', 'psi_init_30_waypoints'
    layers: 'cell_by_cluster_matrix', 'cell_by_junction_matrix'
    obsp: 'connectivities', 'distances'


In [4]:
ge_anndata = ad.read_h5ad(GE_DATA_PATH)
print("GE AnnData:", ge_anndata)

GE AnnData: AnnData object with n_obs × n_vars = 106199 × 7918
    obs: 'FACS.selection', 'age', 'cell', 'cell_ontology_class', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex', 'subtissue', 'tissue', 'n_genes', 'n_counts', 'cell_clean', 'cell_id', 'batch', 'old_cell_id_index', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped', 'leiden'
    var: 'n_cells', 'mouse_gene_name'
    uns: 'age_colors', 'cell_type_grouped_colors', 'leiden', 'log1p', 'neighbors', 'umap'
    obsm: 'X_pca', 'X_umap'
    layers: 'raw_counts'
    obsp: 'connectivities', 'distances'


## 3. Create `.var` DataFrames for Each Modality

Here we create modality-specific `.var` metadata. You might later use these to update the
corresponding AnnData objects inside the MuData container.

In [5]:
gene_expr_var = pd.DataFrame(
    {
        "ID": ge_anndata.var["mouse_gene_name"],  # from the GE AnnData
        "modality": "Gene_Expression",
    },
    index=ge_anndata.var.index
)

splicing_var = pd.DataFrame(
    {
        "ID": atse_anndata.var["junction_id"],  # from the ATSE AnnData
        "modality": "Splicing",
    },
    index=atse_anndata.var.index
)

ge_anndata.var = gene_expr_var.copy()
atse_anndata.var = splicing_var.copy()

## 4. Create a Common `.obs` DataFrame

You can decide which AnnData’s `.obs` to use (or merge them) if both contain the same information.
Here we assume ATSE and GE have matching `obs` indices; we take the ATSE `obs`.

In [6]:
common_obs = atse_anndata.obs.copy()
common_obs["modality"] = "paired"  # if needed; adjust as required
print("Common obs shape:", common_obs.shape)

# Update both AnnData objects:
ge_anndata.obs = common_obs.copy()
atse_anndata.obs = common_obs.copy()

Common obs shape: (106199, 14)


## 5. Compute or Fix Splicing `junc_ratio` Layer

Here we check if `junc_ratio` needs to be recomputed. It is computed as:
`junc_ratio = cell_by_junction_matrix / cell_by_cluster_matrix`
and any NaNs/Inf values are replaced by zeros.


In [7]:
from scipy.sparse import csr_matrix, issparse
import numpy as np

if "junc_ratio" not in atse_anndata.layers or REDO_JUNC_RATIO:
    cell_by_junc = atse_anndata.layers["cell_by_junction_matrix"]
    cell_by_cluster = atse_anndata.layers["cell_by_cluster_matrix"]

    # Convert to dense arrays only when necessary
    if issparse(cell_by_junc):
        cell_by_junc = cell_by_junc.tocoo()
        junc_data = cell_by_junc.data.astype(float)
        row, col = cell_by_junc.row, cell_by_junc.col
    else:
        junc_data = cell_by_junc.astype(float)
        row, col = np.nonzero(junc_data)

    if issparse(cell_by_cluster):
        cluster_vals = cell_by_cluster[row, col].A1  # extract values at same positions
    else:
        cluster_vals = cell_by_cluster[row, col].astype(float)

    # Avoid division by zero using np.divide
    ratio_data = np.divide(
        junc_data, 
        cluster_vals, 
        out=np.zeros_like(junc_data), 
        where=cluster_vals != 0
    )

    # Replace NaN or inf (just in case)
    ratio_data = np.nan_to_num(ratio_data, nan=0.0, posinf=0.0, neginf=0.0)

    # Create sparse matrix from the result
    shape = cell_by_junc.shape if issparse(cell_by_junc) else cell_by_cluster.shape
    ratio_matrix = csr_matrix((ratio_data, (row, col)), shape=shape)

    atse_anndata.layers["junc_ratio"] = ratio_matrix

# Final NaN/inf scrub — just in case it's already there and wasn't overwritten
splicing_ratio = atse_anndata.layers["junc_ratio"]
if issparse(splicing_ratio):
    splicing_ratio = splicing_ratio.copy()
    splicing_ratio.data = np.nan_to_num(splicing_ratio.data, nan=0.0, posinf=0.0, neginf=0.0)
    atse_anndata.layers["junc_ratio"] = splicing_ratio
else:
    splicing_ratio = np.nan_to_num(splicing_ratio.astype(float), nan=0.0, posinf=0.0, neginf=0.0)
    atse_anndata.layers["junc_ratio"] = splicing_ratio


In [11]:
print(atse_anndata.layers['junc_ratio'])

<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 783307782 stored elements and shape (106199, 72460)>
  Coords	Values
  (0, 0)	0.0
  (0, 1)	0.25
  (0, 2)	0.0
  (0, 3)	0.0
  (0, 4)	0.0
  (0, 5)	0.0
  (0, 6)	0.0
  (0, 7)	0.0
  (0, 8)	0.0
  (0, 9)	0.0
  (0, 10)	0.25
  (0, 11)	0.0
  (0, 12)	0.0
  (0, 13)	0.0
  (0, 14)	0.25
  (0, 15)	0.0
  (0, 16)	0.25
  (0, 17)	0.0
  (0, 18)	0.0
  (0, 19)	1.0
  (0, 27)	0.0
  (0, 28)	1.0
  (0, 29)	0.0
  (0, 30)	0.0
  (0, 31)	0.0
  :	:
  (106198, 72230)	0.0
  (106198, 72231)	0.0
  (106198, 72232)	0.0
  (106198, 72233)	0.0
  (106198, 72234)	1.0
  (106198, 72244)	0.0
  (106198, 72245)	0.0
  (106198, 72246)	0.0
  (106198, 72247)	1.0
  (106198, 72325)	0.3826086956521739
  (106198, 72326)	0.0
  (106198, 72327)	0.4260869565217391
  (106198, 72328)	0.0
  (106198, 72329)	0.19130434782608696
  (106198, 72330)	0.6636363636363637
  (106198, 72331)	0.0
  (106198, 72332)	0.33636363636363636
  (106198, 72333)	0.3595505617977528
  (106198, 72334)	0.0
  (106198

## 6. Create a MuData Object

Instead of stacking into one AnnData, we create a MuData container.

For MULTIVISPLICE, the new setup expects modalities with the following keys:
- `rna` : gene expression counts,
- `junc_ratio` : raw splicing/junction count data,
- `cell_by_junction_matrix` and `cell_by_cluster_matrix` as additional layers.

We can use the GE AnnData for gene expression and the ATSE AnnData for all splicing-related data.
(If needed, make copies so that modalities are independent.)


Option 1: Use the GE AnnData for RNA and the ATSE AnnData for splicing modalities.
(You can also combine or pre-process further if desired.)

In [13]:
mdata = mu.MuData({
    "rna": ge_anndata,
    "splicing": atse_anndata
})


# List of shared obs fields to pull up
shared_obs_keys = [
    'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id',
    'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'cell_type_grouped', 'modality'
]

# We'll assume 'rna' modality has them all and they match 'splicing'
for key in shared_obs_keys:
    assert key in mdata["rna"].obs, f"{key} not found in 'rna' obs"
    assert key in mdata["splicing"].obs, f"{key} not found in 'splicing' obs"
    assert (mdata["rna"].obs[key] == mdata["splicing"].obs[key]).all(), f"{key} values differ between modalities"
    mdata.obs[key] = mdata["rna"].obs[key]
    
print("MuData object created with modalities:", list(mdata.mod.keys()))

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


MuData object created with modalities: ['rna', 'splicing']


## 7. Write Out the Final MuData Object

The combined MuData object is now ready for use with `MULTIVISPLICE`. Save it as an H5MU file.

In [16]:
mdata.write(OUTPUT_MUDATA_PATH)
print(f"MuData object written to {OUTPUT_MUDATA_PATH}")

MuData object written to /gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/ALL_CELLS/022025/TMS_MUData_GE_ATSE_20250209_165655.h5mu


  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


## 8. Verify the Output

Read the MuData object back in to ensure everything is correct.

In [17]:
mdata_loaded = mu.read_h5mu(OUTPUT_MUDATA_PATH)
print("Loaded MuData modalities:", list(mdata_loaded.mod.keys()))
print(mdata_loaded)

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


Loaded MuData modalities: ['rna', 'splicing']
MuData object with n_obs × n_vars = 106199 × 80378
  obs:	'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id', 'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'cell_type_grouped', 'modality'
  var:	'ID', 'modality'
  2 modalities
    rna:	106199 x 7918
      obs:	'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id', 'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped', 'modality'
      var:	'ID', 'modality'
      uns:	'age_colors', 'cell_type_grouped_colors', 'leiden', 'log1p', 'neighbors', 'umap'
      obsm:	'X_pca', 'X_umap'
      layers:	'raw_counts'
      obsp:	'connectivities', 'distances'
    splicing:	106199 x 72460
      obs:	'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id', 'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped', 'modality'
      v