# Environment setup

In [None]:
import warnings
warnings.filterwarnings("ignore")

import json
import os
import numpy as np
import pandas as pd
import scanpy as sc
from scipy.io import mmread # For reading .mtx files
from pathlib import Path # For better path handling

# Geneformer specific imports
import datasets
from geneformer import DataCollatorForCellClassification
from geneformer import EmbExtractor
from geneformer import TranscriptomeTokenizer
from transformers import BertForSequenceClassification, Trainer

# PyTorch optimization
from torch import set_float32_matmul_precision
set_float32_matmul_precision('medium')

print("--- Step 1: Define Data Paths and Download External Models ---")

# Define your raw data path
raw_data_path = Path("/projects/bioinformatics/DB/IMMUCan/data_raw")

# --- Download and extract geneformer model (if not already local) ---
print("Downloading and extracting Geneformer model...")
geneformer_model_url = 'https://www.dropbox.com/scl/fi/4edmbf7fik0q8kzyq2pef/fine_tuned_geneformer.tar.gz?rlkey=v0ux8v9a3qe8il6o7bowxep8c&st=6ar0ptjg&dl=0'
!wget '{geneformer_model_url}' -O fine_tuned_geneformer.tar.gz
!tar -xzf fine_tuned_geneformer.tar.gz

# --- Download gene list (if not already local) ---
print("Downloading gene list...")
gene_list_url = 'https://www.dropbox.com/scl/fi/brauikmesjfworl67cxov/cpdb_genelist.csv?rlkey=55ankib03njbf9tkci8tgzqc6&st=ezcv94sg&dl=0'
!wget '{gene_list_url}' -O cpdb_genelist.csv

print("\n--- Step 2: Load Your Raw Data ---")

# Load the matrix
# Ensure the path to matrix.mtx is correct relative to raw_data_path
matrix = mmread(raw_data_path / 'matrix.mtx').tocsr()
print("Original Matrix shape (genes x cells):", matrix.shape) # Confirm this

# Loading gene names and cell barcodes
genes = pd.read_csv(raw_data_path / 'genes.tsv', sep='\t', header=None)
barcodes = pd.read_csv(raw_data_path / 'barcodes.tsv', sep='\t', header=None)

# --- CRITICAL FIX: Transpose the matrix to be (cells x genes) ---
# Also, ensure that the number of rows in your 'barcodes' DataFrame
# matches the number of COLUMNS in your original matrix (number of cells).
# And the number of rows in your 'genes' DataFrame matches the number of ROWS
# in your original matrix (number of genes).

if matrix.shape[1] != len(barcodes):
    raise ValueError(f"Mismatch: Matrix has {matrix.shape[1]} columns (cells), but barcodes.tsv has {len(barcodes)} rows. Please check your data files.")
if matrix.shape[0] != len(genes):
    raise ValueError(f"Mismatch: Matrix has {matrix.shape[0]} rows (genes), but genes.tsv has {len(genes)} rows. Please check your data files.")


adata = sc.AnnData(matrix.T, # <<< --- THE FIX IS HERE: .T for transpose
                   obs=pd.DataFrame(index=barcodes[0].values),
                   var=pd.DataFrame(index=genes[0].values))

# Add gene symbols if available (assuming they are in the second column of genes.tsv)
if len(genes.columns) > 1:
    adata.var['gene_symbol'] = genes[1].values

adata.var["ensembl_id"] = adata.var.index # Keep Ensembl IDs as a separate column

# Add basic QC metrics and joinid
adata.obs["n_counts"] = adata.X.sum(axis=1)
adata.obs["joinid"] = list(range(adata.n_obs))

print(f"AnnData object created with shape (cells x genes): {adata.shape}")
print("AnnData .obs head:")
print(adata.obs.head())
print("\nAnnData .var head:")
print(adata.var.head())

print("\n--- Step 3: Save AnnData and Tokenize Data for Geneformer ---")

h5ad_dir = "./data/h5ad/" # Use a relative path for output in your current working directory

if not os.path.exists(h5ad_dir):
    os.makedirs(h5ad_dir)

adata.write(h5ad_dir + "my_immu_can_data.h5ad")
print(f"AnnData saved to {h5ad_dir}my_immu_can_data.h5ad")

token_dir = "data/tokenized_data/" # Use a relative path for output

if not os.path.exists(token_dir):
    os.makedirs(token_dir)

tokenizer = TranscriptomeTokenizer(custom_attr_name_dict={"joinid": "joinid"})
print(f"Tokenizing data from {h5ad_dir} to {token_dir}...")
tokenizer.tokenize_data(
    data_directory=h5ad_dir,
    output_directory=token_dir,
    output_prefix="my_immu_can", # Use a unique prefix for your data
    file_format="h5ad",
)
print("Data tokenization complete.")


print("\n--- Step 4: Load Geneformer Model and Make Predictions ---")

model_dir = "./fine_tuned_geneformer/" # Path where the downloaded model was extracted
label_mapping_dict_file = os.path.join(model_dir, "label_to_cell_subclass.json")

# Ensure the label mapping file exists in your downloaded model
if not os.path.exists(label_mapping_dict_file):
    raise FileNotFoundError(f"Label mapping file not found: {label_mapping_dict_file}. "
                            "Please ensure the Geneformer model was extracted correctly "
                            "and contains this file.")

with open(label_mapping_dict_file) as fp:
    label_mapping_dict = json.load(fp)

print("First 5 entries of label mapping:")
for k in list(label_mapping_dict.keys())[:5]:
    print(k, ': ', label_mapping_dict[k])

dataset = datasets.load_from_disk(token_dir + "my_immu_can.dataset") # Load your tokenized data
print(f"Loaded tokenized dataset with {len(dataset)} cells.")

# Add dummy label column for prediction, as Geneformer's trainer expects it
dataset = dataset.add_column("label", [0] * len(dataset))

# Reload the fine-tuned Geneformer model
print("Loading fine-tuned Geneformer model...")
model = BertForSequenceClassification.from_pretrained(model_dir)

# Create the trainer for prediction
trainer = Trainer(model=model, data_collator=DataCollatorForCellClassification())

# Use trainer to make predictions
print("Making predictions with Geneformer...")
predictions = trainer.predict(dataset)
print("Predictions complete.")

# Process predictions
predicted_label_ids = np.argmax(predictions.predictions, axis=1)
predicted_logits = [predictions.predictions[i][predicted_label_ids[i]] for i in range(len(predicted_label_ids))]
predicted_labels = [label_mapping_dict[str(i)] for i in predicted_label_ids]

# Add predictions to AnnData object
adata.obs["predicted_cell_subclass"] = predicted_labels
# Using sigmoid for probability-like score (for a single predicted class's logit)
adata.obs["predicted_cell_subclass_probability"] = 1 / (1 + np.exp(-np.array(predicted_logits)))

print("\n--- Step 5: Standard Single-Cell Data Preprocessing and Analysis (Scanpy) ---")

# These steps are standard for scRNA-seq analysis before visualization
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)

# Ensure highly_variable is True before subsetting
if np.sum(adata.var['highly_variable']) > 0:
    adata = adata[:, adata.var.highly_variable]
    print(f"Subsetted to {adata.shape[1]} highly variable genes.")
else:
    print("No highly variable genes found with current parameters. Skipping subsetting.")

sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver="arpack")
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40) # Adjust n_pcs if your data is very different
sc.tl.umap(adata)

print("Running Leiden clustering...")
sc.tl.leiden(adata)

print("Scanpy preprocessing and dimensionality reduction complete.")

print("\n--- Step 6: Visualize Results ---")

print("Generating UMAP plots...")
# Plot UMAP colored by Leiden clusters (your new "original" clusters)
sc.pl.umap(adata, color="leiden", title="Leiden Clustering of IMMUCan Data", show=False, save="_leiden.png")

# Plot UMAP colored by Geneformer predicted cell types and their probabilities
sc.pl.umap(
    adata,
    color=["predicted_cell_subclass_probability", "predicted_cell_subclass"],
    title="Predicted Geneformer Annotations for IMMUCan Data",
    show=False, save="_geneformer_predictions.png"
)

# Plot both side-by-side for comparison
sc.pl.umap(
    adata,
    color=["leiden", "predicted_cell_subclass"],
    legend_loc = 'on data',
    title="Comparison: Leiden vs. Geneformer (IMMUCan Data)",
    show=True, save="_comparison.png"
)

print("Analysis complete. UMAP plots saved as PNG files in your current directory.")

In [None]:
import warnings
warnings.filterwarnings("ignore")

import json
import os
import numpy as np
import pandas as pd
import scanpy as sc
from scipy.io import mmread # For reading .mtx files
from pathlib import Path # For better path handling

# Geneformer specific imports
import datasets
from geneformer import DataCollatorForCellClassification
from geneformer import EmbExtractor
from geneformer import TranscriptomeTokenizer
from transformers import BertForSequenceClassification, Trainer

# PyTorch optimization
from torch import set_float32_matmul_precision
set_float32_matmul_precision('medium')

print("--- Step 1: Define Data Paths and Download External Models ---")

# Define your raw data path
raw_data_path = Path("/projects/bioinformatics/DB/IMMUCan/data_raw")

# --- Download and extract geneformer model (if not already local) ---
print("Downloading and extracting Geneformer model...")
geneformer_model_url = 'https://www.dropbox.com/scl/fi/4edmbf7fik0q8kzyq2pef/fine_tuned_geneformer.tar.gz?rlkey=v0ux8v9a3qe8il6o7bowxep8c&st=6ar0ptjg&dl=0'
if not os.path.exists("./fine_tuned_geneformer"):
    !wget '{geneformer_model_url}' -O fine_tuned_geneformer.tar.gz
    !tar -xzf fine_tuned_geneformer.tar.gz
else:
    print("Geneformer model directory already exists, skipping download/extract.")


# --- Download gene list (if not already local) ---
print("Downloading gene list...")
gene_list_url = 'https://www.dropbox.com/scl/fi/brauikmesjfworl67cxov/cpdb_genelist.csv?rlkey=55ankib03njbf9tkci8tgzqc6&st=ezcv94sg&dl=0'
if not os.path.exists("cpdb_genelist.csv"):
    !wget '{gene_list_url}' -O cpdb_genelist.csv
else:
    print("cpdb_genelist.csv already exists, skipping download.")

print("\n--- Step 2: Load Your Raw Data ---")

# Load the matrix
matrix = mmread(raw_data_path / 'matrix.mtx').tocsr()
print("Original Matrix shape (genes x cells):", matrix.shape)

# Loading gene names and cell barcodes
genes = pd.read_csv(raw_data_path / 'genes.tsv', sep='\t', header=None)
barcodes = pd.read_csv(raw_data_path / 'barcodes.tsv', sep='\t', header=None)

# --- CRITICAL FIXES FOR AnnData CREATION ---

if matrix.shape[1] != len(barcodes):
    raise ValueError(f"Mismatch: Matrix has {matrix.shape[1]} columns (cells), but barcodes.tsv has {len(barcodes)} rows. Please check your data files.")
if matrix.shape[0] != len(genes):
    raise ValueError(f"Mismatch: Matrix has {matrix.shape[0]} rows (genes), but genes.tsv has {len(genes)} rows. Please check your data files.")

# Create the var DataFrame explicitly first, ensuring 'ensembl_id' column is there
var_df = pd.DataFrame(index=genes[0].values)
var_df['ensembl_id'] = genes[0].values

if len(genes.columns) > 1:
    var_df['gene_symbol'] = genes[1].values
else:
    print("Warning: genes.tsv only has one column. 'gene_symbol' will not be set.")

adata = sc.AnnData(matrix.T, # Transpose the matrix
                   obs=pd.DataFrame(index=barcodes[0].values), # Cell barcodes as obs index
                   var=var_df) # Use the carefully constructed var_df

# --- FIX for n_counts access: Ensure it's explicitly a Series and check before save ---
adata.obs["n_counts"] = pd.Series(adata.X.sum(axis=1).A.flatten(), index=adata.obs.index) # .A.flatten() for sparse matrix sum
adata.obs["joinid"] = list(range(adata.n_obs))


print("\n--- DEBUGGING: Check adata.var and adata.obs structure before tokenization ---")
print("adata.var columns:", adata.var.columns.tolist())
print("adata.var head:")
print(adata.var.head())
if 'ensembl_id' in adata.var.columns:
    print("\nFirst 5 ensembl_ids in adata.var['ensembl_id']:")
    print(adata.var['ensembl_id'].head().tolist())
else:
    print("\nERROR: 'ensembl_id' column is NOT in adata.var. Geneformer cannot proceed.")
    raise KeyError("'ensembl_id' column is missing from adata.var. Geneformer cannot proceed.")

print(f"Data type of adata.var['ensembl_id']: {adata.var['ensembl_id'].dtype}")

print("\nadata.obs columns:", adata.obs.columns.tolist())
print("adata.obs head:")
print(adata.obs.head())
if 'n_counts' in adata.obs.columns:
    print("\nFirst 5 n_counts in adata.obs['n_counts']:")
    print(adata.obs['n_counts'].head().tolist())
else:
    print("\nERROR: 'n_counts' column is NOT in adata.obs. Geneformer cannot proceed.")
    raise KeyError("'n_counts' column is missing from adata.obs. Geneformer cannot proceed.")

print(f"Data type of adata.obs['n_counts']: {adata.obs['n_counts'].dtype}")


print(f"AnnData object created with shape (cells x genes): {adata.shape}")


print("\n--- Step 3: Save AnnData and Tokenize Data for Geneformer ---")

h5ad_dir = "./data/h5ad/"
token_dir = "data/tokenized_data/"

if os.path.exists(h5ad_dir):
    print(f"Clearing contents of {h5ad_dir} to avoid tokenizing old files...")
    for f in os.listdir(h5ad_dir):
        os.remove(os.path.join(h5ad_dir, f))
else:
    os.makedirs(h5ad_dir)

if not os.path.exists(token_dir):
    os.makedirs(token_dir)


adata.write(h5ad_dir + "my_immu_can_data.h5ad")
print(f"AnnData saved to {h5ad_dir}my_immu_can_data.h5ad")

tokenizer = TranscriptomeTokenizer(custom_attr_name_dict={"joinid": "joinid"})
print(f"Tokenizing data from {h5ad_dir} (specifically, 'my_immu_can_data.h5ad')...")

tokenizer.tokenize_data(
    data_directory=h5ad_dir,
    output_directory=token_dir,
    output_prefix="my_immu_can",
    file_format="h5ad",
)
print("Data tokenization complete.")


print("\n--- Step 4: Load Geneformer Model and Make Predictions ---")

model_dir = "./fine_tuned_geneformer/"
label_mapping_dict_file = os.path.join(model_dir, "label_to_cell_subclass.json")

if not os.path.exists(label_mapping_dict_file):
    raise FileNotFoundError(f"Label mapping file not found: {label_mapping_dict_file}. "
                            "Please ensure the Geneformer model was extracted correctly "
                            "and contains this file.")

with open(label_mapping_dict_file) as fp:
    label_mapping_dict = json.load(fp)

print("First 5 entries of label mapping:")
for k in list(label_mapping_dict.keys())[:5]:
    print(k, ': ', label_mapping_dict[k])

dataset = datasets.load_from_disk(token_dir + "my_immu_can.dataset")
print(f"Loaded tokenized dataset with {len(dataset)} cells.")

dataset = dataset.add_column("label", [0] * len(dataset))

print("Loading fine-tuned Geneformer model...")
model = BertForSequenceClassification.from_pretrained(model_dir)

trainer = Trainer(model=model, data_collator=DataCollatorForCellClassification())

print("Making predictions with Geneformer...")
predictions = trainer.predict(dataset)
print("Predictions complete.")

predicted_label_ids = np.argmax(predictions.predictions, axis=1)
predicted_logits = [predictions.predictions[i][predicted_label_ids[i]] for i in range(len(predicted_label_ids))]
predicted_labels = [label_mapping_dict[str(i)] for i in predicted_label_ids]

adata.obs["predicted_cell_subclass"] = predicted_labels
adata.obs["predicted_cell_subclass_probability"] = 1 / (1 + np.exp(-np.array(predicted_logits)))

print("\n--- Step 5: Standard Single-Cell Data Preprocessing and Analysis (Scanpy) ---")

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)

if np.sum(adata.var['highly_variable']) > 0:
    adata = adata[:, adata.var.highly_variable]
    print(f"Subsetted to {adata.shape[1]} highly variable genes.")
else:
    print("No highly variable genes found with current parameters. Skipping subsetting.")

sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver="arpack")
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
sc.tl.umap(adata)

print("Running Leiden clustering...")
sc.tl.leiden(adata)

print("Scanpy preprocessing and dimensionality reduction complete.")

print("\n--- Step 6: Visualize Results ---")

print("Generating UMAP plots...")
sc.pl.umap(adata, color="leiden", title="Leiden Clustering of IMMUCan Data", show=False, save="_leiden.png")

sc.pl.umap(
    adata,
    color=["predicted_cell_subclass_probability", "predicted_cell_subclass"],
    title="Predicted Geneformer Annotations for IMMUCan Data",
    show=False, save="_geneformer_predictions.png"
)

sc.pl.umap(
    adata,
    color=["leiden", "predicted_cell_subclass"],
    legend_loc = 'on data',
    title="Comparison: Leiden vs. Geneformer (IMMUCan Data)",
    show=True, save="_comparison.png"
)

print("Analysis complete. UMAP plots saved as PNG files in your current directory.")

# Models download

Models hosted on https://cellxgene.cziscience.com/census-models normally require Amazon Web Services (AWS) to download. Since aws-cli is not installed on google colab I have temporarily loaded them to dropbox.  

In [None]:
# Download and extract geneformer model
!wget 'https://www.dropbox.com/scl/fi/4edmbf7fik0q8kzyq2pef/fine_tuned_geneformer.tar.gz?rlkey=v0ux8v9a3qe8il6o7bowxep8c&st=6ar0ptjg&dl=0' -O fine_tuned_geneformer.tar.gz
!tar -xzf fine_tuned_geneformer.tar.gz

In [None]:
!ls fine_tuned_geneformer

# Data download

In [None]:

!mkdir -p data


# Geneformer for cell class prediction and data projection

This notebook provides examples to utilize the CELLxGENE collaboration fine-tuned Geneformer model with user data. For more information on the model please refer to the [Census model page](https://cellxgene.cziscience.com/census-models).

**IMPORTANT:** This tutorial requires cellxgene-census package version 1.9.1 or later.

**Contents**

1. Requirements.
1. Preparing data and model.
1. Using the Geneformer fine-tuned model for **cell subclass inference**.
1. Using the Geneformer fine-tuned model for **data projection**.

> ⚠️ Note "cell subclass" is a high-level grouping of cell types as annotated in CELLxGENE Discover via the CL ontology see [https://cellxgene.cziscience.com/collections](https://cellxgene.cziscience.com/collections

> ⚠️ Note that the Census RNA data includes duplicate cells present across multiple datasets. Duplicate cells can be filtered in or out using the cell metadata variable `is_primary_data` which is described in the [Census schema](https://github.com/chanzuckerberg/cellxgene-census/blob/main/docs/cellxgene_census_schema.md#repeated-data).

### Downloading the fine-tuned Geneformer model

### Importing required packages

Finally all the required packages are loaded.

In [None]:
import warnings

warnings.filterwarnings("ignore")

import json
import os

#import cellxgene_census
import datasets
import numpy as np
import scanpy as sc
from geneformer import DataCollatorForCellClassification
from geneformer import EmbExtractor
from geneformer import TranscriptomeTokenizer
from transformers import BertForSequenceClassification, Trainer

from torch import set_float32_matmul_precision
set_float32_matmul_precision('medium')

from pathlib import Path
import pandas as pd
from scipy.io import mmread

## Preparing data and model

### Preparing single-cell data

Let's load the test data. In preparation to use with Geneformer we do the following:

- Set the index as the ENSEMBL gene ID and stores it in the `obs` column `"ensembl_id"`
  - e.g. `ENSG00000139618` (*without* a version number suffix)
- Add read counts to the `obs` column `"n_counts"`
- Add an ID column to be used for joining later in the  `obs` column `"joinid"`

Then we write the resulting H5AD file to disk.

Now we can tokenize the test data using Geneformer's tokenizer, while keeping track of `"joinid"` for future joining.

In [None]:
import warnings
warnings.filterwarnings("ignore")

import json
import os
import shutil # Import shutil for rmtree
import numpy as np
import pandas as pd
import scanpy as sc
from scipy.io import mmread # For reading .mtx files
from pathlib import Path # For better path handling

# Geneformer specific imports
import datasets
from geneformer import DataCollatorForCellClassification
from geneformer import EmbExtractor
from geneformer import TranscriptomeTokenizer
from transformers import BertForSequenceClassification, Trainer

# PyTorch optimization
from torch import set_float32_matmul_precision
set_float32_matmul_precision('medium')

print("--- Step 1: Define Data Paths and Download External Models ---")

# Define your raw data path
raw_data_path = Path("/projects/bioinformatics/DB/IMMUCan/data_raw")

# --- Download and extract geneformer model (if not already local) ---
print("Downloading and extracting Geneformer model...")
geneformer_model_url = 'https://www.dropbox.com/scl/fi/4edmbf7fik0q8kzyq2pef/fine_tuned_geneformer.tar.gz?rlkey=v0ux8v9a3qe8il6o7bowxep8c&st=6ar0ptjg&dl=0'
if not os.path.exists("./fine_tuned_geneformer"):
    !wget '{geneformer_model_url}' -O fine_tuned_geneformer.tar.gz
    !tar -xzf fine_tuned_geneformer.tar.gz
else:
    print("Geneformer model directory already exists, skipping download/extract.")


# --- Download gene list (if not already local) ---
print("Downloading gene list...")
gene_list_url = 'https://www.dropbox.com/scl/fi/brauikmesjfworl67cxov/cpdb_genelist.csv?rlkey=55ankib03njbf9tkci8tgzqc6&st=ezcv94sg&dl=0'
if not os.path.exists("cpdb_genelist.csv"):
    !wget '{gene_list_url}' -O cpdb_genelist.csv
else:
    print("cpdb_genelist.csv already exists, skipping download.")

print("\n--- Step 2: Load Your Raw Data ---")

# Load the matrix
matrix = mmread(raw_data_path / 'matrix.mtx').tocsr()
print("Original Matrix shape (genes x cells):", matrix.shape)

# Loading gene names and cell barcodes
genes = pd.read_csv(raw_data_path / 'genes.tsv', sep='\t', header=None)
barcodes = pd.read_csv(raw_data_path / 'barcodes.tsv', sep='\t', header=None)

# --- CRITICAL FIXES FOR AnnData CREATION ---

if matrix.shape[1] != len(barcodes):
    raise ValueError(f"Mismatch: Matrix has {matrix.shape[1]} columns (cells), but barcodes.tsv has {len(barcodes)} rows. Please check your data files.")
if matrix.shape[0] != len(genes):
    raise ValueError(f"Mismatch: Matrix has {matrix.shape[0]} rows (genes), but genes.tsv has {len(genes)} rows. Please check your data files.")

# Create the var DataFrame explicitly first, ensuring 'ensembl_id' column is there
var_df = pd.DataFrame(index=genes[0].values)
var_df['ensembl_id'] = genes[0].values

if len(genes.columns) > 1:
    var_df['gene_symbol'] = genes[1].values
else:
    print("Warning: genes.tsv only has one column. 'gene_symbol' will not be set.")

adata = sc.AnnData(matrix.T, # Transpose the matrix
                   obs=pd.DataFrame(index=barcodes[0].values), # Cell barcodes as obs index
                   var=var_df) # Use the carefully constructed var_df

# --- FIX for n_counts access: Ensure it's explicitly a Series and check before save ---
adata.obs["n_counts"] = pd.Series(adata.X.sum(axis=1).A.flatten(), index=adata.obs.index) # .A.flatten() for sparse matrix sum
adata.obs["joinid"] = list(range(adata.n_obs))


print("\n--- DEBUGGING: Check adata.var and adata.obs structure before tokenization ---")
print("adata.var columns:", adata.var.columns.tolist())
print("adata.var head:")
print(adata.var.head())
if 'ensembl_id' in adata.var.columns:
    print("\nFirst 5 ensembl_ids in adata.var['ensembl_id']:")
    print(adata.var['ensembl_id'].head().tolist())
else:
    print("\nERROR: 'ensembl_id' column is NOT in adata.var. Geneformer cannot proceed.")
    raise KeyError("'ensembl_id' column is missing from adata.var. Geneformer cannot proceed.")

print(f"Data type of adata.var['ensembl_id']: {adata.var['ensembl_id'].dtype}")

print("\nadata.obs columns:", adata.obs.columns.tolist())
print("adata.obs head:")
print(adata.obs.head())
if 'n_counts' in adata.obs.columns:
    print("\nFirst 5 n_counts in adata.obs['n_counts']:")
    print(adata.obs['n_counts'].head().tolist())
else:
    print("\nERROR: 'n_counts' column is NOT in adata.obs. Geneformer cannot proceed.")
    raise KeyError("'n_counts' column is missing from adata.obs. Geneformer cannot proceed.")

print(f"Data type of adata.obs['n_counts']: {adata.obs['n_counts'].dtype}")


print(f"AnnData object created with shape (cells x genes): {adata.shape}")


print("\n--- Step 3: Save AnnData and Tokenize Data for Geneformer ---")

h5ad_dir = "./data/h5ad/"
token_dir = "data/tokenized_data/"

# --- MODIFIED: Use shutil.rmtree for more robust directory cleanup ---
if os.path.exists(h5ad_dir):
    print(f"Attempting to clear contents of {h5ad_dir} using shutil.rmtree...")
    try:
        shutil.rmtree(h5ad_dir)
        print(f"Successfully cleared {h5ad_dir}.")
    except OSError as e:
        print(f"WARNING: Could not clear {h5ad_dir} due to an OS error: {e}")
        print("This is often caused by lingering file locks, especially on network drives (NFS).")
        print("Please try:")
        print("  1. Restarting your Python kernel/session.")
        print(f"  2. Manually deleting the '{h5ad_dir}' directory from your terminal.")
        print("     Example: `rm -rf ./data/h5ad` (from the directory where your script is running)")
        print("Exiting script as a clean state is required for tokenization.")
        exit() # Exit the script if cleanup fails
os.makedirs(h5ad_dir) # Recreate the directory after successful removal

if not os.path.exists(token_dir):
    os.makedirs(token_dir)


adata.write(h5ad_dir + "my_immu_can_data.h5ad")
print(f"AnnData saved to {h5ad_dir}my_immu_can_data.h5ad")

tokenizer = TranscriptomeTokenizer(custom_attr_name_dict={"joinid": "joinid"})
print(f"Tokenizing data from {h5ad_dir} (specifically, 'my_immu_can_data.h5ad')...")

tokenizer.tokenize_data(
    data_directory=h5ad_dir,
    output_directory=token_dir,
    output_prefix="my_immu_can",
    file_format="h5ad",
)
print("Data tokenization complete.")


print("\n--- Step 4: Load Geneformer Model and Make Predictions ---")

model_dir = "./fine_tuned_geneformer/"
label_mapping_dict_file = os.path.join(model_dir, "label_to_cell_subclass.json")

if not os.path.exists(label_mapping_dict_file):
    raise FileNotFoundError(f"Label mapping file not found: {label_mapping_dict_file}. "
                            "Please ensure the Geneformer model was extracted correctly "
                            "and contains this file.")

with open(label_mapping_dict_file) as fp:
    label_mapping_dict = json.load(fp)

print("First 5 entries of label mapping:")
for k in list(label_mapping_dict.keys())[:5]:
    print(k, ': ', label_mapping_dict[k])

dataset = datasets.load_from_disk(token_dir + "my_immu_can.dataset")
print(f"Loaded tokenized dataset with {len(dataset)} cells.")

dataset = dataset.add_column("label", [0] * len(dataset))

print("Loading fine-tuned Geneformer model...")
model = BertForSequenceClassification.from_pretrained(model_dir)

trainer = Trainer(model=model, data_collator=DataCollatorForCellClassification())

print("Making predictions with Geneformer...")
predictions = trainer.predict(dataset)
print("Predictions complete.")

predicted_label_ids = np.argmax(predictions.predictions, axis=1)
predicted_logits = [predictions.predictions[i][predicted_label_ids[i]] for i in range(len(predicted_label_ids))]
predicted_labels = [label_mapping_dict[str(i)] for i in predicted_label_ids]

adata.obs["predicted_cell_subclass"] = predicted_labels
adata.obs["predicted_cell_subclass_probability"] = 1 / (1 + np.exp(-np.array(predicted_logits)))

print("\n--- Step 5: Standard Single-Cell Data Preprocessing and Analysis (Scanpy) ---")

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)

if np.sum(adata.var['highly_variable']) > 0:
    adata = adata[:, adata.var.highly_variable]
    print(f"Subsetted to {adata.shape[1]} highly variable genes.")
else:
    print("No highly variable genes found with current parameters. Skipping subsetting.")

sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver="arpack")
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
sc.tl.umap(adata)

print("Running Leiden clustering...")
sc.tl.leiden(adata)

print("Scanpy preprocessing and dimensionality reduction complete.")

print("\n--- Step 6: Visualize Results ---")

print("Generating UMAP plots...")
sc.pl.umap(adata, color="leiden", title="Leiden Clustering of IMMUCan Data", show=False, save="_leiden.png")

sc.pl.umap(
    adata,
    color=["predicted_cell_subclass_probability", "predicted_cell_subclass"],
    title="Predicted Geneformer Annotations for IMMUCan Data",
    show=False, save="_geneformer_predictions.png"
)

sc.pl.umap(
    adata,
    color=["leiden", "predicted_cell_subclass"],
    legend_loc = 'on data',
    title="Comparison: Leiden vs. Geneformer (IMMUCan Data)",
    show=True, save="_comparison.png"
)

print("Analysis complete. UMAP plots saved as PNG files in your current directory.")

In [None]:
import pandas as pd
import anndata as ad
from scipy.io import mmread
from pathlib import Path
import re # For regular expressions to clean ENSEMBL IDs
import os # For directory operations

# --- Configuration ---
# Define the path to your raw data directory
# IMPORTANT: Replace '/projects/bioinformatics/DB/IMMUCan/data_raw' with your actual path
raw_data_path = Path('/projects/bioinformatics/DB/IMMUCan/data_raw')

# Define the output directory and file name for the prepared H5AD file
h5ad_dir = "./data/h5ad/"
output_file_name = "pbmcs.h5ad" # Using the name from your example
output_file_path = Path(h5ad_dir) / output_file_name
# --- End Configuration ---


print(f"Loading data from: {raw_data_path}")

# Load the data matrix
try:
    matrix = mmread(raw_data_path / 'matrix.mtx').tocsr()
    print("Matrix shape:", matrix.shape)
except FileNotFoundError:
    print(f"Error: matrix.mtx not found at {raw_data_path / 'matrix.mtx'}")
    exit()

# Loading gene names and cell barcodes
try:
    # Assuming genes.tsv has two columns: [0] ENSEMBL ID with version, [1] Gene Symbol
    genes = pd.read_csv(raw_data_path / 'genes.tsv', sep='\t', header=None)
    # Assuming barcodes.tsv has one column: [0] Cell Barcode
    barcodes = pd.read_csv(raw_data_path / 'barcodes.tsv', sep='\t', header=None)
    print(f"Loaded {len(genes)} genes and {len(barcodes)} barcodes.")
except FileNotFoundError as e:
    print(f"Error loading genes.tsv or barcodes.tsv: {e}")
    exit()

# --- Data Preparation Steps ---

# 1. Prepare 'var' (genes) DataFrame
# Extract ENSEMBL gene ID without version number suffix for the primary index
# Example: ENSG00000139618.12 -> ENSG00000139618
# This will be the primary identifier for genes in the AnnData object's .var index
cleaned_ensembl_ids = genes[0].apply(lambda x: re.sub(r'\.\d+$', '', x))

# Create 'var_df' with the cleaned ENSEMBL IDs as the index
var_df = pd.DataFrame(index=cleaned_ensembl_ids)
var_df['gene_symbol'] = genes[1].values # Add gene symbol
var_df['ensembl_id_full'] = genes[0].values # Store the full ENSEMBL ID for reference

# Prepare 'obs' (barcodes/cells) DataFrame
# Use cell barcode as index
obs_df = pd.DataFrame(index=barcodes[0].values)

# Create the AnnData object
# The matrix should be cells x genes, so we transpose the loaded matrix.
adata = ad.AnnData(X=matrix.transpose(), obs=obs_df, var=var_df)
print("Initial AnnData object created with shape:", adata.shape)

# Set the name of the variable index (genes) to 'ensembl_id'
# This is crucial as Geneformer often expects the index to be named 'ensembl_id'
adata.var.index.name = 'ensembl_id'
print(f"Set adata.var.index.name to '{adata.var.index.name}'.")


# 2. Set the 'ensembl_id' column in adata.var
# As per your requested format, set 'ensembl_id' column to the index of adata.var
adata.var["ensembl_id"] = adata.var.index
print(f"Set 'ensembl_id' column in adata.var. Example: {adata.var['ensembl_id'].head()}")

# 3. Add read counts to the obs column "n_counts"
# Calculate the sum of counts for each cell (row sum of the transposed matrix)
adata.obs["n_counts"] = adata.X.sum(axis=1).A1 # .A1 converts sparse matrix row sums to a 1D numpy array
print(f"Added 'n_counts' to obs. Example counts: {adata.obs['n_counts'].head()}")

# 4. Add an ID column to be used for joining later in the obs column "joinid"
# As per your requested format, use list(range(adata.n_obs))
adata.obs["joinid"] = list(range(adata.n_obs))
print(f"Added 'joinid' to obs. Example joinids: {adata.obs['joinid'].head()}")


print("\n--- AnnData Object Summary ---")
print(adata)
print("\nFirst 5 rows of obs:")
print(adata.obs.head())
print("\nFirst 5 rows of var:")
print(adata.var.head())
print("\nColumns in adata.var before saving:")
print(adata.var.columns)


# 5. Create the output directory if it doesn't exist
print(f"\nChecking and creating output directory: {h5ad_dir}...")
if not os.path.exists(h5ad_dir):
    os.makedirs(h5ad_dir)
    print(f"Created directory: {h5ad_dir}")
else:
    print(f"Directory already exists: {h5ad_dir}")

# 6. Write the resulting H5AD file to disk
print(f"\nWriting prepared data to {output_file_path}...")
try:
    # Attempt to remove the file if it exists to prevent 'file already open' errors
    if os.path.exists(output_file_path):
        print(f"Removing existing file: {output_file_path}")
        os.remove(output_file_path)
    adata.write(output_file_path)
    print("Data preparation complete and file saved successfully!")
except Exception as e:
    print(f"Error saving H5AD file: {e}")



In [None]:
token_dir = "data/tokenized_data/"

if not os.path.exists(token_dir):
    os.makedirs(token_dir)

tokenizer = TranscriptomeTokenizer(custom_attr_name_dict={"joinid": "joinid"})
tokenizer.tokenize_data(
    data_directory=h5ad_dir,
    output_directory=token_dir,
    output_prefix="pbmc",
    file_format="h5ad",
)

### Preparing data from model

Then let's fetch the mapping dictionary between Geneformer IDs and the associated cell subclass labels. This information is stored along the fine-tuned model.

In [None]:
model_dir = "./fine_tuned_geneformer/"
label_mapping_dict_file = os.path.join(model_dir, "label_to_cell_subclass.json")

with open(label_mapping_dict_file) as fp:
    label_mapping_dict = json.load(fp)

This dictionary contains all the possible cell labels available for the model, and the predictions on the section below will use these labels.

In [None]:
for k in list(label_mapping_dict.keys())[:5]:
    print(k, ': ', label_mapping_dict[k])

## Using the Geneformer fine-tuned model for cell subclass inference

### Loading tokenized data

Let's load the tokenized test data.

In [None]:
dataset = datasets.load_from_disk(token_dir + "pbmc.dataset")
dataset

We add a dummy cell metadata column `"label"` needed for Geneformer to make predictions.

In [None]:
dataset
dataset = dataset.add_column("label", [0] * len(dataset))

### Performing inference of cell subclass

Now we can load the model and run the inference workflow.

> ⚠️ Note, this step will be slow with CPUs, a machine with one GPU is recommended

In [None]:
# reload pretrained model
model = BertForSequenceClassification.from_pretrained(model_dir)
# create the trainer
trainer = Trainer(model=model, data_collator=DataCollatorForCellClassification())
# use trainer
predictions = trainer.predict(dataset)

And finally we select the most likely cell class based on the probability vector from the predictions of each cell in our test data.

In [None]:
predicted_label_ids = np.argmax(predictions.predictions, axis=1)
predicted_logits = [predictions.predictions[i][predicted_label_ids[i]] for i in range(len(predicted_label_ids))]
predicted_labels = [label_mapping_dict[str(i)] for i in predicted_label_ids]

### Inspecting inference results

Then we add the prediction back to our loaded AnnData test dataset.

In [None]:
adata.obs["predicted_cell_subclass"] = predicted_labels
adata.obs["predicted_cell_subclass_probability"] = np.exp(predicted_logits) / (1 + np.exp(predicted_logits))

And it's ready for inspecting the predictions. Let's visualize the predictions on the UMAP space, the following is a basic processing workflow to derive a UMAP representation, of the data.

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
adata = adata[:, adata.var.highly_variable]
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver="arpack")
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
sc.tl.umap(adata)

Let's also add the original cell type annotations as obtained in [Scapy's annotation tutorial](https://scanpy-tutorials.readthedocs.io/en/latest/pbmc3k.html) of the same data.

In [None]:
sc.tl.leiden(adata)
original_cell_types = [
    "CD4-positive, alpha-beta T cell (1)",
    "CD4-positive, alpha-beta T cell (2)",
    "CD14-positive, monocyte",
    "B cell (1)",
    "CD8-positive, alpha-beta T cell",
    "FCGR3A-positive, monocyte",
    "natural killer cell",
    "dendritic cell",
    "megakaryocyte",
    "B cell (2)",
]
adata.rename_categories("leiden", original_cell_types)

These are the original annotations.

In [None]:
sc.pl.umap(adata, color="leiden", title="Original Annotations")

And these are the predicted annotations.

In [None]:
sc.pl.umap(
    adata,
    color=["predicted_cell_subclass_probability", "predicted_cell_subclass"],
    title="Predicted Geneformer Annotations",
)

In [None]:
sc.pl.umap(
    adata,
    color=["leiden", "predicted_cell_subclass"],
    legend_loc = 'on data'
)

## Using the Geneformer fine-tuned model for data projection

### Generating Geneformer embeddings for 10X PBMC 3K data

To project new data, for example the 10X PBMC 3K data, into the Census embedding space from Geneformer's fine-tune model, we can use `EmbExtractor` from the [Geneformer](https://huggingface.co/ctheodoris/Geneformer) package as follows.

We first need to get the number of categories (cell subclasses) present in the model.

In [None]:
n_classes = len(label_mapping_dict)

Then we can run the `EmbExtractor`, which randomize the cells during the process and thus we keep track of `"joinid"`.

> ⚠️ Note, this step will be slow with CPUs, a machine with one GPU is recommended

In [None]:
output_dir = "data/geneformer_embeddings"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

embex = EmbExtractor(
    model_type="CellClassifier",
    num_classes=n_classes,
    max_ncells=None,
    emb_label=["joinid"],
    emb_layer=0,
    forward_batch_size=30,
    nproc=8,
)

embs = embex.extract_embs(
    model_directory=model_dir,
    input_data_file=token_dir + "pbmc.dataset",
    output_directory=output_dir,
    output_prefix="emb",
)

In [None]:
!ls data/geneformer_embeddings



Then we simply re-order the embeddings based on `"joinid"` and then merge them to the original AnnData

In [None]:
embs = embs.sort_values("joinid")
adata.obsm["geneformer"] = embs.drop(columns="joinid").to_numpy()

Let's take a look at these Geneformer embeddings in a UMAP representation

In [None]:
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40, use_rep="geneformer")
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color="predicted_cell_subclass", title="10X PBMC 3K in Geneformer")

### Joining Geneformer embeddings from 10X PBMC 3K data with other Census datasets

There are multiple datasets in Census from PBMCs, and all human Census data has pre-calculated Geneformer embeddings, so now we can join the embeddings we generated above from the 10X PBMC 3K dataset with Census data.

Let's grab a few PBMC datasets from Census and request the Geneformer embeddings.

In [None]:
import cellxgene_census

In [None]:
# Some PBMC data from these collections
# 1. https://cellxgene.cziscience.com/collections/c697eaaf-a3be-4251-b036-5f9052179e70
# 2. https://cellxgene.cziscience.com/collections/f2a488bf-782f-4c20-a8e5-cb34d48c1f7e

dataset_ids = [
    "fa8605cf-f27e-44af-ac2a-476bee4410d3",
    "3c75a463-6a87-4132-83a8-c3002624394d",
]

with cellxgene_census.open_soma(census_version="2023-12-15") as census:

    adata_census = cellxgene_census.get_anndata(
        census=census,
        measurement_name="RNA",
        organism="Homo sapiens",
        obs_value_filter=f"dataset_id in {dataset_ids}",
        obs_embeddings=["geneformer"],
  )

To simplify let's select the genes that are also present in the 10X PBMC 3K dataset.

In [None]:
adata_census.var_names = adata_census.var["feature_id"]
shared_genes = list(set(adata.var_names) & set(adata_census.var_names))
adata_census = adata_census[:, shared_genes]

And take a subset of these cells, let's take 3K cells to match the size of the test data.

In [None]:
index_subset = np.random.choice(adata_census.n_obs, size=3000, replace=False)
adata_census = adata_census[index_subset, :]

Now we can join these Census data to the 10X PBMC 3K data

In [None]:
adata_census.obs["dataset"] = "Census - " + adata_census.obs["dataset_id"].astype(str)
adata.obs["dataset"] = "10X PBMC 3K"
adata.obs["cell_type"] = "Predicted - " + adata.obs["predicted_cell_subclass"].astype(str)

adata_joined = sc.concat([adata, adata_census], join="outer", label="batch")

Let's now inspect all of the cells in the UMAP space.

In [None]:
sc.pp.neighbors(adata_joined, n_neighbors=10, n_pcs=40, use_rep="geneformer")
sc.tl.umap(adata_joined)

In [None]:
sc.pl.umap(adata_joined, color="dataset")

In [None]:
sc.pl.umap(adata_joined, color="cell_type")