# GeneCompass Fine-Tuning for Cell-type Annotation

### For human-specific and mouse-specific tasks, we compared pre-trained GeneCompass with GeneCompass without pre-training and Geneformer on human multiple sclerosis (hMS), lung (hLung) and liver (hLiver) datasets, and mouse brain (mBrain), lung (mLung) and pancreas (mPancreas) datasets.

### Fine-tune the model for cell-type annotation

In [1]:
# imports
import os
# Choose the GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import sys
sys.path.append("../../")
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns
sns.set()
from datasets import load_from_disk, concatenate_datasets
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import Trainer
from genecompass import BertForSequenceClassification
from transformers.training_args import TrainingArguments
from genecompass import DataCollatorForCellClassification
from genecompass.utils import load_prior_embedding
import argparse
import numpy as np
import random
import torch

[2024-07-10 20:51:49,280] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
# token dict path
token_dictionary_path='../../prior_knowledge/human_mouse_tokens.pickle'

# load knowledges
knowledges = dict()
out = load_prior_embedding(token_dictionary_or_path=token_dictionary_path)
knowledges['promoter'] = out[0]
knowledges['co_exp'] = out[1]
knowledges['gene_family'] = out[2]
knowledges['peca_grn'] = out[3]
knowledges['homologous_gene_human2mouse'] = out[4]

In [5]:
# data path
train_path = '../../data/cell_type_annotation/hMS/train'
test_path =  '../../data/cell_type_annotation/hMS/test'

# load datasets
train_set = load_from_disk(train_path)
test_set = load_from_disk(test_path)

# rename columns
train_set = train_set.rename_column("celltype", "label")
test_set = test_set.rename_column("celltype", "label")

# create dictionary of cell types : label ids
target_names = set(list(Counter(train_set["label"]).keys()) + list(Counter(test_set["label"]).keys()))
target_name_id_dict = dict(zip(target_names, [i for i in range(len(target_names))]))
print(target_name_id_dict)

# change labels to numerical ids
def classes_to_ids(example):
    example["label"] = target_name_id_dict[example["label"]]
    return example
train_set = train_set.map(classes_to_ids, num_proc=16)
test_set = test_set.map(classes_to_ids, num_proc=16)

# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(train_set['label']).keys())
def if_trained_label(example):
    return example['label'] in trained_labels
test_set = test_set.filter(if_trained_label, num_proc=16)

{'SV2C-expressing interneuron': 0, 'mixed glial cell?': 1, 'VIP-expressing interneuron': 2, 'PVALB-expressing interneuron': 3, 'microglial cell': 4, 'astrocyte': 5, 'SST-expressing interneuron': 6, 'oligodendrocyte precursor cell': 7, 'mixed excitatory neuron': 8, 'pyramidal neuron?': 9, 'cortical layer 2-3 excitatory neuron B': 10, 'phagocyte': 11, 'oligodendrocyte C': 12, 'endothelial cell': 13, 'cortical layer 2-3 excitatory neuron A': 14, 'cortical layer 4 excitatory neuron': 15, 'cortical layer 5-6 excitatory neuron': 16, 'oligodendrocyte A': 17}


Map (num_proc=16):   0%|          | 0/7697 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/13244 [00:00<?, ? examples/s]

Filter (num_proc=16):   0%|          | 0/13244 [00:00<?, ? examples/s]

In [6]:
# compute metrics for cell-type annotation
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    # calculate accuracy and macro f1 using sklearn's function
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="macro")
    recall = recall_score(labels, preds, average="macro")
    macro_f1 = f1_score(labels, preds, average="macro")

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'macro_f1': macro_f1
    }

In [7]:
# pretrain checkpoint path
checkpoint_path='../../pretrained_models/GeneCompass_Base'

# set freeze layer
freeze_layers = 12

# reload pretrained model
model = BertForSequenceClassification.from_pretrained(
    checkpoint_path,
    num_labels=len(target_name_id_dict.keys()),
    output_attentions=False,
    output_hidden_states=False,
    knowledges=knowledges,
)

if freeze_layers > 0:
    modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
    for module in modules_to_freeze:
        for param in module.parameters():
            param.requires_grad = False

model = model.to("cuda")
print(model)

Some weights of the model checkpoint at ../pretrained_models/GeneCompass_Base were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls4value.predictions.decoder.bias', 'cls4value.predictions.transform.LayerNorm.weight', 'cls4value.predictions.bias', 'cls4value.predictions.transform.LayerNorm.bias', 'cls4value.predictions.decoder.weight', 'emb_warmup.steps', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls4value.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'emb_warmup.alpha', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls4value.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- Th

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): KnowledgeBertEmbeddings(
      (word_embeddings): Embedding(50558, 768, padding_idx=0)
      (promoter_embeddings): PriorEmbedding(
        (linear1): Linear(in_features=768, out_features=768, bias=True)
      )
      (co_exp_embeddings): PriorEmbedding(
        (linear1): Linear(in_features=768, out_features=768, bias=True)
      )
      (gene_family_embeddings): PriorEmbedding(
        (linear1): Linear(in_features=768, out_features=768, bias=True)
      )
      (peca_grn_embeddings): PriorEmbedding(
        (linear1): Linear(in_features=768, out_features=768, bias=True)
      )
      (concat_embeddings): Sequential(
        (cat_fc): Linear(in_features=3841, out_features=768, bias=True)
        (cat_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (cat_gelu): QuickGELU()
        (cat_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (position_embeddings): Embedding(2048, 7

In [8]:
# set output dir
output_dir='../../down_stream_outputs'
# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
    # "run_name": wandb_name,
    "dataloader_num_workers": 2,
    "learning_rate": 5e-5, 
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch", 
    "logging_steps": 10,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": "linear", 
    "warmup_steps": 100,
    "weight_decay": 0.001,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "num_train_epochs": 30,
    "load_best_model_at_end": True,
    "output_dir": output_dir,
    "metric_for_best_model": "macro_f1",
    "greater_is_better": True,
}
training_args_init = TrainingArguments(**training_args)

In [9]:
# create the trainer
trainer = Trainer(
    model=model,
    args=training_args_init,
    data_collator=DataCollatorForCellClassification(),
    train_dataset=train_set,
    eval_dataset=test_set,
    compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

In [None]:
# test
predictions = trainer.predict(test_set)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
    pickle.dump(predictions, fp)
trainer.save_metrics("eval", predictions.metrics)
trainer.save_model(output_dir)

17%|█████████████                                      | 34/200 [00:06<00:34,  4.82it/s]│
{'eval_loss': 0.9185497760772705, 'eval_accuracy': 0.8920614998431127, 'eval_precision': 0.7067597099376611, 'eval_recall': 0.702980477684731, 'eval_macro_f1': 0.693570637266355, 'eval_runtime': 41.6667, 'eval_samples_per_second': 76.488, 'eval_steps_per_second': 4.8, 'epoch': 0.0}