# Simulated data for testing residual adapter

In [1]:
import os, warnings
warnings.filterwarnings("ignore")

import numpy as np
import scanpy as sc
import torch
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score
import pandas as pd

import scadver
from scadver import (
    adversarial_batch_correction,
    transform_query_adaptive,
    set_global_seed,
)

SEED = 42
set_global_seed(SEED)

print("‚úÖ Libraries imported successfully!")
print(f"   Scanpy  : {sc.__version__}")
print(f"   PyTorch : {torch.__version__}")
print(f"   ScAdver : {scadver.__version__}")
print(f"   Device  : {'MPS' if torch.backends.mps.is_available() else 'CPU'}")


‚úÖ Libraries imported successfully!
   Scanpy  : 1.11.4
   PyTorch : 2.8.0
   ScAdver : 1.1.0
   Device  : MPS


In [4]:
adata = sc.read_h5ad("simulation.h5ad")
adata

AnnData object with n_obs √ó n_vars = 8000 √ó 3839
    obs: 'Source', 'imaging_batch', 'perturbation'
    uns: 'Source_colors', 'imaging_batch_colors', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [9]:
adata.obs.head()

Unnamed: 0,Source,imaging_batch,perturbation
0,AZ,AZ_batch_2,AZ14453229
1,AZ,AZ_batch_2,AZ14170101
2,AZ,AZ_batch_1,AZ14443865
3,AZ,AZ_batch_4,AZ14335593
4,AZ,AZ_batch_1,AZ14480178


In [6]:
adata.obs["imaging_batch"].value_counts()

imaging_batch
AZ_batch_2          1052
AZ_batch_1          1016
AZ_batch_5          1012
AZ_batch_4           961
AZ_batch_3           959
Phenaros_batch_3     782
Phenaros_batch_2     761
Phenaros_batch_1     737
Phenaros_batch_4     720
Name: count, dtype: int64

In [7]:
adata.obs['Source'].value_counts()

Source
AZ          5000
Phenaros    3000
Name: count, dtype: int64

In [8]:
adata_ref = adata[adata.obs['Source'] == 'AZ']
adata_query = adata[adata.obs['Source'] == 'Phenaros']

In [10]:
print(f"\nData split:")
print(f"  Reference : {adata_ref.shape[0]:,} cells  ‚Üí  {dict(adata_ref.obs['imaging_batch'].value_counts())}")
print(f"  Query     : {adata_query.shape[0]:,} cells  ‚Üí  {dict(adata_query.obs['imaging_batch'].value_counts())}")


Data split:
  Reference : 5,000 cells  ‚Üí  {'AZ_batch_2': np.int64(1052), 'AZ_batch_1': np.int64(1016), 'AZ_batch_5': np.int64(1012), 'AZ_batch_4': np.int64(961), 'AZ_batch_3': np.int64(959)}
  Query     : 3,000 cells  ‚Üí  {'Phenaros_batch_3': np.int64(782), 'Phenaros_batch_2': np.int64(761), 'Phenaros_batch_1': np.int64(737), 'Phenaros_batch_4': np.int64(720)}


In [11]:
%%time
adata_ref_corrected, model, ref_metrics = adversarial_batch_correction(
    adata=adata_ref,
    bio_label="perturbation",
    batch_label="imaging_batch",
    latent_dim=256,
    epochs=500,
    bio_weight=20.0,
    batch_weight=0.5,
    learning_rate=0.001,
    device="auto",
    return_reconstructed=True,
    seed=SEED,
)

print("\n‚úÖ Reference training complete!")
print(f"   Latent embedding         : {adata_ref_corrected.obsm['X_ScAdver'].shape}")
print(f"   Reconstructed expression : {adata_ref_corrected.layers['ScAdver_reconstructed'].shape}")
print("\nüìà Reference metrics:")
for k, v in ref_metrics.items():
    print(f"   {k}: {v:.4f}")


üöÄ ADVERSARIAL BATCH CORRECTION
   Device: mps
üìä DATA PREPARATION:
   Valid samples: 5000/5000
   Input shape: (5000, 3839)
   Biology labels: 2308 unique
   Batch labels: 5 unique
   Training on all provided data (global correction mode)
üß† MODEL ARCHITECTURE:
   Input dimension: 3839
   Latent dimension: 256
   Biology classes: 2308
   Batch classes: 5
üèãÔ∏è TRAINING MODEL:
   Epochs: 500
   Learning rate: 0.001
   Bio weight: 20.0
   Batch weight: 0.5
   Epoch 100/500 - Bio accuracy (Reference): 0.181 (best: 0.181)
   Epoch 200/500 - Bio accuracy (Reference): 0.603 (best: 0.603)
   Epoch 300/500 - Bio accuracy (Reference): 0.800 (best: 0.800)
   Epoch 400/500 - Bio accuracy (Reference): 0.866 (best: 0.866)
   Epoch 500/500 - Bio accuracy (Reference): 0.916 (best: 0.916)
‚úÖ Training completed! Best biology accuracy: 0.916
üîÑ GENERATING CORRECTED EMBEDDING:
   Output embedding shape: (5000, 256)
   Reconstructed expression shape: (5000, 3839)
   ‚úÖ Batch-corrected gene ex

In [13]:
%%time
adata_query_corrected = transform_query_adaptive(
    model=model,
    adata_query=adata_query,
    adata_reference=adata_ref,
    bio_label="perturbation",
    adaptation_epochs=200,

    warmup_epochs=40,
    patience=30,
    learning_rate=0.0005,
    device="auto",
    return_reconstructed=True,
    seed=SEED,
)

print("\n‚úÖ Query projection complete!")
print(f"   Latent embedding         : {adata_query_corrected.obsm['X_ScAdver'].shape}")
print(f"   Reconstructed expression : {adata_query_corrected.layers['ScAdver_reconstructed'].shape}")


ü§ñ AUTO-DETECTING DOMAIN SHIFT...
   Strategy: Train test adapter and measure residual magnitude
   üìä Residual Adapter Analysis:
      Residual Magnitude (||R||): 3.3903
      Residual Std Dev: 0.1222
   üéØ Decision: ADAPTER NEEDED
      Confidence: HIGH
   üí° Residual R > 0: Domain shift detected ‚Äî using adapter


üî¨ ADAPTIVE QUERY PROJECTION (Enhanced)
   Device: mps
   Query samples: 3000

üèóÔ∏è  Initializing enhanced residual adapter...
   Architecture: 256 ‚Üí [128]*3 ‚Üí 256  (tanh-bounded, learnable scale)
   Initial adapter scale: 0.0100
   Reference samples for alignment: 5000
   Biological supervision: perturbation (1680 classes)

üèãÔ∏è  Training enhanced residual adapter...
   Epochs: 200  |  Warmup: 40  |  Patience: 30
   Losses: adversarial + MMD + CORAL + moment + bio + reconstruction
   Epoch   1/200 | Adapter: 347.9632 | Disc: 1.3859 | Align: 0.1639 | Scale: 0.0218 | LR: 0.000500 | Warmup: 0.50  üíæ best
   Epoch  10/200 | Adapter: 301.9136 | Disc: 1.1

In [None]:
adata_ref_corrected.obs["Source"]   = "AZ"
adata_query_corrected.obs["Source"] = "Phenaros"
adata_all = sc.concat([adata_ref_corrected, adata_query_corrected])

print("Computing UMAP...")
sc.pp.neighbors(adata_all, use_rep="X_ScAdver", n_neighbors=15)
sc.tl.umap(adata_all)
print("UMAP computed.")

fig, axes = plt.subplots(1, 2, figsize=(15, 6))
sc.pl.umap(adata_all, color="Source",   ax=axes[0], show=False, title="Data Source")
sc.pl.umap(adata_all, color="imaging_batch",     ax=axes[1], show=False, title="Batch Corrected (ScAdver)")
plt.tight_layout()


Computing UMAP...
