In [1]:
import gzip
import html
import os
import random
import string
from functools import lru_cache, partial
from typing import Callable, Optional, List, Union, Literal
import pickle
from pathlib import Path
import logging

import anndata as ad
import scanpy as sc
import scipy.sparse as sp

import ftfy
import numpy as np
import regex as re
import torch


GENE_MEDIAN_FILE = Path('/home/ubuntu/Geneformer/geneformer') / "gene_median_dictionary.pkl"
TOKEN_DICTIONARY_FILE = Path('/home/ubuntu/Geneformer/geneformer') / "token_dictionary.pkl"

In [2]:
def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices]

def tokenize_cell(gene_vector, gene_tokens):
    """
    Convert normalized gene expression vector to tokenized rank value encoding.
    """
    # create array of gene vector with token indices
    # mask undetected genes
    nonzero_mask = np.nonzero(gene_vector)[0]
    # rank by median-scaled gene values
    return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])

In [15]:
class GeneformerTokenizer(object):
    def __init__(self, 
                 nproc=1,
                 gene_median_file=GENE_MEDIAN_FILE,
                 token_dictionary_file=TOKEN_DICTIONARY_FILE
    ):

        """
            Initialize tokenizer.
            Parameters
            ----------
            custom_attr_name_dict : None, dict
                Dictionary of custom attributes to be added to the dataset.
                Keys are the names of the attributes in the loom file.
                Values are the names of the attributes in the dataset.
            nproc : int
                Number of processes to use for dataset mapping.
            gene_median_file : Path
                Path to pickle file containing dictionary of non-zero median
                gene expression values across Genecorpus-30M.
            token_dictionary_file : Path
                Path to pickle file containing token dictionary (Ensembl IDs:token).
        """

        # number of processes for dataset mapping
        self.nproc = nproc

        # load dictionary of gene normalization factors
        # (non-zero median value of expression across Genecorpus-30M)
        with open(gene_median_file, "rb") as f:
            self.gene_median_dict = pickle.load(f)

        # load token dictionary (Ensembl IDs:token)
        with open(token_dictionary_file, "rb") as f:
            self.gene_token_dict = pickle.load(f)

        # gene keys for full vocabulary
        self.gene_keys = list(self.gene_median_dict.keys())

        # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
        self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
    
    def tokenize_anndata(self, gexp, target_sum=10_000, chunk_size=512):
        
        expression = gexp.X.todense() # If needed, convert to dense matrix. Try to avoid this

        coding_miRNA_loc = np.where(
            [self.genelist_dict.get(i, False) for i in gexp.var["ensembl_id"]]
        )[0]
        norm_factor_vector = np.array(
            [
                self.gene_median_dict[i]
                for i in gexp.var["ensembl_id"][coding_miRNA_loc]
            ]
        )
        coding_miRNA_ids = gexp.var["ensembl_id"][coding_miRNA_loc]
        coding_miRNA_tokens = np.array(
            [self.gene_token_dict[i] for i in coding_miRNA_ids]
        )

        filter_pass_loc = np.array([i for i in range(gexp.shape[0])])

        tokenized_cells = []

        for i in range(0, len(filter_pass_loc), chunk_size):
            idx = filter_pass_loc[i:i+chunk_size]

            n_counts = gexp[idx].obs['n_counts'].values[:, None]
            X_view = gexp[idx, coding_miRNA_loc].X
            X_norm = (X_view / n_counts * target_sum / norm_factor_vector)
            X_norm = sp.csr_matrix(X_norm)

            tokenized_cells += [
                rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
                for i in range(X_norm.shape[0])
            ]

        gexp.obsm['geneformer_emb'] = np.array(tokenized_cells)

        # save anndata object to datasets/dataset_name/tokenized_feats/name.h5ad
        # gexp.write_h5ad()

        return gexp
    
    def __call__(self, gexp, context_length: Optional[int] = None) -> torch.Tensor:
        """
        Returns the tokenized representation of given input gexp(s) from GeneFormer paper
        Parameters
        ----------
        gexp : Union[numpy/list, List[numpy/list]]
            An input gexpr or a list of input gexpr to tokenize
        context_length : int
            The context length to use as input to GeneFormer; 

        Returns
        -------
        A two-dimensional tensor containing the resulting tokens, shape = [number of input gexp, context_length]
        """
        gexp = self.tokenize_anndata(gexp)


        return gexp

In [4]:
anndata_processed_files_path = "/home/ubuntu/SpatialCLIP/datasets/HumanPanVisium/Processed"
adata_sample_dict = {}
for file in os.listdir(anndata_processed_files_path):
    if file.endswith(".h5ad"):
        adata_sample_dict[file.split(".")[0]] = sc.read_h5ad(os.path.join(anndata_processed_files_path, file))

In [5]:
adata_sample_dict.keys()

dict_keys(['V10L13-019-A1', 'V11J26-002-B1', 'V19L29-095-A1', 'V19B23-014-A1', 'V52Y10-365-A1', 'V52Y09-003-B1', 'V52Y10-317-B1', 'V19L29-097-B1', 'V19L01-033-A1', 'V19S16-046-C1', 'V10A13-206-C1', 'V52Y10-310-A1', 'V10A13-167-C1', 'V52Y10-286-B1', 'V10L13-020-D1', 'V19L29-098-B1', 'V42A20-354-D1', 'V10L13-021-B1', 'V42U20-030-A1', 'V10M18-056-C1'])

In [16]:
geneformer_tokenizer = GeneformerTokenizer(nproc=4)
V11J26_adata = geneformer_tokenizer(adata_sample_dict['V11J26-002-B1'])

  for i in gexp.var["ensembl_id"][coding_miRNA_loc]
  coding_miRNA_ids = gexp.var["ensembl_id"][coding_miRNA_loc]


In [7]:
# create an empty numpy array of name "accum_V11J26_adata"
accum_V11J26_adata = []
for i in range(adata_sample_dict['V11J26-002-B1'].shape[0]):
    accum_V11J26_adata += geneformer_tokenizer(adata_sample_dict['V11J26-002-B1'][i,:])

  for i in gexp.var["ensembl_id"][coding_miRNA_loc]
  coding_miRNA_ids = gexp.var["ensembl_id"][coding_miRNA_loc]


In [17]:
V11J26_adata = np.array(V11J26_adata)
V11J26_adata.shape

  V11J26_adata = np.array(V11J26_adata)


(1, 3043)

In [18]:
V11J26_adata

array([[array([  351, 12103, 20387, ...,  6850,  9623, 10950], dtype=int16),
        array([  351, 12103,  2218, ...,  5433,  3098, 13393], dtype=int16),
        array([12103, 16683,   351, ...,  6850, 13393, 12938], dtype=int16),
        ...,
        array([  351,  6556, 12103, ..., 13393,  6850,  8385], dtype=int16),
        array([  351, 12103, 16683, ..., 14204, 15711,  8385], dtype=int16),
        array([  351, 12103,   317, ..., 16224, 13393,  4384], dtype=int16)]],
      dtype=object)

In [33]:
accum_V11J26_adata = np.array(accum_V11J26_adata).T
accum_V11J26_adata.shape

(1, 3043)

In [34]:
accum_V11J26_adata

array([[array([  351, 12103, 20387, ...,  6850,  9623, 10950], dtype=int16),
        array([  351, 12103,  2218, ...,  5433,  3098, 13393], dtype=int16),
        array([12103, 16683,   351, ...,  6850, 13393, 12938], dtype=int16),
        ...,
        array([  351,  6556, 12103, ..., 13393,  6850,  8385], dtype=int16),
        array([  351, 12103, 16683, ..., 14204, 15711,  8385], dtype=int16),
        array([  351, 12103,   317, ..., 16224, 13393,  4384], dtype=int16)]],
      dtype=object)

In [39]:
# where is the difference between V11J26_adata[0][0] and accum_V11J26_adata[0][0]?
q = V11J26_adata[0][0] == accum_V11J26_adata[0][0]
# get index of False in q
np.where(q == False)

(array([], dtype=int64),)

(5354,)

In [27]:
V11J26_adata[0][0] == accum_V11J26_adata[0]

  V11J26_adata[0][0] == accum_V11J26_adata[0]


False