In [3]:
import scanpy as sc
import numpy as np
import scipy.sparse as sp
from sklearn.neighbors import NearestNeighbors
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GraphConv
from torch_geometric.data import Data
import torch.optim as optim
from sklearn.cluster import KMeans

In [4]:
file = "BrainAgingSpatialAtlas_MERFISH.h5ad"

adata = sc.read_h5ad(file)
print(adata)

AnnData object with n_obs × n_vars = 378918 × 374
    obs: 'fov', 'center_x', 'center_y', 'min_x', 'max_x', 'min_y', 'max_y', 'age', 'clust_annot', 'slice', 'sex_ontology_term_id', 'suspension_type', 'cell_type_ontology_term_id', 'assay_ontology_term_id', 'tissue_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'development_stage_ontology_term_id', 'donor_id', 'is_primary_data', 'cell_type_annot', 'tissue_type', 'cell_type', 'assay', 'disease', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type'
    uns: 'batch_condition', 'citation', 'organism', 'organism_ontology_term_id', 'schema_reference', 'schema_version', 'title'
    obsm: 'X_pca', 'X_spatial_coords', 'X_umap', 'spatial_coords'


In [5]:
sc.pp.filter_cells(adata, min_genes=5)
sc.pp.filter_genes(adata, min_cells=5)

sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, n_comps=50)

In [6]:
# ========== Node features ==========
X_pca = adata.obsm['X_pca']             # (N_cells, 50 PCs)
X = torch.tensor(X_pca, dtype=torch.float)

# ========== Spatial coordinates ==========
if 'X_spatial_coords' in adata.obsm_keys():
    coords = adata.obsm['X_spatial_coords']
elif 'spatial_coords' in adata.obsm_keys():
    coords = adata.obsm['spatial_coords']
else:
    coords = adata.obs[['center_x', 'center_y']].to_numpy()

coords = coords.astype(float)

# ========== Domain labels ==========
domains = adata.obs['clust_annot'].astype('category')
domain_idx = domains.cat.codes.to_numpy()      # integer labels
n_domains = len(domains.cat.categories)

print("Nodes:", X.shape)
print("Domains:", n_domains)

Nodes: torch.Size([378918, 50])
Domains: 43


In [7]:
# ----- Build spatial kNN graph -----
k = 8
n_cells = coords.shape[0]

nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree')
nbrs.fit(coords)
distances, indices = nbrs.kneighbors(coords)

rows = []
cols = []
weights = []

for i in range(n_cells):
    for j, d in zip(indices[i, 1:], distances[i, 1:]):
        rows.append(i)
        cols.append(j)
        weights.append(np.exp(-d))

# Build sparse adjacency in COO format
A_spatial = sp.coo_matrix((weights, (rows, cols)), shape=(n_cells, n_cells))

# Symmetrize adjacency (may convert format internally)
A_spatial = A_spatial.maximum(A_spatial.T)

# ⭐ Convert back to COO explicitly (fix your error here!)
A_spatial = A_spatial.tocoo()

# Now you can safely access row, col
edge_index = np.vstack([A_spatial.row, A_spatial.col])
edge_index = torch.tensor(edge_index, dtype=torch.long)

edge_attr = torch.tensor(A_spatial.data, dtype=torch.float)

print("Spatial edges:", edge_index.shape[1])

Spatial edges: 3470060


In [8]:
# Map domain names to numeric indices
domain_to_idx = {d: i for i, d in enumerate(domains.cat.categories)}
domain_per_cell = domain_idx

# Initialize domain adjacency counts
DD_counts = np.zeros((n_domains, n_domains), dtype=float)

rows = A_spatial.row
cols = A_spatial.col

for i_cell, j_cell in zip(rows, cols):
    d_i = domain_per_cell[i_cell]
    d_j = domain_per_cell[j_cell]
    DD_counts[d_i, d_j] += 1
    DD_counts[d_j, d_i] += 1

eps = 1e-8
DD_phys = DD_counts / (DD_counts.sum(axis=1, keepdims=True) + eps)

# Compute domain mean vectors in PCA space
domain_mean = np.zeros((n_domains, X_pca.shape[1]))
domain_counts = np.zeros(n_domains)

for cell, d in enumerate(domain_per_cell):
    domain_mean[d] += X_pca[cell]
    domain_counts[d] += 1

domain_counts[domain_counts == 0] = 1
domain_mean = domain_mean / domain_counts[:, None]

from numpy.linalg import norm

DD_sem = np.zeros((n_domains, n_domains))

for i in range(n_domains):
    for j in range(n_domains):
        vi, vj = domain_mean[i], domain_mean[j]
        DD_sem[i, j] = np.dot(vi, vj) / (norm(vi)*norm(vj) + eps)

alpha = 0.5   # weight between physical and semantic
DD_combined = alpha * DD_phys + (1 - alpha) * DD_sem

# threshold weak edges
thr = 0.1
DD_mask = (DD_combined > thr).astype(float)
DD_final = DD_combined * DD_mask

DD_rows, DD_cols = np.where(DD_final > 0)
DD_weights = DD_final[DD_rows, DD_cols]

DD_edge_index = torch.tensor(np.vstack([DD_rows, DD_cols]), dtype=torch.long)
DD_edge_attr  = torch.tensor(DD_weights, dtype=torch.float)

print("DDG edges:", DD_edge_index.shape[1])

from torch_geometric.data import Data

data = Data(
    x=X,                        # PCA features
    edge_index=edge_index,      # spatial graph edges
    edge_attr=edge_attr         # spatial edge weights
)

# Add domain info + DDG
data.y_domain = torch.tensor(domain_per_cell, dtype=torch.long)
data.n_domains = n_domains
data.DD_edge_index = DD_edge_index
data.DD_edge_attr  = DD_edge_attr

print(data)

DDG edges: 483
Data(x=[378918, 50], edge_index=[2, 3470060], edge_attr=[3470060], y_domain=[378918], n_domains=43, DD_edge_index=[2, 483], DD_edge_attr=[483])


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GraphConv  # you can switch to GCNConv or GATConv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ============================
# 1. Encoder: MLP + GNN → μ, logσ
# ============================

class DRASIEncoder(nn.Module):
    def __init__(self, in_dim, hidden_mlp=128, hidden_gnn=128, latent_dim=32):
        super().__init__()
        # MLP to compress input features (PCA) before GNN
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_mlp),
            nn.ReLU(),
            nn.Linear(hidden_mlp, hidden_gnn),
            nn.ReLU()
        )
        # GNN layers over spatial graph
        self.gnn1 = GraphConv(hidden_gnn, hidden_gnn)
        self.gnn2 = GraphConv(hidden_gnn, hidden_gnn)

        # Heads for VAE parameters
        self.mu_head     = nn.Linear(hidden_gnn, latent_dim)
        self.logvar_head = nn.Linear(hidden_gnn, latent_dim)

    def forward(self, x, edge_index):
        # x: (N, in_dim), edge_index: (2, E)
        h = self.mlp(x)
        h = F.relu(self.gnn1(h, edge_index))
        h = F.relu(self.gnn2(h, edge_index))
        mu     = self.mu_head(h)
        logvar = self.logvar_head(h)
        return mu, logvar


# =======================
# 2. Reparameterization
# =======================

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


# =======================
# 3. Decoder: z → x̂
# =======================

class DRASIDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dec=128, out_dim=None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dec),
            nn.ReLU(),
            nn.Linear(hidden_dec, out_dim)
        )

    def forward(self, z):
        return self.net(z)


# =======================
# 4. Full DRA-SI Model
# =======================

class DRASIModel(nn.Module):
    def __init__(self, in_dim, n_domains, latent_dim=32):
        super().__init__()
        self.encoder = DRASIEncoder(in_dim, latent_dim=latent_dim)
        self.decoder = DRASIDecoder(latent_dim, out_dim=in_dim)
        self.n_domains = n_domains

    def forward(self, data):
        # data.x, data.edge_index
        x = data.x
        edge_index = data.edge_index

        mu, logvar = self.encoder(x, edge_index)
        z = reparameterize(mu, logvar)
        x_recon = self.decoder(z)

        return x_recon, mu, logvar, z


# ======================================
# 5. Loss components: Recon + KL + DDG
# ======================================

def loss_recon(x, x_recon):
    # MSE between reconstructed and original PCA features
    return F.mse_loss(x_recon, x)

def loss_kl(mu, logvar):
    # KL divergence between q(z|x) and N(0, I)
    return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

def compute_domain_means(z, domain_idx, n_domains):
    """
    z: (N, latent_dim)
    domain_idx: (N,) int
    Returns: (n_domains, latent_dim) mean embeddings.
    """
    latent_dim = z.size(1)
    device = z.device

    sums = torch.zeros(n_domains, latent_dim, device=device)
    counts = torch.zeros(n_domains, device=device)

    # accumulate sums & counts
    sums.index_add_(0, domain_idx, z)
    ones = torch.ones_like(domain_idx, dtype=torch.float, device=device)
    counts.index_add_(0, domain_idx, ones)

    counts = counts.clamp(min=1.0)
    means = sums / counts.unsqueeze(1)
    return means

def loss_ddg(z, domain_idx, DD_edge_index, DD_edge_attr, n_domains, lambda_ddg=1.0):
    """
    z: (N, latent_dim)
    domain_idx: (N,) int
    DD_edge_index: (2, E_D)
    DD_edge_attr: (E_D,)
    """
    device = z.device
    DD_edge_index = DD_edge_index.to(device)
    DD_edge_attr  = DD_edge_attr.to(device)
    domain_idx    = domain_idx.to(device)

    domain_means = compute_domain_means(z, domain_idx, n_domains)  # (D, latent_dim)

    d_i = DD_edge_index[0]  # (E_D,)
    d_j = DD_edge_index[1]

    mu_i = domain_means[d_i]   # (E_D, latent_dim)
    mu_j = domain_means[d_j]

    diff = mu_i - mu_j
    dist_sq = (diff * diff).sum(dim=1)   # (E_D,)

    weighted = DD_edge_attr * dist_sq
    return lambda_ddg * weighted.mean()


# ======================
# 6. Total loss wrapper
# ======================

def total_loss(data, model, lambda_kl=1e-3, lambda_ddg=0.1):
    """
    data: PyG Data
    model: DRASIModel
    """
    x = data.x
    y_domain = data.y_domain
    DD_edge_index = data.DD_edge_index
    DD_edge_attr  = data.DD_edge_attr
    n_domains = data.n_domains

    x_recon, mu, logvar, z = model(data)

    L_rec = loss_recon(x, x_recon)
    L_kl  = loss_kl(mu, logvar)
    L_DDG = loss_ddg(z, y_domain, DD_edge_index, DD_edge_attr, n_domains)

    L = L_rec + lambda_kl * L_kl + L_DDG
    logs = {
        "loss": L.item(),
        "rec": L_rec.item(),
        "kl": L_kl.item(),
        "ddg": L_DDG.item()
    }
    return L, logs


Using device: cpu


In [10]:
import torch.optim as optim

in_dim = data.x.size(1)
n_domains = int(data.n_domains)

model = DRASIModel(in_dim=in_dim, n_domains=n_domains, latent_dim=32).to(device)
data = data.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 100  # you can increase to 200–300 once things work

for epoch in range(1, n_epochs + 1):
    model.train()
    optimizer.zero_grad()

    L, logs = total_loss(data, model, lambda_kl=1e-3, lambda_ddg=0.1)
    L.backward()
    optimizer.step()

    if epoch % 10 == 0 or epoch == 1:
        print(
            f"Epoch {epoch:03d} | "
            f"loss={logs['loss']:.4f} | "
            f"rec={logs['rec']:.4f} | "
            f"kl={logs['kl']:.4f} | "
            f"ddg={logs['ddg']:.4f}"
        )

# After training, get latent embeddings
model.eval()
with torch.no_grad():
    _, _, _, z = model(data)   # z: (N_cells, latent_dim)

z_cpu = z.cpu().numpy()
print("Latent embedding shape:", z_cpu.shape)


Epoch 001 | loss=2.8592 | rec=2.7212 | kl=3.3570 | ddg=0.1346
Epoch 010 | loss=2.2824 | rec=2.2697 | kl=0.3398 | ddg=0.0123
Epoch 020 | loss=2.2579 | rec=2.2506 | kl=0.9907 | ddg=0.0063
Epoch 030 | loss=2.2357 | rec=2.2245 | kl=2.0832 | ddg=0.0091
Epoch 040 | loss=2.1967 | rec=2.1788 | kl=3.6401 | ddg=0.0143
Epoch 050 | loss=2.1106 | rec=2.0906 | kl=4.6898 | ddg=0.0153
Epoch 060 | loss=2.0061 | rec=1.9806 | kl=5.2722 | ddg=0.0202
Epoch 070 | loss=1.9298 | rec=1.9014 | kl=4.1097 | ddg=0.0243
Epoch 080 | loss=1.8682 | rec=1.8374 | kl=4.3752 | ddg=0.0264
Epoch 090 | loss=1.8465 | rec=1.8107 | kl=4.0641 | ddg=0.0317
Epoch 100 | loss=1.7925 | rec=1.7509 | kl=4.2492 | ddg=0.0374
Latent embedding shape: (378918, 32)


In [11]:
adata.obsm['X_drasi'] = z_cpu
print(adata.obsm['X_drasi'].shape)

(378918, 32)


In [17]:
import scib_metrics

from scib_metrics import silhouette_batch  # ✅ top-level, not from .benchmark

X = adata.obsm["X_drasi"]                     # (n_cells, n_features)
labels = adata.obs["cell_type"].to_numpy()    # or whatever your cell-type column is
batch = adata.obs["slice"].to_numpy()         # your batch key

bASW = silhouette_batch(
    X=X,
    labels=labels,
    batch=batch,
    rescale=True,               # keep [0,1], higher is better
    metric="euclidean",
    between_cluster_distances="nearest",  # standard ASW, not BRAS
)

print("bASW:", bASW)


bASW: 0.98426855
