# Notebook for MuData Creation from GE and ATSE AnnData

This notebook:
1. Reads and inspects ATSE and gene expression AnnData files.
2. Fixes NaNs in the splicing data.
3. Creates modality-specific `.obs`, `.var`, and `.layers` for each AnnData.
4. Creates a MuData object with modalities “rna”, “junc_counts”, “cell_by_junction_matrix”, 
    and “cell_by_cluster_matrix”.
5. Writes out the final MuData object for use with MULTIVISPLICE.

## 0. Set Paths and Configuration

In [1]:
ROOT_PATH = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/"

ATSE_DATA_PATH = ROOT_PATH + "aligned_splicing_data_20250513_035938.h5ad"
GE_DATA_PATH = ROOT_PATH + "aligned_gene_expression_data_20250513_035938.h5ad"
OUTPUT_MUDATA_PATH = ROOT_PATH + "aligned__ge_splice_combined_20250513_035938.h5mu"
REDO_JUNC_RATIO = False

print("ATSE data path:", ATSE_DATA_PATH)
print("GE data path:  ", GE_DATA_PATH)
print("Output MuData path:", OUTPUT_MUDATA_PATH)

ATSE data path: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/aligned_splicing_data_20250513_035938.h5ad
GE data path:   /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/aligned_gene_expression_data_20250513_035938.h5ad
Output MuData path: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/aligned__ge_splice_combined_20250513_035938.h5mu


## 1. Imports

In [1]:
import anndata as ad
import pandas as pd
import scipy.sparse as sp
import numpy as np
from scipy.sparse import csr_matrix, hstack, vstack
import h5py
import anndata as ad
import mudata as mu
import scanpy as sc

## 2. Load ATSE and Gene Expression AnnData

In [2]:
atse_anndata = ad.read_h5ad("/gpfs/commons/groups/knowles_lab/Karin/TMS_MODELING/DATA_FILES/SIMULATED/simulated_data_2025-03-12.h5ad")
print("ATSE AnnData:", atse_anndata)

ATSE AnnData: AnnData object with n_obs × n_vars = 19942 × 9798
    obs: 'cell_id', 'age', 'batch', 'cell_ontology_class', 'method', 'mouse.id', 'sex', 'tissue', 'old_cell_id_index', 'cell_clean', 'cell_id_index', 'subtissue_clean', 'cell_type_grouped', 'cell_type'
    var: 'junction_id', 'event_id', 'splice_motif', 'label_5_prime', 'label_3_prime', 'annotation_status', 'gene_name', 'gene_id', 'num_junctions', 'position_off_5_prime', 'position_off_3_prime', 'CountJuncs', 'non_zero_count_cells', 'non_zero_cell_prop', 'annotation_status_score', 'non_zero_cell_prop_score', 'splice_motif_score', 'junction_id_index', 'chr', 'start', 'end', 'index', '0', '1', '2', '3', '4', '5', '6', '7', '8', 'sample_label', 'difference', 'true_label'
    uns: 'age_colors', 'cell_type_colors', 'neighbors', 'pca_explained_variance_ratio', 'tissue_colors', 'umap'
    obsm: 'X_leafletFA', 'X_pca', 'X_umap', 'phi_init_100_waypoints', 'phi_init_30_waypoints'
    varm: 'psi_init_100_waypoints', 'psi_init_30_waypo



In [None]:
print(atse_anndata.var["event_id"])

0       chr10_100087456_100088152_+
1       chr10_100087456_100089195_+
2       chr10_100088220_100089195_+
3       chr10_100578431_100582262_-
4       chr10_100578431_100583913_-
                   ...             
9793       chrX_98563680_98564715_-
9794       chrX_98564333_98564715_-
9795       chrX_99145569_99146477_+
9796       chrX_99145569_99146751_+
9797       chrX_99146570_99146751_+
Name: junction_id, Length: 9798, dtype: object


In [4]:
ge_anndata = ad.read_h5ad(GE_DATA_PATH)
print("GE AnnData:", ge_anndata)

GE AnnData: AnnData object with n_obs × n_vars = 157418 × 19022
    obs: 'cell_id', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'tissue', 'dataset', 'batch', 'subtissue_clean', 'broad_cell_type', 'cell_id_index', 'cell_name', 'library_size'
    var: 'index', 'gene_name', 'gene_id', 'mean_transcript_length', 'mean_intron_length', 'num_transcripts', 'transcript_biotypes'
    obsm: 'X_library_size'
    layers: 'length_norm', 'log_norm', 'predicted_log_norm_tms', 'raw_counts'


In [5]:
# assert that cell_id is in the exact same order in both ge_anndata.obs and atse_anndata.obs
assert (ge_anndata.obs["cell_id"].values == atse_anndata.obs["cell_id"].values).all()
# assert that cell_id is in the exact same order in both ge_anndata.obs and atse_anndata.obs
assert (atse_anndata.obs["cell_id"].values == ge_anndata.obs["cell_id"].values).all()

In [6]:
# rescale by overall median transcript length (didn't do this in preprocessing of GE AnnData)
ge_anndata.layers["length_norm"] = ge_anndata.layers["length_norm"] * np.median(ge_anndata.var["mean_transcript_length"])
# make sure to round down to get integer counts (this is CSR)
ge_anndata.layers["length_norm"].data = np.floor(ge_anndata.layers["length_norm"].data)
print(ge_anndata.layers["length_norm"])

<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 593435257 stored elements and shape (157418, 19022)>
  Coords	Values
  (0, 21)	51.0
  (0, 54)	29.0
  (0, 137)	54.0
  (0, 185)	8.0
  (0, 340)	3.0
  (0, 372)	35.0
  (0, 415)	1.0
  (0, 419)	120.0
  (0, 455)	4.0
  (0, 482)	184.0
  (0, 510)	397.0
  (0, 535)	1.0
  (0, 557)	5.0
  (0, 708)	120.0
  (0, 804)	133.0
  (0, 900)	47.0
  (0, 1018)	20.0
  (0, 1121)	37.0
  (0, 1146)	1.0
  (0, 1147)	17.0
  (0, 1178)	12.0
  (0, 1180)	109.0
  (0, 1189)	1.0
  (0, 1219)	7.0
  (0, 1231)	0.0
  :	:
  (157417, 18963)	14.0
  (157417, 18965)	45.0
  (157417, 18968)	2.0
  (157417, 18969)	67.0
  (157417, 18971)	55.0
  (157417, 18972)	52.0
  (157417, 18974)	31.0
  (157417, 18976)	1.0
  (157417, 18987)	5.0
  (157417, 18988)	53.0
  (157417, 18990)	52.0
  (157417, 18992)	1.0
  (157417, 18996)	8.0
  (157417, 18997)	9.0
  (157417, 18998)	1.0
  (157417, 19006)	35.0
  (157417, 19008)	20.0
  (157417, 19010)	0.0
  (157417, 19012)	9.0
  (157417, 19013)	6.0
  (157417,

In [7]:
# Recalculate library size using length normalized counts
ge_anndata.obsm["X_library_size"] = ge_anndata.layers["length_norm"].sum(axis=1)
print(ge_anndata.obsm["X_library_size"])

[[ 104530.]
 [ 415320.]
 [ 330365.]
 ...
 [ 873490.]
 [1368928.]
 [ 431221.]]




In [8]:
# Ensure ge_anndata.var_names are gene names
ge_anndata.var_names = ge_anndata.var["gene_name"]
ge_anndata.var_names

Index(['0610005C13Rik', '0610009B22Rik', '0610009L18Rik', '0610010F05Rik',
       '0610010K14Rik', '0610030E20Rik', '0610031O16Rik', '0610037L13Rik',
       '0610038B21Rik', '0610040B10Rik',
       ...
       'Zswim7', 'Zw10', 'Zwilch', 'Zwint', 'Zxdc', 'Zyg11a', 'Zyg11b', 'Zyx',
       'Zzef1', 'Zzz3'],
      dtype='object', name='gene_name', length=19022)

In [9]:
# Do processing required to calculate most highly variable genes
# mitochondrial genes, "MT-" for human, "Mt-" for mouse
ge_anndata.var["mt"] = ge_anndata.var_names.str.startswith("mt-")
# ribosomal genes
ge_anndata.var["ribo"] = ge_anndata.var_names.str.startswith(("Rps", "Rpl"))
# hemoglobin genes
ge_anndata.var["hb"] = ge_anndata.var_names.str.contains("^Hb[^(P)]")
sc.pp.calculate_qc_metrics(
    ge_anndata, qc_vars=["mt", "ribo", "hb"], inplace=True, log1p=True
)

In [10]:
# count number of cells with pct_counts_ribo > 40%
print(f"Number of cells with pct_counts_ribo > 40%: {(ge_anndata.obs['pct_counts_ribo'] > 40).sum()}")
# count number of cells with pct_counts_hb > 40%
print(f"Number of cells with pct_counts_hb > 40%: {(ge_anndata.obs['pct_counts_hb'] > 40).sum()}")
# count number of cells with pct_counts_mt > 40%
print(f"Number of cells with pct_counts_mt > 40%: {(ge_anndata.obs['pct_counts_mt'] > 40).sum()}")

Number of cells with pct_counts_ribo > 40%: 8
Number of cells with pct_counts_hb > 40%: 2
Number of cells with pct_counts_mt > 40%: 0


In [11]:
# Step 1: Create a working copy of length-normalized data
ge_anndata.layers["working_norm"] = ge_anndata.layers["length_norm"].copy()

# Step 2: Normalize and log-transform the working layer
sc.pp.normalize_total(ge_anndata, layer="working_norm", inplace=True)
sc.pp.log1p(ge_anndata, layer="working_norm")

# Step 3: Compute highly variable genes on working layer
sc.pp.highly_variable_genes(
    ge_anndata, n_top_genes=5000, layer="working_norm", batch_key="dataset"
)

# Step 4: Subset to HVGs
ge_anndata = ge_anndata[:, ge_anndata.var["highly_variable"]]

# Step 5: Assign unmodified length-normalized data to .X
ge_anndata.X = ge_anndata.layers["length_norm"]

print(f"The .X of ge_anndata is layer: {ge_anndata.X} corresponding to {ge_anndata.layers['length_norm']}")



The .X of ge_anndata is layer: <Compressed Sparse Row sparse matrix of dtype 'float32'
	with 96059694 stored elements and shape (157418, 5000)>
  Coords	Values
  (0, 154)	35.0
  (0, 194)	397.0
  (0, 204)	5.0
  (0, 256)	120.0
  (0, 461)	12.0
  (0, 484)	22.0
  (0, 495)	4.0
  (0, 500)	18.0
  (0, 528)	68.0
  (0, 549)	80.0
  (0, 552)	1.0
  (0, 567)	1.0
  (0, 568)	11.0
  (0, 583)	35.0
  (0, 605)	98.0
  (0, 608)	3.0
  (0, 610)	16.0
  (0, 633)	3.0
  (0, 640)	1119.0
  (0, 651)	36.0
  (0, 671)	6.0
  (0, 677)	50.0
  (0, 690)	9.0
  (0, 691)	3.0
  (0, 694)	1.0
  :	:
  (157417, 4946)	5.0
  (157417, 4947)	21.0
  (157417, 4949)	3.0
  (157417, 4950)	25.0
  (157417, 4951)	8.0
  (157417, 4952)	25.0
  (157417, 4956)	26.0
  (157417, 4957)	3.0
  (157417, 4959)	10.0
  (157417, 4961)	13.0
  (157417, 4962)	19.0
  (157417, 4963)	9.0
  (157417, 4964)	22.0
  (157417, 4966)	32.0
  (157417, 4970)	1.0
  (157417, 4974)	27.0
  (157417, 4976)	1.0
  (157417, 4977)	15.0
  (157417, 4980)	7.0
  (157417, 4981)	39.0
  (15741

In [12]:
# reset atse_anndata.obs
atse_anndata.obs.reset_index(drop=True, inplace=True)
ge_anndata.obs.reset_index(drop=True, inplace=True)

# assert that cell_id is in the exact same order in both ge_anndata.obs and atse_anndata.obs
assert (ge_anndata.obs["cell_id"].values == atse_anndata.obs["cell_id"].values).all()
assert (atse_anndata.obs["cell_id"].values == ge_anndata.obs["cell_id"].values).all()

  new_obj.index = new_index


## 3. Create `.var` DataFrames for Each Modality

Here we create modality-specific `.var` metadata. You might later use these to update the
corresponding AnnData objects inside the MuData container.

In [13]:
gene_expr_var = pd.DataFrame(
    {
        "ID": ge_anndata.var["gene_id"],  # from the GE AnnData
        "modality": "Gene_Expression",
    },
    index=ge_anndata.var.index
)

splicing_var = pd.DataFrame(
    {
        "ID": atse_anndata.var["junction_id"],  # from the ATSE AnnData
        "modality": "Splicing",
    },
    index=atse_anndata.var.index
)

ge_anndata.var = gene_expr_var.copy()
atse_anndata.var = splicing_var.copy()

## 4. Create a Common `.obs` DataFrame

You can decide which AnnData’s `.obs` to use (or merge them) if both contain the same information.
Here we assume ATSE and GE have matching `obs` indices; we take the ATSE `obs`.

In [14]:
common_obs = atse_anndata.obs.copy()
common_obs["modality"] = "paired"  # if needed; adjust as required
print("Common obs shape:", common_obs.shape)

# Update both AnnData objects:
ge_anndata.obs = common_obs.copy()
atse_anndata.obs = common_obs.copy()

Common obs shape: (157418, 14)


: 

## 5. Compute or Fix Splicing `junc_ratio` Layer

Here we check if `junc_ratio` needs to be recomputed. It is computed as:
`junc_ratio = cell_by_junction_matrix / cell_by_cluster_matrix`
and any NaNs/Inf values are replaced by zeros.


In [None]:
# %% [markdown]
# ### 5.1 Build junc_ratio + psi_mask on the filtered data

# %%
import numpy as np
from scipy import sparse
from scipy.sparse import csr_matrix, issparse
import gc

# grab the splicing modality
splicing = atse_anndata  # if you later rename it to 'splicing', otherwise: atse_anndata

cell_by_junc    = splicing.layers["cell_by_junction_matrix"]
cell_by_cluster = splicing.layers["cell_by_cluster_matrix"]

# 1) ensure CSR format
if not issparse(cell_by_junc):
    cell_by_junc = csr_matrix(cell_by_junc)
if not issparse(cell_by_cluster):
    cell_by_cluster = csr_matrix(cell_by_cluster)

# 2) build psi_mask (1 wherever cluster>0)
mask = cell_by_cluster.copy()
mask.data = np.ones_like(mask.data, dtype=np.uint8)
splicing.layers["psi_mask"] = mask

# 3) compute junc_ratio = junction / cluster, nan→0
cj = cell_by_junc.toarray()
cc = cell_by_cluster.toarray()

junc_ratio = np.divide(
    cj,
    cc,
    out=np.zeros_like(cj, dtype=float),
    where=(cc != 0),
)
# 4) assign back as dense or sparse (dense is fine)
splicing.layers["junc_ratio"] = junc_ratio

print("New splicing layers:", list(splicing.layers.keys()))
print(f"  junc_ratio shape: {junc_ratio.shape}, psi_mask nnz: {mask.nnz}")

# 5) cleanup
del cell_by_junc, cell_by_cluster, cj, cc, mask
gc.collect()


In [None]:
print(atse_anndata.layers['junc_ratio'])

## 6. Create a MuData Object

Instead of stacking into one AnnData, we create a MuData container.

For MULTIVISPLICE, the new setup expects modalities with the following keys:
- `rna` : gene expression counts,
- `junc_ratio` : raw splicing/junction count data,
- `cell_by_junction_matrix` and `cell_by_cluster_matrix` as additional layers.

We can use the GE AnnData for gene expression and the ATSE AnnData for all splicing-related data.
(If needed, make copies so that modalities are independent.)


Option 1: Use the GE AnnData for RNA and the ATSE AnnData for splicing modalities.
(You can also combine or pre-process further if desired.)

In [None]:
mdata = mu.MuData({
    "rna": ge_anndata,
    "splicing": atse_anndata
})

# assert "library_size" in ge_anndata.obs, "'library_size' not found in ge_anndata.obs"
mdata.obsm["X_library_size"] = ge_anndata.obsm["X_library_size"]

# # Confirm it's stored correctly
# print("Library size moved to mdata.obsm['library_size'] with shape:", mdata.obsm["library_size"].shape)


# List of shared obs fields to pull up
shared_obs_keys = [
    'cell_id', 'age', 'cell_ontology_class', 'mouse.id', 'sex', 'tissue', 'dataset', 'broad_cell_type', 'cell_id_index', 'cell_name', 'modality'
]

# We'll assume 'rna' modality has them all and they match 'splicing'
for key in shared_obs_keys:
    assert key in mdata["rna"].obs, f"{key} not found in 'rna' obs"
    assert key in mdata["splicing"].obs, f"{key} not found in 'splicing' obs"
    assert (mdata["rna"].obs[key] == mdata["splicing"].obs[key]).all(), f"{key} values differ between modalities"
    mdata.obs[key] = mdata["rna"].obs[key]
    
print("MuData object created with modalities:", list(mdata.mod.keys()))

## 7. Write Out the Final MuData Object

The combined MuData object is now ready for use with `MULTIVISPLICE`. Save it as an H5MU file.

In [None]:
mdata.write(OUTPUT_MUDATA_PATH)
print(f"MuData object written to {OUTPUT_MUDATA_PATH}")

## 8. Verify the Output

Read the MuData object back in to ensure everything is correct.

In [None]:
mdata_loaded = mu.read_h5mu(OUTPUT_MUDATA_PATH)
print("Loaded MuData modalities:", list(mdata_loaded.mod.keys()))
print(mdata_loaded)