# C2S Tutorial-Style 1: Finetuning On New Dataset

Dieses Notebook ist am Workflow von:
- `c2s_tutorial_3_finetuning_on_new_datasets.ipynb`

Ziel: `vandijklab/C2S-Pythia-410m-cell-type-prediction` auf deinem Datensatz feinjustieren.

In [None]:
# Optional (bei Bedarf):
# %pip install -q cell2sentence anndata scanpy transformers datasets pandas numpy scipy

In [None]:
# Python built-in libraries
import os
from pathlib import Path
from datetime import datetime
import random
from collections import Counter
import json

# Third-party libraries
import numpy as np
from tqdm import tqdm
try:
    import torch
except RuntimeError as e:
    _msg = str(e)
    if ('RpcBackendOptions' in _msg) or ('already has a docstring' in _msg):
        raise RuntimeError(
            'Torch is in a partially initialized state in this kernel. '
            'Restart the Jupyter kernel, then run the notebook from the first cell.'
        ) from e
    raise

from transformers import TrainingArguments
from transformers.utils import logging as hf_logging
hf_logging.enable_progress_bar()

# Single-cell libraries
import anndata
import scanpy as sc

# Cell2Sentence imports
import cell2sentence as cs

In [None]:
# ---------- Config ----------
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)

DATA_PATH = Path('../../data/dominguez_conde_immune_tissue_two_donors.h5ad')
BASE_MODEL = 'vandijklab/C2S-Pythia-410m-cell-type-prediction'
TRAINING_TASK = 'cell_type_prediction'
TOP_K_GENES = 128

RUN_NAME = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
RUN_DIR = Path('./runs') / RUN_NAME
CSDATA_DIR = RUN_DIR / 'csdata'
MODEL_DIR = RUN_DIR / 'model'
RUN_DIR.mkdir(parents=True, exist_ok=True)
CSDATA_DIR.mkdir(parents=True, exist_ok=True)
MODEL_DIR.mkdir(parents=True, exist_ok=True)

assert DATA_PATH.exists(), f'Not found: {DATA_PATH.resolve()}'
print('RUN_DIR:', RUN_DIR.resolve())

In [None]:
# ---------- Load dataset ----------
adata = anndata.read_h5ad(DATA_PATH)
print('shape:', adata.shape)
print('obs columns:', list(adata.obs.columns))

if 'cell_type' not in adata.obs.columns:
    raise ValueError("adata.obs must contain 'cell_type' for training labels.")

In [None]:
# ---------- Minimal preprocessing (tutorial-style baseline) ----------
adata = adata.copy()
# adata.var_names_make_unique()
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata, base=10)

# label_cols = [c for c in ['cell_type', 'tissue', 'batch_condition', 'organism', 'sex'] if c in adata.obs.columns]

adata_obs_cols_to_keep = ["cell_type", "tissue", "batch_condition", "organism", "sex"]
# print('label columns:', label_cols)

In [None]:
# ---------- AnnData -> Arrow + vocabulary ----------
arrow_ds, vocab = cs.CSData.adata_to_arrow(
    adata,
    random_state=SEED,
    sentence_delimiter=' ',
    label_col_names=adata_obs_cols_to_keep,
)
print('n samples in arrow:', len(arrow_ds))
print('vocab size:', len(vocab))

In [None]:
# # ---------- Train/val/test split ----------
# _, split_indices = cs.utils.train_test_split_arrow_ds(arrow_ds)

# split_path = RUN_DIR / 'split_indices.json'
# with split_path.open('w') as f:
#     json.dump(split_indices, f, indent=2)

# print('saved:', split_path.resolve())
# print({k: len(v) for k, v in split_indices.items()})

In [None]:
# ---------- Save CSData ----------
csdata = cs.CSData.csdata_from_arrow(
    arrow_dataset=arrow_ds,
    vocabulary=vocab,
    save_dir=str(CSDATA_DIR),
    save_name='dataset_arrow',
    dataset_backend='arrow',
)
print(csdata)

In [None]:
# ---------- Init model ----------
csmodel = cs.CSModel(
    model_name_or_path=BASE_MODEL,
    save_dir=str(MODEL_DIR),
    save_name='finetuned_cell_type_prediction',
)
print('model:', csmodel)

In [None]:
# ---------- TrainingArguments ----------
HF_OUTPUT_DIR = MODEL_DIR / 'hf_trainer_output'
HF_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

train_args = TrainingArguments(
    output_dir=str(HF_OUTPUT_DIR),
    bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
    fp16=torch.cuda.is_available() and (not torch.cuda.is_bf16_supported()),
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant': False},
    learning_rate=1e-5,
    num_train_epochs=5,
    warmup_steps=1, # warupsteps are used to stabilize training in early iterations, especially when using a low learning rate
    lr_scheduler_type='cosine',
    logging_steps=1,
    eval_strategy='steps',
    eval_steps=50,
    save_strategy='steps',
    save_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to='none'
)
train_args

In [None]:
# ---------- Fine-tune ----------
# Compatibility shim for newer transformers versions where Trainer uses
# `processing_class` instead of `tokenizer`.
import inspect
from transformers import Trainer

_trainer_sig = inspect.signature(Trainer.__init__).parameters
if ('processing_class' in _trainer_sig) and ('tokenizer' not in _trainer_sig):
    _orig_trainer_init = Trainer.__init__
    def _trainer_init_compat(self, *args, **kwargs):
        if 'tokenizer' in kwargs and 'processing_class' not in kwargs:
            kwargs['processing_class'] = kwargs.pop('tokenizer')
        return _orig_trainer_init(self, *args, **kwargs)
    Trainer.__init__ = _trainer_init_compat

# Prefer periodic text logs over tqdm progress bars
if hasattr(train_args, 'disable_tqdm'):
    train_args.disable_tqdm = True
if hasattr(train_args, 'logging_strategy'):
    train_args.logging_strategy = 'steps'
if hasattr(train_args, 'logging_steps'):
    train_args.logging_steps = max(1, int(getattr(train_args, 'logging_steps', 50) or 50))
if hasattr(train_args, 'report_to'):
    train_args.report_to = []
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print(f"Training heartbeat every {getattr(train_args, 'logging_steps', 'N/A')} steps.")

csmodel.fine_tune(
    csdata=csdata,
    task=TRAINING_TASK,
    train_args=train_args,
    loss_on_response_only=False,
    top_k_genes=TOP_K_GENES,
    max_eval_samples=10, 
    num_proc=1,
    # data_split_indices_dict={
    #     'train': split_indices['train'],
    #     'val': split_indices['val'],
    #     'test': split_indices.get('test', []),
    # },
)
print('Fine-tuning finished.')

In [None]:
# ---------- Save run metadata for Notebook 2 ----------
run_info = {
    'h5ad_path': str(DATA_PATH),
    'base_model': BASE_MODEL,
    'training_task': TRAINING_TASK,
    'top_k_genes': TOP_K_GENES,
    'run_dir': str(RUN_DIR.resolve()),
    'csdata_dir': str(CSDATA_DIR.resolve()),
    'model_dir': str(MODEL_DIR.resolve()),
    'finetuned_model_path': str((MODEL_DIR / 'finetuned_cell_type_prediction').resolve()),
    # 'split_indices_path': str(split_path.resolve()),
}

run_info_path = RUN_DIR / 'run_info.json'
with run_info_path.open('w') as f:
    json.dump(run_info, f, indent=2)

print('saved:', run_info_path.resolve())
run_info