## Geneformer Fine-Tuning for Cell Annotation Application

In [1]:
import os

GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"

In [2]:
# imports
import datetime
import pickle
import subprocess
from collections import Counter

import seaborn as sns

sns.set()
import sys

import numpy as np
from geneformer import DataCollatorForCellClassification
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification, Trainer
from transformers.training_args import TrainingArguments

from datasets import load_from_disk

2024-07-18 06:28:06.300073: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Prepare training and evaluation datasets

In [3]:
# load cell type or disease dataset (includes all tissues)


# select fine turning type (ctc: cell type classification or isp: in silico perturbation)
f_type = "isp"

# dataset_name(xxx.dataset path)
dataset_name = "/path/to/your/dataset/to/analysis/xxx.dataset/"

# load dataset
train_dataset = load_from_disk(dataset_name)

# check and remove column names
if f_type == "isp":
    try:
        print(np.unique(train_dataset["disease"]))
    except KeyError as e:
        print("KeyError: {}".format(e))
        print("changing to disease")
        train_dataset = train_dataset.rename_column(
            "column name in diseases infomation", "disease"
        )
        print("change finished")
        print(np.unique(train_dataset["disease"]))

elif f_type == "ctc":
    try:
        print(np.unique(train_dataset["cell_type"]))
    except KeyError as e:
        print("KeyError: {}".format(e))
        print("changing to cell_type")
        train_dataset = train_dataset.rename_column(
            "column name in cell types infomation", "cell_type"
        )
        print("change finished")
        print(np.unique(train_dataset["cell_type"]))

else:
    print("error: select fine turning type (ctc or isp)")
    sys.exit(1)


print(train_dataset)

['10X_KC_24_sum' '10X_KO_24_sum']
KeyError: "Column cell_type not in the dataset. Current columns in the dataset: ['input_ids', 'cell_types', 'organ_major', 'disease', 'individual', 'length']"
changing from cell_types to cell_type
change finished
['endothelial_cell']
Dataset({
    features: ['input_ids', 'cell_type', 'organ_major', 'disease', 'individual', 'length'],
    num_rows: 6831
})
use_norm: 
exchange: 
dataset level: 


In [4]:
# remove cache files in xxx.dataset

import glob
import os

import tqdm
from tqdm import tqdm_notebook as tqdm
from tqdm.notebook import tqdm

rmfiles = glob.glob(dataset_name + "/cache*")

if rmfiles == []:
    print("not exist cache files")
else:
    for tqdm_i2, rmfile in zip(tqdm(rmfiles, desc="remove files loop"), rmfiles):
        os.remove(rmfile)
    print("Finished removeing cache file in it !!")

not exist cache files


## Cell Type Classification

In [5]:
# cell type classification

dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []


for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)
    if organ in ["bone_marrow"]:
        continue
    elif organ == "immune":
        organ_ids = ["immune", "bone_marrow"]
        organ_list += ["immune"]
    else:
        organ_ids = [organ]
        organ_list += [organ]

    print(organ)

    # filter datasets for given organ
    def if_organ(example):
        return example["organ_major"] in organ_ids

    trainset_organ = train_dataset.filter(if_organ, num_proc=16)

    # per scDeepsort published method, drop cell types representing <0.5% of cells
    celltype_counter = Counter(trainset_organ["cell_type"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [
        k for k, v in celltype_counter.items() if v > (0.005 * total_cells)
    ]

    def if_not_rare_celltype(example):
        return example["cell_type"] in cells_to_keep

    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)

    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column(
        "cell_type", "label"
    )
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")

    # create dictionary of cell types : label ids
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names, [i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]

    # change labels to numerical ids
    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example

    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)

    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select(
        [i for i in range(0, round(len(labeled_trainset) * 0.8))]
    )
    labeled_eval_split = labeled_trainset.select(
        [i for i in range(round(len(labeled_trainset) * 0.8), len(labeled_trainset))]
    )

    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())

    def if_trained_label(example):
        return example["label"] in trained_labels

    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]

brain


Filter (num_proc=16):   0%|          | 0/7189 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/7189 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/7147 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/1429 [00:00<?, ? examples/s]

## Disease Type Classification

In [5]:
# disease classification

dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []


for organ in Counter(train_dataset["organ_major"]).keys():
    # collect list of tissues for fine-tuning (immune and bone marrow are included together)

    organ_ids = [organ]
    organ_list += [organ]

    print(organ)

    # filter datasets for given organ
    def if_organ(example):
        return example["organ_major"] in organ_ids

    trainset_organ = train_dataset.filter(if_organ, num_proc=16)

    # per scDeepsort published method, drop cell types representing <0.5% of cells
    celltype_counter = Counter(trainset_organ["disease"])
    total_cells = sum(celltype_counter.values())
    cells_to_keep = [
        k for k, v in celltype_counter.items() if v > (0.005 * total_cells)
    ]

    def if_not_rare_celltype(example):
        return example["disease"] in cells_to_keep

    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)

    # shuffle datasets and rename columns
    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
    trainset_organ_shuffled = trainset_organ_shuffled.rename_column("disease", "label")
    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")

    # create dictionary of cell types : label ids
    target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
    target_name_id_dict = dict(zip(target_names, [i for i in range(len(target_names))]))
    target_dict_list += [target_name_id_dict]

    # change labels to numerical ids
    def classes_to_ids(example):
        example["label"] = target_name_id_dict[example["label"]]
        return example

    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)

    # create 80/20 train/eval splits
    labeled_train_split = labeled_trainset.select(
        [i for i in range(0, round(len(labeled_trainset) * 0.8))]
    )
    labeled_eval_split = labeled_trainset.select(
        [i for i in range(round(len(labeled_trainset) * 0.8), len(labeled_trainset))]
    )

    # filter dataset for cell types in corresponding training set
    trained_labels = list(Counter(labeled_train_split["label"]).keys())

    def if_trained_label(example):
        return example["label"] in trained_labels

    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

    dataset_list += [labeled_train_split]
    evalset_list += [labeled_eval_split_subset]

brain


Filter (num_proc=16):   0%|          | 0/6831 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/6831 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/6831 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/1366 [00:00<?, ? examples/s]

In [6]:
trainset_dict = dict(zip(organ_list, dataset_list))
traintargetdict_dict = dict(zip(organ_list, target_dict_list))

evalset_dict = dict(zip(organ_list, evalset_list))


print(trainset_dict)
print(traintargetdict_dict)

print(evalset_dict)

{'brain': Dataset({
    features: ['input_ids', 'cell_type', 'label', 'individual', 'length'],
    num_rows: 5465
})}
{'brain': {'10X_KC_24_sum': 0, '10X_KO_24_sum': 1}}
{'brain': Dataset({
    features: ['input_ids', 'cell_type', 'label', 'individual', 'length'],
    num_rows: 1366
})}


## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [7]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average="macro")
    return {"accuracy": acc, "macro_f1": macro_f1}

### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the "hyperparam_optimiz_for_disease_classifier" script for an example of how to tune hyperparameters for downstream applications.

In [8]:
# set model parameters
# max input size
max_input_size = 2**11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 12
# learning schedule
lr_schedule_fn = "linear"  # "polynomial", "linear", "cosine"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 20
# optimizer
optimizer = "adamW"

'\n# set model parameters\n# max input size\nmax_input_size = 2 ** 11  # 2048\n\n# set training hyperparameters\n# max learning rate\nmax_lr = 9.7e-5\n# how many pretrained layers to freeze\nfreeze_layers = 0\n# number gpus\nnum_gpus = 1\n# number cpu cores\nnum_proc = 16\n# batch size for training and eval\ngeneformer_batch_size = 12\n# learning schedule\nlr_schedule_fn = "polynomial" #choice the "polynomial" or "linear" or "cosine"\n# warmup steps\nwarmup_steps = 1_825\n# number of epochs\nepochs = 10\n# optimizer\noptimizer = "adamW"\n'

In [9]:
for organ in organ_list:
    print(organ)
    organ_trainset = trainset_dict[organ]
    organ_evalset = evalset_dict[organ]
    organ_label_dict = traintargetdict_dict[organ]
    print(organ_label_dict)

    # set logging steps
    logging_steps = round(len(organ_trainset) / geneformer_batch_size / 10)

    pretrain_model = "your mouse-Geneformer name"

    # reload pretrained model
    model = BertForSequenceClassification.from_pretrained(
        "/path/to/your/mouse-Geneformer/model/{}/models/".format(pretrain_model),
        num_labels=len(organ_label_dict.keys()),
        output_attentions=False,
        output_hidden_states=False,
    ).to("cuda")

    # define output directory path
    current_date = datetime.datetime.now()
    datestamp = (
        f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
    )
    if f_type == "isp":
        output_dir = f"/path/to/your/fine-tuning/model/to/save/in_silico_pretraining/{datestamp}_mouse-geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}_ISP-{organ}/"
    elif f_type == "ctc":
        output_dir = f"/path/to/your/fine-tuning/model/to/save/cell_type_classification/{datestamp}_mouse-geneformer_DiseaseClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}_CTC-{organ}/"
    else:
        print("error: select fine turining type (ctc or isp)")
        sys.exit(1)
    # ensure not overwriting previously saved model
    saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
    if os.path.isfile(saved_model_test) == True:
        raise Exception("Model already saved to this directory.")

    # make output directory
    subprocess.call(f"mkdir {output_dir}", shell=True)

    # set training arguments
    training_args = {
        "learning_rate": max_lr,
        "do_train": True,
        "do_eval": True,
        "evaluation_strategy": "epoch",
        "save_strategy": "epoch",
        "logging_steps": logging_steps,
        "group_by_length": True,
        "length_column_name": "length",
        "disable_tqdm": False,
        "lr_scheduler_type": lr_schedule_fn,
        "warmup_steps": warmup_steps,
        "weight_decay": 0.001,
        "per_device_train_batch_size": geneformer_batch_size,
        "per_device_eval_batch_size": geneformer_batch_size,
        "num_train_epochs": epochs,
        "load_best_model_at_end": True,
        "output_dir": output_dir,
        # "max_position_embeddings": 2**11,
    }

    training_args_init = TrainingArguments(**training_args)

    # create the trainer
    trainer = Trainer(
        model=model,
        args=training_args_init,
        data_collator=DataCollatorForCellClassification(),
        train_dataset=organ_trainset,
        eval_dataset=organ_evalset,
        compute_metrics=compute_metrics,
    )
    # train the cell type classifier
    trainer.train()
    predictions = trainer.predict(organ_evalset)
    with open(f"{output_dir}predictions.pickle", "wb") as fp:
        pickle.dump(predictions, fp)
    trainer.save_metrics("eval", predictions.metrics)
    trainer.save_model(output_dir)

brain
{'10X_KC_24_sum': 0, '10X_KO_24_sum': 1}
mouse-Geneformer


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /mnt/keita/data/prog/jupyter/Geneformer/models/240628_155329_mouse-geneformer_20M_DV-n1_PTTMLM_L6_emb256_SL2048_E10_B12_LR0.001_LScosine_WU10000_ACTsilu_Oadamw_DS8/models/ and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.6807,0.68503,0.558565,0.551703
2,0.6667,0.655963,0.611274,0.605085
3,0.6422,0.623628,0.65959,0.631523
4,0.5011,0.552069,0.725476,0.722686
5,0.4005,0.523209,0.782577,0.782551
6,0.3325,0.559892,0.811859,0.81135
7,0.1627,0.557681,0.830161,0.829353
8,0.2028,0.599117,0.852123,0.851193
9,0.1582,0.627147,0.849195,0.848751
10,0.058,0.656693,0.860176,0.859606


  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch = {k

In [10]:
# remove cache files in xxx.dataset

rmfiles = glob.glob(dataset_name + "/cache*")
# print(rmfiles)
for tqdm_i2, rmfile in zip(tqdm(rmfiles, desc="remove files loop"), rmfiles):
    os.remove(rmfile)
print("Finished removing cache files in this dataset!!")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for tqdm_i2, rmfile in zip(tqdm(rmfiles, desc='remove files loop'), rmfiles) :


remove files loop:   0%|          | 0/65 [00:00<?, ?it/s]

cacheファイルの削除完了!!
