# C2S Tutorial-Style 2: Cell Type Prediction

Dieses Notebook ist am Workflow von:
- `c2s_tutorial_4_cell_type_prediction.ipynb`

Es nutzt das finetunte Modell aus Notebook 1 und evaluiert auf dem Test-Split.

In [None]:
# Optional (bei Bedarf):
# %pip install -q cell2sentence anndata scanpy datasets transformers pandas numpy scipy

In [None]:
from pathlib import Path
import json
import string

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc

import cell2sentence as cs
from cell2sentence.tasks import predict_cell_types_of_data

In [None]:
# ---------- Config ----------
# Pfad zu run_info.json aus Notebook 1 eintragen
RUN_INFO_PATH = Path('./runs/REPLACE_WITH_YOUR_RUN/run_info.json')

assert RUN_INFO_PATH.exists(), (
    'Bitte zuerst Notebook 1 ausfÃ¼hren und RUN_INFO_PATH auf dessen run_info.json setzen.'
)

with RUN_INFO_PATH.open() as f:
    run_info = json.load(f)

H5AD_PATH = Path(run_info['h5ad_path'])
FINETUNED_MODEL_PATH = Path(run_info['finetuned_model_path'])
SPLIT_PATH = Path(run_info['split_indices_path'])
TOP_K_GENES = int(run_info.get('top_k_genes', 200))

assert H5AD_PATH.exists(), f'Not found: {H5AD_PATH}'
assert FINETUNED_MODEL_PATH.exists(), f'Not found: {FINETUNED_MODEL_PATH}'
assert SPLIT_PATH.exists(), f'Not found: {SPLIT_PATH}'

PRED_DIR = RUN_INFO_PATH.parent / 'predictions'
PRED_DIR.mkdir(parents=True, exist_ok=True)

print('using finetuned model:', FINETUNED_MODEL_PATH)

In [None]:
# ---------- Load split indices ----------
with SPLIT_PATH.open() as f:
    split_indices = json.load(f)

print({k: len(v) for k, v in split_indices.items()})

In [None]:
# ---------- Build CSData from the same dataset ----------
adata = ad.read_h5ad(H5AD_PATH)
adata.var_names_make_unique()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

label_cols = [c for c in ['cell_type', 'tissue', 'batch_condition', 'organism', 'sex'] if c in adata.obs.columns]
arrow_ds, vocab = cs.CSData.adata_to_arrow(
    adata,
    random_state=42,
    sentence_delimiter=' ',
    label_col_names=label_cols,
)

test_ids = split_indices.get('test', [])
if len(test_ids) == 0:
    raise ValueError('No test split found in split_indices.json')

test_arrow = arrow_ds.select(test_ids)

csdata_test = cs.CSData.csdata_from_arrow(
    arrow_dataset=test_arrow,
    vocabulary=vocab,
    save_dir=str(PRED_DIR),
    save_name='test_subset_arrow',
    dataset_backend='arrow',
)

print('test samples:', len(test_arrow))

In [None]:
# ---------- Load finetuned model ----------
csmodel = cs.CSModel(
    model_name_or_path=str(FINETUNED_MODEL_PATH),
    save_dir=str(PRED_DIR / 'model_wrapper'),
    save_name='prediction_wrapper',
)

In [None]:
# ---------- Cell type prediction ----------
preds = predict_cell_types_of_data(
    csdata=csdata_test,
    csmodel=csmodel,
    n_genes=TOP_K_GENES,
    max_num_tokens=32,
)

y_true = [test_arrow[i]['cell_type'] for i in range(len(test_arrow))]
y_pred = [str(p).strip() for p in preds]

df = pd.DataFrame({
    'y_true': y_true,
    'y_pred': y_pred,
})
df['exact_match'] = (df['y_true'] == df['y_pred']).astype(int)
df.head(20)

In [None]:
# ---------- Normalized exact match (more robust) ----------
def normalize_label(x: str) -> str:
    x = x.lower().strip()
    x = x.translate(str.maketrans('', '', string.punctuation))
    x = ' '.join(x.split())
    return x

df['y_true_norm'] = df['y_true'].map(normalize_label)
df['y_pred_norm'] = df['y_pred'].map(normalize_label)
df['exact_match_norm'] = (df['y_true_norm'] == df['y_pred_norm']).astype(int)

print('Exact match       :', round(df['exact_match'].mean(), 4))
print('Exact match (norm):', round(df['exact_match_norm'].mean(), 4))

In [None]:
# ---------- Save predictions ----------
pred_csv = PRED_DIR / 'cell_type_predictions_test.csv'
df.to_csv(pred_csv, index=False)
print('saved:', pred_csv.resolve())