# Quick and Dirty Regression between Gene Expression and Splicing Counts

In [1]:
import scanpy as sc
import mudata
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import spearmanr, pearsonr
import matplotlib.pyplot as plt

## Load and Extract Data

In [4]:
mdata = mudata.read_h5mu("/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/SUBSETTOP5CELLSTYPES_aligned__ge_splice_combined_20250513_035938.h5mu")

rna_counts = mdata.mod['rna'].layers['raw_counts'].copy().astype(np.float32)
splicing_counts = mdata.mod['splicing'].layers['cell_by_junction_matrix'].copy().astype(np.float32)

# Densify if needed
if not isinstance(rna_counts, np.ndarray):
    rna_counts = rna_counts.toarray()
if not isinstance(splicing_counts, np.ndarray):
    splicing_counts = splicing_counts.toarray()


  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


In [12]:
def make_loader(X, Y, batch_size=256):
    dataset = TensorDataset(torch.tensor(X), torch.tensor(Y))
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Define Model

In [13]:
class SimpleRegressor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )

    def forward(self, x):
        return self.model(x)


## Train Loop

In [None]:
from tqdm import tqdm
import torch.nn as nn
import torch

def train(model, loader, epochs=1, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        loop = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for x, y in loop:
            pred = model(x)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loop.set_postfix(loss=loss.item())


RNA → Splicing

In [15]:
loader = make_loader(rna_counts, splicing_counts)
model_rna2splicing = SimpleRegressor(rna_counts.shape[1], splicing_counts.shape[1])
train(model_rna2splicing, loader)

Epoch 1/10:  23%|██▎       | 74/326 [00:39<02:15,  1.86it/s, loss=5.97e+3]


KeyboardInterrupt: 

Splicing → RNA

In [None]:
loader = make_loader(splicing_counts, rna_counts)
model_splicing2rna = SimpleRegressor(splicing_counts.shape[1], rna_counts.shape[1])
train(model_splicing2rna, loader)

# Evaluation

In [None]:
def evaluate(model, X, Y_true):
    model.eval()
    with torch.no_grad():
        pred = model(torch.tensor(X)).numpy()
    true = Y_true
    pred = np.clip(pred, 0, None)  # prevent negative predictions

    # Pearson / Spearman on flattened
    pearson = pearsonr(true.flatten(), pred.flatten())[0]
    spearman = spearmanr(true.flatten(), pred.flatten())[0]
    r2 = r2_score(true, pred)
    mse = mean_squared_error(true, pred)

    print(f"Pearson R: {pearson:.3f}")
    print(f"Spearman R: {spearman:.3f}")
    print(f"R2 Score: {r2:.3f}")
    print(f"MSE: {mse:.3f}")
    return true, pred


In [None]:
# ----------------------
# RNA → Splicing Evaluation
# ----------------------
true, pred = evaluate(model_rna2splicing, rna_counts, splicing_counts)

plt.figure(figsize=(5, 5))
plt.scatter(true.flatten(), pred.flatten(), s=1, alpha=0.2)
plt.title("RNA → Splicing")
plt.xlabel("True")
plt.ylabel("Predicted")
plt.show()

In [None]:
# ----------------------
# Splicing → RNA Evaluation
# ----------------------
true, pred = evaluate(model_splicing2rna, splicing_counts, rna_counts)

plt.figure(figsize=(5, 5))
plt.scatter(true.flatten(), pred.flatten(), s=1, alpha=0.2)
plt.title("Splicing → RNA")
plt.xlabel("True")
plt.ylabel("Predicted")
plt.show()