## Import necessary packages

In [1]:
import warnings
warnings.filterwarnings("ignore")

import hdf5plugin
import numpy as np
import anndata as ad
from scipy.sparse import csr_matrix
from CellPLM.utils import set_seed
from CellPLM.pipeline.cell_type_annotation import CellTypeAnnotationPipeline, CellTypeAnnotationDefaultPipelineConfig, CellTypeAnnotationDefaultModelConfig

## Specify important parameters before getting started

In [2]:
DATASET = 'MS' # 'hPancreas'
PRETRAIN_VERSION = '20230926_85M'
DEVICE = 'cuda:3'

## Load Downstream Dataset

The MS dataset is contributed by [scGPT](https://github.com/bowang-lab/scGPT/blob/main/tutorials/Tutorial_Annotation.ipynb). hPancreas dataset is contributed by [TOSICA](https://github.com/JackieHanLab/TOSICA/blob/main/test/tutorial.ipynb).


In [3]:
set_seed(42)
if DATASET == 'hPancreas':
    data_train = ad.read_h5ad(f'../data/demo_train.h5ad')
    data_test = ad.read_h5ad(f'../data/demo_test.h5ad')
    train_num = data_train.shape[0]
    data = ad.concat([data_train, data_test])
    data.X = csr_matrix(data.X)
    data.obs['celltype'] = data.obs['Celltype']

elif DATASET == 'MS':
    data_train = ad.read_h5ad(f'../data/c_data.h5ad')
    data_test = ad.read_h5ad(f'../data/filtered_ms_adata.h5ad')
    data_train.var = data_train.var.set_index('index_column')
    data_test.var = data_test.var.set_index('index_column')
    train_num = data_train.shape[0]
    data = ad.concat([data_train, data_test])
    data.var_names_make_unique()

data.obs['split'] = 'test'
tr = np.random.permutation(train_num) #torch.randperm(train_num).numpy()
data.obs['split'][tr[:int(train_num*0.9)]] = 'train'
data.obs['split'][tr[int(train_num*0.9):train_num]] = 'valid'

## Overwrite parts of the default config
These hyperparameters are recommended for general purpose. We did not tune it for individual datasets. You may update them if needed.

In [4]:
pipeline_config = CellTypeAnnotationDefaultPipelineConfig.copy()

model_config = CellTypeAnnotationDefaultModelConfig.copy()
model_config['out_dim'] = data.obs['celltype'].nunique()
pipeline_config, model_config

## Fine-tuning

Efficient data setup and fine-tuning can be seamlessly conducted using the CellPLM built-in `pipeline` module.

First, initialize a `CellTypeAnnotationPipeline`. This pipeline will automatically load a pretrained model.

In [5]:
pipeline = CellTypeAnnotationPipeline(pretrain_prefix=PRETRAIN_VERSION, # Specify the pretrain checkpoint to load
                                      overwrite_config=model_config,  # This is for overwriting part of the pretrain config
                                      pretrain_directory='../ckpt')
pipeline.model

Next, employ the `fit` function to fine-tune the model on your downstream dataset. This dataset should be in the form of an AnnData object, where `.X` is a csr_matrix, and `.obs` includes information for train-test splitting and cell type labels.

Typically, a dataset containing approximately 20,000 cells can be trained in under 10 minutes using a V100 GPU card, with an expected GPU memory consumption of around 8GB.

In [6]:
pipeline.fit(data, # An AnnData object
            pipeline_config, # The config dictionary we created previously, optional
            split_field = 'split', #  Specify a column in .obs that contains split information
            train_split = 'train',
            valid_split = 'valid',
            label_fields = ['celltype']) # Specify a column in .obs that contains cell type labels

## Inference and evaluation
Once the pipeline has been fitted to the downstream datasets, performing inference or evaluation on new datasets can be easily accomplished using the built-in `predict` and `score` functions.

In [7]:
pipeline.predict(
                data, # An AnnData object
                pipeline_config, # The config dictionary we created previously, optional
            )

In [8]:
pipeline.score(data, # An AnnData object
                pipeline_config, # The config dictionary we created previously, optional
                split_field = 'split', # Specify a column in .obs to specify train and valid split, optional
                target_split = 'test', # Specify a target split to predict, optional
                label_fields = ['celltype'])  # Specify a column in .obs that contains cell type labels