# Loading Package

In [None]:
import os

# Limit the number of threads for underlying BLAS/OpenMP libraries
os.environ["OMP_NUM_THREADS"] = "64"         # OpenMP threads
os.environ["OPENBLAS_NUM_THREADS"] = "64"      # OpenBLAS threads
os.environ["MKL_NUM_THREADS"] = "64"           # MKL threads
os.environ["NUMEXPR_NUM_THREADS"] = "64"       # NumExpr threads

In [None]:
import SPIX
from SPIX.image_processing.image_cache import *
import pandas as pd
import scanpy as sc
import squidpy as sq
import numpy as np
import squidpy as sq
import matplotlib.pyplot as plt
import anndata as ad


# Multiscale Workflow - bin3(~2um) Stereo-seq MOSTA data

In [None]:
adata = sc.read_h5ad('/home/Data_Drive_8TB/chs1151/Vesalius/datas/E16.5_E1S3_bin3.h5ad')

In [None]:
y= adata.obsm['spatial'][:,0].copy()
x= adata.obsm['spatial'][:,1].copy()
adata.obsm['spatial'][:,0] = x.copy()
adata.obsm['spatial'][:,1] = y.copy()

coords = adata.obsm['spatial'].copy()
coords[:, 1] = coords[:, 1].max() - coords[:, 1]

adata.obsm['spatial'] = coords

coords = adata.obsm['spatial'].copy()
coords[:, 0] = coords[:, 0].max() - coords[:, 0]

adata.obsm['spatial'] = coords

In [None]:
adata = SPIX.tm.generate_embeddings(
    adata,
    dim_reduction='PCA',
    normalization='log_norm', 
    n_jobs=32,
    dimensions=30,
    nfeatures=2000,
    force=True,
    use_hvg_only=True,
    use_coords_as_tiles=True,
    coords_max_gap_factor=None,
    raster_stride=10,
    filter_threshold=1,
    raster_max_pixels_per_tile=400,
    raster_random_seed=42
)


In [None]:
adata = SPIX.ip.smooth_image(
    adata,
    methods=['graph','gaussian'],
    embedding='X_embedding',
    embedding_dims = list(range(30)),
    graph_k=30,
    graph_t=50,
    gaussian_sigma=500,
    n_jobs=10,
)

In [None]:
adata = SPIX.ip.equalize_image(
    adata,
    dimensions=list(range(30)),
    embedding='X_embedding_smooth',
    sleft=5,sright=5)

In [None]:
cache_embedding_image(adata, embedding='X_embedding_equalize',
                            dimensions=list(range(30)), 
                            key='image_plot_slic',
                            origin=True,
                            figsize=(30,30),
                            fig_dpi=300,
                            verbose=False,
                      show=True)

In [None]:
show_cached_image(adata, key='image_plot_slic', channels=[0,1,2])

In [None]:
SPIX.sp.segment_image(
    adata,
    dimensions=list(range(30)),
     embedding='X_embedding_equalize',
     method='image_plot_slic',
    pitch_um=2,
    target_segment_um=500,
     compactness=0.5,
     verbose=True,
    enforce_connectivity=False,
    use_cached_image=True,
     origin=True,
    show_image=True
)

In [None]:
SPIX.pl.image_plot(
    adata,
    dimensions=[0, 1, 2],
    embedding='X_embedding_segment',
    boundary_method='pixel',
    figsize=(10, 10),     # Adjust figure size here
    fixed_boundary_color='black',
    boundary_linewidth=1,
    alpha=1,
    plot_boundaries=True,
    origin=True
)

In [None]:
sq.gr.spatial_neighbors(adata, coord_type='generic')

In [None]:
# multiscale_svg_advanced2.py
import os, gc, tempfile, shutil, warnings, itertools
import numpy as np, pandas as pd, scanpy as sc, seaborn as sns, matplotlib.pyplot as plt
from joblib import Parallel, delayed
from scipy.stats import rankdata
from sklearn.mixture import GaussianMixture

#################################
# 1) setting
#################################
resolutions    = [2,8,16,30,50,100,250,500]
# target_segment_um = [2,8,10,16,30,50, 100,500]
compactnesses  = [0.5]
# compactness_values = [0.1,0.2,0.5,1, 3, 5, 10],
dims_use       = list(range(30))
segment_method = 'image_plot_slic'
embedding_key  = 'X_embedding_equalize'
use_cached_image=True,
moran_thresh   = 0      # MoranI threshold
n_jobs         = 3
top_changed    = 100
use_memmap     = False    # True: h5ad 
#################################
# 2) single-scale calculation function
#################################
def calc_scale(adata,scale_id, res, comp, adata_path=None):
    if adata_path:
        ad = sc.read_h5ad(adata_path)
    else:
        ad = adata.copy()
    # segmentation
    SPIX.sp.segment_image(
        ad,
        dimensions=dims_use,
        embedding=embedding_key,
        method=segment_method,
        target_segment_um=res,
        compactness=comp,
        figsize = (30, 30),
        fig_dpi=300,
        enforce_connectivity=False,
        use_cached_image=use_cached_image,
        origin=True,
        verbose=False
    )
    # pseudo-bulk Moran
    _, moran = SPIX.an.perform_pseudo_bulk_analysis(
        ad,
        segment_key='Segment',
        normalize_total=True,
        log_transform=True,
        expr_agg='sum',
        moranI_threshold=moran_thresh,
        segment_graph_strategy='collapsed',
        collapse_row_normalize=True,
        perform_pca=False,
        highly_variable=False,
        mode='moran'
    )
    if moran.empty:
        warnings.warn(f"[{scale_id}] No MoranI result ")
        return pd.DataFrame()
    # rank 
    moran['rank_'+scale_id] = rankdata(-moran['I'], method='min')
    return moran[['rank_'+scale_id]]

#################################
# 3) parallel calculation
#################################
param_grid = list(itertools.product(resolutions, compactnesses))
print(f"▶ Total {len(param_grid)} scales ")

adata_path = None
if use_memmap:
    td = tempfile.mkdtemp()
    adata_path = os.path.join(td, 'adata_tmp.h5ad')
    adata.write_h5ad(adata_path)

results = Parallel(n_jobs=n_jobs, backend='loky', verbose=10)(
    delayed(calc_scale)(adata,f"r{r}_c{c}", r, c, adata_path)
    for r, c in param_grid
)

if use_memmap:
    shutil.rmtree(td)


In [None]:

# concat rank table
rank_tables = [df for df in results if not df.empty]

rank_mat = pd.concat(rank_tables, axis=1, join='outer')
max_rank = int(rank_mat.max().max())
rank_mat = rank_mat.fillna(max_rank + 1)

rank_mat.sort_index(axis=1, inplace=True)  

#################################
# 5) resolution / compactness axis
#################################
# resolution axis
res_rank = pd.DataFrame(index=rank_mat.index)
for r in resolutions:
    cols = [f"rank_r{r}_c{c}" for c in compactnesses]
    res_rank[f"res_{r}"] = rank_mat[cols].mean(axis=1)

# compactness axis
comp_rank = pd.DataFrame(index=rank_mat.index)
for c in compactnesses:
    cols = [f"rank_r{r}_c{c}" for r in resolutions]
    comp_rank[f"comp_{c}"] = rank_mat[cols].mean(axis=1)


In [None]:
df = res_rank.copy()

In [None]:
import pandas as pd

# Define regions by column names
early_cols = ['res_2', 'res_8','res_16']
mid_cols   = ['res_30','res_50','res_100']
late_cols  = ['res_250','res_500']


# Compute mean rank in each region
df['mean_early'] = df[early_cols].mean(axis=1)
df['mean_mid']   = df[mid_cols].mean(axis=1)
df['mean_late']  = df[late_cols].mean(axis=1)

# Set threshold ratio: difference between 1st and 2nd smallest mean must be >= this
# threshold_ratio = 0.7  # e.g., 20%
threshold_ratio = 0.93  # e.g., 20%

def categorize_with_threshold(row):
    # collect region means
    regions = {
        'early': row['mean_early'],
        'mid':   row['mean_mid'],
        'late':  row['mean_late']
    }
    # sort regions by mean ascending
    sorted_regs = sorted(regions.items(), key=lambda x: x[1])
    # lowest and second-lowest
    (reg1, val1), (reg2, val2) = sorted_regs[0], sorted_regs[1]
    # check relative difference
    # if absolute or relative difference is big enough, assign; else mixed
    if (val2 - val1) / val2 >= threshold_ratio:
        return reg1
    else:
        return 'mixed'

# Apply categorization
df['category'] = df.apply(categorize_with_threshold, axis=1)

# Optional: view how many genes fall in each category
print(df['category'].value_counts())

# Optional: inspect sample
print(df[['mean_early','mean_mid','mean_late','category']].head())


In [None]:
def plot_traj(df, genes, xlabel, title):
    plt.figure(figsize=(8,4))
    x = np.arange(df.shape[1])
    for g in genes:
        plt.plot(x, df.loc[g], '-o', alpha=0.7, label=g)
    plt.gca().invert_yaxis()
    plt.xticks(x, df.columns, rotation=45)
    plt.xlabel(xlabel); plt.ylabel('Rank')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.02,1), loc='upper left', fontsize=8)
    plt.tight_layout(); plt.show()

In [None]:
df['late_m_mid'] = df['mean_late']-df['mean_mid']

In [None]:
plot_traj(res_rank,df[df['category']=='early'].sort_values('mean_early').head(10).index.tolist(), 'Resolution', 'Unique in high resolution')

In [None]:
plot_traj(res_rank,df[df['category']=='late'].sort_values('mean_late').head(10).index.tolist(), 'Resolution', 'Unique in low resolution')

In [None]:
plot_traj(res_rank,df[df['category']=='late'].sort_values('mean_mid',ascending=False).head(10).index.tolist(), 'Resolution', 'Unique in low resolution')

In [None]:
plot_traj(res_rank,df[df['category']=='mid'].sort_values('mean_early',ascending=False).head(10).index.tolist(), 'Resolution', 'Unique in middle resolution')

In [None]:
# ==============================================================
# Parallel SPIX grids using in-memory `adata` (no disk reload)
# - Works best on Linux with 'fork' start method.
# - Each worker process inherits the parent's memory (copy-on-write),
#   so we can use the global `adata` directly without reloading.
# ==============================================================

import os
import math
import warnings
import gc
from typing import List, Tuple
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import pandas as pd
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp

# Try to enforce 'fork' (Linux). On macOS/Windows this may fail and fall back.
try:
    mp.set_start_method("fork", force=True)
except RuntimeError:
    # Already set or platform doesn't support 'fork'
    pass

# -----------------------------
# 0) Helpers: ranking + lists
# -----------------------------


# -----------------------------------------
# 1) Worker: render one gene (uses global adata)
# -----------------------------------------
def _render_gene_png_worker(gene: str,
                            out_png: str,
                            segment_key: str,
                            normalize_total: bool,
                            log1p: bool,
                            title_prefix: str,
                            tile_figsize: tuple,
                            tile_dpi: int):
    """
    Executed in a child process.
    Uses globally inherited `adata` and `SPIX` (via fork on Linux).
    - Compute single-gene embedding -> dimensions=[0]
    - Plot and save PNG
    """
    # Import here so that child process resolves the same modules
    import SPIX
    import matplotlib.pyplot as plt

    global adata  # inherited via fork (Linux). On spawn, this may copy/pickle (slower).

    # Gracefully handle missing gene
    if str(gene) not in set(map(str, adata.var_names)):
        fig = plt.figure(figsize=tile_figsize)
        plt.text(0.5, 0.5, f"{title_prefix}{gene}\n(not in var_names)", ha='center', va='center')
        plt.axis('off')
        fig.savefig(out_png, dpi=tile_dpi, bbox_inches='tight')
        plt.close(fig)
        return out_png

    # Compute single-gene embedding (overwrites X_gene_embedding in THIS process only)
    SPIX.an.add_gene_expression_embedding(
        adata,
        genes=[gene],
        segment_key=segment_key,
        normalize_total=normalize_total,
        log1p=log1p
    )

    # Plot using the single dimension [0]
    SPIX.pl.image_plot(
        adata,
        dimensions=[0],                  # single dim only
        embedding='X_gene_embedding',
        boundary_method='pixel',
        imshow_tile_size=10,
        imshow_scale_factor=1,
        figsize=tile_figsize,
        fixed_boundary_color='Black',
        cmap='viridis',
        boundary_linewidth=1,
        show_colorbar=True,
        prioritize_high_values=True,
        title=f"{title_prefix}{gene}",
        alpha=1,
        plot_boundaries=False,
        origin=True
    )
    plt.savefig(out_png, dpi=tile_dpi, bbox_inches='tight')
    plt.close()

    # Cleanup in child
    gc.collect()
    return out_png

# ------------------------------
# 2) Grid assembly
# ------------------------------
def _assemble_grid(tile_paths: List[str], out_path: str, cols: int = 5, suptitle: str = "", dpi: int = 150):
    """Assemble saved PNG tiles into a grid figure."""
    rows = math.ceil(len(tile_paths)/cols) if tile_paths else 1
    fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4), squeeze=False)
    if suptitle:
        fig.suptitle(suptitle, fontsize=16)
    i = 0
    for r in range(rows):
        for c in range(cols):
            ax = axes[r][c]
            ax.axis('off')
            if i < len(tile_paths):
                img = Image.open(tile_paths[i])
                ax.imshow(img)
            i += 1
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close(fig)

# -------------------------------------------------------
# 3) Orchestrator: real parallel pipeline per group
# -------------------------------------------------------
def build_grid_for_group_parallel(group_name: str,
                                  genes: List[str],
                                  out_dir: str = "./spix_grids",
                                  cols: int = 5,
                                  tile_figsize=(10,10),
                                  tile_dpi=150,
                                  segment_key='Segment',
                                  normalize_total=True,
                                  log1p=True,
                                  max_workers=6):
    """Run per-gene SPIX rendering in parallel (per-process), then assemble grid."""
    if not genes:
        warnings.warn(f"[{group_name}] No genes to render.")
        return None

    os.makedirs(out_dir, exist_ok=True)
    tmp_dir = os.path.join(out_dir, f"tiles_{group_name}")
    os.makedirs(tmp_dir, exist_ok=True)

    out_png = os.path.join(out_dir, f"grid_{group_name}.png")

    # Submit jobs
    futures = []
    ordered_tiles = [os.path.join(tmp_dir, f"{g}.png") for g in genes]
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        for g, png in zip(genes, ordered_tiles):
            futures.append(
                ex.submit(
                    _render_gene_png_worker,
                    g, png,
                    segment_key, normalize_total, log1p,
                    f"{group_name} | ",
                    tile_figsize, tile_dpi
                )
            )
        # Ensure all complete (also catches exceptions early)
        for f in as_completed(futures):
            _ = f.result()

    # Assemble in original order
    existing_tiles = [p for p in ordered_tiles if os.path.exists(p)]
    _assemble_grid(existing_tiles, out_path=out_png, cols=cols, suptitle=f"{group_name} (n={len(existing_tiles)})", dpi=tile_dpi)

    print(f"[Saved] {out_png}")
    return out_png


In [None]:
matplotlib.use("Agg")  # no GUI backend inside workers
# ==============================================================
# 4) RUN: make top lists & build grids
# ==============================================================

TOP_K = 10                 # number of genes per group
GROUP_COLS = 5             # columns in grid
TILE_FIGSIZE = (10, 10)    # per-tile SPIX figure size
TILE_DPI = 400
MAX_WORKERS = 10
SEGMENT_KEY = "Segment"
NORMALIZE_TOTAL = True
LOG1P = True

OUT_DIR = "./spix_1007_MOSTA_multiscale"

build_grid_for_group_parallel(
    group_name=f"moran_late_500_top{TOP_K}",
    genes=df[df['category']=='late'].sort_values('mean_late').head(10).index.tolist(),
    out_dir=OUT_DIR,
    cols=GROUP_COLS,
    tile_figsize=TILE_FIGSIZE,
    tile_dpi=TILE_DPI,
    segment_key=SEGMENT_KEY,
    normalize_total=NORMALIZE_TOTAL,
    log1p=LOG1P,
    max_workers=MAX_WORKERS
)

build_grid_for_group_parallel(
    group_name=f"moran_late_m_mid_500_top{TOP_K}",
    genes=df[df['category']=='late'].sort_values('late_m_mid',ascending=True).head(10).index.tolist(),
    out_dir=OUT_DIR,
    cols=GROUP_COLS,
    tile_figsize=TILE_FIGSIZE,
    tile_dpi=TILE_DPI,
    segment_key=SEGMENT_KEY,
    normalize_total=NORMALIZE_TOTAL,
    log1p=LOG1P,
    max_workers=MAX_WORKERS
)

import matplotlib
import matplotlib.pyplot as plt

# Switch back to Jupyter's inline backend
plt.switch_backend('module://matplotlib_inline.backend_inline')
