# Model Developer Workflow using the CZ Benchmarks Framework\n

This Jupyter Notebook provides an overview of the model developer workflow using the CZI-Benchmark Task Framework. The framework streamlines model benchmarking and evaluation, enabling efficient and reproducible assessments across various tasks. In this example, we focus on utilizing **Geneformer**, a powerful model for gene expression analysis.

## Build & Run Instructions

Ensure the following settings are set as follows in the beginning of the docker run script (`scripts/run_docker.sh`):

```bash
BUILD_DEV_CONTAINER=true
EVAL_CMD="jupyter-lab --notebook-dir=/app/examples --port=8888 --no-browser --allow-root"
```

The following command will launch the container and start jupyter lab.

```bash
bash scripts/run_docker.sh -m geneformer
```

The the appropriate URL (`http://127.0.0.1:8888/lab?token=<TOKEN>`) with a browser. Open the notebook and execute it. If the notebook is being run remotely, substitute the correct IP address and use either an SSH tunnel (more secure) or add `--ip 0.0.0.0` (insecure) to the Jupyter lab command.

### User Pre-Defined Paths

In [None]:
import os
from pathlib import Path

# Setting the predefined paths
os.environ["DATASETS_CACHE_PATH"] = "/raw"
os.environ["MODEL_WEIGHTS_PATH_DOCKER"] = "/weights"  # - user checkpoint path,
os.environ["MODEL_WEIGHTS_CACHE_PATH"] = "/weights"

### Setup Benchmark Dataset

In [None]:
from czbenchmarks.datasets.utils import load_dataset

In [None]:
# - load benchmark dataset
dataset_name = "tsv2_bladder"
dataset = load_dataset(dataset_name=dataset_name)
dataset.load_data()
adata = dataset.adata

### Setup User Model

In [None]:
from model import Geneformer
import geneformer.perturber_utils as pu

In [None]:
model = Geneformer(model_variant="gf_12L_30M")
model.download_model_weights(dataset)
model.model = pu.load_model("Pretrained", 0, model.model_weights_dir, mode="eval")

#### Model Preprocessing Steps

In [None]:
model.validate_dataset(dataset)
model._prepare_metadata(dataset)
data_path = model._save_dataset_temp(dataset)
tokenized_dataset_path = model._tokenize_dataset(data_path)

In [None]:
tokenized_dataset = model._load_tokenized_dataset(tokenized_dataset_path)

### User Defined DataLoader

In [None]:
from torch.utils.data import DataLoader

In [None]:
import pickle
import torch

token_file = (
    Path(model.model_weights_dir).parent / model.token_config.token_dictionary_file
)
gene_token_dict = pickle.load(open(token_file, "rb"))
pad_token_id = gene_token_dict.get("<pad>")
model_input_size = model.token_config.input_size

In [None]:
# - define collate_fn
def prepare_pad_tensor_data(dict_list):
    lengths, cell_idx, input_ids = zip(
        *[(d["length"], d["cell_idx"], torch.tensor(d["input_ids"])) for d in dict_list]
    )
    lengths_tensor = torch.tensor(lengths, dtype=torch.int64, requires_grad=False)
    cell_idx_tensor = torch.tensor(cell_idx, dtype=torch.int64)
    max_len = max(lengths)

    # - pad to max_len
    input_data_minibatch = list(input_ids)
    input_data = pu.pad_tensor_list(
        input_data_minibatch, max_len, pad_token_id, model_input_size
    )
    attention_mask = torch.tensor(
        [
            [1] * original_len + [0] * (max_len - original_len)
            if original_len <= max_len
            else [1] * max_len
            for original_len in lengths
        ]
    )
    return {
        "input_ids": input_data,
        "cell_idxs": cell_idx_tensor,
        "lengths": lengths_tensor,
        "attention_mask": attention_mask,
    }

In [None]:
dataloader = DataLoader(
    tokenized_dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=prepare_pad_tensor_data,
    num_workers=16,
)

### Setup Benchmark Evaluation Task

#### Model Predictions

In [None]:
import numpy as np
from tqdm import tqdm

device = "cuda"
model_embeddings = []
cell_idxs = []
# - user custom loop for extracting embeddings
for idx, batch in tqdm(
    enumerate(dataloader), desc="Extracting Embeddings..", total=len(dataloader)
):
    original_lens = batch["lengths"].to(device)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    with torch.no_grad():
        outputs = model.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
    embs_i = outputs.hidden_states[-1]
    mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens).cpu().numpy()
    model_embeddings.append(mean_embs)
    cell_idxs += batch["cell_idxs"].cpu().tolist()

model_embeddings = np.concatenate(model_embeddings, axis=0)[np.argsort(cell_idxs)]

#### Task Evaluation

In [None]:
from czbenchmarks.tasks import ClusteringTask
from czbenchmarks.datasets import DataType

In [None]:
task = ClusteringTask(label_key="cell_type")
dataset.set_output(None, DataType.EMBEDDING, model_embeddings)
result = task.run(dataset)

In [None]:
print(result)