# Classification fine-tuning using Helical

## Cell type classification task

In [None]:
from helical.utils import get_anndata_from_hf_dataset
from helical import Geneformer, GeneformerConfig, GeneformerFineTuningModel, scGPT, scGPTConfig, scGPTFineTuningModel, UCE, UCEConfig, UCEFineTuningModel
import torch
import numpy as np

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Install datasets

In [5]:
from datasets import load_dataset
ds = load_dataset("helical-ai/yolksac_human",trust_remote_code=True, download_mode="reuse_cache_if_exists")

Generating train split:   0%|          | 0/25344 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6336 [00:00<?, ? examples/s]

In [None]:
train_dataset = get_anndata_from_hf_dataset(ds["train"])
test_dataset = get_anndata_from_hf_dataset(ds["test"])

## Prepare training labels

- For this classification task we want to predict cell type classes
- So we save the cell types as a list

In [29]:
cell_types_train = list(np.array(train_dataset.obs["LVL1"].tolist()))
cell_types_test = list(np.array(test_dataset.obs["LVL1"].tolist()))

- We convert these string labels into unique integer classes for training

In [30]:
label_set = set(cell_types_train)
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types_train)):
    cell_types_train[i] = class_id_dict[cell_types_train[i]]

for i in range(len(cell_types_test)):
    cell_types_test[i] = class_id_dict[cell_types_test[i]]



## Geneformer Fine-Tuning

Load the desired pretrained Geneformer model and desired configs

In [None]:
geneformer_config = GeneformerConfig(device=device, batch_size=5, model_name="gf-6L-30M-i2048")
geneformer = Geneformer(configurer = geneformer_config)

Process the data so it is in the correct form for Geneformer

In [None]:
geneformer_train_dataset = geneformer.process_data(train_dataset)
geneformer_test_dataset = geneformer.process_data(test_dataset)

Geneformer makes use of the Hugging Face dataset class and so we need to add the labels as a column to this dataset

In [33]:
geneformer_train_dataset = geneformer_train_dataset.add_column("LVL1", cell_types_train)
geneformer_test_dataset = geneformer_test_dataset.add_column("LVL1", cell_types_test)

Define the Geneformer Fine-Tuning Model from the Helical package which appends a fine-tuning head automatically from the list of available heads
- Define the task type, which in this case is classification
- Defined the output size, which is the number of unique labels for classification

In [34]:
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_model=geneformer, fine_tuning_head="classification", output_size=len(label_set))

Fine-tune the model

In [35]:
geneformer_fine_tune.train(train_dataset=geneformer_train_dataset, validation_dataset=geneformer_test_dataset, label="LVL1")

INFO:helical.models.geneformer.fine_tuning_model:Freezing the first 2 encoder layers of the Geneformer model during fine-tuning.
INFO:helical.models.geneformer.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 5069/5069 [06:17<00:00, 13.44it/s, loss=0.061] 
Fine-Tuning Validation: 100%|██████████| 1268/1268 [00:44<00:00, 28.38it/s, accuracy=0.99]
INFO:helical.models.geneformer.fine_tuning_model:Fine-Tuning Complete. Epochs: 1


## scGPT Fine-Tuning

Now the same procedure with scGPT
- Loading the model and setting desired configs

In [None]:
scgpt_config=scGPTConfig(batch_size=10, device=device)
scgpt = scGPT(configurer=scgpt_config)

A slightly different methodology for getting the dataset for scGPT since it does not make use of the Hugging Face Dataset class
- Split the data into a train and validation set

In [37]:
dataset = scgpt.process_data(train_dataset, gene_names = "gene_name")
validation_dataset = scgpt.process_data(test_dataset, gene_names = "gene_name")

INFO:helical.models.scgpt.model:Filtering out 11163 genes to a total of 26155 genes with an id in the scGPT vocabulary.


INFO:helical.models.scgpt.model:Filtering out 11163 genes to a total of 26155 genes with an id in the scGPT vocabulary.


Define the scGPT fine-tuning model with the desired head and number of classes

In [38]:
scgpt_fine_tune = scGPTFineTuningModel(scGPT_model=scgpt, fine_tuning_head="classification", output_size=len(label_set))

For scGPT fine tuning we have to pass in the labels as a separate list
- This is the same for the validation and training sets

In [39]:
scgpt_fine_tune.train(train_input_data=dataset, train_labels=cell_types_train, validation_input_data=validation_dataset, validation_labels=cell_types_test)

INFO:helical.models.scgpt.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 2535/2535 [02:03<00:00, 20.52it/s, loss=0.227]
Fine-Tuning Validation: 100%|██████████| 634/634 [00:10<00:00, 59.94it/s, accuracy=0.986]
INFO:helical.models.scgpt.fine_tuning_model:Fine-Tuning Complete. Epochs: 1


## UCE Fine-Tuning

In [None]:
uce_config=UCEConfig(batch_size=5, device=device)
uce = UCE(configurer=uce_config)

Prepare data the same way as for scGPT
- Add names for each dataset, as datasets are stored as .npz files and separate files are needed

In [None]:
dataset = uce.process_data(train_dataset, name="train", gene_names="gene_name")
validation_dataset = uce.process_data(test_dataset, name="validation", gene_names="gene_name")

Define the fine-tuning model

In [42]:
uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set))

Fine-tune the model

In [43]:
uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types_train, validation_input_data=validation_dataset, validation_labels=cell_types_test)

INFO:helical.models.uce.fine_tuning_model:Starting Fine-Tuning
Fine-Tuning: epoch 1/1:   0%|          | 1/5069 [00:00<14:38,  5.77it/s, loss=1.79]

Fine-Tuning: epoch 1/1: 100%|██████████| 5069/5069 [12:52<00:00,  6.56it/s, loss=1.12]
Fine-Tuning Validation: 100%|██████████| 1268/1268 [01:16<00:00, 16.49it/s, accuracy=0.473]
INFO:helical.models.uce.fine_tuning_model:Fine-Tuning Complete. Epochs: 1
