## Geneformer Fine-Tuning for Cell Annotation Application

In [11]:
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
print(os.environ['CUDA_VISIBLE_DEVICES'])
os.environ["NCCL_DEBUG"] = "INFO"
os.chdir("/Users/clark04/Geneformer")

0


In [12]:
# imports
from collections import Counter
import torch
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, f1_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
import matplotlib.pyplot as plt

from geneformer import DataCollatorForCellClassification

## Prepare training and evaluation datasets

In [13]:
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("/Users/clark04/toby/Mesenchyme1_PSC1.dataset")

In [14]:
dataset_list = []
evalset_list = []
cell_list = []
target_dict_list = []


celltype_counter = Counter(train_dataset["cell_type"])
total_cells = sum(celltype_counter.values())
print(total_cells)

# shuffle datasets and rename columns
trainset_cell_shuffled = train_dataset.shuffle(seed=42)
trainset_cell_shuffled = trainset_cell_shuffled.rename_column("cell_type","label")

# create dictionary of cell types : label ids
target_names = list(Counter(trainset_cell_shuffled["label"]).keys())
print(target_names)
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
print(target_name_id_dict)
# change labels to numerical ids
def classes_to_ids(example):
    example["label"] = target_name_id_dict[example["label"]]
    return example
labeled_trainset = trainset_cell_shuffled.map(classes_to_ids, num_proc=16)

# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])

# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example):
    return example["label"] in trained_labels
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)

trainset = labeled_train_split
evalset = labeled_eval_split_subset

21341
['PSC1', 'Mesenchyme1']
{'PSC1': 0, 'Mesenchyme1': 1}


## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

In [15]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    # calculate accuracy and macro f1 using sklearn's function
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    return {
      'accuracy': acc,
      'macro_f1': macro_f1
    }

### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the "hyperparam_optimiz_for_disease_classifier" script for an example of how to tune hyperparameters for downstream applications.

In [18]:
# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 2
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 32
# batch size for training and eval
geneformer_batch_size = 4
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 4
# optimizer
optimizer = "adamw"

In [19]:
import gc
import torch
torch.cuda.empty_cache()
gc.collect()

0

In [None]:
cell_trainset = trainset
cell_evalset = evalset
cell_label_dict = target_name_id_dict

# set logging steps
logging_steps = round(len(cell_trainset)/geneformer_batch_size/10)

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("/mnt/scratchc/ghlab/toby/Geneformer/geneformer-12L-30M", 
                                                  num_labels=2,
                                                  output_attentions = False,
                                                  output_hidden_states = False).to("cuda")

# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/mnt/scratchc/ghlab/toby/models/{datestamp}_geneformer_CellClassifier_Mesenchyme_PSC_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "logging_steps": logging_steps,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=cell_trainset,
    eval_dataset=cell_evalset,
    compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(cell_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
    pickle.dump(predictions, fp)
with torch.no_grad():
    trainer.save_metrics("eval",predictions.metrics)
    trainer.save_model(output_dir)
trainer = None
model = None
torch.cuda.empty_cache()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /mnt/scratchc/ghlab/toby/Geneformer/geneformer-12L-30M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}


Epoch,Training Loss,Validation Loss,Accuracy,Macro F1
1,0.0452,0.023205,0.994377,0.994377
2,0.0693,0.021424,0.99508,0.995079


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [None]:
def plot_confusion_matrix(classes_list, conf_mat, title, output_dir):
    display_labels = []
    i = 0
    for label in classes_list:
        display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
        i = i + 1
    display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"), 
                                     display_labels=display_labels)
    display.plot(cmap="Blues",values_format=".3g")
    plt.title(title)
    plt.savefig(f'{output_dir}/conf_mat.png')

In [None]:
predicted_labels = predictions.predictions.argmax(axis=1)
true_labels = predictions.label_ids
# Create a confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)

classes_list = ['PSC1','Mesenchyme1']
plot_confusion_matrix(classes_list, conf_matrix, 'Confusion Matrix for Cell Classifier of PSC1 vs Mesenchyme1',output_dir)