# Notebook for MuData Creation from GE and ATSE AnnData

This notebook:
1. Reads and inspects ATSE and gene expression AnnData files.
2. Fixes NaNs in the splicing data.
3. Creates modality-specific `.obs`, `.var`, and `.layers` for each AnnData.
4. Creates a MuData object with modalities “rna”, “junc_counts”, “cell_by_junction_matrix”, 
    and “cell_by_cluster_matrix”.
5. Writes out the final MuData object for use with MULTIVISPLICE.

## 0. Set Paths and Configuration

In [1]:
ROOT_PATH = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/"

ATSE_DATA_PATH = ROOT_PATH + "mouse_foundation_data_20250502_155802_splice.h5ad"
GE_DATA_PATH = ROOT_PATH + "mouse_foundation_data_20250502_155802_ge.h5ad"
OUTPUT_MUDATA_PATH = ROOT_PATH + "mouse_foundation_data_20250502_155802_ge_splice_combined.h5mu"
REDO_JUNC_RATIO = False

print("ATSE data path:", ATSE_DATA_PATH)
print("GE data path:  ", GE_DATA_PATH)
print("Output MuData path:", OUTPUT_MUDATA_PATH)

ATSE data path: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/mouse_foundation_data_20250502_155802_splice.h5ad
GE data path:   /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/mouse_foundation_data_20250502_155802_ge.h5ad
Output MuData path: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/mouse_foundation_data_20250502_155802_ge_splice_combined.h5mu


## 1. Imports

In [4]:
import anndata as ad
import mudata as mu
import pandas as pd
import scipy.sparse as sp
import numpy as np
from scipy.sparse import csr_matrix, hstack, vstack

# (Also your other imports, if needed)
import scvi

import jax
import jaxlib
print("jax version:", jax.__version__)
print("jaxlib version:", jaxlib.__version__)

import h5py
import anndata as ad

jax version: 0.4.35
jaxlib version: 0.4.35


## 2. Load ATSE and Gene Expression AnnData

In [3]:
atse_anndata = ad.read_h5ad(ATSE_DATA_PATH)
print("ATSE AnnData:", atse_anndata)

ATSE AnnData: AnnData object with n_obs × n_vars = 157418 × 34845
    obs: 'cell_id_index', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'subtissue', 'tissue', 'dataset', 'cell_name', 'cell_id', 'broad_cell_type', 'seqtech', 'cell_clean'
    var: 'junction_id', 'event_id', 'splice_motif', 'annotation_status', 'gene_name', 'gene_id', 'num_junctions', 'position_off_5_prime', 'position_off_3_prime', 'CountJuncs', 'junction_id_index'
    layers: 'cell_by_cluster_matrix', 'cell_by_junction_matrix', 'junc_ratio'




In [4]:
ge_anndata = ad.read_h5ad(GE_DATA_PATH)
print("GE AnnData:", ge_anndata)

GE AnnData: AnnData object with n_obs × n_vars = 157418 × 19022
    obs: 'cell_id', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'tissue', 'dataset', 'batch', 'subtissue_clean', 'broad_cell_type', 'cell_id_index', 'cell_name', 'library_size'
    var: 'gene_symbol', 'gene_name', 'gene_id', 'mean_transcript_length', 'mean_intron_length', 'num_transcripts', 'transcript_biotypes'
    obsm: 'X_library_size'
    layers: 'length_norm', 'log_norm', 'predicted_log_norm_tms', 'raw_counts'




## 3. Create `.var` DataFrames for Each Modality

Here we create modality-specific `.var` metadata. You might later use these to update the
corresponding AnnData objects inside the MuData container.

In [5]:
gene_expr_var = pd.DataFrame(
    {
        "ID": ge_anndata.var["gene_id"],  # from the GE AnnData
        "modality": "Gene_Expression",
    },
    index=ge_anndata.var.index
)

splicing_var = pd.DataFrame(
    {
        "ID": atse_anndata.var["junction_id"],  # from the ATSE AnnData
        "modality": "Splicing",
    },
    index=atse_anndata.var.index
)

ge_anndata.var = gene_expr_var.copy()
atse_anndata.var = splicing_var.copy()

## 4. Create a Common `.obs` DataFrame

You can decide which AnnData’s `.obs` to use (or merge them) if both contain the same information.
Here we assume ATSE and GE have matching `obs` indices; we take the ATSE `obs`.

In [6]:
common_obs = atse_anndata.obs.copy()
common_obs["modality"] = "paired"  # if needed; adjust as required
print("Common obs shape:", common_obs.shape)

# Update both AnnData objects:
ge_anndata.obs = common_obs.copy()
atse_anndata.obs = common_obs.copy()

Common obs shape: (157418, 14)


## 5. Compute or Fix Splicing `junc_ratio` Layer

Here we check if `junc_ratio` needs to be recomputed. It is computed as:
`junc_ratio = cell_by_junction_matrix / cell_by_cluster_matrix`
and any NaNs/Inf values are replaced by zeros.


In [7]:
from scipy.sparse import csr_matrix, issparse
import numpy as np

if "junc_ratio" not in atse_anndata.layers or REDO_JUNC_RATIO:
    cell_by_junc = atse_anndata.layers["cell_by_junction_matrix"]
    cell_by_cluster = atse_anndata.layers["cell_by_cluster_matrix"]

    # Convert to dense arrays only when necessary
    if issparse(cell_by_junc):
        cell_by_junc = cell_by_junc.tocoo()
        junc_data = cell_by_junc.data.astype(float)
        row, col = cell_by_junc.row, cell_by_junc.col
    else:
        junc_data = cell_by_junc.astype(float)
        row, col = np.nonzero(junc_data)

    if issparse(cell_by_cluster):
        cluster_vals = cell_by_cluster[row, col].A1  # extract values at same positions
    else:
        cluster_vals = cell_by_cluster[row, col].astype(float)

    # Avoid division by zero using np.divide
    ratio_data = np.divide(
        junc_data, 
        cluster_vals, 
        out=np.zeros_like(junc_data), 
        where=cluster_vals != 0
    )

    # Replace NaN or inf (just in case)
    ratio_data = np.nan_to_num(ratio_data, nan=0.0, posinf=0.0, neginf=0.0)

    # Create sparse matrix from the result
    shape = cell_by_junc.shape if issparse(cell_by_junc) else cell_by_cluster.shape
    ratio_matrix = csr_matrix((ratio_data, (row, col)), shape=shape)

    atse_anndata.layers["junc_ratio"] = ratio_matrix

# Final NaN/inf scrub — just in case it's already there and wasn't overwritten
splicing_ratio = atse_anndata.layers["junc_ratio"]
if issparse(splicing_ratio):
    splicing_ratio = splicing_ratio.copy()
    splicing_ratio.data = np.nan_to_num(splicing_ratio.data, nan=0.0, posinf=0.0, neginf=0.0)
    atse_anndata.layers["junc_ratio"] = splicing_ratio
else:
    splicing_ratio = np.nan_to_num(splicing_ratio.astype(float), nan=0.0, posinf=0.0, neginf=0.0)
    atse_anndata.layers["junc_ratio"] = splicing_ratio


In [7]:
print(atse_anndata.layers['junc_ratio'])

<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 1036427639 stored elements and shape (157418, 34845)>
  Coords	Values
  (0, 197)	0.10923481588268069
  (0, 198)	-0.10923481588265521
  (0, 348)	-0.14492827599315417
  (0, 349)	0.2712504288962023
  (0, 350)	-0.12632215290306226
  (0, 358)	0.21830223720291528
  (0, 359)	-0.21830223720293218
  (0, 389)	-0.030185492868850552
  (0, 390)	0.030185492868867292
  (0, 471)	0.22918618381758282
  (0, 472)	-0.22918618381752612
  (0, 505)	-0.017186076899810214
  (0, 506)	0.017186076899781466
  (0, 563)	-0.40972488985415306
  (0, 564)	-0.29206162002713676
  (0, 565)	0.831267266379959
  (0, 566)	-0.07320007073989347
  (0, 567)	-0.05628068575870901
  (0, 576)	-0.03308391273053365
  (0, 577)	0.03308391273052658
  (0, 823)	-0.17652836794577326
  (0, 824)	0.41063247019988236
  (0, 825)	-0.23410410225412417
  (0, 826)	0.06095086199442612
  (0, 827)	-0.060950861994476004
  :	:
  (157417, 34754)	-0.03141646214777989
  (157417, 34755)	0.00717088569

## 6. Create a MuData Object

Instead of stacking into one AnnData, we create a MuData container.

For MULTIVISPLICE, the new setup expects modalities with the following keys:
- `rna` : gene expression counts,
- `junc_ratio` : raw splicing/junction count data,
- `cell_by_junction_matrix` and `cell_by_cluster_matrix` as additional layers.

We can use the GE AnnData for gene expression and the ATSE AnnData for all splicing-related data.
(If needed, make copies so that modalities are independent.)


Option 1: Use the GE AnnData for RNA and the ATSE AnnData for splicing modalities.
(You can also combine or pre-process further if desired.)

In [8]:
mdata = mu.MuData({
    "rna": ge_anndata,
    "splicing": atse_anndata
})

# assert "library_size" in ge_anndata.obs, "'library_size' not found in ge_anndata.obs"
mdata.obsm["X_library_size"] = ge_anndata.obsm["X_library_size"]

# # Confirm it's stored correctly
# print("Library size moved to mdata.obsm['library_size'] with shape:", mdata.obsm["library_size"].shape)


# List of shared obs fields to pull up
shared_obs_keys = [
    'cell_id', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'tissue', 'dataset', 'broad_cell_type', 'cell_id_index', 'cell_name', 'modality'
]

# We'll assume 'rna' modality has them all and they match 'splicing'
for key in shared_obs_keys:
    assert key in mdata["rna"].obs, f"{key} not found in 'rna' obs"
    assert key in mdata["splicing"].obs, f"{key} not found in 'splicing' obs"
    assert (mdata["rna"].obs[key] == mdata["splicing"].obs[key]).all(), f"{key} values differ between modalities"
    mdata.obs[key] = mdata["rna"].obs[key]
    
print("MuData object created with modalities:", list(mdata.mod.keys()))

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


MuData object created with modalities: ['rna', 'splicing']


## 7. Write Out the Final MuData Object

The combined MuData object is now ready for use with `MULTIVISPLICE`. Save it as an H5MU file.

In [9]:
mdata.write(OUTPUT_MUDATA_PATH)
print(f"MuData object written to {OUTPUT_MUDATA_PATH}")

MuData object written to /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/mouse_foundation_data_20250502_155802_ge_splice_combined.h5mu


  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


## 8. Verify the Output

Read the MuData object back in to ensure everything is correct.

In [5]:
mdata_loaded = mu.read_h5mu(OUTPUT_MUDATA_PATH)
print("Loaded MuData modalities:", list(mdata_loaded.mod.keys()))
print(mdata_loaded)

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


Loaded MuData modalities: ['rna', 'splicing']
MuData object with n_obs × n_vars = 157418 × 53867
  obs:	'cell_id', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'tissue', 'dataset', 'broad_cell_type', 'cell_id_index', 'cell_name', 'modality'
  obsm:	'X_library_size'
  2 modalities
    rna:	157418 x 19022
      obs:	'cell_id_index', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'subtissue', 'tissue', 'dataset', 'cell_name', 'cell_id', 'broad_cell_type', 'seqtech', 'cell_clean', 'modality'
      var:	'ID', 'modality'
      obsm:	'X_library_size'
      layers:	'length_norm', 'log_norm', 'predicted_log_norm_tms', 'raw_counts'
    splicing:	157418 x 34845
      obs:	'cell_id_index', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'subtissue', 'tissue', 'dataset', 'cell_name', 'cell_id', 'broad_cell_type', 'seqtech', 'cell_clean', 'modality'
      var:	'ID', 'modality'
      layers:	'cell_by_cluster_matrix', 'cell_by_junction_matrix', 'junc_ratio'


In [7]:
print(mdata_loaded.obsm['X_library_size'])

[[ 53.12896806]
 [224.7704842 ]
 [179.97726018]
 ...
 [484.43131097]
 [747.11240078]
 [238.18621119]]
