# scGPT Gene Masking → Reconstruction → C2S Cell-Type Evaluation

**Pipeline:**
1. Load 24 donor cells (same sample as `c2s_donor_celltype_prediction.ipynb`)
2. **Mask** – randomly zero out `MASK_FRACTION` of each cell's expressed genes
3. **Reconstruct** – feed masked cells through `tdc/scGPT` (with `-1` mask tokens)
   and replace masked positions with scGPT's predicted expression values
4. **Evaluate** – run C2S cell-type prediction on three AnnData objects:
   - `original` (clean baseline)
   - `masked`   (corrupted, masked genes → 0)
   - `reconstructed` (scGPT-repaired)
5. **Compare** results side-by-side using the 3-tier accuracy metric
   (Exactly correct / Partly correct / Not correct)

In [18]:

# No extra installs needed – everything required is already in the environment.
# (torch, numpy, scanpy, anndata, cell2sentence, transformers, tqdm)
print('Environment ready.')


Environment ready.


In [19]:
from pathlib import Path
import re
import random

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import torch

import cell2sentence as cs
from cell2sentence.tasks import predict_cell_types_of_data

In [20]:

# ── Configuration ─────────────────────────────────────────────────────────────
SEED          = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

H5AD_PATH     = Path('../../data/dominguez_conde_immune_tissue_two_donors.h5ad')
DONOR_COLUMN  = 'batch_condition'
DONOR_VALUE   = 'A29'
N_CELLS       = 24
TOP_K_GENES   = 200          # genes passed to C2S
MASK_FRACTION = 0.40         # fraction of expressed genes to mask per cell
# MASK_VALUE and PAD_VALUE come from the model's args.json (loaded later):
#   mask_value = -1  (sentinel for positions to reconstruct)
#   pad_value  = -2  (padding, distinct from mask)
C2S_MODEL     = 'vandijklab/C2S-Pythia-410m-cell-type-prediction'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', DEVICE)

assert H5AD_PATH.exists(), f'File not found: {H5AD_PATH.resolve()}'


Device: cpu


## 1 · Load data & sample 24 cells

In [21]:
adata = ad.read_h5ad(H5AD_PATH)
print('Full dataset shape:', adata.shape)

adata_donor = adata[adata.obs[DONOR_COLUMN] == DONOR_VALUE].copy()
rng = np.random.default_rng(SEED)
idx = rng.choice(adata_donor.n_obs, size=N_CELLS, replace=False)
adata_small = adata_donor[idx].copy()

print(f'Donor {DONOR_VALUE}: {adata_donor.n_obs} cells  →  sampled {adata_small.n_obs}')
print(adata_small.obs['cell_type'].value_counts().head(10))

Full dataset shape: (29773, 36503)
Donor A29: 17327 cells  →  sampled 24
cell_type
macrophage                                               4
memory B cell                                            3
naive thymus-derived CD4-positive, alpha-beta T cell     3
T follicular helper cell                                 2
classical monocyte                                       2
plasma cell                                              2
CD4-positive helper T cell                               1
CD16-negative, CD56-bright natural killer cell, human    1
alveolar macrophage                                      1
effector memory CD4-positive, alpha-beta T cell          1
Name: count, dtype: int64


## 2 · Preprocessing (normalize + log1p)

In [22]:
adata_small.var_names_make_unique()
sc.pp.normalize_total(adata_small, target_sum=1e4)
sc.pp.log1p(adata_small)

# Dense expression matrix  (n_cells × n_genes)
import scipy.sparse as sp
X_orig = adata_small.X.toarray() if sp.issparse(adata_small.X) else adata_small.X.copy()
gene_names = np.array(adata_small.var_names.tolist())

print('Expression matrix:', X_orig.shape)
print(f'Sparsity: {(X_orig == 0).mean()*100:.1f}% zeros per cell on average')

Expression matrix: (24, 36503)
Sparsity: 95.3% zeros per cell on average


## 3 · Random gene masking

For each cell, `MASK_FRACTION` of the expressed (non-zero) genes are set to  
`MASK_VALUE = -1` internally. A separate zero-copy is created for C2S  
(since C2S expects non-negative expression).

In [23]:
# X_scgpt_in : expression matrix sent to scGPT  (masked genes = -1, kept in tokenizer)
# X_masked   : expression matrix for C2S masked baseline  (masked genes = 0)
# mask_record: list of boolean arrays tracking which gene positions were masked

rng_mask = np.random.default_rng(SEED)

X_scgpt_in = X_orig.copy()
X_masked   = X_orig.copy()
mask_record = []          # mask_record[i] -> indices of masked genes in cell i

total_masked = 0
for i in range(X_orig.shape[0]):
    expressed_idx = np.where(X_orig[i] > 0)[0]
    n_mask = max(1, int(len(expressed_idx) * MASK_FRACTION))
    masked_idx = rng_mask.choice(expressed_idx, size=n_mask, replace=False)

    X_scgpt_in[i, masked_idx] = MASK_VALUE  # -1  → tokenizer keeps these
    X_masked[i, masked_idx]   = 0.0         # 0   → removed from C2S

    mask_record.append(masked_idx)
    total_masked += n_mask

print(f'Masked {total_masked} genes across {N_CELLS} cells'
      f' (~{total_masked/N_CELLS:.0f} per cell, {MASK_FRACTION*100:.0f}% of expressed)')

Masked 16337 genes across 24 cells (~681 per cell, 40% of expressed)



## 4 · scGPT reconstruction  *(no TDC – native scgpt package)*

We load the model via `huggingface_hub.snapshot_download` and the native
`scgpt.model.TransformerModel` + `scgpt.tokenizer.GeneVocab`.

**Masking strategy:** masked genes are set to `-1` (scGPT's standard mask sentinel).
Since `-1 ≠ 0`, they stay in the gene sequence passed to the model.  
`mlm_output` predicts expression at **every position** – including the `-1` ones.  
Those predicted values replace the masked genes in `X_reconstructed`.


In [24]:

# ── scGPT model – self-contained PyTorch implementation ───────────────────────
# Adapted from https://github.com/bowang-lab/scGPT (MIT License).
# No scgpt library required – only torch + huggingface_hub.

import json
from pathlib import Path
from typing import Dict, Mapping, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from huggingface_hub import snapshot_download


# ── GeneVocab: replaces scgpt.tokenizer.GeneVocab ────────────────────────────
class GeneVocab:
    """Minimal gene-name ↔ token-ID vocabulary loaded from vocab.json."""

    def __init__(self, gene_to_id: dict):
        self._g2i = gene_to_id
        self._i2g = {v: k for k, v in gene_to_id.items()}

    @classmethod
    def from_file(cls, path):
        with open(path) as f:
            data = json.load(f)
        # Format A: {"gene_name": int_id, ...}
        if isinstance(data, dict) and all(isinstance(v, int) for v in data.values()):
            return cls(data)
        # Format B: torchtext-style {"itos": [...], ...}
        if isinstance(data, dict) and "itos" in data:
            return cls({t: i for i, t in enumerate(data["itos"])})
        # Format C: list of tokens
        if isinstance(data, list):
            return cls({t: i for i, t in enumerate(data)})
        raise ValueError(f"Cannot parse vocab format in {path}")

    def __contains__(self, gene):  return gene in self._g2i
    def __getitem__(self, gene):   return self._g2i[gene]
    def __len__(self):             return len(self._g2i)


# ── Sub-modules ───────────────────────────────────────────────────────────────
class GeneEncoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, padding_idx=None):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
        self.enc_norm  = nn.LayerNorm(embedding_dim)
    def forward(self, x):
        return self.enc_norm(self.embedding(x))


class ContinuousValueEncoder(nn.Module):
    """Projects scalar expression values → d_model embedding."""
    def __init__(self, d_model, dropout=0.1, max_value=512):
        super().__init__()
        self.max_value = max_value
        self.linear1   = nn.Linear(1, d_model)
        self.linear2   = nn.Linear(d_model, d_model)
        self.norm      = nn.LayerNorm(d_model)
        self.dropout   = nn.Dropout(dropout)
    def forward(self, x):                          # x: (B, L)
        x = x.unsqueeze(-1).clamp(-self.max_value, self.max_value)
        x = self.linear1(x).relu()
        return self.dropout(self.norm(self.linear2(x)))


class CategoryValueEncoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, padding_idx=None):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
        self.enc_norm  = nn.LayerNorm(embedding_dim)
    def forward(self, x):
        return self.enc_norm(self.embedding(x.long()))


class ExprDecoder(nn.Module):
    """MLP that predicts scalar expression from transformer hidden states."""
    def __init__(self, d_model, explicit_zero_prob=False, use_batch_labels=False):
        super().__init__()
        d_in = d_model * 2 if use_batch_labels else d_model
        self.fc = nn.Sequential(
            nn.Linear(d_in, d_model), nn.LeakyReLU(),
            nn.Linear(d_model, d_model), nn.LeakyReLU(),
            nn.Linear(d_model, 1),
        )
        self.explicit_zero_prob = explicit_zero_prob
        if explicit_zero_prob:
            self.zero_logit = nn.Sequential(
                nn.Linear(d_in, d_model), nn.LeakyReLU(),
                nn.Linear(d_model, d_model), nn.LeakyReLU(),
                nn.Linear(d_model, 1),
            )
    def forward(self, x):
        pred = self.fc(x).squeeze(-1)           # (B, L)
        out  = {"pred": pred}
        if self.explicit_zero_prob:
            out["zero_probs"] = torch.sigmoid(self.zero_logit(x).squeeze(-1))
        return out


class ClsDecoder(nn.Module):
    def __init__(self, d_model, n_cls, nlayers=3, activation=nn.ReLU):
        super().__init__()
        layers = []
        for _ in range(nlayers - 1):
            layers += [nn.Linear(d_model, d_model), activation(), nn.LayerNorm(d_model)]
        layers.append(nn.Linear(d_model, n_cls))
        self.fc = nn.Sequential(*layers)
    def forward(self, x):
        return self.fc(x)


class MVCDecoder(nn.Module):
    """Masked-value-consistency decoder (inner-product style)."""
    def __init__(self, d_model, arch_style="inner product", query_activation=nn.Sigmoid,
                 hidden_activation=nn.PReLU, explicit_zero_prob=False, use_batch_labels=False):
        super().__init__()
        self.arch_style        = arch_style
        self.explicit_zero_prob = explicit_zero_prob
        self.gene2query        = nn.Linear(d_model, d_model)
        self.query_activation  = query_activation()
        self.W                 = nn.Linear(d_model, d_model, bias=False)
        if explicit_zero_prob:
            self.fc_zero = nn.Linear(d_model, 1)
    def forward(self, cell_emb, gene_embs):
        query     = self.query_activation(self.gene2query(gene_embs))  # (B, L, d)
        cell_emb_ = self.W(cell_emb).unsqueeze(2)                      # (B, d, 1)
        pred      = torch.bmm(query, cell_emb_).squeeze(2)             # (B, L)
        out = {"mvc_output": pred}
        if self.explicit_zero_prob:
            out["mvc_zero_probs"] = torch.sigmoid(self.fc_zero(gene_embs).squeeze(-1))
        return out


class AdversarialDiscriminator(nn.Module):
    def __init__(self, d_model, n_cls, nlayers=3, activation=nn.LeakyReLU, reverse_grad=False):
        super().__init__()
        layers = []
        for _ in range(nlayers - 1):
            layers += [nn.Linear(d_model, d_model), activation(), nn.LayerNorm(d_model)]
        layers.append(nn.Linear(d_model, n_cls))
        self.fc = nn.Sequential(*layers)
    def forward(self, x):
        return self.fc(x)


# ── TransformerModel ──────────────────────────────────────────────────────────
class TransformerModel(nn.Module):
    """
    scGPT TransformerModel – self-contained PyTorch reimplementation.
    Weights are fully compatible with the original scgpt library checkpoint.
    """

    def __init__(
        self,
        ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int,
        nlayers_cls: int = 3, n_cls: int = 1, vocab=None,
        dropout: float = 0.5, pad_token: str = "<pad>", pad_value: int = 0,
        do_mvc: bool = False, do_dab: bool = False,
        use_batch_labels: bool = False, num_batch_labels: Optional[int] = None,
        domain_spec_batchnorm: bool = False,
        input_emb_style: str = "continuous", n_input_bins: Optional[int] = None,
        cell_emb_style: str = "cls", mvc_decoder_style: str = "inner product",
        ecs_threshold: float = 0.3, explicit_zero_prob: bool = False,
        use_fast_transformer: bool = False,   # ignored – always use standard attn
        fast_transformer_backend: str = "flash",
        pre_norm: bool = False,
    ):
        super().__init__()
        self.d_model             = d_model
        self.do_mvc              = do_mvc
        self.do_dab              = do_dab
        self.use_batch_labels    = use_batch_labels
        self.cell_emb_style      = cell_emb_style
        self.explicit_zero_prob  = explicit_zero_prob
        self.pad_value           = pad_value
        self.ecs_threshold       = ecs_threshold

        pad_idx = vocab[pad_token] if (vocab is not None and pad_token in vocab) else pad_value

        # Embeddings
        self.encoder       = GeneEncoder(ntoken, d_model, padding_idx=pad_idx)
        if input_emb_style == "continuous":
            self.value_encoder = ContinuousValueEncoder(d_model, dropout)
        elif input_emb_style == "category":
            self.value_encoder = CategoryValueEncoder(n_input_bins, d_model, padding_idx=0)
        else:
            self.value_encoder = nn.Identity()

        if use_batch_labels:
            self.batch_encoder = nn.Embedding(num_batch_labels, d_model)

        # Batch-norm (domain_spec_batchnorm not needed; stub keeps weight keys intact)
        self.bn = nn.BatchNorm1d(d_model)
        if domain_spec_batchnorm:
            self.dsbn = nn.BatchNorm1d(d_model)   # stub

        # Transformer
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_hid,
            dropout=dropout, batch_first=True, norm_first=pre_norm,
        )
        self.transformer_encoder = nn.TransformerEncoder(enc_layer, num_layers=nlayers)

        # Decoders
        self.decoder     = ExprDecoder(d_model, explicit_zero_prob, use_batch_labels)
        self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls)
        if do_mvc:
            self.mvc_decoder = MVCDecoder(d_model, mvc_decoder_style, explicit_zero_prob=explicit_zero_prob,
                                          use_batch_labels=use_batch_labels)
        if do_dab:
            self.grad_reverse_discriminator = AdversarialDiscriminator(d_model, n_cls=num_batch_labels,
                                                                        reverse_grad=True)
        self.sim = nn.CosineSimilarity(dim=-1)

        nn.init.uniform_(self.encoder.embedding.weight, -0.1, 0.1)

    def _encode(self, src, values, src_key_padding_mask, batch_labels=None):
        src_emb = self.encoder(src)           # (B, L, d)
        val_emb = self.value_encoder(values)  # (B, L, d)
        emb     = src_emb + val_emb

        if self.use_batch_labels and batch_labels is not None:
            emb = emb + self.batch_encoder(batch_labels).unsqueeze(1)

        B, L, d = emb.shape
        emb = self.bn(emb.view(B * L, d)).view(B, L, d)

        return self.transformer_encoder(emb, src_key_padding_mask=src_key_padding_mask)

    def _cell_emb(self, layer_out, values=None):
        if self.cell_emb_style == "cls":
            return layer_out[:, 0, :]
        non_pad = (values != self.pad_value).float().unsqueeze(-1) if values is not None \
                  else torch.ones(*layer_out.shape[:2], 1, device=layer_out.device)
        return (layer_out * non_pad).sum(1) / non_pad.sum(1).clamp(min=1)

    def forward(
        self,
        src: Tensor, values: Tensor, src_key_padding_mask: Tensor,
        batch_labels: Optional[Tensor] = None,
        CLS: bool = False, CCE: bool = False, MVC: bool = False,
        ECS: bool = False, do_sample: bool = False,
    ) -> Mapping[str, Tensor]:
        h      = self._encode(src, values, src_key_padding_mask, batch_labels)
        output = {}

        mlm = self.decoder(h)
        output["mlm_output"] = mlm["pred"]               # (B, L)  ← used for reconstruction
        if self.explicit_zero_prob and "zero_probs" in mlm:
            output["mlm_zero_probs"] = mlm["zero_probs"]

        cell_emb           = self._cell_emb(h, values)
        output["cell_emb"] = cell_emb

        if CLS:
            output["cls_output"] = self.cls_decoder(cell_emb)
        if MVC and self.do_mvc:
            output.update(self.mvc_decoder(cell_emb, h))
        if self.do_dab:
            output["dab_output"] = self.grad_reverse_discriminator(cell_emb)
        return output


print('TransformerModel + GeneVocab defined (no scgpt library used).')


TransformerModel + GeneVocab defined (no scgpt library used).


In [25]:

# ── Load model from local checkpoint ─────────────────────────────────────────
SCGPT_DIR = Path('../../models/scGPT')
assert SCGPT_DIR.exists(), f'Model folder not found: {SCGPT_DIR.resolve()}'
print('Using local checkpoint:', SCGPT_DIR.resolve())
print('Files:', sorted(f.name for f in SCGPT_DIR.iterdir()))

# ── Vocab ─────────────────────────────────────────────────────────────────────
vocab = GeneVocab.from_file(str(SCGPT_DIR / 'vocab.json'))
PAD_ID = vocab['<pad>'] if '<pad>' in vocab else 0
UNK_ID = vocab['<unk>'] if '<unk>' in vocab else PAD_ID
print(f'Vocab size: {len(vocab)}  PAD_ID={PAD_ID}  UNK_ID={UNK_ID}')

# ── Config (args.json) ────────────────────────────────────────────────────────
with open(SCGPT_DIR / 'args.json') as f:
    cfg = json.load(f)

# Key values read from args.json:
#   embsize=512, nheads=8, d_hid=512, nlayers=12, n_layers_cls=3
#   input_emb_style="continuous", pad_value=-2, mask_value=-1, MVC=True
PAD_VALUE  = cfg.get('pad_value',  -2)   # -2 in this checkpoint
MASK_VALUE = cfg.get('mask_value', -1)   # -1 (confirms our masking strategy)
print(f'pad_value={PAD_VALUE}  mask_value={MASK_VALUE}')

# ── Build architecture ────────────────────────────────────────────────────────
scgpt_model = TransformerModel(
    ntoken           = len(vocab),
    d_model          = cfg['embsize'],
    nhead            = cfg['nheads'],
    d_hid            = cfg['d_hid'],
    nlayers          = cfg['nlayers'],
    nlayers_cls      = cfg.get('n_layers_cls', cfg.get('nlayers_cls', 3)),
    n_cls            = 1,
    vocab            = vocab,
    dropout          = 0.0,                  # no dropout at inference
    pad_token        = cfg.get('pad_token', '<pad>'),
    pad_value        = PAD_VALUE,
    do_mvc           = cfg.get('MVC', False),
    do_dab           = cfg.get('do_dab', False),
    use_batch_labels = cfg.get('use_batch_labels', False),
    num_batch_labels = cfg.get('num_batch_labels', 1),
    domain_spec_batchnorm = cfg.get('dsbn', False),
    input_emb_style  = cfg.get('input_emb_style', 'continuous'),
    n_input_bins     = cfg.get('n_bins', 51),
    cell_emb_style   = 'cls' if not cfg.get('no_cls', True) else 'avg-non-pad',
    explicit_zero_prob = cfg.get('explicit_zero_prob', False),
    use_fast_transformer = False,           # disable flash-attn (CPU / Windows)
    pre_norm         = cfg.get('pre_norm', False),
)

# ── Load weights ──────────────────────────────────────────────────────────────
ckpt_path = SCGPT_DIR / 'best_model.pt'
state = torch.load(str(ckpt_path), map_location='cpu', weights_only=False)
if isinstance(state, dict) and 'model_state_dict' in state:
    state = state['model_state_dict']

missing, unexpected = scgpt_model.load_state_dict(state, strict=False)
print(f'Weights loaded — missing: {len(missing)}, unexpected: {len(unexpected)}')
if missing:
    print('  Missing :', missing[:5])

scgpt_model = scgpt_model.to(DEVICE).eval()
print(f'\nscGPT ready on {DEVICE}.')


Using local checkpoint: C:\Users\Daniel\Desktop\GitProjects\Improving-Cell2Sentence-with-Single-Cell-Foundation-Model-Embeddings\models\scGPT
Files: ['args.json', 'best_model.pt', 'vocab.json']
Vocab size: 60697  PAD_ID=60694  UNK_ID=60694
pad_value=-2  mask_value=-1
Weights loaded — missing: 38, unexpected: 25
  Missing : ['bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var', 'transformer_encoder.layers.0.self_attn.in_proj_weight']

scGPT ready on cpu.


In [26]:

from tqdm.auto import tqdm

def gene_names_to_ids(names):
    """Map gene name strings → scGPT vocabulary IDs (UNK_ID for out-of-vocab genes)."""
    return np.array([vocab[g] if g in vocab else UNK_ID for g in names], dtype=np.int64)

# X_reconstructed: copy of original; masked gene positions will be overwritten
X_reconstructed = X_orig.copy()
n_in_vocab  = 0
n_recovered = 0

with torch.no_grad():
    for i in tqdm(range(N_CELLS), desc='scGPT reconstruction'):

        cell_expr = X_scgpt_in[i]          # shape (n_genes,); -1 for masked genes

        # ── Select non-zero genes (includes -1 masked ones) ──────────────────
        nonzero_idx = np.where(cell_expr != 0)[0]
        if len(nonzero_idx) == 0:
            continue

        genes_sel  = gene_names[nonzero_idx]                    # gene name strings
        vals_sel   = cell_expr[nonzero_idx].astype(np.float32)  # expr values (incl. -1)
        gene_ids   = gene_names_to_ids(genes_sel)               # vocab IDs

        # ── Build tensors ─────────────────────────────────────────────────────
        src_t   = torch.from_numpy(gene_ids).unsqueeze(0).to(DEVICE)    # (1, L)
        vals_t  = torch.from_numpy(vals_sel).unsqueeze(0).to(DEVICE)    # (1, L)
        # padding mask: False = attend, True = ignore  (no padding here)
        pad_mask = torch.zeros(1, len(gene_ids), dtype=torch.bool, device=DEVICE)

        n_in_vocab += int((gene_ids != UNK_ID).sum())

        # ── scGPT forward pass ────────────────────────────────────────────────
        # mlm_output has shape (1, L) or (1, L, 1) – predicts expression at
        # every position, including the masked (-1) ones.
        out = scgpt_model(
            src                  = src_t,
            values               = vals_t,
            src_key_padding_mask = pad_mask,
            CLS                  = False,
        )

        mlm = out.get('mlm_output')
        if mlm is None:
            print(f'  Cell {i}: mlm_output not in scGPT output '
                  f'(keys: {list(out.keys())}) – skipping')
            continue

        if mlm.ndim == 3:
            mlm = mlm.squeeze(-1)           # (1, L, 1) → (1, L)
        mlm_np = mlm.squeeze(0).cpu().float().numpy()   # (L,)

        # ── Find masked positions and fill with scGPT predictions ─────────────
        masked_in_local = np.where(vals_sel < 0)[0]    # positions where value was -1
        if len(masked_in_local) == 0:
            continue

        masked_orig_idx = nonzero_idx[masked_in_local]  # back to global gene indices
        predicted_vals  = np.clip(mlm_np[masked_in_local], 0.0, None)

        X_reconstructed[i, masked_orig_idx] = predicted_vals
        n_recovered += len(masked_orig_idx)

print(f'\nReconstruction complete.')
print(f'  Avg genes in scGPT vocab / cell : {n_in_vocab / N_CELLS:.0f}')
print(f'  Masked genes recovered          : {n_recovered} / {total_masked}'
      f'  ({n_recovered / total_masked * 100:.1f} %)')


scGPT reconstruction: 100%|██████████| 24/24 [00:53<00:00,  2.22s/it]


Reconstruction complete.
  Avg genes in scGPT vocab / cell : 1669
  Masked genes recovered          : 16337 / 16337  (100.0 %)





## 5 · Build three AnnData objects

| Name | Expression matrix | Description |
|---|---|---|
| `adata_original` | `X_orig` | Clean, unmodified cells |
| `adata_masked` | `X_masked` | Corrupted: `MASK_FRACTION` genes zeroed |
| `adata_reconstructed` | `X_reconstructed` | scGPT-repaired: masked genes filled with predicted expression |

In [27]:
import scipy.sparse as sp

def make_adata(X_new, template_adata):
    """Clone the obs/var metadata from template and replace the expression matrix."""
    a = ad.AnnData(
        X=sp.csr_matrix(X_new),
        obs=template_adata.obs.copy(),
        var=template_adata.var.copy(),
    )
    return a

adata_original     = make_adata(X_orig,         adata_small)
adata_masked       = make_adata(X_masked,        adata_small)
adata_reconstructed = make_adata(X_reconstructed, adata_small)

print('AnnData objects created:')
for name, a in [('original', adata_original), ('masked', adata_masked), ('reconstructed', adata_reconstructed)]:
    mean_nonzero = (a.X.toarray() > 0).mean()
    print(f'  {name:15s}  non-zero fraction: {mean_nonzero:.3f}')

AnnData objects created:
  original         non-zero fraction: 0.047
  masked           non-zero fraction: 0.028
  reconstructed    non-zero fraction: 0.047


## 6 · C2S Cell-Type Prediction (3×)

> **Runtime note:** each C2S call takes ~9 min on CPU (22 s/cell × 24 cells).  
> Total ≈ 27 min. Grab a coffee ☕

In [28]:
label_cols = [c for c in ['cell_type', DONOR_COLUMN, 'tissue', 'sex', 'organism']
              if c in adata_small.obs.columns]

# Load C2S model once (we'll reuse it across all three predictions)
csmodel = cs.CSModel(
    model_name_or_path=C2S_MODEL,
    save_dir='./tmp_c2s_reconstruction_model',
    save_name='pretrained_c2s_inference',
)
print('C2S model loaded:', C2S_MODEL)


def run_c2s(adata_in, tag, csmodel_in):
    """Run C2S cell-type prediction and return a DataFrame with y_true / y_pred."""
    print(f'\n─── C2S: {tag} ───')
    arrow_ds, vocab = cs.CSData.adata_to_arrow(
        adata_in,
        random_state=SEED,
        sentence_delimiter=' ',
        label_col_names=label_cols,
    )
    csdata = cs.CSData.csdata_from_arrow(
        arrow_dataset=arrow_ds,
        vocabulary=vocab,
        save_dir=f'./tmp_c2s_{tag}',
        save_name='data',
        dataset_backend='arrow',
    )
    preds = predict_cell_types_of_data(
        csdata=csdata,
        csmodel=csmodel_in,
        n_genes=TOP_K_GENES,
        max_num_tokens=32,
    )
    return pd.DataFrame({
        'cell_id': adata_in.obs_names.astype(str),
        'y_true' : adata_in.obs['cell_type'].astype(str).values,
        'y_pred' : [str(p).strip() for p in preds],
        'version': tag,
    })

Using device: cpu


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


C2S model loaded: vandijklab/C2S-Pythia-410m-cell-type-prediction


In [29]:
df_original = run_c2s(adata_original, 'original', csmodel)


─── C2S: original ───


WARN: more variables (36503) than observations (24)... did you mean to transpose the object (e.g. adata.T)?
WARN: more variables (36503) than observations (24), did you mean to transpose the object (e.g. adata.T)?
100%|██████████| 24/24 [00:00<00:00, 344.89it/s]
Saving the dataset (1/1 shards): 100%|██████████| 24/24 [00:00<00:00, 552.12 examples/s]


Reloading model from path on disk: ./tmp_c2s_reconstruction_model\pretrained_c2s_inference
Predicting cell types for 24 cells using CSModel...


100%|██████████| 24/24 [08:47<00:00, 21.99s/it]


In [30]:
df_masked = run_c2s(adata_masked, 'masked', csmodel)

WARN: more variables (36503) than observations (24)... did you mean to transpose the object (e.g. adata.T)?
WARN: more variables (36503) than observations (24), did you mean to transpose the object (e.g. adata.T)?



─── C2S: masked ───


100%|██████████| 24/24 [00:00<00:00, 1081.31it/s]
Saving the dataset (1/1 shards): 100%|██████████| 24/24 [00:00<00:00, 1477.93 examples/s]


Reloading model from path on disk: ./tmp_c2s_reconstruction_model\pretrained_c2s_inference
Predicting cell types for 24 cells using CSModel...


100%|██████████| 24/24 [08:47<00:00, 21.98s/it]


In [31]:
df_reconstructed = run_c2s(adata_reconstructed, 'reconstructed', csmodel)

WARN: more variables (36503) than observations (24)... did you mean to transpose the object (e.g. adata.T)?
WARN: more variables (36503) than observations (24), did you mean to transpose the object (e.g. adata.T)?



─── C2S: reconstructed ───


100%|██████████| 24/24 [00:00<00:00, 797.41it/s]
Saving the dataset (1/1 shards): 100%|██████████| 24/24 [00:00<00:00, 2528.02 examples/s]


Reloading model from path on disk: ./tmp_c2s_reconstruction_model\pretrained_c2s_inference
Predicting cell types for 24 cells using CSModel...


100%|██████████| 24/24 [08:36<00:00, 21.51s/it]


## 7 · 3-Tier Accuracy Evaluation

In [32]:
def normalize(text):
    text = str(text).strip().rstrip('.').lower()
    return re.sub(r'\s+', ' ', text)

def classify(y_true, y_pred):
    """
    2 = Exactly correct  (normalized match)
    1 = Partly correct   (substring containment OR Jaccard word overlap ≥ 30%)
    0 = Not correct
    """
    t, p = normalize(y_true), normalize(y_pred)
    if t == p:
        return 2
    if p in t or t in p:
        return 1
    t_w = {w for w in re.findall(r'\b\w+\b', t) if len(w) > 2}
    p_w = {w for w in re.findall(r'\b\w+\b', p) if len(w) > 2}
    if t_w and p_w and len(t_w & p_w) / len(t_w | p_w) >= 0.30:
        return 1
    return 0

label_map = {2: 'Exactly correct', 1: 'Partly correct', 0: 'Not correct'}

def score_df(df):
    df = df.copy()
    df['score']   = df.apply(lambda r: classify(r['y_true'], r['y_pred']), axis=1)
    df['verdict'] = df['score'].map(label_map)
    return df

df_original     = score_df(df_original)
df_masked       = score_df(df_masked)
df_reconstructed = score_df(df_reconstructed)

print('Scoring complete.')

Scoring complete.


In [33]:
def summary(df, name):
    n = len(df)
    rows = []
    for s in [2, 1, 0]:
        cnt = (df['score'] == s).sum()
        rows.append({'Version': name, 'Tier': label_map[s],
                     'Count': cnt, 'Pct': cnt / n * 100})
    return rows

rows = []
for df, name in [(df_original, 'Original (clean)'),
                 (df_masked,   f'Masked ({int(MASK_FRACTION*100)}% zeroed)'),
                 (df_reconstructed, 'Reconstructed (scGPT)')]:
    rows.extend(summary(df, name))

summary_df = pd.DataFrame(rows)

print('=' * 65)
print('  3-TIER ACCURACY COMPARISON')
print('=' * 65)
for name, grp in summary_df.groupby('Version', sort=False):
    print(f'\n  {name}')
    for _, row in grp.iterrows():
        bar = '█' * int(row['Pct'] / 5)
        print(f'    {row["Tier"]:<20}: {row["Count"]:>2}/{N_CELLS}  ({row["Pct"]:>5.1f}%)  {bar}')
print('=' * 65)

# Combined metric: Exact + Partial
print('\n  Exact + Partial (score ≥ 1):')
for df, name in [(df_original, 'Original (clean)'),
                 (df_masked,   f'Masked ({int(MASK_FRACTION*100)}% zeroed)'),
                 (df_reconstructed, 'Reconstructed (scGPT)')]:
    ep = (df['score'] >= 1).sum()
    print(f'    {name:<30}: {ep}/{N_CELLS}  ({ep/N_CELLS*100:.1f}%)')

  3-TIER ACCURACY COMPARISON

  Original (clean)
    Exactly correct     :  4/24  ( 16.7%)  ███
    Partly correct      : 13/24  ( 54.2%)  ██████████
    Not correct         :  7/24  ( 29.2%)  █████

  Masked (40% zeroed)
    Exactly correct     :  8/24  ( 33.3%)  ██████
    Partly correct      : 13/24  ( 54.2%)  ██████████
    Not correct         :  3/24  ( 12.5%)  ██

  Reconstructed (scGPT)
    Exactly correct     :  1/24  (  4.2%)  
    Partly correct      :  8/24  ( 33.3%)  ██████
    Not correct         : 15/24  ( 62.5%)  ████████████

  Exact + Partial (score ≥ 1):
    Original (clean)              : 17/24  (70.8%)
    Masked (40% zeroed)           : 21/24  (87.5%)
    Reconstructed (scGPT)         : 9/24  (37.5%)


In [34]:
# ── Per-cell side-by-side view ────────────────────────────────────────────────
compare = df_original[['cell_id', 'y_true']].copy()
compare['pred_original']     = df_original['y_pred']
compare['score_original']    = df_original['score']
compare['pred_masked']       = df_masked['y_pred']
compare['score_masked']      = df_masked['score']
compare['pred_reconstructed']  = df_reconstructed['y_pred']
compare['score_reconstructed'] = df_reconstructed['score']
compare['scgpt_helped'] = compare['score_reconstructed'] > compare['score_masked']
compare['scgpt_hurt']   = compare['score_reconstructed'] < compare['score_masked']

print(f'scGPT improved prediction: {compare["scgpt_helped"].sum()} cells')
print(f'scGPT hurt prediction    : {compare["scgpt_hurt"].sum()} cells')
print(f'No change                : {(~compare["scgpt_helped"] & ~compare["scgpt_hurt"]).sum()} cells')

pd.set_option('display.max_colwidth', 55)
compare[['y_true',
         'pred_masked', 'score_masked',
         'pred_reconstructed', 'score_reconstructed',
         'scgpt_helped']].style.apply(
    lambda col: ['background-color: #c8e6c9' if v else
                 'background-color: #ffcdd2' if col.name == 'scgpt_hurt' else ''
                 for v in col], subset=['scgpt_helped']
)

scGPT improved prediction: 1 cells
scGPT hurt prediction    : 16 cells
No change                : 7 cells


Unnamed: 0,y_true,pred_masked,score_masked,pred_reconstructed,score_reconstructed,scgpt_helped
0,mast cell,mast cell.,2,"CD4-positive, alpha-beta thymocyte.",0,False
1,T follicular helper cell,"CD4-positive, alpha-beta T cell.",0,"CD4-positive, alpha-beta T cell.",0,False
2,plasma cell,plasma cell.,2,malignant cell.,1,False
3,plasma cell,plasma cell.,2,germ cell.,1,False
4,"effector memory CD4-positive, alpha-beta T cell","naive thymus-derived CD4-positive, alpha-beta T cell.",1,germ cell.,0,False
5,"naive thymus-derived CD4-positive, alpha-beta T cell","CD4-positive, alpha-beta memory T cell.",1,germ cell.,0,False
6,macrophage,macrophage.,2,erythroid lineage cell.,0,False
7,memory B cell,naive B cell.,1,germ cell.,1,False
8,memory B cell,B cell.,1,malignant cell.,1,False
9,erythroid lineage cell,erythroid progenitor cell.,1,germ cell.,0,False


In [35]:
# ── Save full results ─────────────────────────────────────────────────────────
out = Path('./scgpt_reconstruction_results.csv')
compare.to_csv(out, index=False)
print('Saved:', out.resolve())

Saved: C:\Users\Daniel\Desktop\GitProjects\Improving-Cell2Sentence-with-Single-Cell-Foundation-Model-Embeddings\notebooks\c2s_donor_new_approach\scgpt_reconstruction_results.csv
