# 3. Cell Type Annotation using a Pre-Trained Cell2Sentence Model

In this notebook, we'll:
1. Load a pre-trained C2S model from the HuggingFace Hub.
2. Use the model to predict cell types based on our generated cell sentences.
3. Integrate these predictions back into AnnData.

## Learning Objectives
- Learn how to load a pre-trained LLM for single-cell data.
- Automatically annotate PBMC cells.
- Assess annotation quality by comparing known marker genes.

In [None]:
import torch
import os
import random
import numpy as np

# Cell2Sentence imports
import cell2sentence as cs

# Single-cell libraries
import anndata
import scanpy as sc

import tqdm as notebook_tqdm

In [None]:
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)

In [None]:
DATA_PATH = "./data/pbmc3k_final.h5ad"
adata = anndata.read_h5ad(DATA_PATH)

In [None]:
adata

### Converting to Cell2Sentence (CSData)

In [None]:
adata_obs_cols_to_keep = ["cell_type","organism"]

In [None]:
# Create CSData object
arrow_ds, vocabulary = cs.CSData.adata_to_arrow(
    adata=adata, 
    random_state=SEED, 
    sentence_delimiter=' ',
    label_col_names=adata_obs_cols_to_keep
)

In [None]:
arrow_ds

For this exercise we will consider the top 100 genes of the cell sentences

In [None]:
k = 100  # replace with your desired number of genes

arrow_ds = arrow_ds.map(lambda x: {"cell_sentence": " ".join(x["cell_sentence"].split()[:k])})

In [None]:
sample_idx = 2000
len(arrow_ds[sample_idx]['cell_sentence'].split())

In [None]:
c2s_save_dir = "./c2s_api_testing"  # C2S dataset will be saved into this directory
c2s_save_name = "PBMC_3K_tutorial2"  # This will be the name of our C2S dataset on disk

In [None]:
cs_data = cs.CSData.csdata_from_arrow(
    arrow_dataset=arrow_ds, 
    vocabulary=vocabulary,
    save_dir=c2s_save_dir,
    save_name=c2s_save_name,
    dataset_backend="arrow"
)

In [None]:
#dir(cs_data)

In [None]:
#cs_data.get_sentence_strings()

## 3.1. Load a Pre-trained Model
We can specify a model from the Hugging Face Hub. For a smaller/faster model:
```
model_name = "vandijklab/pythia-160m-c2s"
```
For a more powerful ~410M parameter model:
```
model_name = "vandijklab/C2S-Pythia-410m-cell-type-prediction"
```

In [None]:
#model_name = "vandijklab/pythia-160m-c2s"
model_name = "vandijklab/C2S-Pythia-410m-cell-type-prediction"

save_dir = "models"  # local folder to store the downloaded model

cs_model = cs.csmodel.CSModel(
    model_name,
    save_dir=save_dir,
    save_name="cs_model"
)
print("Model loaded successfully.")

## 3.2. Predict Cell Types
We'll use a built-in function to predict labels for our dataset. This function will take each cell's top genes (as a sentence) and produce a text label for the cell type.

Note: This step can take a few minutes if you're on CPU, but the 160M model is typically manageable.

In [None]:
# Ensure the 'organism' key is provided if required by the function
pred_labels = cs.tasks.predict_cell_types_of_data(
    cs_data,
    cs_model,
    n_genes=100,  # must match how we created the sentences
)
pred_labels[:10]

In [None]:
adata

We have a list of cell type predictions for each cell. Let's store them in the AnnData object.

In [None]:
# Use the existing AnnData object loaded earlier (from DATA_PATH)

adata.obs['C2S_predicted_celltype'] = pred_labels
adata.obs['C2S_predicted_celltype'].value_counts()

Check the distribution of predicted cell types. For PBMC data, you might see T cell subsets, B cells, monocytes, etc.

## 3.3. Inspect Example Predictions
We can pick a few cells of a predicted type and look at their top genes. For example, if the model predicted `NK cell` for cell index 1, let's see if `NKG7, GZMB, PRF1` are indeed at the top.

In [None]:
cell_idx = 1773  # arbitrary example
predicted_type = adata.obs['C2S_predicted_celltype'][cell_idx]
print("Cell index:", cell_idx)
print("Predicted type:", predicted_type)
cell_type = adata.obs["cell_type"][cell_idx]
print("Cell type:\n", cell_type)

In [None]:
SAVE_PATH = "./data/pbmc3k_410m_predictions.h5ad"

In [None]:
adata.write_h5ad(SAVE_PATH)

In [None]:
adata.obs.head(50)

Compare the top genes in the sentence with known markers for that predicted type. This is a quick validation that the model is leveraging real biological signals.

## Exercise:
1. Select different cells from `adata.obs['C2S_predicted_celltype']` categories.
2. Check if the top genes match known marker genes for that type.
3. Generate a UMAP (by standard scRNA-seq workflow) and color cells by predicted labels to visualize.


### (Optional) UMAP Visualization
Let's do a quick typical scRNA-seq analysis to see how predicted labels cluster.


In [None]:
# Filtered data is already in adata.
# We'll do a standard pipeline: log-transform, PCA, neighbors, UMAP.
#sc.pp.normalize_total(adata, target_sum=1e4)
#sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
adata = adata[:, adata.var.highly_variable]

sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=20)
sc.tl.umap(adata)

sc.pl.umap(adata, color=["C2S_predicted_celltype"], wspace=0.4, ncols=1)


You should see clusters that (hopefully) align with typical PBMC subsets. The predicted labels from the LLM can be visually inspected on the UMAP plot.

## Next Steps
[Go to Notebook 4 →](./4_Finetuning_on_New_Datasets.ipynb) to learn about generating synthetic cells, fine-tuning, and more advanced use cases.