### Barycentric projection

In [36]:
import scanpy as sc
import anndata as ad
import numpy as np
import scipy.sparse as sp
import pandas as pd

def to_dense_row(x):
    return x.toarray() if sp.issparse(x) else np.asarray(x)

def project_sm_to_st(adata_st, adata_sm, P_csr, eps=1e-12):
    """
    adata_st.X: (n_st, g)
    adata_sm.X: (n_sm, m)
    P_csr: csr_matrix, shape (n_st, n_sm)
    returns joint AnnData on ST grid with X=[genes | projected metabolites]
    """
    assert P_csr.shape == (adata_st.n_obs, adata_sm.n_obs)

    # Row-normalize P
    row_sums = np.asarray(P_csr.sum(axis=1)).ravel()
    row_sums[row_sums == 0] = eps
    Dinv = sp.diags(1.0 / row_sums)
    P_row = Dinv @ P_csr  # still CSR

    X_st = adata_st.X  # can be sparse
    X_sm = adata_sm.X  # can be sparse
    # X_st = np.log1p(X_st)           # if you’re using an L2 decoder
    # X_sm = X_sm.A if sp.issparse(X_sm) else np.asarray(X_sm)
    # X_sm = (X_sm - X_sm.mean(0)) / (X_sm.std(0) + 1e-8)     # z-score per metabolite

    # Barycentric projection: SM -> ST
    # (n_st x n_sm) @ (n_sm x m) => (n_st x m)
    X_sm_to_st = P_row @ X_sm
    # hstack features: genes + metabolites
    X_joint = sp.hstack([X_st, X_sm_to_st], format='csr') if (sp.issparse(X_st) or sp.issparse(X_sm_to_st)) \
              else np.hstack([to_dense_row(X_st), to_dense_row(X_sm_to_st)])

    # Build var table with unique names
    var_st = adata_st.var.copy()
    var_sm = adata_sm.var.copy()
    var_st["type"] = "ST"
    var_sm["type"] = "SM"
    # Prefix to avoid name collisions
    var_st = var_st.copy()
    var_st.index = ["g:" + str(i) for i in var_st.index]
    var_sm = var_sm.copy()
    var_sm.index = ["m:" + str(i) for i in var_sm.index]
    var_joint = pd.concat([var_st, var_sm])

    # New AnnData on ST grid
    adata_joint = ad.AnnData(
        X=X_joint,
        obs=adata_st.obs.copy(),
        var=var_joint,
        obsm=adata_st.obsm.copy(),
        obsp=adata_st.obsp.copy() if hasattr(adata_st, "obsp") else None,
        uns=adata_st.uns.copy()
    )
    # Keep originals & coupling for reference
    adata_joint.uns["joint_base"] = "ST"
    adata_joint.uns["features"] = {"genes_prefix": "g:", "metabolites_prefix": "m:"}
    adata_joint.obsm["coupling_to_SM"] = P_csr  # shape: (n_ST, n_SM)
    #adata_joint.uns["coupling_meta"] = {"shape": P_csr.shape, "base": "ST", "direction": "ST->SM"}

    return adata_joint

In [37]:
# Usage:
adata_st = sc.read_h5ad("/scratch/gpfs/BRAPHAEL/ST_SM/adata_ST_Y7_T_raw.h5ad")
adata_sm = sc.read_h5ad("/scratch/gpfs/BRAPHAEL/ST_SM/adata_SM_Y7_T_raw.h5ad")
P = np.load("coupling_matrix_1.npy")
adata_joint = project_sm_to_st(adata_st, adata_sm, P)
adata_joint.write_h5ad("/scratch/gpfs/BRAPHAEL/ST_SM/joint_STbase_Y7.h5ad")
adata_joint

AnnData object with n_obs × n_vars = 2018 × 37900
    obs: 'in_tissue', 'array_row', 'array_col'
    var: 'gene_ids', 'feature_types', 'genome', 'type', 'name'
    uns: 'spatial', 'joint_base', 'features'
    obsm: 'spatial', 'coupling_to_SM'

### Sampling based on P

In [3]:
import scanpy as sc
import scipy.sparse as sp
import numpy as np
from mgw.dataloader import make_coupled_dataloader

In [6]:
# 1) Load your raw AnnData objects
adata_st = sc.read_h5ad("/scratch/gpfs/BRAPHAEL/ST_SM/adata_ST_Y7_T_raw.h5ad")
adata_sm = sc.read_h5ad("/scratch/gpfs/BRAPHAEL/ST_SM/adata_SM_Y7_T_raw.h5ad")

# 2) Load/prepare your coupling matrix P (n_ST x n_SM)
P = np.load("coupling_matrix_1.npy")
assert P.sum() > 0

# 3) Optional transforms (e.g., log1p for raw counts)
st_tf = lambda x: np.log1p(x)
sm_tf = lambda x: np.log1p(x)

# 4) Build dataloader
loader = make_coupled_dataloader(
    adata_st,
    adata_sm,
    P,
    batch_size=1,                   # one aligned pair each iteration
    mode="per_row",                 # or "global"
    st_obsm_key=None,               # or "X_pca" etc.
    sm_obsm_key=None,
    st_transform=st_tf,
    sm_transform=sm_tf,
    temperature=0.7,                # <1 => emphasize highest probs
)

# 5) Iterate
for step, (x_st, x_sm, i, j, p_ij) in enumerate(loader):
    # x_st: [1, d_st], x_sm: [1, d_sm]
    # i, j: spot indices; p_ij: sampling prob (approx joint)
    if step < 3:
        print(f"Pair {step}: ST[{i.item()}] ↔ SM[{j.item()}], sample prob≈{p_ij.item():.3e}")
        print(x_st.shape, x_sm.shape, i, j, p_ij)
    else:
        break

Pair 0: ST[268] ↔ SM[10545], sample prob≈1.555e-06
torch.Size([1, 36601]) torch.Size([1, 1299]) tensor([268]) tensor([10545]) tensor([1.5551e-06], dtype=torch.float64)
Pair 1: ST[1785] ↔ SM[2111], sample prob≈3.789e-06
torch.Size([1, 36601]) torch.Size([1, 1299]) tensor([1785]) tensor([2111]) tensor([3.7895e-06], dtype=torch.float64)
Pair 2: ST[559] ↔ SM[8015], sample prob≈1.891e-05
torch.Size([1, 36601]) torch.Size([1, 1299]) tensor([559]) tensor([8015]) tensor([1.8907e-05], dtype=torch.float64)
