# C2S Tutorial-Style 1: Finetuning On New Dataset

Dieses Notebook ist am Workflow von:
- `c2s_tutorial_3_finetuning_on_new_datasets.ipynb`

Ziel: `vandijklab/C2S-Pythia-410m-cell-type-prediction` auf deinem Datensatz feinjustieren.

In [1]:
# Optional (bei Bedarf):
# %pip install -q cell2sentence anndata scanpy transformers datasets pandas numpy scipy

In [2]:
# Python built-in libraries
import os
from pathlib import Path
from datetime import datetime
import random
from collections import Counter
import json

# Third-party libraries
import numpy as np
from tqdm import tqdm
try:
    import torch
except RuntimeError as e:
    _msg = str(e)
    if ('RpcBackendOptions' in _msg) or ('already has a docstring' in _msg):
        raise RuntimeError(
            'Torch is in a partially initialized state in this kernel. '
            'Restart the Jupyter kernel, then run the notebook from the first cell.'
        ) from e
    raise

from transformers import TrainingArguments
from transformers.utils import logging as hf_logging
hf_logging.enable_progress_bar()

# Single-cell libraries
import anndata
import scanpy as sc

# Cell2Sentence imports
import cell2sentence as cs

In [3]:
# ---------- Config ----------
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)

DATA_PATH = Path('../../data/dominguez_conde_immune_tissue_two_donors.h5ad')
BASE_MODEL = 'vandijklab/C2S-Pythia-410m-cell-type-prediction'
TRAINING_TASK = 'cell_type_prediction'
TOP_K_GENES = 200

RUN_NAME = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
RUN_DIR = Path('./runs') / RUN_NAME
CSDATA_DIR = RUN_DIR / 'csdata'
MODEL_DIR = RUN_DIR / 'model'
RUN_DIR.mkdir(parents=True, exist_ok=True)
CSDATA_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR.mkdir(parents=True, exist_ok=True)

assert DATA_PATH.exists(), f'Not found: {DATA_PATH.resolve()}'
print('RUN_DIR:', RUN_DIR.resolve())

RUN_DIR: /root/AI-Biomedicine/Improving-Cell2Sentence-with-Single-Cell-Foundation-Model-Embeddings/notebooks/c2s_tutorial_style/runs/2026-02-26_16-39-38


In [4]:
# ---------- Load dataset ----------
adata = anndata.read_h5ad(DATA_PATH)
print('shape:', adata.shape)
print('obs columns:', list(adata.obs.columns))

if 'cell_type' not in adata.obs.columns:
    raise ValueError("adata.obs must contain 'cell_type' for training labels.")

shape: (29773, 36503)
obs columns: ['cell_type', 'tissue', 'batch_condition', 'organism', 'assay', 'sex']


In [5]:
# ---------- Minimal preprocessing (tutorial-style baseline) ----------
adata = adata.copy()
# adata.var_names_make_unique()
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata, base=10)

# label_cols = [c for c in ['cell_type', 'tissue', 'batch_condition', 'organism', 'sex'] if c in adata.obs.columns]

adata_obs_cols_to_keep = ["cell_type", "tissue", "batch_condition", "organism", "sex"]
# print('label columns:', label_cols)

In [6]:
# ---------- AnnData -> Arrow + vocabulary ----------
arrow_ds, vocab = cs.CSData.adata_to_arrow(
    adata,
    random_state=SEED,
    sentence_delimiter=' ',
    label_col_names=adata_obs_cols_to_keep,
)
print('n samples in arrow:', len(arrow_ds))
print('vocab size:', len(vocab))

100%|██████████| 29773/29773 [00:10<00:00, 2776.66it/s]


n samples in arrow: 29773
vocab size: 23944


In [7]:
# # ---------- Train/val/test split ----------
# _, split_indices = cs.utils.train_test_split_arrow_ds(arrow_ds)

# split_path = RUN_DIR / 'split_indices.json'
# with split_path.open('w') as f:
#     json.dump(split_indices, f, indent=2)

# print('saved:', split_path.resolve())
# print({k: len(v) for k, v in split_indices.items()})

In [8]:
# ---------- Save CSData ----------
csdata = cs.CSData.csdata_from_arrow(
    arrow_dataset=arrow_ds,
    vocabulary=vocab,
    save_dir=str(CSDATA_DIR),
    save_name='dataset_arrow',
    dataset_backend='arrow',
)
print(csdata)

Saving the dataset (0/1 shards):   0%|          | 0/29773 [00:00<?, ? examples/s]

CSData Object; Path=runs/2026-02-26_16-39-38/csdata/dataset_arrow, Format=arrow


In [9]:
# ---------- Init model ----------
csmodel = cs.CSModel(
    model_name_or_path=BASE_MODEL,
    save_dir=str(MODEL_DIR),
    save_name='finetuned_cell_type_prediction',
)
print('model:', csmodel)

Using device: cpu


config.json:   0%|          | 0.00/899 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.62G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/133 [00:00<?, ?B/s]

model: CSModel Object; Path=runs/2026-02-26_16-39-38/model/finetuned_cell_type_prediction


In [None]:
# ---------- TrainingArguments ----------
HF_OUTPUT_DIR = MODEL_DIR / 'hf_trainer_output'
HF_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

train_args = TrainingArguments(
    output_dir=str(HF_OUTPUT_DIR),
    bf16=True,
    fp16=False,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    num_train_epochs=5,
    warmup_steps=1,
    lr_scheduler_type='cosine',
    logging_steps=1, # loggevery 50 steps (not epochs) to get more frequent feedback on training progress, especially since the dataset is small and we want to see multiple evals within an epoch
    eval_strategy='steps',
    eval_steps=50,
    save_strategy='steps',
    save_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to='none',
    use_cpu=True
)
train_args

TrainingArguments(
_n_gpu=0,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=False,
batch_eval_metrics=False,
bf16=True,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=False,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=50,
eval_strategy=IntervalStrategy.STEPS,
eval_us

: 

In [None]:
# ---------- Fine-tune ----------
# Compatibility shim for newer transformers versions where Trainer uses
# `processing_class` instead of `tokenizer`.
import inspect
from transformers import Trainer

_trainer_sig = inspect.signature(Trainer.__init__).parameters
if ('processing_class' in _trainer_sig) and ('tokenizer' not in _trainer_sig):
    _orig_trainer_init = Trainer.__init__
    def _trainer_init_compat(self, *args, **kwargs):
        if 'tokenizer' in kwargs and 'processing_class' not in kwargs:
            kwargs['processing_class'] = kwargs.pop('tokenizer')
        return _orig_trainer_init(self, *args, **kwargs)
    Trainer.__init__ = _trainer_init_compat

# Prefer periodic text logs over tqdm progress bars
if hasattr(train_args, 'disable_tqdm'):
    train_args.disable_tqdm = True
if hasattr(train_args, 'logging_strategy'):
    train_args.logging_strategy = 'steps'
if hasattr(train_args, 'logging_steps'):
    train_args.logging_steps = max(1, int(getattr(train_args, 'logging_steps', 50) or 50))
if hasattr(train_args, 'report_to'):
    train_args.report_to = []
print(f"Training heartbeat every {getattr(train_args, 'logging_steps', 'N/A')} steps.")

csmodel.fine_tune(
    csdata=csdata,
    task=TRAINING_TASK,
    train_args=train_args,
    loss_on_response_only=False,
    top_k_genes=TOP_K_GENES,
    max_eval_samples=10, 
    num_proc=1,
    # data_split_indices_dict={
    #     'train': split_indices['train'],
    #     'val': split_indices['val'],
    #     'test': split_indices.get('test', []),
    # },
)
print('Fine-tuning finished.')

Training heartbeat every 1 steps.
Reloading model from path on disk: runs/2026-02-26_16-39-38/model/finetuned_cell_type_prediction


Map:   0%|          | 0/29773 [00:00<?, ? examples/s]

Starting training. Output directory: runs/2026-02-26_16-39-38/model/hf_trainer_output
Selecting 10 samples of eval dataset to shorten validation loop.




In [None]:
# ---------- Save run metadata for Notebook 2 ----------
run_info = {
    'h5ad_path': str(DATA_PATH),
    'base_model': BASE_MODEL,
    'training_task': TRAINING_TASK,
    'top_k_genes': TOP_K_GENES,
    'run_dir': str(RUN_DIR.resolve()),
    'csdata_dir': str(CSDATA_DIR.resolve()),
    'model_dir': str(MODEL_DIR.resolve()),
    'finetuned_model_path': str((MODEL_DIR / 'finetuned_cell_type_prediction').resolve()),
    # 'split_indices_path': str(split_path.resolve()),
}

run_info_path = RUN_DIR / 'run_info.json'
with run_info_path.open('w') as f:
    json.dump(run_info, f, indent=2)

print('saved:', run_info_path.resolve())
run_info