# scGPT Gene Masking → Reconstruction → C2S Cell-Type Evaluation

**Pipeline:**
1. Load 24 donor cells (same sample as `c2s_donor_celltype_prediction.ipynb`)
2. **Mask** – randomly zero out `MASK_FRACTION` of each cell's expressed genes
3. **Reconstruct** – feed masked cells through `tdc/scGPT` (with `-1` mask tokens)
   and replace masked positions with scGPT's predicted expression values
4. **Evaluate** – run C2S cell-type prediction on three AnnData objects:
   - `original` (clean baseline)
   - `masked`   (corrupted, masked genes → 0)
   - `reconstructed` (scGPT-repaired)
5. **Compare** results side-by-side using the 3-tier accuracy metric
   (Exactly correct / Partly correct / Not correct)

In [None]:
# Optional: install dependencies
# %pip install -q cell2sentence anndata scanpy datasets transformers pandas numpy scipy PyTDC torch

In [None]:
from pathlib import Path
import re
import random

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import torch

import cell2sentence as cs
from cell2sentence.tasks import predict_cell_types_of_data

  from .autonotebook import tqdm as notebook_tqdm
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]


In [None]:
# ── Configuration ─────────────────────────────────────────────────────────────
SEED          = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

H5AD_PATH     = Path('../../data/dominguez_conde_immune_tissue_two_donors.h5ad')
DONOR_COLUMN  = 'batch_condition'
DONOR_VALUE   = 'A29'
N_CELLS       = 24
TOP_K_GENES   = 200          # genes passed to C2S
MASK_FRACTION = 0.40         # fraction of expressed genes to mask per cell
MASK_VALUE    = -1.0         # scGPT sentinel for masked positions (must be != 0)
C2S_MODEL     = 'vandijklab/C2S-Pythia-410m-cell-type-prediction'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', DEVICE)

assert H5AD_PATH.exists(), f'File not found: {H5AD_PATH.resolve()}'

Device: cpu


## 1 · Load data & sample 24 cells

In [None]:
adata = ad.read_h5ad(H5AD_PATH)
print('Full dataset shape:', adata.shape)

adata_donor = adata[adata.obs[DONOR_COLUMN] == DONOR_VALUE].copy()
rng = np.random.default_rng(SEED)
idx = rng.choice(adata_donor.n_obs, size=N_CELLS, replace=False)
adata_small = adata_donor[idx].copy()

print(f'Donor {DONOR_VALUE}: {adata_donor.n_obs} cells  →  sampled {adata_small.n_obs}')
print(adata_small.obs['cell_type'].value_counts().head(10))

Full dataset shape: (29773, 36503)
Donor A29: 17327 cells  →  sampled 24
cell_type
macrophage                                               4
memory B cell                                            3
naive thymus-derived CD4-positive, alpha-beta T cell     3
T follicular helper cell                                 2
classical monocyte                                       2
plasma cell                                              2
CD4-positive helper T cell                               1
CD16-negative, CD56-bright natural killer cell, human    1
alveolar macrophage                                      1
effector memory CD4-positive, alpha-beta T cell          1
Name: count, dtype: int64


## 2 · Preprocessing (normalize + log1p)

In [None]:
adata_small.var_names_make_unique()
sc.pp.normalize_total(adata_small, target_sum=1e4)
sc.pp.log1p(adata_small)

# Dense expression matrix  (n_cells × n_genes)
import scipy.sparse as sp
X_orig = adata_small.X.toarray() if sp.issparse(adata_small.X) else adata_small.X.copy()
gene_names = np.array(adata_small.var_names.tolist())

print('Expression matrix:', X_orig.shape)
print(f'Sparsity: {(X_orig == 0).mean()*100:.1f}% zeros per cell on average')

Expression matrix: (24, 36503)
Sparsity: 95.3% zeros per cell on average


## 3 · Random gene masking

For each cell, `MASK_FRACTION` of the expressed (non-zero) genes are set to  
`MASK_VALUE = -1` internally. A separate zero-copy is created for C2S  
(since C2S expects non-negative expression).

In [None]:
# X_scgpt_in : expression matrix sent to scGPT  (masked genes = -1, kept in tokenizer)
# X_masked   : expression matrix for C2S masked baseline  (masked genes = 0)
# mask_record: list of boolean arrays tracking which gene positions were masked

rng_mask = np.random.default_rng(SEED)

X_scgpt_in = X_orig.copy()
X_masked   = X_orig.copy()
mask_record = []          # mask_record[i] -> indices of masked genes in cell i

total_masked = 0
for i in range(X_orig.shape[0]):
    expressed_idx = np.where(X_orig[i] > 0)[0]
    n_mask = max(1, int(len(expressed_idx) * MASK_FRACTION))
    masked_idx = rng_mask.choice(expressed_idx, size=n_mask, replace=False)

    X_scgpt_in[i, masked_idx] = MASK_VALUE  # -1  → tokenizer keeps these
    X_masked[i, masked_idx]   = 0.0         # 0   → removed from C2S

    mask_record.append(masked_idx)
    total_masked += n_mask

print(f'Masked {total_masked} genes across {N_CELLS} cells'
      f' (~{total_masked/N_CELLS:.0f} per cell, {MASK_FRACTION*100:.0f}% of expressed)')

Masked 16337 genes across 24 cells (~681 per cell, 40% of expressed)



## 4 · scGPT reconstruction  *(no TDC – native scgpt package)*

We load the model via `huggingface_hub.snapshot_download` and the native
`scgpt.model.TransformerModel` + `scgpt.tokenizer.GeneVocab`.

**Masking strategy:** masked genes are set to `-1` (scGPT's standard mask sentinel).
Since `-1 ≠ 0`, they stay in the gene sequence passed to the model.  
`mlm_output` predicts expression at **every position** – including the `-1` ones.  
Those predicted values replace the masked genes in `X_reconstructed`.


In [None]:

import json
from pathlib import Path
from huggingface_hub import snapshot_download
from scgpt.model import TransformerModel
from scgpt.tokenizer import GeneVocab

# ── Download weights from HuggingFace ────────────────────────────────────────
SCGPT_HF   = 'tdc/scGPT'          # change to 'bowang-lab/scGPT_human' if needed
SCGPT_DIR  = Path('./scgpt_checkpoint')

print(f'Downloading {SCGPT_HF} from HuggingFace …')
local_dir = snapshot_download(SCGPT_HF, local_dir=str(SCGPT_DIR))
local_dir  = Path(local_dir)
print('Files:', [f.name for f in local_dir.iterdir()])

# ── Vocabulary ────────────────────────────────────────────────────────────────
vocab_path = local_dir / 'vocab.json'
assert vocab_path.exists(), f'vocab.json not found in {local_dir}'
vocab  = GeneVocab.from_file(str(vocab_path))
PAD_ID = vocab['<pad>'] if '<pad>' in vocab else 0
UNK_ID = vocab['<unk>'] if '<unk>' in vocab else PAD_ID
print(f'Vocabulary size: {len(vocab)}  |  PAD={PAD_ID}  UNK={UNK_ID}')

# ── Model config ──────────────────────────────────────────────────────────────
config_path = local_dir / 'args.json'   # scGPT saves as args.json
if not config_path.exists():
    config_path = local_dir / 'config.json'
with open(config_path) as f:
    cfg = json.load(f)
print('Config keys:', list(cfg.keys())[:10])

# ── Build model architecture ──────────────────────────────────────────────────
scgpt_model = TransformerModel(
    ntoken          = len(vocab),
    d_model         = cfg.get('embsize',     cfg.get('d_model',  512)),
    nhead           = cfg.get('nheads',      cfg.get('nhead',      8)),
    d_hid           = cfg.get('d_hid',       cfg.get('d_hid',    512)),
    nlayers         = cfg.get('nlayers',     12),
    nlayers_cls     = cfg.get('nlayers_cls', 3),
    n_cls           = 1,          # not used at inference
    dropout         = 0.0,
    pad_token       = '<pad>',
    pad_value       = PAD_ID,
    do_mvc          = cfg.get('GEPC',        cfg.get('do_mvc',  False)),
    do_dab          = cfg.get('do_dab',      False),
    use_batch_labels= cfg.get('use_batch_labels', False),
    num_batch_labels= cfg.get('num_batch_labels', 1),
    domain_spec_batchnorm = cfg.get('dsbn', False),
    input_emb_style = cfg.get('input_emb_style', 'continuous'),
    n_input_bins    = cfg.get('n_input_bins', 51),
    explicit_zero_prob = cfg.get('explicit_zero_prob', False),
    use_fast_transformer = False,   # disable flash-attn (not needed on CPU)
    pre_norm        = cfg.get('pre_norm', False),
)

# ── Load weights ──────────────────────────────────────────────────────────────
ckpt_path = local_dir / 'best_model.pt'
if not ckpt_path.exists():
    ckpt_path = local_dir / 'model.pt'
assert ckpt_path.exists(), f'No model weights found in {local_dir}'

state = torch.load(ckpt_path, map_location='cpu')
if isinstance(state, dict) and 'model_state_dict' in state:
    state = state['model_state_dict']
missing, unexpected = scgpt_model.load_state_dict(state, strict=False)
if missing:
    print(f'Missing keys ({len(missing)}): {missing[:5]} …')
if unexpected:
    print(f'Unexpected keys ({len(unexpected)}): {unexpected[:5]} …')

scgpt_model = scgpt_model.to(DEVICE)
scgpt_model.eval()
print(f'\nscGPT ready on {DEVICE}.')


In [None]:

from tqdm.auto import tqdm

def gene_names_to_ids(names):
    """Map gene name strings → scGPT vocabulary IDs (UNK_ID for out-of-vocab genes)."""
    return np.array([vocab[g] if g in vocab else UNK_ID for g in names], dtype=np.int64)

# X_reconstructed: copy of original; masked gene positions will be overwritten
X_reconstructed = X_orig.copy()
n_in_vocab  = 0
n_recovered = 0

with torch.no_grad():
    for i in tqdm(range(N_CELLS), desc='scGPT reconstruction'):

        cell_expr = X_scgpt_in[i]          # shape (n_genes,); -1 for masked genes

        # ── Select non-zero genes (includes -1 masked ones) ──────────────────
        nonzero_idx = np.where(cell_expr != 0)[0]
        if len(nonzero_idx) == 0:
            continue

        genes_sel  = gene_names[nonzero_idx]                    # gene name strings
        vals_sel   = cell_expr[nonzero_idx].astype(np.float32)  # expr values (incl. -1)
        gene_ids   = gene_names_to_ids(genes_sel)               # vocab IDs

        # ── Build tensors ─────────────────────────────────────────────────────
        src_t   = torch.from_numpy(gene_ids).unsqueeze(0).to(DEVICE)    # (1, L)
        vals_t  = torch.from_numpy(vals_sel).unsqueeze(0).to(DEVICE)    # (1, L)
        # padding mask: False = attend, True = ignore  (no padding here)
        pad_mask = torch.zeros(1, len(gene_ids), dtype=torch.bool, device=DEVICE)

        n_in_vocab += int((gene_ids != UNK_ID).sum())

        # ── scGPT forward pass ────────────────────────────────────────────────
        # mlm_output has shape (1, L) or (1, L, 1) – predicts expression at
        # every position, including the masked (-1) ones.
        out = scgpt_model(
            src                  = src_t,
            values               = vals_t,
            src_key_padding_mask = pad_mask,
            CLS                  = False,
            output_hidden_states = False,
        )

        mlm = out.get('mlm_output')
        if mlm is None:
            print(f'  Cell {i}: mlm_output not in scGPT output '
                  f'(keys: {list(out.keys())}) – skipping')
            continue

        if mlm.ndim == 3:
            mlm = mlm.squeeze(-1)           # (1, L, 1) → (1, L)
        mlm_np = mlm.squeeze(0).cpu().float().numpy()   # (L,)

        # ── Find masked positions and fill with scGPT predictions ─────────────
        masked_in_local = np.where(vals_sel < 0)[0]    # positions where value was -1
        if len(masked_in_local) == 0:
            continue

        masked_orig_idx = nonzero_idx[masked_in_local]  # back to global gene indices
        predicted_vals  = np.clip(mlm_np[masked_in_local], 0.0, None)

        X_reconstructed[i, masked_orig_idx] = predicted_vals
        n_recovered += len(masked_orig_idx)

print(f'\nReconstruction complete.')
print(f'  Avg genes in scGPT vocab / cell : {n_in_vocab / N_CELLS:.0f}')
print(f'  Masked genes recovered          : {n_recovered} / {total_masked}'
      f'  ({n_recovered / total_masked * 100:.1f} %)')


## 5 · Build three AnnData objects

| Name | Expression matrix | Description |
|---|---|---|
| `adata_original` | `X_orig` | Clean, unmodified cells |
| `adata_masked` | `X_masked` | Corrupted: `MASK_FRACTION` genes zeroed |
| `adata_reconstructed` | `X_reconstructed` | scGPT-repaired: masked genes filled with predicted expression |

In [None]:
import scipy.sparse as sp

def make_adata(X_new, template_adata):
    """Clone the obs/var metadata from template and replace the expression matrix."""
    a = ad.AnnData(
        X=sp.csr_matrix(X_new),
        obs=template_adata.obs.copy(),
        var=template_adata.var.copy(),
    )
    return a

adata_original     = make_adata(X_orig,         adata_small)
adata_masked       = make_adata(X_masked,        adata_small)
adata_reconstructed = make_adata(X_reconstructed, adata_small)

print('AnnData objects created:')
for name, a in [('original', adata_original), ('masked', adata_masked), ('reconstructed', adata_reconstructed)]:
    mean_nonzero = (a.X.toarray() > 0).mean()
    print(f'  {name:15s}  non-zero fraction: {mean_nonzero:.3f}')

## 6 · C2S Cell-Type Prediction (3×)

> **Runtime note:** each C2S call takes ~9 min on CPU (22 s/cell × 24 cells).  
> Total ≈ 27 min. Grab a coffee ☕

In [None]:
label_cols = [c for c in ['cell_type', DONOR_COLUMN, 'tissue', 'sex', 'organism']
              if c in adata_small.obs.columns]

# Load C2S model once (we'll reuse it across all three predictions)
csmodel = cs.CSModel(
    model_name_or_path=C2S_MODEL,
    save_dir='./tmp_c2s_reconstruction_model',
    save_name='pretrained_c2s_inference',
)
print('C2S model loaded:', C2S_MODEL)


def run_c2s(adata_in, tag, csmodel_in):
    """Run C2S cell-type prediction and return a DataFrame with y_true / y_pred."""
    print(f'\n─── C2S: {tag} ───')
    arrow_ds, vocab = cs.CSData.adata_to_arrow(
        adata_in,
        random_state=SEED,
        sentence_delimiter=' ',
        label_col_names=label_cols,
    )
    csdata = cs.CSData.csdata_from_arrow(
        arrow_dataset=arrow_ds,
        vocabulary=vocab,
        save_dir=f'./tmp_c2s_{tag}',
        save_name='data',
        dataset_backend='arrow',
    )
    preds = predict_cell_types_of_data(
        csdata=csdata,
        csmodel=csmodel_in,
        n_genes=TOP_K_GENES,
        max_num_tokens=32,
    )
    return pd.DataFrame({
        'cell_id': adata_in.obs_names.astype(str),
        'y_true' : adata_in.obs['cell_type'].astype(str).values,
        'y_pred' : [str(p).strip() for p in preds],
        'version': tag,
    })

In [None]:
df_original = run_c2s(adata_original, 'original', csmodel)

In [None]:
df_masked = run_c2s(adata_masked, 'masked', csmodel)

In [None]:
df_reconstructed = run_c2s(adata_reconstructed, 'reconstructed', csmodel)

## 7 · 3-Tier Accuracy Evaluation

In [None]:
def normalize(text):
    text = str(text).strip().rstrip('.').lower()
    return re.sub(r'\s+', ' ', text)

def classify(y_true, y_pred):
    """
    2 = Exactly correct  (normalized match)
    1 = Partly correct   (substring containment OR Jaccard word overlap ≥ 30%)
    0 = Not correct
    """
    t, p = normalize(y_true), normalize(y_pred)
    if t == p:
        return 2
    if p in t or t in p:
        return 1
    t_w = {w for w in re.findall(r'\b\w+\b', t) if len(w) > 2}
    p_w = {w for w in re.findall(r'\b\w+\b', p) if len(w) > 2}
    if t_w and p_w and len(t_w & p_w) / len(t_w | p_w) >= 0.30:
        return 1
    return 0

label_map = {2: 'Exactly correct', 1: 'Partly correct', 0: 'Not correct'}

def score_df(df):
    df = df.copy()
    df['score']   = df.apply(lambda r: classify(r['y_true'], r['y_pred']), axis=1)
    df['verdict'] = df['score'].map(label_map)
    return df

df_original     = score_df(df_original)
df_masked       = score_df(df_masked)
df_reconstructed = score_df(df_reconstructed)

print('Scoring complete.')

In [None]:
def summary(df, name):
    n = len(df)
    rows = []
    for s in [2, 1, 0]:
        cnt = (df['score'] == s).sum()
        rows.append({'Version': name, 'Tier': label_map[s],
                     'Count': cnt, 'Pct': cnt / n * 100})
    return rows

rows = []
for df, name in [(df_original, 'Original (clean)'),
                 (df_masked,   f'Masked ({int(MASK_FRACTION*100)}% zeroed)'),
                 (df_reconstructed, 'Reconstructed (scGPT)')]:
    rows.extend(summary(df, name))

summary_df = pd.DataFrame(rows)

print('=' * 65)
print('  3-TIER ACCURACY COMPARISON')
print('=' * 65)
for name, grp in summary_df.groupby('Version', sort=False):
    print(f'\n  {name}')
    for _, row in grp.iterrows():
        bar = '█' * int(row['Pct'] / 5)
        print(f'    {row["Tier"]:<20}: {row["Count"]:>2}/{N_CELLS}  ({row["Pct"]:>5.1f}%)  {bar}')
print('=' * 65)

# Combined metric: Exact + Partial
print('\n  Exact + Partial (score ≥ 1):')
for df, name in [(df_original, 'Original (clean)'),
                 (df_masked,   f'Masked ({int(MASK_FRACTION*100)}% zeroed)'),
                 (df_reconstructed, 'Reconstructed (scGPT)')]:
    ep = (df['score'] >= 1).sum()
    print(f'    {name:<30}: {ep}/{N_CELLS}  ({ep/N_CELLS*100:.1f}%)')

In [None]:
# ── Per-cell side-by-side view ────────────────────────────────────────────────
compare = df_original[['cell_id', 'y_true']].copy()
compare['pred_original']     = df_original['y_pred']
compare['score_original']    = df_original['score']
compare['pred_masked']       = df_masked['y_pred']
compare['score_masked']      = df_masked['score']
compare['pred_reconstructed']  = df_reconstructed['y_pred']
compare['score_reconstructed'] = df_reconstructed['score']
compare['scgpt_helped'] = compare['score_reconstructed'] > compare['score_masked']
compare['scgpt_hurt']   = compare['score_reconstructed'] < compare['score_masked']

print(f'scGPT improved prediction: {compare["scgpt_helped"].sum()} cells')
print(f'scGPT hurt prediction    : {compare["scgpt_hurt"].sum()} cells')
print(f'No change                : {(~compare["scgpt_helped"] & ~compare["scgpt_hurt"]).sum()} cells')

pd.set_option('display.max_colwidth', 55)
compare[['y_true',
         'pred_masked', 'score_masked',
         'pred_reconstructed', 'score_reconstructed',
         'scgpt_helped']].style.apply(
    lambda col: ['background-color: #c8e6c9' if v else
                 'background-color: #ffcdd2' if col.name == 'scgpt_hurt' else ''
                 for v in col], subset=['scgpt_helped']
)

In [None]:
# ── Save full results ─────────────────────────────────────────────────────────
out = Path('./scgpt_reconstruction_results.csv')
compare.to_csv(out, index=False)
print('Saved:', out.resolve())