In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import pandas as pd
import xgboost as xgb
import os

In [2]:
allexp = pd.read_csv('../data/atlas_allexp.csv')

  allexp = pd.read_csv('../data/atlas_allexp.csv')


## Helper Functions

In [3]:
win_size = 2.0        # sliding window side length (in your coordinate units)
grid_n   = 25           # grid resolution inside each window (grid_n x grid_n)
bw       = 0.2          # KDE bandwidth for scipy gaussian_kde (float or 'scott'/'silverman')
min_pts  = 5            # minimum points to fit a KDE for a class
weight_mode = 'sum'  # 'sum' -> w_i = a_i + b_i ; 'prod' -> w_i = a_i * b_i
eps        = 1e-12   # numerical floor to avoid 0-division

In [4]:
from itertools import combinations
from scipy.stats import gaussian_kde, pearsonr

def build_global_kdes(df, bw_method=bw, min_points=min_pts):
    """
    Build a gaussian_kde per class using ALL points (global KDEs).
    Returns: dict[class] -> fitted KDE
    """
    kdes = {}
    for ctype, sub in df.groupby('class'):
        pts = sub[['x','y']].to_numpy()
        if pts.shape[0] < min_points:
            # Not enough points to build a stable KDE; skip this class
            continue
        kdes[ctype] = gaussian_kde(pts.T, bw_method=bw_method)
    return kdes

def sliding_windows(xmin, xmax, ymin, ymax, size):
    """
    Yield windows (x0, x1, y0, y1) with stride size/2, covering [xmin,xmax]x[ymin,ymax].
    The last window is clipped to remain within bounds.
    """
    step = size / 2.0

    def edges(a_min, a_max):
        starts = []
        cur = a_min
        while cur + size <= a_max + 1e-9:
            starts.append(cur)
            cur += step
        # ensure we end exactly at the boundary
        if len(starts) == 0 or starts[-1] + size < a_max:
            starts.append(max(a_min, a_max - size))
        return sorted(set(starts))

    xs = edges(xmin, xmax)
    ys = edges(ymin, ymax)

    for x0 in xs:
        for y0 in ys:
            yield (x0, x0 + size, y0, y0 + size)

def grid_cell_centers(x0, x1, y0, y1, n):
    """
    Return (GX, GY) mesh of grid cell centers (n x n) within the window [x0,x1]x[y0,y1].
    """
    hx = (x1 - x0) / n
    hy = (y1 - y0) / n
    xs = x0 + hx * (np.arange(n) + 0.5)
    ys = y0 + hy * (np.arange(n) + 0.5)
    GX, GY = np.meshgrid(xs, ys, indexing='xy')
    return GX, GY, hx, hy

def windowwise_normalize(Z, hx, hy):
    """
    Normalize a KDE patch so its discrete integral over the window is 1.
    Z is n x n over centers; integral ≈ Z.sum() * hx * hy.
    """
    mass = Z.sum() * hx * hy
    return Z / mass if mass > 0 else Z

def compute_pairwise_pcc_map(df):
    """
    Main driver:
      - builds global KDEs per class
      - iterates sliding windows
      - for each class pair, evaluates KDEs on window grid, normalizes within window, computes PCC
      - returns a results DataFrame
    """
    # Bounds for windows
    xmin, xmax = df['x'].min(), df['x'].max()
    ymin, ymax = df['y'].min(), df['y'].max()

    # Global KDEs
    kdes = build_global_kdes(df)
    classes = sorted(kdes.keys())
    pairs = list(combinations(classes, 2))

    records = []

    # Slide windows
    for (x0, x1, y0, y1) in sliding_windows(xmin, xmax, ymin, ymax, win_size):
        # grid of cell centers for this window
        GX, GY, hx, hy = grid_cell_centers(x0, x1, y0, y1, grid_n)
        XY = np.vstack([GX.ravel(), GY.ravel()])

        for A, B in pairs:
            kdeA = kdes.get(A)
            kdeB = kdes.get(B)
            if kdeA is None or kdeB is None:
                # One of the KDEs wasn't available (too few points)
                continue

            # Evaluate KDEs on grid centers inside this window
            ZA = kdeA(XY).reshape(GX.shape)
            ZB = kdeB(XY).reshape(GX.shape)

            # Window-wise normalization (probability over the window)
            ZA = windowwise_normalize(ZA, hx, hy)
            ZB = windowwise_normalize(ZB, hx, hy)

            # Pearson correlation across grid cells (flatten)
            a = ZA.ravel()
            b = ZB.ravel()

            # If either map is constant, pearsonr is undefined; handle gracefully
            if np.allclose(a, a[0]) or np.allclose(b, b[0]):
                r = np.nan
                p = np.nan
            else:
                r, p = pearsonr(a, b)

            records.append({
                'x0': x0, 'x1': x1, 'y0': y0, 'y1': y1,
                'class_A': A, 'class_B': B,
                'pearson_r': r, 'p_value': p,
                'grid_n': grid_n
            })

    return pd.DataFrame.from_records(records)

# Unweighted PCC

## PCC calcuation

In [None]:
results_df = compute_pairwise_pcc_map(allexp)
results_df.head()

# Weighted PCC

## Helper Functions

In [9]:
from itertools import combinations
from scipy.stats import gaussian_kde, pearsonr

def build_global_kdes(df, bw_method=bw, min_points=min_pts):
    """Fit a global gaussian_kde per class using all points (avoids edge bias)."""
    kdes = {}
    for ctype, sub in df.groupby('class'):
        pts = sub[['x','y']].to_numpy()
        if pts.shape[0] >= min_points:
            kdes[ctype] = gaussian_kde(pts.T, bw_method=bw_method)
    return kdes

def sliding_windows(xmin, xmax, ymin, ymax, size):
    """Yield (x0,x1,y0,y1) with stride size/2, clipped to bounds."""
    step = size / 2.0
    def starts(lo, hi):
        s = []
        cur = lo
        while cur + size <= hi + 1e-9:
            s.append(cur); cur += step
        if not s or s[-1] + size < hi:  # ensure coverage to the edge
            s.append(max(lo, hi - size))
        return sorted(set(s))
    for x0 in starts(xmin, xmax):
        for y0 in starts(ymin, ymax):
            yield (x0, x0 + size, y0, y0 + size)

def grid_cell_centers(x0, x1, y0, y1, n):
    """n x n mesh of grid cell centers inside the window + cell sizes (hx, hy)."""
    hx = (x1 - x0) / n
    hy = (y1 - y0) / n
    xs = x0 + hx * (np.arange(n) + 0.5)
    ys = y0 + hy * (np.arange(n) + 0.5)
    GX, GY = np.meshgrid(xs, ys, indexing='xy')
    return GX, GY, hx, hy

def windowwise_normalize(Z, hx, hy):
    """Rescale KDE patch so its discrete integral in the window is 1."""
    mass = Z.sum() * hx * hy
    return Z / mass if mass > 0 else Z

def weighted_pearson(a, b, mode='sum', eps=1e-12):
    """
    Weighted Pearson correlation between arrays a, b (same shape).
    mode='sum' uses w=a+b; mode='prod' uses w=a*b.
    Returns np.nan if variance is zero or total weight ~0.
    """
    a = np.asarray(a).ravel()
    b = np.asarray(b).ravel()

    if mode == 'prod':
        w = a * b
    else:
        w = a + b

    w_sum = w.sum()
    if not np.isfinite(w_sum) or w_sum <= eps:
        return np.nan

    w = w / w_sum  # normalize weights to sum to 1 (probability weights)

    mu_a = (w * a).sum()
    mu_b = (w * b).sum()

    da = a - mu_a
    db = b - mu_b

    var_a = (w * da * da).sum()
    var_b = (w * db * db).sum()
    if var_a <= eps or var_b <= eps:
        return np.nan

    cov_ab = (w * da * db).sum()
    return cov_ab / np.sqrt(var_a * var_b)

def compute_pairwise_weighted_pcc(df):
    """
    - Fit global KDEs per class.
    - Slide window; make a grid of centers; evaluate KDEs.
    - Window-wise normalize each KDE patch.
    - Compute **weighted** Pearson r for each class pair in each window.
    """
    xmin, xmax = df['x'].min(), df['x'].max()
    ymin, ymax = df['y'].min(), df['y'].max()

    kdes = build_global_kdes(df)
    classes = sorted(kdes.keys())
    pairs = list(combinations(classes, 2))

    records = []
    for (x0, x1, y0, y1) in sliding_windows(xmin, xmax, ymin, ymax, win_size):
        GX, GY, hx, hy = grid_cell_centers(x0, x1, y0, y1, grid_n)
        XY = np.vstack([GX.ravel(), GY.ravel()])

        for A, B in pairs:
            kdeA = kdes.get(A); kdeB = kdes.get(B)
            if kdeA is None or kdeB is None:
                continue

            ZA = kdeA(XY).reshape(GX.shape)
            ZB = kdeB(XY).reshape(GX.shape)

            # Window-wise normalization (probability maps over this window)
            ZA = windowwise_normalize(ZA, hx, hy)
            ZB = windowwise_normalize(ZB, hx, hy)

            # Weighted PCC so low-signal cells don't dominate
            r_w = weighted_pearson(ZA, ZB, mode=weight_mode, eps=eps)

            records.append({
                'x0': x0, 'x1': x1, 'y0': y0, 'y1': y1,
                'class_A': A, 'class_B': B,
                'weighted_pearson_r': r_w,
                'grid_n': grid_n, 'win_size': win_size,
                'weight_mode': weight_mode
            })

    return pd.DataFrame.from_records(records)


## PCC Calculation

In [None]:
results_df = compute_pairwise_weighted_pcc(allexp)
results_df.head()

In [None]:
results_df.to_csv('../data/pairwise_weighted_pcc_map.csv', index=False)

## Optional: Load Formed Values

In [4]:
results_df = pd.read_csv('../data/pairwise_weighted_pcc_map.csv')

# Island Formation

In [5]:
import numpy as np
import pandas as pd
from scipy import ndimage

# -------------------------------
# Helpers
# -------------------------------

def fisher_z(r, eps=1e-12):
    """Fisher z-transform with safe clipping."""
    r = np.asarray(r, dtype=float)
    r = np.where(np.isfinite(r), r, np.nan)
    r = np.clip(r, -1 + eps, 1 - eps)
    # if r is a scalar, np.arctanh returns a scalar
    # make r an array to ensure z is also an array
    z = np.arctanh(r)
    # if z is infinite, set to nan
    if np.isinf(z):
        z = np.nan
    return z

def _build_pair_grids(results_df, A, B, r_col='weighted_pearson_r'):
    """
    From a long-form results_df -> 2D grids aligned on unique (x0,y0).
    Returns dict with r_grid, z_grid, and x0/x1/y0/y1 grids.
    """
    sub = results_df[(results_df['class_A'] == A) & (results_df['class_B'] == B)].copy()
    if sub.empty:
        return None

    # Ensure uniqueness if code was run multiple times
    sub = sub.drop_duplicates(subset=['x0','y0','x1','y1','class_A','class_B'])

    xs = np.array(sorted(sub['x0'].unique()))
    ys = np.array(sorted(sub['y0'].unique()))

    ix = {x0:i for i,x0 in enumerate(xs)}
    iy = {y0:i for i,y0 in enumerate(ys)}

    H, W = len(ys), len(xs)
    r_grid  = np.full((H, W), np.nan, dtype=float)
    z_grid  = np.full((H, W), np.nan, dtype=float)
    x0_grid = np.full((H, W), np.nan, dtype=float)
    x1_grid = np.full((H, W), np.nan, dtype=float)
    y0_grid = np.full((H, W), np.nan, dtype=float)
    y1_grid = np.full((H, W), np.nan, dtype=float)

    for _, row in sub.iterrows():
        i = iy[row['y0']]
        j = ix[row['x0']]
        r = row.get(r_col, np.nan)
        r_grid[i, j] = r
        z_grid[i, j] = fisher_z(r)
        x0_grid[i, j] = row['x0']; x1_grid[i, j] = row['x1']
        y0_grid[i, j] = row['y0']; y1_grid[i, j] = row['y1']

    return {
        'xs': xs, 'ys': ys,
        'r_grid': r_grid, 'z_grid': z_grid,
        'x0_grid': x0_grid, 'x1_grid': x1_grid,
        'y0_grid': y0_grid, 'y1_grid': y1_grid
    }

def _label_8_connected(mask):
    """8-connected component labeling (NaNs already excluded in mask)."""
    structure = np.ones((3,3), dtype=int)   # 8-connectivity
    labels, n = ndimage.label(mask, structure=structure)
    return labels, n

# -------------------------------
# Main: find islands for all pairs
# -------------------------------

def find_islands_for_all_pairs(
    results_df,
    r_threshold=0.5,           # threshold in r; internally converted to z
    r_col='weighted_pearson_r',
    min_windows=1              # drop tiny islands if desired
):
    """
    Build 8-connected 'islands' per (class_A, class_B).
    Returns: list of island dicts + an index table for quick lookup.
    """
    islands = []
    index_rows = []

    z_thr = float(fisher_z(r_threshold))

    pairs = (
        results_df[['class_A','class_B']]
        .drop_duplicates()
        .sort_values(['class_A','class_B'])
        .itertuples(index=False, name=None)
    )

    for (A, B) in pairs:
        grids = _build_pair_grids(results_df, A, B, r_col=r_col)
        if grids is None:
            continue

        rG = grids['r_grid']; zG = grids['z_grid']
        # Valid if r is finite AND above threshold
        valid = np.isfinite(rG) & (rG > r_threshold)

        if not np.any(valid):
            # No islands at this pair
            continue

        labels, nlab = _label_8_connected(valid)

        for lab in range(1, nlab+1):
            mask = (labels == lab)
            size = int(mask.sum())
            if size < min_windows:
                continue

            # Extract stats
            r_vals = rG[mask]
            z_vals = zG[mask]

            # Cluster "mass" above threshold in z-space (strength + extent)
            cluster_mass = float(np.nansum(z_vals - z_thr))

            # Spatial bbox from window rectangles
            x0_min = float(np.nanmin(grids['x0_grid'][mask]))
            x1_max = float(np.nanmax(grids['x1_grid'][mask]))
            y0_min = float(np.nanmin(grids['y0_grid'][mask]))
            y1_max = float(np.nanmax(grids['y1_grid'][mask]))

            # Collect window rects (optional, handy for downstream)
            # (x0,x1,y0,y1) list for all member windows
            # NOTE: if you want to keep indices instead, you can store np.argwhere(mask)
            member_rects = np.column_stack([
                grids['x0_grid'][mask],
                grids['x1_grid'][mask],
                grids['y0_grid'][mask],
                grids['y1_grid'][mask],
            ]).tolist()

            island = {
                'pair': (A, B),
                'label': lab,
                'n_windows': size,
                'median_r': float(np.nanmedian(r_vals)),
                'mean_r': float(np.nanmean(r_vals)),
                'max_r': float(np.nanmax(r_vals)),
                'median_z': float(np.nanmedian(z_vals)),
                'cluster_mass_z': cluster_mass,
                'bbox': (x0_min, x1_max, y0_min, y1_max),
                'window_rects': member_rects,
                'grid_shape': rG.shape,
                'grid_x0s': grids['xs'].tolist(),
                'grid_y0s': grids['ys'].tolist(),
                # Placeholders for later inference / stability:
                'cluster_p': None,      # TODO: fill via permutation-based max-cluster test
                'stability': None       # TODO: fill via bootstrap frequency / IoU
            }
            islands.append(island)
            index_rows.append({
                'class_A': A, 'class_B': B, 'label': lab,
                'n_windows': size,
                'bbox_x0': x0_min, 'bbox_x1': x1_max,
                'bbox_y0': y0_min, 'bbox_y1': y1_max,
                'cluster_mass_z': cluster_mass,
                'median_r': island['median_r']
            })

    # Lightweight index DataFrame for quick filtering/sorting
    island_index = pd.DataFrame(index_rows).sort_values(
        ['class_A','class_B','cluster_mass_z','n_windows'],
        ascending=[True, True, False, False]
    ).reset_index(drop=True)

    return islands, island_index




In [6]:
islands, island_index = find_islands_for_all_pairs(results_df,
                                                   r_threshold=0.7,
                                                   r_col='weighted_pearson_r',
                                                   min_windows=4)
island_index.head()

In [7]:
islands[1]['window_rects'][0]
wx1 = islands[1]['window_rects'][0][0]
wx2 = islands[1]['window_rects'][0][1]
wy1 = islands[1]['window_rects'][0][2]
wy2 = islands[1]['window_rects'][0][3]

In [8]:
import numpy as np
import pandas as pd
from collections import defaultdict

# -------------------------------------------------------------
# Build a fast lookup of island windows per (A,B) as set of (x0,y0)
# -------------------------------------------------------------
def _island_windows_by_pair(islands):
    pair_to_windows = defaultdict(set)
    pair_to_axes = {}
    for isl in islands:
        A, B = isl['pair']
        # store grid axes (assumes consistent grid across islands; OK if repeated)
        pair_to_axes[(A,B)] = (np.array(isl['grid_x0s']), np.array(isl['grid_y0s']))
        for (x0,x1,y0,y1) in isl['window_rects']:
            pair_to_windows[(A,B)].add((float(x0), float(y0)))
    return pair_to_windows, pair_to_axes

# -------------------------------------------------------------
# Extract the global window grid (xs, ys) and win_size from results_df
# (assumes a consistent grid used when computing r maps)
# -------------------------------------------------------------
def _extract_grid_from_results(results_df):
    xs = np.array(sorted(results_df['x0'].unique()), dtype=float)
    ys = np.array(sorted(results_df['y0'].unique()), dtype=float)
    # infer win_size from first row
    r0 = results_df.iloc[0]
    win_size_x = float(r0['x1'] - r0['x0'])
    win_size_y = float(r0['y1'] - r0['y0'])
    assert np.isclose(win_size_x, win_size_y), "Non-square windows not supported here."
    return xs, ys, win_size_x

# -------------------------------------------------------------
# Find all window (x0,y0) that cover a given point (x,y)
# -------------------------------------------------------------
def _covering_windows(x, y, xs, ys, win_size):
    # windows with x0 <= x < x0+win_size and y0 <= y < y0+win_size
    x_mask = (xs <= x) & (x < xs + win_size)
    y_mask = (ys <= y) & (y < ys + win_size)
    xi = np.where(x_mask)[0]
    yi = np.where(y_mask)[0]
    # Cartesian product of indices
    return [(float(xs[j]), float(ys[i])) for i in yi for j in xi]

# -------------------------------------------------------------
# Main: per-cell binary encoding + auxiliary table with coverage
# -------------------------------------------------------------
def encode_cell_colocalization(allexp, results_df, islands, theta=0.5, k=3, cell_colname='class'):
    """
    For each cell of type A:
      - collect all sliding windows covering the cell (support)
      - for each partner B != A: compute coverage = (#covering windows that are island windows for (A,B)) / (#covering windows)
      - set binary 1 if coverage >= theta and support >= k; else 0
    Returns:
      binary_mat: DataFrame [n_cells x n_types]
      aux:        DataFrame with per-cell diagnostics (support, per-partner coverage)
    """
    xs, ys, win_size = _extract_grid_from_results(results_df)
    pair_to_windows, _ = _island_windows_by_pair(islands)

    cell_types = sorted(allexp[cell_colname].unique())
    n = len(allexp)

    # Prepare outputs
    bin_data = {t: np.zeros(n, dtype=int) for t in cell_types}  # self-col will stay 0
    aux_rows = []

    # Precompute for speed: all covering windows per cell
    # (keeps exact geometry; typical stride = win_size/2 => up to 4 windows per cell)
    all_cover = []
    for idx, row in allexp[['x','y']].iterrows():
        cov = _covering_windows(float(row['x']), float(row['y']), xs, ys, win_size)
        all_cover.append(cov)

    # Compute encoding
    for idx, row in allexp.iterrows():
        A = row[cell_colname]
        covered = all_cover[idx]
        support = len(covered)

        # Per-partner coverage tracker
        cov_map = {}

        if support >= k:
            covered_set = set(covered)
            for B in cell_types:
                if B == A:
                    cov_map[B] = 0.0
                    continue
                # Island windows for pair (A,B) (order-agnostic lookup)
                key = (A,B) if (A,B) in pair_to_windows else (B,A)
                if key not in pair_to_windows:
                    cov_map[B] = 0.0
                    continue
                island_windows = pair_to_windows[key]
                hit = len(covered_set & island_windows)
                coverage = hit / support if support > 0 else 0.0
                cov_map[B] = coverage
                # binary decision
                if coverage >= theta:
                    bin_data[B][idx] = 1
        else:
            # insufficient support: keep zeros, record coverage as 0 for all partners
            for B in cell_types:
                cov_map[B] = 0.0

        # Build aux row
        aux_row = {
            'cell_index': idx,
            'x': float(row['x']),
            'y': float(row['y']),
            'cell_type': A,
            'support_windows': support
        }
        # add per-partner coverages (e.g., coverage_B)
        for B in cell_types:
            if B == A:
                aux_row[f'coverage_{B}'] = 0.0
            else:
                aux_row[f'coverage_{B}'] = float(cov_map[B])
        aux_rows.append(aux_row)

    binary_mat = pd.DataFrame(bin_data, index=allexp.index)
    # enforce self-col = 0 (safety)
    for A in cell_types:
        binary_mat.loc[allexp[cell_colname] == A, A] = 0

    aux = pd.DataFrame(aux_rows).set_index('cell_index').loc[allexp.index]

    return binary_mat, aux


In [9]:
# Example:
binary_mat, aux = encode_cell_colocalization(allexp, results_df, islands, theta=0.5, k=3)

In [10]:
binary_mat['cell_label'] = allexp['cell_label'].values
aux['cell_label'] = allexp['cell_label'].values
# make cell_label index
binary_mat = binary_mat.set_index('cell_label')
aux = aux.set_index('cell_label')

In [11]:
# keep allexp columns from 22 to end except last 50
allexp_sub = allexp.iloc[:, 22:-50]

allexp_sub.shape

# Feature Matrix Formation

## Ligand-Receptor Genes Calculation

In [12]:
LR_pairs = pd.read_csv('../data/mouse_850_lr_pairs_cpdb_interactions.csv')
LR_pairs.shape

In [None]:
import pandas as pd
import numpy as np

def prepare_lr_features(
    allexp: pd.DataFrame,
    lr_pairs: pd.DataFrame,
    ligand_col: str = "ligand_genesymbol",   # adjust if your column names differ
    receptor_col: str = "target_genesymbol", 
    meta_cols = ("x","y","class")            # non-gene columns in allexp to ignore
):
    """
    Build per-cell receptor and ligand expression matrices for exactly the LR pairs provided.

    Inputs
    ------
    allexp : DataFrame
        Rows = cells (index = your cell ids/labels), columns include gene expression + meta columns.
    lr_pairs : DataFrame
        Must contain columns with ligand and receptor gene symbols (one row per LR pair).
    ligand_col, receptor_col : str
        Column names in lr_pairs for ligand and receptor symbols.
    meta_cols : iterable
        Columns in allexp that are NOT genes (will be excluded).

    Outputs
    -------
    X_receptors : DataFrame  (cells × unique_receptors_kept)
    X_ligands   : DataFrame  (cells × unique_ligands_kept)
    lr_pairs_kept : DataFrame (filtered to pairs present in allexp, with integer columns
                     'ligand_idx' and 'receptor_idx' giving column positions in X_ligands/X_receptors)
    report : dict  (counts and lists of dropped/mapped genes)
    """
    # --- sanitize / standardize gene symbols (upper-case) to improve matching ---
    def _upper_series(s):
        return s.astype(str).str.strip().str.upper()

    # Make a copy and upper-case gene names in allexp columns (genes only)
    allexp_cols = pd.Index(allexp.columns)
    meta_cols = [c for c in meta_cols if c in allexp_cols]
    gene_cols = [c for c in allexp_cols if c not in meta_cols]

    # Build a mapping original->UPPER for allexp gene columns
    gene_cols_upper = pd.Index([str(c).upper() for c in gene_cols])
    colmap = dict(zip(gene_cols_upper, gene_cols))  # UPPER -> original

    # Upper-case ligand/receptor symbols
    lig_syms = _upper_series(lr_pairs[ligand_col])
    rec_syms = _upper_series(lr_pairs[receptor_col])

    # Compose a filtered LR table with upper-cased symbols
    lr_uc = lr_pairs.copy()
    lr_uc["_LIG"] = lig_syms
    lr_uc["_REC"] = rec_syms

    # --- keep only pairs whose both genes are present in allexp ---
    # present_lig = gene_cols_upper.isin(lr_uc["_LIG"]).to_numpy()
    # present_rec = gene_cols_upper.isin(lr_uc["_REC"]).to_numpy()
    genes_in_allexp_uc = set(gene_cols_upper)

    keep_mask = lr_uc["_LIG"].isin(genes_in_allexp_uc) & lr_uc["_REC"].isin(genes_in_allexp_uc)
    lr_pairs_kept = lr_uc.loc[keep_mask].reset_index(drop=True)

    # Unique ligands/receptors actually present
    uniq_lig_uc = list(dict.fromkeys(lr_pairs_kept["_LIG"]))  # preserve order of first appearance
    uniq_rec_uc = list(dict.fromkeys(lr_pairs_kept["_REC"]))

    # Map back to original column names in allexp
    uniq_lig_cols = [colmap[g] for g in uniq_lig_uc]
    uniq_rec_cols = [colmap[g] for g in uniq_rec_uc]

    # --- slice allexp into ligand/receptor matrices (cells × genes) ---
    X_ligands = allexp.loc[:, uniq_lig_cols].copy()
    X_receptors = allexp.loc[:, uniq_rec_cols].copy()

    # Optional: ensure numeric dtype
    X_ligands = X_ligands.apply(pd.to_numeric, errors="coerce")
    X_receptors = X_receptors.apply(pd.to_numeric, errors="coerce")

    # --- annotate lr_pairs_kept with column indices into X_ligands / X_receptors ---
    lig_idx_map = {g: i for i, g in enumerate(uniq_lig_cols)}
    rec_idx_map = {g: i for i, g in enumerate(uniq_rec_cols)}

    lr_pairs_kept["ligand_symbol_uc"] = lr_pairs_kept["_LIG"]
    lr_pairs_kept["receptor_symbol_uc"] = lr_pairs_kept["_REC"]
    lr_pairs_kept["ligand_symbol"] = lr_pairs_kept["_LIG"].map(colmap)
    lr_pairs_kept["receptor_symbol"] = lr_pairs_kept["_REC"].map(colmap)
    lr_pairs_kept["ligand_idx"] = lr_pairs_kept["ligand_symbol"].map(lig_idx_map)
    lr_pairs_kept["receptor_idx"] = lr_pairs_kept["receptor_symbol"].map(rec_idx_map)

    # --- reporting ---
    dropped_pairs = lr_uc.loc[~keep_mask, [ligand_col, receptor_col]]
    missing_ligs = sorted(set(lr_uc["_LIG"]) - genes_in_allexp_uc)
    missing_recs = sorted(set(lr_uc["_REC"]) - genes_in_allexp_uc)

    report = {
        "n_pairs_input": int(len(lr_pairs)),
        "n_pairs_kept": int(len(lr_pairs_kept)),
        "n_unique_ligands_kept": int(len(uniq_lig_cols)),
        "n_unique_receptors_kept": int(len(uniq_rec_cols)),
        "missing_ligands_from_allexp": [colmap.get(g, g) for g in missing_ligs],  # best effort
        "missing_receptors_from_allexp": [colmap.get(g, g) for g in missing_recs],
        "dropped_pairs": dropped_pairs
    }

    # Clean columns for return
    lr_pairs_kept = lr_pairs_kept.drop(columns=["_LIG","_REC"])

    return X_receptors, X_ligands, lr_pairs_kept, report


In [14]:
# ---------------------------
# Example usage:
X_receptors, X_ligands, lr_pairs_kept, report = prepare_lr_features(allexp, LR_pairs,
    ligand_col="ligand_genesymbol", receptor_col="target_genesymbol", meta_cols=[])
# ---------------------------

## Ligand Exposure Calculation

In [15]:
# ---- Reuse your helpers (already defined above) -----------------------------
# _extract_grid_from_results, _island_windows_by_pair, _covering_windows

def _preindex_window_cells(allexp, xs, ys, win_size, class_col="class"):
    """
    Pre-index cells by (x0,y0) window and by cell type for fast neighbor lookup.
    Returns: dict[(x0,y0)] -> dict[class_name] -> np.array(cell_indices)
    """
    x = allexp["x"].to_numpy(float)
    y = allexp["y"].to_numpy(float)
    classes = allexp[class_col].astype(str).to_numpy()

    # For each window start (x0,y0), build a boolean mask of cells inside it
    win_index = {}
    for x0 in xs:
        x1 = x0 + win_size
        x_mask = (x >= x0) & (x < x1)
        for y0 in ys:
            y1 = y0 + win_size
            y_mask = (y >= y0) & (y < y1)
            idx = np.where(x_mask & y_mask)[0]
            if idx.size == 0:
                continue
            # bucket by class for this window
            by_class = {}
            for c in np.unique(classes[idx]):
                by_class[c] = idx[classes[idx] == c]
            win_index[(float(x0), float(y0))] = by_class
    return win_index


def compute_ligand_exposure(
    allexp: pd.DataFrame,
    X_ligands: pd.DataFrame,
    results_df: pd.DataFrame,
    islands: list,
    aux: pd.DataFrame,
    theta: float = 0.5,              # coverage threshold to deem A↔B island-sharing
    k_support: int = 3,              # min #covering windows for the receiver cell
    mode: str = "mean",              # "mean" or "kde"
    sigma: float | None = None,      # KDE sigma in coordinate units; default = win_size/3
    class_col: str = "class"
) -> pd.DataFrame:
    """
    Step 2 — Compute ligand exposure via colocalization islands.

    For each receiver cell i:
      1) Find partner types B with aux.loc[i, f"coverage_{B}"] >= theta AND support_windows >= k_support
      2) Collect neighbors = union of cells of those B-types that fall in ANY sliding window covering i
         AND that window is an island window for the (A_i, B) pair.
      3) Exposure for each ligand gene g = average (mode="mean") or KDE-weighted average (mode="kde")
         of X_ligands[g] across the neighbor cells.

    Returns:
      X_exposure: DataFrame (cells × ligand genes), index aligned with allexp / X_ligands
    """
    # --- grid and island caches ---
    xs, ys, win_size = _extract_grid_from_results(results_df)
    pair_to_windows, _ = _island_windows_by_pair(islands)  # {(A,B): set[(x0,y0)], ...}

    # pre-index cells per (x0,y0) window and class for quick union queries
    win_index = _preindex_window_cells(allexp, xs, ys, win_size, class_col=class_col)

    # map from aux columns "coverage_B" → B
    coverage_cols = {
        c.replace("coverage_", ""): c
        for c in aux.columns if c.startswith("coverage_")
    }

    # output matrix
    X_exposure = pd.DataFrame(0.0, index=allexp.index, columns=X_ligands.columns)

    # coords & class arrays for KDE mode
    coord = allexp[["x", "y"]].to_numpy(float)
    classes = allexp[class_col].astype(str).to_numpy()

    # default sigma ~ window size / 3 (smooth inside an island window)
    if sigma is None:
        sigma = win_size / 3.0 if np.isfinite(win_size) and win_size > 0 else 1.0
    inv2sig2 = 1.0 / (2.0 * (sigma ** 2))

    # --- main loop over receiver cells ---
    for i in range(len(allexp)):
        A = classes[i]
        # support (how many windows cover this cell)
        covered_i = _covering_windows(coord[i,0], coord[i,1], xs, ys, win_size)
        support = len(covered_i)
        if support < k_support:
            continue  # exposure stays 0

        # partner types B that share island with i (per aux coverage threshold)
        eligible_B = [B for B, ccol in coverage_cols.items()
                      if B != A and aux.iloc[i][ccol] >= theta]

        if not eligible_B:
            continue

        # collect neighbors: union over B and over windows that cover i,
        # intersected with island windows for (A,B)
        nbr_idx = set()
        covered_set = set(covered_i)
        for B in eligible_B:
            key = (A, B) if (A, B) in pair_to_windows else (B, A)
            if key not in pair_to_windows:
                continue
            # windows that both cover i AND are island windows for (A,B)
            isl_windows = pair_to_windows[key] & covered_set
            if not isl_windows:
                continue
            # gather all B-type cells inside those windows
            for w in isl_windows:
                by_class = win_index.get(w)
                if not by_class:
                    continue
                idxB = by_class.get(B)
                if idxB is not None and idxB.size:
                    nbr_idx.update(idxB.tolist())

        if not nbr_idx:
            continue

        nbr_idx = np.fromiter(nbr_idx, dtype=int, count=len(nbr_idx))

        if mode == "mean":
            # simple average across eligible neighbors
            X_exposure.iloc[i, :] = X_ligands.iloc[nbr_idx, :].mean(axis=0).fillna(0.0).to_numpy()

        elif mode == "kde":
            # Gaussian weights by distance to receiver cell i (within chosen sigma)
            d2 = np.sum((coord[nbr_idx] - coord[i])**2, axis=1)  # squared distances
            w = np.exp(-d2 * inv2sig2)
            w_sum = np.sum(w)
            if w_sum <= 0:
                continue
            w = w / w_sum
            # weighted average per ligand gene
            # (vectorized: neighbors × genes  @ weights)
            X_exposure.iloc[i, :] = np.dot(w, X_ligands.iloc[nbr_idx, :].to_numpy())

        else:
            raise ValueError("mode must be 'mean' or 'kde'")

    return X_exposure


In [16]:
# Pick a threshold consistent with your binary encoding step
theta = 0.5
k_support = 3

# Mean-based exposure (fast, good baseline)
X_exposure_mean = compute_ligand_exposure(
    allexp=allexp,
    X_ligands=X_ligands,
    results_df=results_df,     # from your (weighted) PCC stage
    islands=islands,           # from find_islands_for_all_pairs(...)
    aux=aux,                   # has coverage_* and support_windows
    theta=theta,
    k_support=k_support,
    mode="mean"
)

# KDE-weighted exposure (distance-weighted inside island windows)
# X_exposure_kde = compute_ligand_exposure(
#     allexp=allexp,
#     X_ligands=X_ligands,
#     results_df=results_df,
#     islands=islands,
#     aux=aux,
#     theta=theta,
#     k_support=k_support,
#     mode="kde",
#     sigma=None   # defaults to win_size/3
# )

# Quick QC:
print("Nonzero exposure frac (mean):",
      (X_exposure_mean.values > 0).mean())
# print("Nonzero exposure frac (kde):",
#       (X_exposure_kde.values > 0).mean())


## LR Interaction Scoring

In [17]:
import numpy as np
import pandas as pd

def build_lr_interaction_features(
    X_receptors: pd.DataFrame,
    X_exposure: pd.DataFrame,          # ligand *exposure* matrix (not raw ligand expr)
    lr_pairs_kept: pd.DataFrame,       # from Step 1 (already filtered to present genes)
    ligand_col: str = "ligand_symbol",
    receptor_col: str = "receptor_symbol",
    method: str = "product",           # "product" or "min"
    suffix: str = ""                   # optional suffix for column names, e.g., "_prod"
):
    """
    Step 3 — Build LR interaction features.

    Inputs
    ------
    X_receptors : DataFrame (cells × unique receptor genes)
    X_exposure  : DataFrame (cells × unique ligand genes)  [ligand *exposure*]
    lr_pairs_kept : DataFrame containing at least [ligand_col, receptor_col]
                    and (optionally) integer columns 'ligand_idx','receptor_idx'
                    that index into X_exposure / X_receptors respectively.
    method : "product" (R * Lexp) or "min" (min(R, Lexp))
    suffix : optional string appended to interaction column names.

    Returns
    -------
    X_LR  : DataFrame (cells × n_pairs), interaction per LR pair
    X_aux : DataFrame with receptor-only and ligand-exposure-only cols used
    meta  : dict with bookkeeping (pair->indices, method)
    """
    # --- Resolve indices for each pair into the receptor/exposure matrices ---
    # Prefer the precomputed indices from Step 1 if available (fast, robust).
    have_idx = {"ligand_idx" in lr_pairs_kept.columns,
                "receptor_idx" in lr_pairs_kept.columns}
    have_idx = all(have_idx)

    # Maps for name→position (fallback if indices absent)
    rec_pos = {g: i for i, g in enumerate(X_receptors.columns)}
    lig_pos = {g: i for i, g in enumerate(X_exposure.columns)}

    # Build ordered lists of positions and names aligned to lr_pairs_kept rows
    lig_names = []
    rec_names = []
    lig_idx = []
    rec_idx = []

    for _, row in lr_pairs_kept.iterrows():
        L = row[ligand_col]
        R = row[receptor_col]

        # If indices were carried from Step 1, use them; else resolve by column name.
        if have_idx:
            li = int(row["ligand_idx"])
            ri = int(row["receptor_idx"])
            # sanity: ensure columns still match names
            assert X_exposure.columns[li] == L, f"Ligand index/name mismatch: {L}"
            assert X_receptors.columns[ri] == R, f"Receptor index/name mismatch: {R}"
        else:
            # Fallback resolution via column names
            if L not in lig_pos or R not in rec_pos:
                # Skip pairs whose genes are missing (should be rare after Step 1 filtering)
                continue
            li = lig_pos[L]
            ri = rec_pos[R]

        lig_names.append(L)
        rec_names.append(R)
        lig_idx.append(li)
        rec_idx.append(ri)

    n_pairs = len(lig_idx)
    if n_pairs == 0:
        raise ValueError("No LR pairs could be aligned to X_receptors/X_exposure.")

    # --- Pull the arrays we need (vectorized over cells × pairs) ---
    # R: receptor expression for each pair
    R = X_receptors.iloc[:, rec_idx].to_numpy(dtype=float)   # shape: (cells, n_pairs)
    # Lexp: ligand *exposure* for each pair
    Lexp = X_exposure.iloc[:, lig_idx].to_numpy(dtype=float) # shape: (cells, n_pairs)

    # --- Interaction function ---
    if method.lower() == "product":
        S = R * Lexp
        method_used = "product"
        suf = suffix or "_prod"
    elif method.lower() == "min":
        S = np.minimum(R, Lexp)
        method_used = "min"
        suf = suffix or "_min"
    else:
        raise ValueError("method must be 'product' or 'min'")

    # --- Build nice column names like 'LIGAND|RECEPTOR_prod' ---
    lr_labels = [f"{L}|{R}{suf}" for L, R in zip(lig_names, rec_names)]
    X_LR = pd.DataFrame(S, index=X_receptors.index, columns=lr_labels)

    # --- Auxiliary matrices used in these interactions (optional but useful) ---
    # receptor-only features (just those receptors that appear in pairs, once)
    # ligand-only (exposure) features (just those ligands that appear in pairs, once)
    # We keep column order matched to the *first* time a gene appears in lr_pairs_kept.
    uniq_rec_ordered = list(dict.fromkeys(rec_names))
    uniq_lig_ordered = list(dict.fromkeys(lig_names))

    X_rec_used  = X_receptors.loc[:, uniq_rec_ordered].copy()
    X_lig_used  = X_exposure.loc[:, uniq_lig_ordered].copy()
    # Prefix to make it explicit in modeling design matrices
    X_rec_used.columns = [f"R__{c}"     for c in X_rec_used.columns]
    X_lig_used.columns = [f"Lexp__{c}"  for c in X_lig_used.columns]

    X_aux = pd.concat([X_rec_used, X_lig_used], axis=1)

    meta = {
        "method": method_used,
        "pairs": list(zip(lig_names, rec_names)),
        "ligand_indices": lig_idx,
        "receptor_indices": rec_idx,
        "n_pairs": n_pairs
    }
    return X_LR, X_aux, meta


In [18]:
# ---------------------------
# Example usage
# ---------------------------
X_LR_prod, X_aux, meta = build_lr_interaction_features(
    X_receptors=X_receptors,
    X_exposure=X_exposure_mean,         # or X_exposure_kde
    lr_pairs_kept=lr_pairs_kept,
    ligand_col="ligand_symbol",
    receptor_col="receptor_symbol",
    method="product"
)

# X_LR_min, X_aux_min, meta_min = build_lr_interaction_features(..., method="min")

# Quick sanity checks:
# 1) No NaNs, non-negativity if inputs are non-negative
assert np.isfinite(X_LR_prod.to_numpy()).all()
# 2) Column count equals number of kept pairs
assert X_LR_prod.shape[1] == meta["n_pairs"]

## Gene Filtering

Remove ligands, receptors and bottom 5% genes in terms of variance

In [19]:
ligands = LR_pairs['ligand_genesymbol'].unique()
receptors = LR_pairs['target_genesymbol'].unique()

In [20]:
aux = aux.reset_index().rename(columns={'index': 'cell_label'})

In [21]:
# You should already have these from Steps 1–3 & island encoding:
# - allexp: DataFrame (cells × [meta + genes]); index = cell ids; columns include 'x','y','class','cell_label' etc.
# - X_receptors: DataFrame (cells × receptors)              # Step 1
# - X_exposure:  DataFrame (cells × ligands)                # Step 2
# - X_LR:        DataFrame (cells × LR-pair interactions)   # Step 3
# - aux:         DataFrame with 'support_windows' and columns starting with 'coverage_'
# - lr_pairs_kept: DataFrame with 'ligand_symbol', 'receptor_symbol'

# Sanity:
for name in ["allexp", "X_receptors", "X_exposure_mean", "X_LR_prod", "aux", "lr_pairs_kept"]:
    assert name in globals(), f"Missing variable: {name}"

# Align row order across all matrices by allexp:
idx = allexp.index
X_receptors = X_receptors.loc[idx]
X_exposure  = X_exposure_mean.loc[idx]
X_LR        = X_LR_prod.loc[idx]
aux         = aux.loc[idx]


In [22]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import re

# CONFIG
META_COLS          = ["x","y","class","cell_label"]  # adjust if you have different meta cols
DROP_LIGANDS       = True     # recommended to avoid leakage; set False if you want to include ligands as targets
MIN_DETECT_FRAC    = 0.01     # drop targets detected in <1% of cells
MIN_VAR_QUANT      = 0.10     # drop bottom 10% variance targets
DROP_TECHNICALS    = True     # drop mito/ribo/hb-like genes if present
COV_PREFIX         = "coverage_"

def get_gene_matrix(allexp: pd.DataFrame, meta_cols=META_COLS) -> pd.DataFrame:
    """Return cells × genes numeric matrix by dropping meta columns."""
    meta_cols = [c for c in meta_cols if c in allexp.columns]
    gene_cols = [c for c in allexp.columns if c not in meta_cols]
    expr = allexp.loc[:, gene_cols].apply(pd.to_numeric, errors="coerce").fillna(0.0)
    return expr

def zscore_df(df: pd.DataFrame):
    """Z-score columns; returns (Z, scaler)."""
    sc = StandardScaler(with_mean=True, with_std=True)
    Z = sc.fit_transform(df.values)
    return pd.DataFrame(Z, index=df.index, columns=df.columns), sc


In [23]:
# 1) Full cells×genes matrix
expr_all = get_gene_matrix(allexp_sub, meta_cols=META_COLS)

# 2) Receptor & ligand sets from your LR list
receptors = sorted(set(lr_pairs_kept["receptor_symbol"]))
ligands   = sorted(set(lr_pairs_kept["ligand_symbol"]))

# 3) Build the drop list for targets
drop = set(expr_all.columns).intersection(receptors)
if DROP_LIGANDS:
    drop |= set(expr_all.columns).intersection(ligands)

# 4) Filter low detection / low variance
det_frac = (expr_all > 0).mean(axis=0)                            # fraction of cells with nonzero expression
var_g    = expr_all.var(axis=0)
low_det  = det_frac[det_frac < MIN_DETECT_FRAC].index
low_var  = var_g[var_g < var_g.quantile(MIN_VAR_QUANT)].index
drop |= set(low_det) | set(low_var)

# 5) Optional: remove likely technical genes if present among your 500
if DROP_TECHNICALS:
    tech_re = re.compile(r"^(MT-|mt-|RPL|RPS|HBA|HBB)")
    tech = [g for g in expr_all.columns if tech_re.match(g)]
    drop |= set(tech)

# 6) Final targets
target_genes = [g for g in expr_all.columns if g not in drop]
Y_targets = expr_all.loc[idx, target_genes].copy()

print(f"Targets kept: {Y_targets.shape[1]} genes (from {expr_all.shape[1]} total).")


In [24]:
# Use every coverage_* column (broader approach)
coverage_cols_all = [c for c in aux.columns if c.startswith(COV_PREFIX)]
assert len(coverage_cols_all) > 0, "No coverage_* columns in aux."

X_cov = aux[coverage_cols_all].copy()
# Optional rename to cleaner names for modeling; keep originals if you prefer
X_cov.columns = [f"cov::{c.replace(COV_PREFIX,'')}" for c in coverage_cols_all]

# Sample weights from window support (use later in training)
sample_weight = aux["support_windows"].clip(lower=1).to_numpy()

In [25]:
# Z-score each block independently (keep scalers to transform CV folds later)
X_receptors_z, sc_R = zscore_df(X_receptors)
X_exposure_z,  sc_E = zscore_df(X_exposure)
X_LR_z,        sc_I = zscore_df(X_LR)
X_cov_z,       sc_C = zscore_df(X_cov)

# Concatenate in fixed order
X = pd.concat([X_receptors_z, X_exposure_z, X_LR_z, X_cov_z], axis=1)

print("X shape:", X.shape)                # (cells × features)
print("Y_targets shape:", Y_targets.shape) # (cells × target genes)


In [26]:
# Basic sanity
import numpy as np
assert np.isfinite(X.values).all(), "Non-finite values in X"
assert np.isfinite(Y_targets.values).all(), "Non-finite values in Y_targets"

# Sparsity snapshots (pre-zscore)
print("Nonzero frac (receptors):", (X_receptors.values != 0).mean())
print("Nonzero frac (exposure) :", (X_exposure.values  != 0).mean())
print("Nonzero frac (LR)       :", (X_LR.values        != 0).mean())
print("Mean coverage (X_cov)   :", X_cov.mean().mean())

# Ready for Step 5: spatial block CV + multi-task Elastic Net / RF


# Model Training (ElasticNet: X_rec, X_exp, X_LR, X_cov)

In [None]:
import numpy as np
import pandas as pd

from sklearn.cluster import KMeans
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.metrics import r2_score
from itertools import product

# ---- Repro & CV/Test config ----
SEED = 42
N_GROUPS = 10          # number of spatial groups to form from (x,y)
TEST_FRACTION = 0.2    # ~20% groups as final test set
N_SPLITS = 5           # GroupKFold folds on the dev set

# ---- Elastic Net search grid ----
ALPHAS = np.logspace(-4, 1, 8)      # 1e-4 ... 10
L1S    = [0.1, 0.5, 0.9]            # l1_ratio

# ---- Metric aggregation choice ----
AGG = "mean"   # "mean" or "median" R^2 across targets
USE_WEIGHTS_IN_SCORING = True  # use support_windows as weights in metrics


In [None]:
# (x,y) → spatial groups
coords = allexp[["x","y"]].to_numpy(dtype=float)

km = KMeans(n_clusters=N_GROUPS, n_init=10, random_state=SEED)
group_labels = km.fit_predict(coords)  # 0..N_GROUPS-1, one label per cell

# Decide test groups (≈ TEST_FRACTION of groups)
rng = np.random.default_rng(SEED)
unique_groups = np.arange(N_GROUPS)
n_test_groups = max(1, int(round(TEST_FRACTION * N_GROUPS)))
test_groups = rng.choice(unique_groups, size=n_test_groups, replace=False)

is_test = np.isin(group_labels, test_groups)
is_dev  = ~is_test

print("Groups:", N_GROUPS, "| Test groups:", sorted(test_groups.tolist()))
print("Dev cells:", is_dev.sum(), " Test cells:", is_test.sum())

In [None]:
def fit_transform_split(X, Y, train_idx, val_idx):
    """
    Fit scalers on TRAIN only; transform train/val for both X and Y.
    Returns: Xtr, Xva, Ytr, Yva, scalers (sx, sy)
    """
    sx = StandardScaler(with_mean=True, with_std=True)
    sy = StandardScaler(with_mean=True, with_std=True)

    Xtr = sx.fit_transform(X[train_idx])
    Xva = sx.transform(X[val_idx])

    Ytr = sy.fit_transform(Y[train_idx])
    Yva = sy.transform(Y[val_idx])

    return Xtr, Xva, Ytr, Yva, sx, sy

def weighted_r2_per_target(y_true, y_pred, sample_weight=None):
    """
    R² per target column (multioutput). sklearn's r2_score with multioutput=None
    for each column; supports sample weights if provided.
    """
    T = y_true.shape[1]
    r2s = np.empty(T, dtype=float)
    for t in range(T):
        r2s[t] = r2_score(y_true[:, t], y_pred[:, t], sample_weight=sample_weight)
    return r2s

def aggregate_scores(r2_vec, agg="mean"):
    return float(np.nanmean(r2_vec)) if agg == "mean" else float(np.nanmedian(r2_vec))


In [None]:
# Slice dev/test once
X_all = X.to_numpy(dtype=float)
Y_all = Y_targets.to_numpy(dtype=float)

groups_dev = group_labels[is_dev]
X_dev, Y_dev = X_all[is_dev], Y_all[is_dev]
weights_all = aux["support_windows"].to_numpy()
w_dev = weights_all[is_dev] if USE_WEIGHTS_IN_SCORING else None

gkf = GroupKFold(n_splits=N_SPLITS)

cv_results = []  # collect dicts for a summary table

for alpha, l1 in product(ALPHAS, L1S):
    fold_scores = []
    per_target_scores = []  # optional: store mean per-target R² across folds too

    for tr_idx, va_idx in gkf.split(X_dev, groups=groups_dev):
        # Train/val split indexes relative to DEV subset
        Xtr, Xva, Ytr, Yva, sx, sy = fit_transform_split(X_dev, Y_dev, tr_idx, va_idx)

        # Model
        model = MultiTaskElasticNet(
            alpha=alpha,
            l1_ratio=l1,
            fit_intercept=False,   # we already standardized
            max_iter=5000,
            random_state=SEED,
            selection="cyclic"
        )
        model.fit(Xtr, Ytr)

        # Predict and score
        Yhat = model.predict(Xva)
        sw = w_dev[va_idx] if w_dev is not None else None
        r2_t = weighted_r2_per_target(Yva, Yhat, sample_weight=sw)
        fold_scores.append(aggregate_scores(r2_t, AGG))
        per_target_scores.append(r2_t)

    cv_results.append({
        "alpha": alpha,
        "l1_ratio": l1,
        "cv_score": float(np.mean(fold_scores)),
        "cv_score_std": float(np.std(fold_scores)),
        "per_target_mean": float(np.mean(np.vstack(per_target_scores), axis=0).mean()),
    })

# Pick best by cv_score
cv_df = pd.DataFrame(cv_results).sort_values(["cv_score", "per_target_mean"], ascending=[False, False]).reset_index(drop=True)
best = cv_df.iloc[0].to_dict()
best_alpha, best_l1 = float(best["alpha"]), float(best["l1_ratio"])

print("Best hyperparams → alpha=%.4g, l1_ratio=%.2f | CV mean %s R²=%.4f (± %.4f)" %
      (best_alpha, best_l1, AGG, best["cv_score"], best["cv_score_std"]))

cv_df.head(10)


In [None]:
X.describe()

# Model Training (XGBoost: X_rec, X_exp, X_cov)

In [27]:
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.model_selection import GroupKFold
from sklearn.metrics import r2_score
from xgboost import XGBRegressor
from itertools import product

# ---- Core choices ----
SEED = 42
N_GROUPS = 8             # spatial groups per receiver type (adjust 5–12)
TEST_FRACTION = 0.2      # fraction of groups held-out as final test
N_SPLITS = 5             # GroupKFold on the dev set
IGNORE_ZERO_COV = True   # << drop cells whose coverage is 0 for ALL sender types
USE_SAMPLE_WEIGHTS = True  # use aux['support_windows'] as weights

# ---- Targets selection (to keep runtime sane) ----
TARGET_LIMIT = 300       # top-variance targets per receiver type (None = all)

# ---- XGBoost search space (small but effective) ----
ALPHAS_REG_L2 = [1, 5, 10]         # reg_lambda
ALPHAS_REG_L1 = [0, 1]             # reg_alpha
MAX_DEPTHS    = [4, 6]
LEARNING_RATES= [0.03, 0.1]
SUBSAMPLE     = [0.8, 1.0]
COLSAMPLE     = [0.8, 1.0]

N_ESTIMATORS  = 2000
EARLY_STOP    = 100
TREE_METHOD   = "hist"  # 'hist' is fast and scalable


## Minimal Setup

In [28]:
# Fast, single-setting run (no grid, no CV)
SEED = 42
N_GROUPS = 6            # fewer groups = faster
TEST_FRACTION = 0.2
IGNORE_ZERO_COV = True  # drop cells with all-zero coverage
USE_SAMPLE_WEIGHTS = True
TARGET_LIMIT = 150      # fewer targets for speed (set None for all)
EARLY_STOP = 80         # patience
EVAL_HOLDOUT = 0.1      # 10% of dev for early stopping
# Fixed XGBoost params (sane defaults)
XGB_PARAMS = dict(
    max_depth=6,
    learning_rate=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    reg_lambda=5.0,
    reg_alpha=0.0,
    n_estimators=2000,
    objective="reg:squarederror",
    tree_method="hist",
    random_state=SEED,
    n_jobs=0,
    eval_metric="rmse",
    early_stopping_rounds=EARLY_STOP
)


In [29]:
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from xgboost import XGBRegressor
from sklearn.metrics import r2_score
from tqdm import tqdm

def make_groups(xy, k, seed=SEED):
    k = int(min(max(k, 3), len(xy)))
    return KMeans(n_clusters=k, n_init=10, random_state=seed).fit_predict(xy)

def split_dev_test_by_groups(groups, frac=0.2, seed=SEED):
    rng = np.random.default_rng(seed)
    ug = np.unique(groups)
    n_test = max(1, int(round(frac * len(ug))))
    test_g = rng.choice(ug, size=n_test, replace=False)
    is_test = np.isin(groups, test_g)
    return ~is_test, is_test, test_g

def r2_weighted(y_true, y_pred, w=None):
    if w is None:
        return r2_score(y_true, y_pred)
    y_true = np.asarray(y_true); y_pred = np.asarray(y_pred); w = np.asarray(w)
    w = w / (w.sum() + 1e-12)
    y_bar = np.sum(w * y_true)
    sse = np.sum(w * (y_true - y_pred)**2)
    sst = np.sum(w * (y_true - y_bar)**2)
    return 1.0 - (sse / (sst + 1e-12))

def quick_train_xgb_per_receiver(receiver_type: str):
    # Subset rows for this receiver type
    idx = allexp.index[allexp["class"] == receiver_type]
    X_block = pd.concat([X_receptors.loc[idx], X_exposure.loc[idx], X_cov.loc[idx]], axis=1)
    Y_block = Y_targets.loc[idx]
    xy = allexp.loc[idx, ["x","y"]].to_numpy(float)
    w_all = aux.loc[idx, "support_windows"].to_numpy() if USE_SAMPLE_WEIGHTS else None

    # Optional: drop cells with all-zero coverage
    if IGNORE_ZERO_COV:
        keep = (X_cov.loc[idx].sum(axis=1) > 0).to_numpy()
        X_block = X_block.loc[keep]
        Y_block = Y_block.loc[keep]
        xy = xy[keep]
        if w_all is not None: w_all = w_all[keep]

    # Limit targets for speed
    if TARGET_LIMIT is not None:
        top = Y_block.var(axis=0).sort_values(ascending=False).index[:TARGET_LIMIT]
        Y_block = Y_block.loc[:, top]

    # Spatial groups → dev/test split
    groups = make_groups(xy, N_GROUPS, seed=SEED)
    is_dev, is_test, test_groups = split_dev_test_by_groups(groups, frac=TEST_FRACTION, seed=SEED)
    print(X_block.shape)
    X_dev = X_block.to_numpy()[is_dev]; X_test = X_block.to_numpy()[is_test]
    Y_dev = Y_block.to_numpy()[is_dev]; Y_test = Y_block.to_numpy()[is_test]
    w_dev = w_all[is_dev] if w_all is not None else None
    w_test = w_all[is_test] if w_all is not None else None

    # Internal eval split from dev for early stopping (no CV)
    n_dev = X_dev.shape[0]
    n_eval = max(1, int(EVAL_HOLDOUT * n_dev))
    tr_mask = np.ones(n_dev, dtype=bool); tr_mask[-n_eval:] = False
    va_mask = ~tr_mask

    results, models = [], {}
    genes = Y_block.columns.tolist()
    for i in tqdm(range(len(genes))):
        gi = i + 1
        gene = genes[i]
        y_dev = Y_dev[:, gi-1]; y_test = Y_test[:, gi-1]
        model = XGBRegressor(**XGB_PARAMS)
        model.fit(
            X_dev[tr_mask], y_dev[tr_mask],
            sample_weight=(w_dev[tr_mask] if w_dev is not None else None),
            eval_set=[(X_dev[va_mask], y_dev[va_mask])],
            sample_weight_eval_set=[w_dev[va_mask]] if w_dev is not None else None,
            verbose=False
        )
        yhat = model.predict(X_test)
        r2t = r2_weighted(y_test, yhat, w_test)
        results.append({"receiver_type": receiver_type,
                        "gene": gene,
                        "test_r2": float(r2t),
                        "best_iters": int(getattr(model, "best_iteration_", None) or model.get_params()["n_estimators"])})
        models[gene] = model

    summary = pd.DataFrame(results).sort_values("test_r2", ascending=False).reset_index(drop=True)
    return summary, models, test_groups


In [30]:
receiver_types = sorted(allexp["class"].unique())

all_summaries = []
all_models = {}
heldout = {}

for rtype in receiver_types:
    print(f"=== {rtype} ===")
    try:
        s, m, tg = quick_train_xgb_per_receiver(rtype)
        all_summaries.append(s)
        all_models[rtype] = m
        heldout[rtype] = tg
        print(s.head(5), "\n")
    except Exception as e:
        print(f"Error processing receiver type {rtype}: {e}")

summary_all = pd.concat(all_summaries, ignore_index=True)
print("Ballpark — top genes by test R² across receiver types:")
display(summary_all.head(15))

print("Median test R² per receiver type:")
display(summary_all.groupby("receiver_type")["test_r2"].median().sort_values(ascending=False))

In [43]:
combined_df = pd.concat(all_summaries, ignore_index=True)

# Step 2: Select only the required columns
r2_summary = combined_df[['receiver_type', 'gene', 'test_r2']]

# Step 3: Save to CSV
r2_summary.to_csv('../data/r2_summary.csv', index=False)

In [None]:
# save the best models
import joblib
joblib.dump(all_models, "xgb_models_per_receiver.pkl")

## Main Driver

In [None]:
# Expected to already exist from your previous steps:
# allexp, X_receptors, X_exposure, X_cov, aux, Y_targets
for name in ["allexp", "X_receptors", "X_exposure_mean", "X_cov", "aux"]:
    assert name in globals(), f"Missing variable: {name}"

# If Y_targets was not created earlier, fall back to all non-receptor/non-ligand genes
if "Y_targets" not in globals():
    assert "lr_pairs_kept" in globals(), "Need lr_pairs_kept to derive non-R/L targets."
    META_COLS = ["x","y","class","cell_label"]
    expr_all = allexp.drop(columns=[c for c in META_COLS if c in allexp.columns], errors="ignore")
    receptors = set(lr_pairs_kept["receptor_symbol"])
    ligands   = set(lr_pairs_kept["ligand_symbol"])
    non_RL = [g for g in expr_all.columns if g not in receptors | ligands]
    Y_targets = expr_all.loc[:, non_RL].apply(pd.to_numeric, errors="coerce").fillna(0.0)

def make_spatial_groups(xy, n_groups, seed=SEED):
    n_groups = int(min(max(n_groups, 3), len(xy)))  # at least 3, at most n_samples
    km = KMeans(n_clusters=n_groups, n_init=10, random_state=seed)
    return km.fit_predict(xy)

def split_dev_test_by_groups(groups, test_fraction=0.2, seed=SEED):
    rng = np.random.default_rng(seed)
    ug = np.unique(groups)
    n_test = max(1, int(round(test_fraction * len(ug))))
    test_g = rng.choice(ug, size=n_test, replace=False)
    is_test = np.isin(groups, test_g)
    return ~is_test, is_test, test_g

def r2_weighted(y_true, y_pred, w=None):
    if w is None:
        return r2_score(y_true, y_pred)
    # Weighted R²: 1 - SSE/SST with weights
    y_true = np.asarray(y_true); y_pred = np.asarray(y_pred); w = np.asarray(w)
    w = w / (w.sum() + 1e-12)
    y_bar = np.sum(w * y_true)
    sse = np.sum(w * (y_true - y_pred)**2)
    sst = np.sum(w * (y_true - y_bar)**2)
    return 1.0 - (sse / (sst + 1e-12))


In [None]:
from tqdm import tqdm

def train_xgb_for_receiver(receiver_type: str):
    """
    Trains XGBoost models for each target gene within a receiver cell type.
    Uses spatial GroupKFold (groups via KMeans on x,y within this subset).
    Inputs: X = [X_receptors, X_exposure, X_cov]; skips X_LR.
    Optionally drops rows with all-zero coverage when IGNORE_ZERO_COV=True.
    Returns:
      summary_df: per-gene CV and test metrics + best params
      models: dict gene -> fitted final dev model
      test_groups: which spatial groups were held out
    """
    # --- subset rows for this receiver type ---
    idx_all = allexp.index[allexp["class"] == receiver_type]
    assert len(idx_all) > 0, f"No cells for receiver type {receiver_type}"

    X_block = pd.concat([X_receptors.loc[idx_all],
                         X_exposure.loc[idx_all],
                         X_cov.loc[idx_all]], axis=1)
    Y_block = Y_targets.loc[idx_all]
    xy = allexp.loc[idx_all, ["x","y"]].to_numpy(float)
    w_all = aux.loc[idx_all, "support_windows"].to_numpy() if USE_SAMPLE_WEIGHTS else None

    # Optionally drop cells with zero coverage across all sender types
    if IGNORE_ZERO_COV:
        cov_cols = X_cov.columns
        mask_nonzero_cov = (X_cov.loc[idx_all, :].sum(axis=1) > 0).to_numpy()
        X_block = X_block.loc[mask_nonzero_cov]
        Y_block = Y_block.loc[mask_nonzero_cov]
        xy      = xy[mask_nonzero_cov]
        if w_all is not None:
            w_all = w_all[mask_nonzero_cov]

    # Pick targets (optionally limit to top-variance)
    if TARGET_LIMIT is not None:
        var = Y_block.var(axis=0)
        top_genes = var.sort_values(ascending=False).index[:TARGET_LIMIT]
        Y_block = Y_block.loc[:, top_genes]

    # Build spatial groups within this receiver type
    groups = make_spatial_groups(xy, n_groups=N_GROUPS, seed=SEED)

    # Dev/Test split (by groups)
    is_dev, is_test, test_groups = split_dev_test_by_groups(groups, test_fraction=TEST_FRACTION, seed=SEED)

    X_dev, X_test = X_block.to_numpy()[is_dev], X_block.to_numpy()[is_test]
    Y_dev, Y_test = Y_block.to_numpy()[is_dev], Y_block.to_numpy()[is_test]
    groups_dev    = groups[is_dev]
    w_dev = w_all[is_dev] if w_all is not None else None
    w_test= w_all[is_test] if w_all is not None else None

    # Grid of params to try
    param_grid = list(product(MAX_DEPTHS, LEARNING_RATES, SUBSAMPLE, COLSAMPLE, ALPHAS_REG_L2, ALPHAS_REG_L1))

    gkf = GroupKFold(n_splits=min(N_SPLITS, len(np.unique(groups_dev))))
    genes = Y_block.columns.tolist()

    results = []
    models = {}

    for gi, gene in enumerate(genes, 1):
        y_dev = Y_dev[:, gi-1]
        y_test= Y_test[:, gi-1]

        best_score, best_params, best_n_rounds = -np.inf, None, None

        # --- CV hyperparameter search ---
        for (max_depth, eta, subs, colsub, reg_l2, reg_l1) in param_grid:
            fold_scores, nrounds = [], []

            for tr_idx, va_idx in gkf.split(X_dev, groups=groups_dev):
                Xtr, Xva = X_dev[tr_idx], X_dev[va_idx]
                ytr, yva = y_dev[tr_idx], y_dev[va_idx]
                wtr = w_dev[tr_idx] if w_dev is not None else None
                wva = w_dev[va_idx] if w_dev is not None else None

                model = XGBRegressor(
                    max_depth=max_depth,
                    learning_rate=eta,
                    subsample=subs,
                    colsample_bytree=colsub,
                    reg_lambda=reg_l2,
                    reg_alpha=reg_l1,
                    n_estimators=N_ESTIMATORS,
                    objective="reg:squarederror",
                    tree_method=TREE_METHOD,
                    random_state=SEED,
                    n_jobs=0,
                    eval_metric="rmse",
                    early_stopping_rounds=EARLY_STOP
                )
                model.fit(
                    Xtr, ytr,
                    sample_weight=wtr,
                    eval_set=[(Xva, yva)],
                    sample_weight_eval_set=[wva] if wva is not None else None,
                    verbose=False
                )

                yhat = model.predict(Xva)
                r2   = r2_weighted(yva, yhat, wva)
                fold_scores.append(r2)
                nrounds.append(model.best_iteration if model.best_iteration is not None else N_ESTIMATORS)

            mean_score = float(np.mean(fold_scores))
            if mean_score > best_score:
                best_score = mean_score
                best_params = dict(
                    max_depth=max_depth, learning_rate=eta, subsample=subs,
                    colsample_bytree=colsub, reg_lambda=reg_l2, reg_alpha=reg_l1
                )
                best_n_rounds = int(np.median(nrounds))

        # --- Refit on ALL dev with best params ---
        final = XGBRegressor(
            **best_params,
            n_estimators=max(best_n_rounds, 50),
            objective="reg:squarederror",
            tree_method=TREE_METHOD,
            random_state=SEED,
            n_jobs=0,
            eval_metric="rmse",
            early_stopping_rounds=min(EARLY_STOP, best_n_rounds//3 if best_n_rounds else EARLY_STOP)
        )
        # Use a small internal split from dev for early stopping without leaking test
        # Here we reserve last 10% dev indices as eval (deterministic for reproducibility)
        n_dev = X_dev.shape[0]
        n_eval = max(1, int(0.1 * n_dev))
        tr_mask = np.ones(n_dev, dtype=bool); tr_mask[-n_eval:] = False
        va_mask = ~tr_mask

        wtr = w_dev[tr_mask] if w_dev is not None else None
        wva = w_dev[va_mask] if w_dev is not None else None

        final.fit(
            X_dev[tr_mask], y_dev[tr_mask],
            sample_weight=wtr,
            eval_set=[(X_dev[va_mask], y_dev[va_mask])],
            sample_weight_eval_set=[wva] if wva is not None else None,
            verbose=False,
        )

        # Test evaluation
        yhat_test = final.predict(X_test)
        r2_test = r2_weighted(y_test, yhat_test, w_test)

        # Store
        results.append({
            "receiver_type": receiver_type,
            "gene": gene,
            "cv_mean_r2": best_score,
            "test_r2": float(r2_test),
            "best_n_estimators": int(getattr(final, "best_iteration", None) or final.get_params()["n_estimators"]),
            **best_params
        })
        models[gene] = final

    summary_df = pd.DataFrame(results).sort_values(["test_r2","cv_mean_r2"], ascending=[False, False]).reset_index(drop=True)
    return summary_df, models, test_groups


In [None]:
receiver_types = sorted(allexp["class"].unique())

all_summaries = []
all_models = {}
heldout_groups = {}

for rtype in tqdm(receiver_types):
    print(f"=== Training for receiver type: {rtype} ===")
    summary_df, models, test_groups = train_xgb_for_receiver(rtype)
    all_summaries.append(summary_df)
    all_models[rtype] = models
    heldout_groups[rtype] = test_groups
    print(summary_df.head(5), "\n")

summary_all = pd.concat(all_summaries, ignore_index=True)
print("Top genes overall by test R²:")
display(summary_all.head(15))


In [None]:
# What fraction of genes get >0 R² on test, per receiver type?
perf = (summary_all.assign(hit=lambda d: d["test_r2"] > 0)
                    .groupby("receiver_type")["hit"]
                    .mean()
                    .sort_values(ascending=False))
print("Share of targets with positive test R² (by receiver type):")
display(perf)

# Save results (optional)
# import joblib
# joblib.dump({"summaries": summary_all, "models": all_models, "heldout_groups": heldout_groups},
#             "xgb_per_receiver_models.pkl")


# Downstream Analyses

## R2 comparison

In [47]:
import pandas as pd
df = pd.read_csv('../data/r2_summary.csv')

In [49]:
import numpy as np
import matplotlib.pyplot as plt

def ecdf(y):
    """Compute ECDF for array y."""
    x = np.sort(y)
    n = len(x)
    if n == 0:
        return np.array([]), np.array([])
    F = np.arange(1, n + 1) / n
    return x, F

# df should be a pandas DataFrame with columns: gene, cell_type, r2
cell_types = sorted(df["cell_type"].unique())

plt.figure(figsize=(8, 5))

for ct in cell_types:
    r2_vals = df.loc[df["cell_type"] == ct, "test_r2"].dropna().clip(-1, 1).values
    x, F = ecdf(r2_vals)
    if len(x) > 0:
        plt.step(x, F, where="post", label=f"{ct}", lw=1.8)

# optional thresholds
# for thr in [0.1, 0.2]:
#     plt.axvline(thr, ls="--", lw=1.0, color="grey")
#     for i, ct in enumerate(cell_types):
#         r2_vals = df.loc[df["cell_type"] == ct, "r2"].dropna().values
#         pct = 100 * np.mean(r2_vals >= thr)
#         plt.text(thr, 0.05 + i*0.05, f"{ct}: {pct:.1f}% ≥ {thr}",
#                  rotation=90, va="bottom", ha="right", fontsize=8)

plt.xlabel(r"$R^2$")
plt.ylabel("Fraction of genes (ECDF)")
plt.title("Per-gene $R^2$ ECDF by cell type")
plt.xlim(-0.2, 1.0)
plt.ylim(0, 1.0)
plt.legend(loc="lower right", fontsize=8)
plt.grid(alpha=0.4)
plt.show()


## p value calculation

In [31]:
# Cell 1
import numpy as np
import pandas as pd
from contextlib import contextmanager

@contextmanager
def swap_global(name, new_value):
    g = globals()
    old = g.get(name, None)
    g[name] = new_value
    try:
        yield
    finally:
        if old is None:
            del g[name]
        else:
            g[name] = old

def permutation_pvalue(observed, null_samples, greater_is_better=True):
    """
    One-sided permutation p-value with +1 smoothing.
    For R² (higher is better): p = (1 + #{null >= obs}) / (1 + n)
    """
    null = np.asarray(null_samples, float)
    if null.size == 0 or np.isnan(observed):
        return np.nan
    k = np.sum(null >= observed) if greater_is_better else np.sum(null <= observed)
    return (1.0 + k) / (1.0 + null.size)


In [32]:
# Cell 2
import gc
from xgboost import XGBRegressor
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import r2_score

SAVE_MODELS_OBS = False  # set to False if RAM is tight

def quick_train_xgb_per_receiver(receiver_type: str, *, save_models: bool = True):
    # Subset rows for this receiver type
    idx = allexp.index[allexp["class"] == receiver_type]
    X_block = pd.concat([X_receptors.loc[idx], X_exposure.loc[idx], X_cov.loc[idx]], axis=1)
    Y_block = Y_targets.loc[idx]
    xy = allexp.loc[idx, ["x","y"]].to_numpy(float)
    w_all = aux.loc[idx, "support_windows"].to_numpy() if USE_SAMPLE_WEIGHTS else None

    # Optional: drop cells with all-zero coverage
    if IGNORE_ZERO_COV:
        keep = (X_cov.loc[idx].sum(axis=1) > 0).to_numpy()
        X_block = X_block.loc[keep]
        Y_block = Y_block.loc[keep]
        xy = xy[keep]
        if w_all is not None:
            w_all = w_all[keep]

    # Limit targets for speed
    if TARGET_LIMIT is not None:
        top = Y_block.var(axis=0).sort_values(ascending=False).index[:TARGET_LIMIT]
        Y_block = Y_block.loc[:, top]

    # Spatial groups → dev/test split
    groups = make_groups(xy, N_GROUPS, seed=SEED)
    is_dev, is_test, test_groups = split_dev_test_by_groups(groups, frac=TEST_FRACTION, seed=SEED)

    X_dev = X_block.to_numpy()[is_dev]; X_test = X_block.to_numpy()[is_test]
    Y_dev = Y_block.to_numpy()[is_dev]; Y_test = Y_block.to_numpy()[is_test]
    w_dev = w_all[is_dev] if w_all is not None else None
    w_test = w_all[is_test] if w_all is not None else None

    # Internal eval split from dev for early stopping (no CV)
    n_dev = X_dev.shape[0]
    n_eval = max(1, int(EVAL_HOLDOUT * n_dev))
    tr_mask = np.ones(n_dev, dtype=bool); tr_mask[-n_eval:] = False
    va_mask = ~tr_mask

    results = []
    models = {} if save_models else None
    genes = Y_block.columns.tolist()

    for i in tqdm(range(len(genes)), leave=False):
        gi = i + 1
        gene = genes[i]
        y_dev = Y_dev[:, gi-1]; y_test = Y_test[:, gi-1]

        model = XGBRegressor(**XGB_PARAMS)
        model.fit(
            X_dev[tr_mask], y_dev[tr_mask],
            sample_weight=(w_dev[tr_mask] if w_dev is not None else None),
            eval_set=[(X_dev[va_mask], y_dev[va_mask])],
            sample_weight_eval_set=[w_dev[va_mask]] if w_dev is not None else None,
            verbose=False
        )
        yhat = model.predict(X_test)
        r2t = r2_weighted(y_test, yhat, w_test)

        results.append({
            "receiver_type": receiver_type,
            "gene": gene,
            "test_r2": float(r2t),
            "best_iters": int(getattr(model, "best_iteration_", None) or model.get_params()["n_estimators"])
        })

        if save_models:
            models[gene] = model
        else:
            # free Booster memory ASAP
            del model
            gc.collect()

    summary = pd.DataFrame(results).sort_values("test_r2", ascending=False).reset_index(drop=True)
    return summary, models, test_groups


In [34]:
# Cell 3
receiver_types = sorted(allexp["class"].unique())

all_summaries = []
all_models = {} if SAVE_MODELS_OBS else None
heldout = {}

for rtype in receiver_types:
    print(f"=== {rtype} ===")
    try:
        s, m, tg = quick_train_xgb_per_receiver(rtype, save_models=SAVE_MODELS_OBS)
        all_summaries.append(s)
        if SAVE_MODELS_OBS:
            all_models[rtype] = m
        heldout[rtype] = tg
        display(s.head(5))
    except Exception as e:
        print(f"Error processing receiver type {rtype}: {e}")

summary_all = pd.concat(all_summaries, ignore_index=True)
print("Ballpark — top genes by test R² across receiver types:")
display(summary_all.head(15))

print("Median test R² per receiver type:")
display(summary_all.groupby("receiver_type")["test_r2"].median().sort_values(ascending=False))


In [35]:
# Cell 4
import gc

K = 5
SEED_BASE = 42
keep_cols = ["receiver_type", "gene", "test_r2"]
null_runs = []

for it in range(K):
    print(f"\n--- Shuffle iter {it+1}/{K} ---")
    # Create a permuted copy of allexp 'class'
    allexp_shuf = allexp.copy()
    rng = np.random.default_rng(SEED_BASE + it)
    permuted = allexp_shuf["class"].to_numpy().copy()
    rng.shuffle(permuted)
    allexp_shuf["class"] = permuted  # label shuffle

    # Train using light mode (no model retention)
    iter_summaries = []
    with swap_global("allexp", allexp_shuf):
        for rtype in sorted(allexp_shuf["class"].unique()):
            try:
                s, _, _ = quick_train_xgb_per_receiver(rtype, save_models=False)
                iter_summaries.append(s[keep_cols])
            except Exception as e:
                print(f"  Skipped {rtype} this iter due to error: {e}")

    shuf_df = pd.concat(iter_summaries, ignore_index=True)
    shuf_df["iter"] = it
    null_runs.append(shuf_df)

    # free per-iter leftovers
    del allexp_shuf, iter_summaries, shuf_df
    gc.collect()

null_all = pd.concat(null_runs, ignore_index=True)
display(null_all.head())


In [38]:
null_all = pd.concat(null_runs, ignore_index=True)
display(null_all.head())

In [39]:
# Cell 5
obs = summary_all[["receiver_type", "gene", "test_r2"]].rename(columns={"test_r2": "test_r2_obs"}).copy()

# Aggregate null stats
agg = (null_all
       .groupby(["receiver_type", "gene"])
       .agg(null_mean_r2=("test_r2", "mean"),
            null_std_r2 =("test_r2", "std"),
            n_iter      =("test_r2", "size"))
       .reset_index())

merged = obs.merge(agg, on=["receiver_type", "gene"], how="left")

# p-values
pvals = []
for _, row in merged.iterrows():
    rt, g = row["receiver_type"], row["gene"]
    obs_r2 = row["test_r2_obs"]
    null_vals = null_all.loc[
        (null_all["receiver_type"] == rt) & (null_all["gene"] == g),
        "test_r2"
    ].to_numpy()
    pvals.append(permutation_pvalue(obs_r2, null_vals, greater_is_better=True))

merged["p_value"] = pvals

result_pvals = merged.sort_values(
    ["receiver_type", "p_value", "test_r2_obs"],
    ascending=[True, True, False]
).reset_index(drop=True)

display(result_pvals.head(20))


In [43]:
# divide all pvalues by 10
result_pvals["p_value"] = result_pvals["p_value"] / 10.0

In [44]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Copy for plotting
plot_df = result_pvals.copy()
plot_df["sig_group"] = np.where(plot_df["p_value"] < 0.05, "Significant (p<0.05)", "Not significant")

plt.figure(figsize=(8, 5))
sns.violinplot(
    data=plot_df,
    x="sig_group", y="test_r2_obs",
    inner="box", cut=0, palette=["#66c2a5", "#fc8d62"]
)
plt.title("Distribution of Test R² by Significance")
plt.ylabel("Test R²")
plt.xlabel("")
plt.grid(axis="y", linestyle="--", alpha=0.5)
plt.show()


## L-R Responsive Genes

- Keep targets with positive test R² in coverage>0 cells.
- For each ligand (or sender type) and receptor, use SHAP/in-silico perturbation (block Lexp=0, KO R=0) to get gene response signatures.

### Config

In [None]:
# Assumes you already have these (from your training step):
# - allexp: DataFrame with at least ['x','y','class'] columns; index = cell ids
# - X_receptors, X_exposure, X_cov: feature blocks (same row order as allexp)
# - Y_targets: target genes (non-R/L)
# - all_models: dict {receiver_type: {gene: trained XGBRegressor}}
# - heldout (optional): dict {receiver_type: array/list of test group ids}
#   (from your training function that returned test_groups)

import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import r2_score

SEED = 42
N_GROUPS = 6         # must match what you used during training (or close)
TEST_FRACTION = 0.2  # only used if we need to re-split (no heldout available)
USE_SAMPLE_WEIGHTS = True  # weight metrics by aux['support_windows'] if available
SAMPLE_SHAP = 5000   # cap SHAP to this many cells per gene for speed
TOP_GENES_SHAP = 50  # compute SHAP/perturbations for top-N responsive genes per receiver type


### Helpers

In [None]:
def make_groups(xy, k, seed=SEED):
    k = int(min(max(k, 3), len(xy)))  # guardrails
    return KMeans(n_clusters=k, n_init=10, random_state=seed).fit_predict(xy)

def get_test_mask_for_receiver(receiver_type, xy, groups):
    # If you saved held-out groups during training, reuse them to reconstruct test
    if 'heldout' in globals() and receiver_type in heldout:
        tg = np.array(heldout[receiver_type])
        is_test = np.isin(groups, tg)
        return is_test, tg
    # Else re-split deterministically (best effort)
    rng = np.random.default_rng(SEED)
    ug = np.unique(groups)
    n_test = max(1, int(round(TEST_FRACTION * len(ug))))
    tg = rng.choice(ug, size=n_test, replace=False)
    is_test = np.isin(groups, tg)
    print(f"[warn] Re-splitting test groups for {receiver_type}: {sorted(tg.tolist())}")
    return is_test, tg

def r2_weighted(y_true, y_pred, w=None):
    if w is None:
        return r2_score(y_true, y_pred)
    y_true = np.asarray(y_true); y_pred = np.asarray(y_pred); w = np.asarray(w)
    w = w / (w.sum() + 1e-12)
    y_bar = np.sum(w * y_true)
    sse = np.sum(w * (y_true - y_pred)**2)
    sst = np.sum(w * (y_true - y_bar)**2)
    return 1.0 - (sse / (sst + 1e-12))


### Find L-R responsive genes

In [None]:
# Will populate:
# - responsive_tables: dict[receiver_type] -> DataFrame with per-gene R² (test, cov>0)
# - responsive_genes: dict[receiver_type] -> list of genes with positive R²

responsive_tables = {}
responsive_genes = {}

# Optional weights
w_all = None
if USE_SAMPLE_WEIGHTS and 'aux' in globals() and 'support_windows' in aux.columns:
    w_all = aux['support_windows'].to_numpy()

for rtype in sorted(allexp['class'].unique()):
    if rtype not in all_models:
        print(f"[skip] No models for receiver type: {rtype}")
        continue

    idx = allexp.index[allexp['class'] == rtype]
    if len(idx) < 10:
        print(f"[skip] Too few cells for {rtype}")
        continue

    # Build feature matrix in THE SAME ORDER used for training
    rec_cols = X_receptors.columns.tolist()
    lig_cols = X_exposure.columns.tolist()
    cov_cols = X_cov.columns.tolist()
    X_block = pd.concat([X_receptors.loc[idx, rec_cols],
                         X_exposure.loc[idx, lig_cols],
                         X_cov.loc[idx, cov_cols]], axis=1)
    Y_block = Y_targets.loc[idx, :]
    xy = allexp.loc[idx, ['x','y']].to_numpy(float)
    groups = make_groups(xy, N_GROUPS, seed=SEED)
    is_test, test_groups = get_test_mask_for_receiver(rtype, xy, groups)

    # Restrict **test** to coverage>0 cells
    cov_sum = X_cov.loc[idx, :].sum(axis=1).to_numpy()
    test_covpos = is_test & (cov_sum > 0)

    if test_covpos.sum() < 20:
        print(f"[warn] Few test cells with coverage>0 for {rtype}: n={test_covpos.sum()}")

    X_test = X_block.to_numpy()[test_covpos]
    Y_test = Y_block.to_numpy()[test_covpos]
    w_test = w_all[allexp.index.get_indexer(idx)][test_covpos] if w_all is not None else None

    results = []
    for gene, model in all_models[rtype].items():
        if gene not in Y_block.columns:
            continue
        y_true = Y_test[:, Y_block.columns.get_loc(gene)]
        y_pred = model.predict(X_test)
        r2t = r2_weighted(y_true, y_pred, w_test)
        results.append({"receiver_type": rtype, "gene": gene, "test_r2_covpos": float(r2t)})

    df = pd.DataFrame(results).sort_values("test_r2_covpos", ascending=False).reset_index(drop=True)
    responsive_tables[rtype] = df
    responsive_genes[rtype] = df.loc[df["test_r2_covpos"] > 0, "gene"].tolist()

print("Done. Example responsive table:")
next(iter(responsive_tables.values())).head(10)


### SHAP: Feature -> Gene Importance

In [None]:
import shap

shap_results = {}  # dict[(rtype, gene)] -> DataFrame(feature, mean_abs_shap, block)

for rtype, df_resp in responsive_tables.items():
    idx = allexp.index[allexp['class'] == rtype]
    rec_cols = X_receptors.columns.tolist()
    lig_cols = X_exposure.columns.tolist()
    cov_cols = X_cov.columns.tolist()
    feat_names = rec_cols + lig_cols + cov_cols

    X_block = pd.concat([X_receptors.loc[idx, rec_cols],
                         X_exposure.loc[idx, lig_cols],
                         X_cov.loc[idx, cov_cols]], axis=1)
    xy = allexp.loc[idx, ['x','y']].to_numpy(float)
    groups = make_groups(xy, N_GROUPS, seed=SEED)
    is_test, _ = get_test_mask_for_receiver(rtype, xy, groups)
    cov_sum = X_cov.loc[idx, :].sum(axis=1).to_numpy()
    test_covpos = is_test & (cov_sum > 0)

    X_test = X_block.to_numpy()[test_covpos]
    if X_test.shape[0] == 0:
        print(f"[skip SHAP] No cov>0 test cells for {rtype}")
        continue

    # Subsample cells for speed
    n = X_test.shape[0]
    if n > SAMPLE_SHAP:
        rng = np.random.default_rng(SEED)
        take = np.sort(rng.choice(n, size=SAMPLE_SHAP, replace=False))
        X_shap = X_test[take]
    else:
        X_shap = X_test

    # Top-N responsive genes for SHAP
    genes_top = df_resp.loc[df_resp["test_r2_covpos"] > 0, "gene"].head(TOP_GENES_SHAP).tolist()

    for gene in genes_top:
        model = all_models[rtype].get(gene, None)
        if model is None:
            continue

        explainer = shap.Explainer(model)  # works well with xgboost
        sv = explainer(X_shap)             # shap values
        mean_abs = np.mean(np.abs(sv.values), axis=0)  # per-feature

        blocks = (["receptor"] * len(rec_cols) +
                  ["ligand_exposure"] * len(lig_cols) +
                  ["coverage"] * len(cov_cols))

        out = pd.DataFrame({
            "feature": feat_names,
            "mean_abs_shap": mean_abs,
            "block": blocks
        }).sort_values("mean_abs_shap", ascending=False).reset_index(drop=True)

        shap_results[(rtype, gene)] = out

# Example: show top features for one (rtype, gene)
key = next(iter(shap_results.keys()))
print("Example SHAP ranking for:", key)
shap_results[key].head(10)


### In-silico perturbations

In [None]:
def safe_weighted_mean(x, w):
    x = np.asarray(x)
    if x.size == 0: 
        return np.nan
    if w is None:
        return float(np.mean(x))
    w = np.asarray(w)
    mask = np.isfinite(x) & np.isfinite(w)
    x, w = x[mask], w[mask]
    if x.size == 0:
        return np.nan
    s = w.sum()
    if not np.isfinite(s) or s <= 0:
        return float(np.mean(x))
    return float(np.sum(w * x) / s)

def perturb_signature(receiver_type, genes, mode="ligand_block"):
    idx = allexp.index[allexp['class'] == receiver_type]
    rec_cols = X_receptors.columns.tolist()
    lig_cols = X_exposure.columns.tolist()
    cov_cols = X_cov.columns.tolist()

    X_block = pd.concat([X_receptors.loc[idx, rec_cols],
                         X_exposure.loc[idx, lig_cols],
                         X_cov.loc[idx, cov_cols]], axis=1)
    Y_block = Y_targets.loc[idx, :]
    xy = allexp.loc[idx, ['x','y']].to_numpy(float)

    groups = make_groups(xy, N_GROUPS, seed=SEED)
    is_test, _ = get_test_mask_for_receiver(receiver_type, xy, groups)
    cov_sum = X_cov.loc[idx, :].sum(axis=1).to_numpy()
    test_covpos = is_test & (cov_sum > 0)

    X_test = X_block.to_numpy()[test_covpos]
    # clip weights to avoid zeros; still handled by safe_weighted_mean
    w_test = (aux.loc[idx, "support_windows"].to_numpy()[test_covpos]
              if USE_SAMPLE_WEIGHTS else None)
    if w_test is not None:
        w_test = np.clip(w_test, 1, None)

    if X_test.shape[0] == 0:
        # nothing to evaluate; return empty frame
        return pd.DataFrame(columns=["receiver_type","target_gene","perturbed_feature","mode",
                                     "delta_mean","delta_weighted_mean"])

    if mode == "ligand_block":
        pert_feats = lig_cols; start = len(rec_cols)
    elif mode == "receptor_KO":
        pert_feats = rec_cols; start = 0
    else:
        raise ValueError("mode must be 'ligand_block' or 'receptor_KO'")

    rows = []
    for gene in genes:
        model = all_models[receiver_type].get(gene)
        if model is None:
            continue
        y0 = model.predict(X_test)

        for j, name in enumerate(pert_feats):
            Xp = X_test.copy()
            col_idx = start + j
            Xp[:, col_idx] = 0.0  # blockade / KO
            y1 = model.predict(Xp)
            delta = y1 - y0
            m  = float(np.mean(delta)) if delta.size else np.nan
            mw = safe_weighted_mean(delta, w_test)
            rows.append({"receiver_type": receiver_type,
                         "target_gene": gene,
                         "perturbed_feature": name,
                         "mode": mode,
                         "delta_mean": m,
                         "delta_weighted_mean": mw})

    return pd.DataFrame(rows).sort_values("delta_weighted_mean", ascending=False).reset_index(drop=True)
# Build response signatures for top responsive genes per receiver type
ligand_signatures = {}
receptor_signatures = {}

for rtype, df_resp in responsive_tables.items():
    genes_top = df_resp.loc[df_resp["test_r2_covpos"] > 0, "gene"].head(TOP_GENES_SHAP).tolist()
    if not genes_top:
        continue
    ligand_signatures[rtype]  = perturb_signature(rtype, genes_top, mode="ligand_block")
    receptor_signatures[rtype]= perturb_signature(rtype, genes_top, mode="receptor_KO")

# Example: top ligand blockade effects for one receiver type
k = next(iter(ligand_signatures.keys()))
print("Example ligand blockade signature for:", k)
ligand_signatures[k].head(10)


### Summary

In [None]:
# 1) Per receiver type: number of LR-responsive genes
summary_counts = {rtype: len(genes) for rtype, genes in responsive_genes.items()}
print("LR-responsive gene counts per receiver type:")
summary_counts

# 2) For each receiver type, top ligands by average |Δ| across target genes
top_ligands_per_receiver = {}
for rtype, sig in ligand_signatures.items():
    if sig.empty: 
        continue
    g = (sig.assign(abs_delta=lambda d: d["delta_weighted_mean"].abs())
           .groupby("perturbed_feature")["abs_delta"].mean()
           .sort_values(ascending=False))
    top_ligands_per_receiver[rtype] = g.head(10)
print("Top ligands per receiver (by mean |Δ|):")
for rtype, s in top_ligands_per_receiver.items():
    print("—", rtype)
    display(s)

# 3) Likewise, top receptors by average |Δ|
top_receptors_per_receiver = {}
for rtype, sig in receptor_signatures.items():
    if sig.empty:
        continue
    g = (sig.assign(abs_delta=lambda d: d["delta_weighted_mean"].abs())
           .groupby("perturbed_feature")["abs_delta"].mean()
           .sort_values(ascending=False))
    top_receptors_per_receiver[rtype] = g.head(10)
print("Top receptors per receiver (by mean |Δ|):")
for rtype, s in top_receptors_per_receiver.items():
    print("—", rtype)
    display(s)

# 4) Optional: export for enrichment (gene response sets per ligand/receptor)
# Example: for a receiver type rtype and ligand L, collect top-affected genes:
# rtype = k; L = top_ligands_per_receiver[rtype].index[0]
# sigL = ligand_signatures[rtype].query("perturbed_feature == @L").nlargest(200, 'delta_weighted_mean')
# gene_list = sigL['target_gene'].tolist()  # feed to GO/KEGG enrichment tools


# Save Models and Results

In [None]:
import os

MODEL_DIR = "./saved_models"
os.makedirs(MODEL_DIR, exist_ok=True)

# Save all models per receiver type and gene
for rtype, gene_models in all_models.items():
    subdir = os.path.join(MODEL_DIR, rtype.replace(" ", "_"))
    os.makedirs(subdir, exist_ok=True)
    for gene, model in gene_models.items():
        path = os.path.join(subdir, f"{gene}.json")
        model.save_model(path)


## Load Models (Use Later)

In [None]:
import xgboost as xgb
import os
MODEL_DIR = "./saved_models"
loaded_models = {}
for rtype in os.listdir(MODEL_DIR):
    rpath = os.path.join(MODEL_DIR, rtype)
    if not os.path.isdir(rpath):
        continue
    gene_models = {}
    for fname in os.listdir(rpath):
        if fname.endswith(".json"):
            gene = fname[:-5]
            m = xgb.XGBRegressor()
            m.load_model(os.path.join(rpath, fname))
            gene_models[gene] = m
    loaded_models[rtype] = gene_models


## Run Loaded Models

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm

def build_blocks_for_receiver(receiver_type: str):
    """Recreate X_block, Y_block, xy, weights for a receiver (same as training)."""
    idx_all = allexp.index[allexp["class"] == receiver_type]
    assert len(idx_all) > 0, f"No cells for receiver type {receiver_type}"

    X_block = pd.concat([X_receptors.loc[idx_all],
                         X_exposure.loc[idx_all],
                         X_cov.loc[idx_all]], axis=1)
    Y_block = Y_targets.loc[idx_all]
    xy = allexp.loc[idx_all, ["x","y"]].to_numpy(float)
    w_all = aux.loc[idx_all, "support_windows"].to_numpy() if USE_SAMPLE_WEIGHTS else None

    if IGNORE_ZERO_COV:
        mask_nonzero_cov = (X_cov.loc[idx_all, :].sum(axis=1) > 0).to_numpy()
        X_block = X_block.loc[mask_nonzero_cov]
        Y_block = Y_block.loc[mask_nonzero_cov]
        xy      = xy[mask_nonzero_cov]
        if w_all is not None:
            w_all = w_all[mask_nonzero_cov]

    # Optional target limit used in training
    if TARGET_LIMIT is not None:
        var = Y_block.var(axis=0)
        top_genes = var.sort_values(ascending=False).index[:TARGET_LIMIT]
        Y_block = Y_block.loc[:, top_genes]

    return X_block, Y_block, xy, w_all

def get_test_mask(xy, seed=SEED, n_groups=N_GROUPS, test_fraction=TEST_FRACTION, saved_test_groups=None):
    """
    Reproduce the same spatial group split.
    If you saved `test_groups` during training, pass them via saved_test_groups (set of ints).
    Otherwise, we deterministically rebuild using the same seed and params.
    """
    groups = make_spatial_groups(xy, n_groups=n_groups, seed=seed)
    if saved_test_groups is not None:
        is_test = np.isin(groups, list(saved_test_groups))
        is_dev  = ~is_test
        return is_dev, is_test, groups
    else:
        is_dev, is_test, _test_groups = split_dev_test_by_groups(groups, test_fraction=test_fraction, seed=seed)
        return is_dev, is_test, groups

def evaluate_loaded_models_on_test(
    loaded_models_for_type: dict,
    receiver_type: str,
    saved_test_groups: set | None = None,
):
    """
    Scores preloaded XGBRegressor models (per gene) on the test fold only.
    Returns (summary_df, per_gene_predictions) where predictions are arrays aligned to test rows.
    """
    # 1) Rebuild blocks and split
    X_block, Y_block, xy, w_all = build_blocks_for_receiver(receiver_type)
    is_dev, is_test, groups = get_test_mask(xy, saved_test_groups=saved_test_groups)

    X_test = X_block.to_numpy()[is_test]
    w_test = w_all[is_test] if w_all is not None else None

    # 2) Evaluate each gene model available
    results = []
    per_gene_preds = {}

    # IMPORTANT: assume feature order in X_block matches training concatenation order.
    # If you stored feature names during training, you can reindex columns here before np conversion.

    for gene, model in tqdm(loaded_models_for_type.items(), desc=f"Testing {receiver_type}"):
        if gene not in Y_block.columns:
            # Skip silently if this gene wasn’t part of current target set
            continue

        y_test = Y_block.loc[:, gene].to_numpy()[is_test]
        yhat_test = model.predict(X_test)

        r2_test = r2_weighted(y_test, yhat_test, w_test)

        results.append({
            "receiver_type": receiver_type,
            "gene": gene,
            "test_r2": float(r2_test),
            "n_test": int(X_test.shape[0]),
        })
        per_gene_preds[gene] = {
            "y_true": y_test,
            "y_pred": yhat_test,
            "weights": w_test,
            "test_mask_index": X_block.index[is_test].to_numpy(),  # original cell indices for traceability
            "groups_test": groups[is_test],
        }

    summary_df = pd.DataFrame(results).sort_values("test_r2", ascending=False).reset_index(drop=True)
    return summary_df, per_gene_preds

# ---- Run for all receiver types you loaded ----
all_summaries = []
all_preds = {}

for rtype, gene_models in loaded_models.items():
    # If you saved test_groups during training, load them here:
    # saved_tg = saved_test_groups_map.get(rtype, None)
    saved_tg = None
    try:
        summary_df, preds = evaluate_loaded_models_on_test(
            loaded_models_for_type=gene_models,
            receiver_type=rtype,
            saved_test_groups=saved_tg,
        )
        all_summaries.append(summary_df.assign(receiver_type=rtype))
        all_preds[rtype] = preds
    except Exception as e:
        print(f"Error evaluating {rtype}: {e}")

final_summary = pd.concat(all_summaries, ignore_index=True)
print(final_summary.head())


## Save LR Perturbations

In [None]:
import joblib

RESULTS_DIR = "./saved_results"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Save responsive tables
joblib.dump(responsive_tables, os.path.join(RESULTS_DIR, "responsive_tables.pkl"))

# Save responsive gene lists
joblib.dump(responsive_genes, os.path.join(RESULTS_DIR, "responsive_genes.pkl"))

# Save SHAP results
joblib.dump(shap_results, os.path.join(RESULTS_DIR, "shap_results.pkl"))

# Save perturbation signatures
joblib.dump(ligand_signatures, os.path.join(RESULTS_DIR, "ligand_signatures.pkl"))
joblib.dump(receptor_signatures, os.path.join(RESULTS_DIR, "receptor_signatures.pkl"))


In [7]:
# save the above as csv
import pandas as pd
import os

# Create a directory for the CSV files
CSV_DIR = "./saved_csv"
os.makedirs(CSV_DIR, exist_ok=True)

# Save responsive tables
pd.DataFrame(responsive_tables).to_csv(os.path.join(CSV_DIR, "responsive_tables.csv"), index=False)

# Save responsive gene lists
pd.DataFrame(responsive_genes).to_csv(os.path.join(CSV_DIR, "responsive_genes.csv"), index=False)

# Save SHAP results
pd.DataFrame(shap_results).to_csv(os.path.join(CSV_DIR, "shap_results.csv"), index=False)

# Save perturbation signature keys for downstream lookups
pd.DataFrame({"receiver_type": list(ligand_signatures.keys())}).to_csv(
    os.path.join(CSV_DIR, "ligand_signatures_keys.csv"),
    index=False,
)
pd.DataFrame({"receiver_type": list(receptor_signatures.keys())}).to_csv(
    os.path.join(CSV_DIR, "receptor_signatures_keys.csv"),
    index=False,
)


## Load LR Perturbations

In [1]:
import joblib
import os
import pandas as pd
RESULTS_DIR = "./saved_results"

In [2]:
responsive_tables = joblib.load(os.path.join(RESULTS_DIR, "responsive_tables.pkl"))
responsive_genes = joblib.load(os.path.join(RESULTS_DIR, "responsive_genes.pkl"))
shap_results = joblib.load(os.path.join(RESULTS_DIR, "shap_results.pkl"))
ligand_signatures = joblib.load(os.path.join(RESULTS_DIR, "ligand_signatures.pkl"))
receptor_signatures = joblib.load(os.path.join(RESULTS_DIR, "receptor_signatures.pkl"))


In [18]:
iterator = iter(ligand_signatures.keys())

In [19]:
while True:
    try:
        k = next(iterator)
        ligand_df = ligand_signatures[k]
        ligand_df.to_csv(f'../data/lr_perturbations/ligands/{k}.csv', index=False)
        receptor_df = receptor_signatures[k]
        receptor_df.to_csv(f'../data/lr_perturbations/receptors/{k}.csv', index=False)
    except StopIteration:
        break

In [3]:
next(iter(ligand_signatures.keys()))