# Immune Data Tokenization

## Convert to HF dataset


In [34]:
from tokenizer import ScImmuneTokenizer # refactored version

import scanpy as sc

In [35]:
## Set data folder
data_path = "../data/cellxgene_data"

## Load dataset
adata_immune = sc.read_h5ad(f"{data_path}/immune_1M_merged.h5ad")


  utils.warn_names_duplicates("obs")


In [36]:
adata_immune.obs["disease_ontology_term_id"] = "DOID:4" # to disease ontology root ID

In [37]:
tokenizer = ScImmuneTokenizer(vocab_file="vocab_with_metadata.json") # initialize tokenizer
len(tokenizer)

61048

In [38]:
import numpy as np
import torch
from tqdm import tqdm

# Define the fields from obs to be turned into metadata tokens

metadata_fields = [
                    "cell_type_ontology_term_id",
                    "self_reported_ethnicity_ontology_term_id", 
                    "tissue_general_ontology_term_id",
                    "development_stage_ontology_term_id",
                    "sex_ontology_term_id",
                    "disease_ontology_term_id"
]

gene_names = adata_immune.var.feature_name.values # get gene names

tokenized_input_ids = []
tokenized_values = []

for i in tqdm(range(adata_immune.n_obs)):
    
    # 1. Get dense expression vector
    row = adata_immune.X[i]
    if not isinstance(row, np.ndarray):
        row = row.toarray().squeeze()

    # 2. Get metadata tokens
    obs_row = adata_immune.obs.iloc[i]
    metadata_tokens = []
    for field in metadata_fields:
        val = obs_row.get(field)
        if isinstance(val, str) and val != "NA" and "=" not in val:
            token = f"<{field.split('_ontology_term_id')[0]}={val}>"
            metadata_tokens.append(token)

    # 3. Tokenize
    tokenized = tokenizer.tokenize_cell_batch(
        data=np.expand_dims(row, axis=0),
        gene_ids=gene_names,
        metadata_tokens=metadata_tokens,
        append_cls=True,
        include_zero_gene=False
    )
    
    input_ids, values = tokenized[0]
    tokenized_input_ids.append(input_ids)
    tokenized_values.append(values)

100%|██████████| 1000000/1000000 [1:11:40<00:00, 232.55it/s]


In [39]:
from datasets import Dataset

# Padding and truncation happen in DataCollator, not here
hf_dataset = Dataset.from_dict({
    "genes": tokenized_input_ids,
    "values": tokenized_values
})

In [40]:
hf_dataset.set_format(type="torch", columns=["genes", "values"])

In [42]:
hf_dataset.save_to_disk("scimmune-model/")

Saving the dataset (44/44 shards): 100%|██████████| 1000000/1000000 [00:07<00:00, 136979.16 examples/s]
