In [None]:
PROJECT_NAME = "reverse-gene-finder"

In [None]:
import os
PROJECT_HOME = os.path.join("/content/drive/My Drive/Projects", PROJECT_NAME)

import sys
sys.path.append(PROJECT_HOME)

In [None]:
# Google Drive storage setup
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
%pip install -U tdigest anndata scanpy loompy > /dev/null 2> /dev/null
%pip install -U transformers[torch] ray[data,train,tune,serve] datasets > /dev/null 2> /dev/null

In [None]:
import json
import joblib
import warnings

from libs.classifier import Classifier

In [None]:
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
CV_FOLD = 0 # Choose from 0 to 4 to fine-tune the model testing on a specific fold and trained on the rest

In [None]:
pretrained_model_path = os.path.join(PROJECT_HOME, "models/pretrained_models/geneformer-12L-30M/")

In [None]:
output_prefix = "ad_cell_classifier"
output_dir = os.path.join(PROJECT_HOME, "models", "finetuned_models", "cv_%d" % CV_FOLD)
tmp_output_dir = "/tmp"
!rm -rf '{output_dir}'
!mkdir -p '{output_dir}'

In [None]:
# Load CV sample IDs

selected_labels = ['nonAD', 'earlyAD']
train_ids, valid_ids, test_ids = joblib.load(os.path.join(PROJECT_HOME, "data", "id_splits", "split_%s.joblib" % CV_FOLD))

In [None]:
# Hyperparameter settings

from ray import tune

ray_config = {
    "num_train_epochs": tune.choice([1, 2, 4]),
    "learning_rate": tune.loguniform(1e-6, 1e-3),
    "weight_decay": tune.uniform(0.0, 0.3),
    "lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
    "warmup_steps": tune.randint(100, 2000),
    "seed": tune.randint(0, 100),
    "per_device_train_batch_size": tune.choice([4, 8])
}

In [None]:
cc = Classifier(classifier="cell",
                cell_state_dict = {"state_key": "disease", "states": selected_labels},
                filter_data=None,
                training_args=None,
                ray_config=ray_config,
                freeze_layers=4,
                num_crossval_splits=1,
                forward_batch_size=200,
                nproc=8)

In [None]:
input_data_path = os.path.join(PROJECT_HOME, "data", "tokenized_data", "rosmap.dataset")

train_test_id_split_dict = {"attr_key": "joinid", "train": train_ids + valid_ids, "test": test_ids}

cc.prepare_data(input_data_file=input_data_path,
                output_directory=output_dir,
                output_prefix=output_prefix,
                split_id_dict=train_test_id_split_dict)

In [None]:
# Hyperparameter tuning

train_valid_id_split_dict = {"attr_key": "joinid", "train": train_ids, "eval": valid_ids}

best_training_args = cc.validate(model_directory=pretrained_model_path,
                                 prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset",
                                 id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
                                 output_directory=tmp_output_dir,
                                 output_prefix=output_prefix,
                                 split_id_dict=train_valid_id_split_dict,
                                 lib_dir_path=os.path.join(PROJECT_HOME, "libs"),
                                 n_hyperopt_trials=10)

cc.ray_config = None
cc.training_args = best_training_args.hyperparameters
best_hyperparameters = best_training_args.hyperparameters
print("Best hyperparameters: %s" % best_hyperparameters)

In [None]:
# Training with the best hyperparameters

train_valid_id_split_dict = {"attr_key": "joinid", "train": train_ids, "eval": valid_ids}

all_metrics = cc.validate(model_directory=pretrained_model_path,
                          prepared_input_data_file=f"{output_dir}/{output_prefix}_labeled_train.dataset",
                          id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
                          output_directory=output_dir,
                          output_prefix=output_prefix,
                          split_id_dict=train_valid_id_split_dict)

In [None]:
# Evaluation on the test set

cc = Classifier(classifier="cell",
                cell_state_dict = {"state_key": "disease", "states": selected_labels},
                forward_batch_size=200,
                nproc=8)

all_metrics_test = cc.evaluate_saved_model(
        model_directory=f"{output_dir}/geneformer_cellClassifier_{output_prefix}/ksplit1/",
        id_class_dict_file=f"{output_dir}/{output_prefix}_id_class_dict.pkl",
        test_data_file=f"{output_dir}/{output_prefix}_labeled_test.dataset",
        output_directory=output_dir,
        output_prefix=output_prefix,
    )

In [None]:
# Save evaluation results

os.makedirs(os.path.join(PROJECT_HOME, "results"), exist_ok=True)
result_output = {}
for key in best_hyperparameters:
    result_output[key] = best_hyperparameters[key]
result_output["sensitivity"] = all_metrics_test['all_roc_metrics']['sensitivity']
result_output["specificity"] = all_metrics_test['all_roc_metrics']['specificity']
result_output["roc_auc"] = all_metrics_test['all_roc_metrics']['all_roc_auc']
with open(os.path.join(PROJECT_HOME, "results", "eval_cv_%d.json" % CV_FOLD), "w") as json_f:
    json.dump(result_output, json_f, indent=4)

In [None]:
from google.colab import runtime
runtime.unassign()