# Zero-Shot Evaluation of scGPT
Following https://scgpt.readthedocs.io/en/latest/tutorial_annotation.html

In [1]:
from os.path import join

import anndata
import scanpy as sc
import numpy as np
import pandas as pd
import dask.dataframe as dd

from scipy.sparse import csr_matrix
from tqdm.auto import tqdm

In [6]:
DATA_PATH = "/lustre/groups/ml01/workspace/till.richter/merlin_cxg_2023_05_15_sf-log1p"
SAVE_PATH = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/classification/scGPT"
ADATA_PATH = join(DATA_PATH, 'adata')

In [3]:
def get_count_matrix(ddf):
    x = (
        ddf['X']
        .map_partitions(
            lambda xx: pd.DataFrame(np.vstack(xx.tolist())), 
            meta={col: 'f4' for col in range(19331)}
        )
        .to_dask_array(lengths=[1024] * ddf.npartitions)
    )
    
    return x

In [4]:
n_cells_train = 1_500_000

ddf_train = dd.read_parquet(join(DATA_PATH, 'train'), split_row_groups=True)
x_train = get_count_matrix(ddf_train)[:n_cells_train, :]
y_train = dd.read_parquet(join(DATA_PATH, 'train'), columns=['cell_type']).compute().iloc[:n_cells_train]

ddf_test = dd.read_parquet(join(DATA_PATH, 'test'), split_row_groups=True)
x_test = get_count_matrix(ddf_test)
y_test = dd.read_parquet(join(DATA_PATH, 'test'), columns=['cell_type']).compute()

var = pd.read_parquet(join(DATA_PATH, 'var.parquet'))

### Training data

In [9]:
# run inference in batches to save memory

for i, idxs in tqdm(enumerate(np.array_split(np.arange(x_train.shape[0]), 10))):
    # data is already normalized
    anndata.AnnData(
        X=x_train[idxs, :].map_blocks(csr_matrix).compute(), 
        var=var.set_index('feature_name'),
        obs=y_train.iloc[idxs]
    ).write_h5ad(join(ADATA_PATH, 'adata_train', f'{i}.h5ad'))

0it [00:00, ?it/s]

### Test data

In [8]:
for i, idxs in tqdm(enumerate(np.array_split(np.arange(x_test.shape[0]), 30))):
    # data is already normalized
    anndata.AnnData(
        X=x_test[idxs, :].map_blocks(csr_matrix).compute(), 
        var=var.set_index('feature_name'),
        obs=y_test.iloc[idxs]
    ).write_h5ad(join(ADATA_PATH, 'adata_test', f'{i}.h5ad'))

0it [00:00, ?it/s]

### Get embeddings from scGPT

In [None]:
from pathlib import Path
from os.path import join

import scgpt as scg
import anndata
import scanpy as sc

In [None]:
model_dir = Path(SAVE_PATH)
cell_type_key = "cell_type"
gene_col = "index"

### Training data

In [None]:
for i in range(10):
    adata = sc.read_h5ad(join(SAVE_PATH, 'train', f'{i}.h5ad'))
    adata = scg.tasks.embed_data(
        adata,
        model_dir,
        cell_type_key=cell_type_key,
        gene_col=gene_col,
        batch_size=64,
        return_new_adata=True,
    ).write_h5ad(join(SAVE_PATH, 'train', f'{i}_embed.h5ad'))

### Test data

In [None]:
for i in range(30):
    adata = sc.read_h5ad(join(SAVE_PATH, 'test', f'{i}.h5ad'))
    adata = scg.tasks.embed_data(
        adata,
        model_dir,
        cell_type_key=cell_type_key,
        gene_col=gene_col,
        batch_size=64,
        return_new_adata=True,
    ).write_h5ad(join(SAVE_PATH, 'test', f'{i}_embed.h5ad'))

# Evaluate scGPT embeddings

### Train Linear model

In [None]:
from os.path import join

import anndata
import scanpy as sc
import numpy as np

from cuml.linear_model import LogisticRegression

### Evaluate on test data

In [None]:
import pandas as pd
from sklearn.metrics import classification_report
from statistics import mean, stdev

from utils import correct_labels

In [None]:
cell_type_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/cell_type.parquet'))
cell_type_hierarchy = np.load(join(DATA_PATH, 'cell_type_hierarchy/child_matrix.npy'))

In [None]:
import scgpt as scg
import anndata
import numpy as np
from os.path import join
from pathlib import Path

# Define paths
cell_type_key = "cell_type"
gene_col = "index"

# Process test data to generate embeddings
for i in range(30):  # Assuming you have 30 test splits as in the original code
    # Load the h5ad file
    adata = sc.read_h5ad(join(SAVE_PATH, 'test', f'{i}.h5ad'))
    
    # Generate embeddings using the scGPT model
    embeddings = scg.tasks.embed_data(
        adata,
        model_dir,
        cell_type_key=cell_type_key,
        gene_col=gene_col,
        batch_size=64,
        return_new_adata=True,
    ).X  # Extract the X (embedding) part of the Anndata object
    
    # Save the embeddings as a .npy file
    np.save(join(SAVE_PATH, 'test', f'{i}_embed.npy'), embeddings)

In [None]:
clf_reports = []

for clf in clf_list:
    preds = clf.predict(X_test)
    preds_corr = correct_labels(y_test, preds, cell_type_hierarchy)
    clf_reports.append(pd.DataFrame(classification_report(y_test, preds_corr, output_dict=True)).T)

In [None]:
f1_scores_scgpt = [clf_report.loc['macro avg', 'f1-score'] for clf_report in clf_reports]
print(f'{mean(f1_scores_scgpt):.4f}±{stdev(f1_scores_scgpt):.4f}')

In [6]:
# Load the model
model_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/scGPT/best_model.pt"
model = torch.load(model_path)
model.eval()

AttributeError: 'collections.OrderedDict' object has no attribute 'eval'

In [None]:
# Define function to extract embeddings
def get_embeddings(dataloader, model):
    embeddings = []
    labels = []
    for batch in dataloader:
        # Obtain the input tensor from the batch. Adjust this line according to your dataloader's structure.
        inputs = batch['input_ids']
        with torch.no_grad():
            output = model(input_ids=inputs)
        
        # Assuming the output of your model provides the embeddings directly. If not, adjust accordingly.
        embeddings.append(output.embeddings)
        labels.append(batch['labels'])
    
    embeddings = torch.cat(embeddings)
    labels = torch.cat(labels)
    return embeddings, labels

# Get embeddings and labels from the validation set
val_embeddings, val_labels = get_embeddings(estim.datamodule.val_dataloader(), model)
