In [1]:
import geneformer
from geneformer import TranscriptomeTokenizer
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import loompy

In [2]:
print(geneformer.__file__)

/u/home/s/schhina/.conda/envs/geneformer_env/lib/python3.11/site-packages/geneformer/__init__.py


In [22]:
data_path = "/u/scratch/s/schhina/geneformer_raw_data"
output_path = "/u/scratch/s/schhina/geneformer_tokenized_data"

In [23]:
def remove_char(s):
    for i in range(len(s)):
        if s[i].isnumeric():
            return s[i:]
    return s[4:]

In [24]:
adata = ad.read_h5ad("/u/scratch/s/schhina/labeled_t_cell_data.h5ad")
# adata.write_loom("/u/scratch/s/schhina/geneformer_raw_data/labeled_t_cell_data.loom")
adata

AnnData object with n_obs × n_vars = 47726 × 60725
    obs: 'stimulation', 'cd_status'

In [5]:
ds = loompy.connect("/u/scratch/s/schhina/geneformer_raw_data/labeled_t_cell_data.loom")
ds.shape

(60725, 47726)

In [6]:
ds.ca['n_counts'] = [np.sum(ds[:, i]) for i in np.arange(ds.shape[1])]
ds.ra['ensembl_id'] = [(s.split('.')[0]) for s in ds.ra.Accession]
del ds.ra['Accession']
ds.close()

In [7]:
tk = TranscriptomeTokenizer({'stimulation': 'stimulation'}, nproc=16)

In [8]:
tk.tokenize_data(data_path, output_path, "Q_data_", file_format="loom")

Tokenizing /u/scratch/s/schhina/geneformer_raw_data/labeled_t_cell_data.loom
/u/scratch/s/schhina/geneformer_raw_data/labeled_t_cell_data.loom has no column attribute 'filter_pass'; tokenizing all cells.
Creating dataset.


Map (num_proc=16):   0%|          | 0/47726 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/47726 [00:00<?, ? examples/s]

In [1]:
from datasets import load_from_disk

In [4]:
import torch

In [28]:
d = load_from_disk("/u/scratch/s/schhina/geneformer_tokenized_data/Q_data_.dataset")

In [29]:
d.shape

(47726, 3)

In [30]:
d[1].keys()

dict_keys(['input_ids', 'stimulation', 'length'])

In [8]:
ids = d[0]['input_ids']

In [12]:
len(ids)

2048

In [18]:
ds[:, :10]

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [11]:
[i for i in ds.ca]

['n_counts', 'obs_names', 'stimulation']

In [9]:
from pathlib import Path
import os
import pickle
import loompy as lp

In [17]:
loom_file_path = "/u/scratch/s/schhina/geneformer_raw_data/labeled_t_cell_data.loom"
GENE_MEDIAN_FILE = "/u/home/s/schhina/.conda/envs/geneformer_env/lib/python3.11/site-packages/geneformer/gene_median_dictionary.pkl"
TOKEN_DICTIONARY_FILE = "/u/home/s/schhina/.conda/envs/geneformer_env/lib/python3.11/site-packages/geneformer/token_dictionary.pkl"

with open(GENE_MEDIAN_FILE, "rb") as f:
    gene_median_dict = pickle.load(f)

# load token dictionary (Ensembl IDs:token)
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
    gene_token_dict = pickle.load(f)

# gene keys for full vocabulary
gene_keys = list(gene_token_dict.keys())

# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
genelist_dict = dict(zip(gene_keys, [True] * len(gene_keys)))
target_sum = 10_000
custom_attr_name_dict = None

In [18]:
def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices]

def tokenize_cell(gene_vector, gene_tokens):
    """
    Convert normalized gene expression vector to tokenized rank value encoding.
    """
    # create array of gene vector with token indices
    # mask undetected genes
    nonzero_mask = np.nonzero(gene_vector)[0]
    # rank by median-scaled gene values
    return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])

In [None]:
with lp.connect(str(loom_file_path)) as data:
    # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
    coding_miRNA_loc = np.where(
        [genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
    )[0]
    print(data.ra["ensembl_id"][:10])
    print(coding_miRNA_loc)
    norm_factor_vector = np.array(
        [
            gene_median_dict[i]
            for i in data.ra["ensembl_id"][coding_miRNA_loc]
        ]
    )
    coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
    coding_miRNA_tokens = np.array(
        [gene_token_dict[i] for i in coding_miRNA_ids]
    )

    # define coordinates of cells passing filters for inclusion (e.g. QC)
    try:
        data.ca["filter_pass"]
    except AttributeError:
        var_exists = False
    else:
        var_exists = True

    if var_exists:
        filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
    elif not var_exists:
        print(
            f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
        )
        filter_pass_loc = np.array([i for i in range(data.shape[1])])
    print(len(filter_pass_loc))
    print(filter_pass_loc[:20])
    print(data.shape)

    # scan through .loom files and tokenize cells
    tokenized_cells = []
    for _ix, _selection, view in data.scan(
        items=filter_pass_loc, axis=1, batch_size=10
    ):
        # select subview with protein-coding and miRNA genes
        if _ix == 0:
            print(view.shape)
            print([i for i in view])
        subview = view.view[coding_miRNA_loc, :]
        if _ix == 0:
            print(subview.shape)
            print([i for i in subview])

        # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
        # and normalize by gene normalization factors
        subview_norm_array = (
            subview[:, :]
            / subview.ca.n_counts
            * target_sum
            / norm_factor_vector[:, None]
        )
        # tokenize subview gene vectors
        if _ix == 0:
            print(subview_norm_array[:, 0])
            print(coding_miRNA_tokens)
            print(tokenize_cell(subview_norm_array[:, 0], coding_miRNA_tokens))

        tokenized_cells += [
            tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
            for i in range(subview_norm_array.shape[1])
        ]

        # add custom attributes for subview to dict
        if custom_attr_name_dict is not None:
            for k in file_cell_metadata.keys():
                file_cell_metadata[k] += subview.ca[k].tolist()
        else:
            file_cell_metadata = None

['00000180346' '00000185800' '00000255389' '00000147059' '00000238045'
 '00000056972' '00000198920' '00000213937' '00000244113' '00000265720']
[]
/u/scratch/s/schhina/geneformer_raw_data/labeled_t_cell_data.loom has no column attribute 'filter_pass'; tokenizing all cells.
52648
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
(60725, 52648)
(60725, 10)
[array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), array([

In [5]:
adata

AnnData object with n_obs × n_vars = 47726 × 60725
    obs: 'stimulation', 'cd_status'

In [7]:
import pickle

In [66]:
with open("/u/home/s/schhina/scratch_backup/Geneformer/geneformer/gene_median_dictionary.pkl", "rb") as f:
    pkl = pickle.load(f)

In [67]:
pkl

{'ENSG00000000003': 2.001186019549122,
 'ENSG00000000005': 3.2282132031640383,
 'ENSG00000000419': 2.218873777678385,
 'ENSG00000000457': 3.7923348732758737,
 'ENSG00000000460': 2.348520948900473,
 'ENSG00000000938': 3.6523609233116896,
 'ENSG00000000971': 4.125091814103064,
 'ENSG00000001036': 1.7514524716218405,
 'ENSG00000001084': 2.4156348393662945,
 'ENSG00000001167': 2.1358156567181648,
 'ENSG00000001460': 2.061695265270317,
 'ENSG00000001461': 2.6187784630312465,
 'ENSG00000001497': 1.9663485923180597,
 'ENSG00000001561': 2.4441504114207344,
 'ENSG00000001617': 3.309488427821417,
 'ENSG00000001626': 4.027858975050529,
 'ENSG00000001629': 5.861882250538956,
 'ENSG00000001630': 3.204867402355354,
 'ENSG00000001631': 2.6897286606762094,
 'ENSG00000002016': 3.117751364396336,
 'ENSG00000002330': 2.0478968854073893,
 'ENSG00000002549': 2.4413285851852984,
 'ENSG00000002586': 5.126741841744427,
 'ENSG00000002587': 2.711258988917515,
 'ENSG00000002726': 3.9698528284408083,
 'ENSG000000

In [5]:
with open("/u/home/s/schhina/scratch_backup/Geneformer/geneformer/token_dictionary.pkl", "rb") as f:
    pkl = pickle.load(f)

In [6]:
len(pkl)

25426

In [7]:
inverse_pkl = {ind: name for name, ind in pkl.items()}

In [9]:
cell_1_ensembl_ids = [inverse_pkl[i] for i in ids]

In [10]:
len(cell_1_ensembl_ids)

2048

In [20]:
cell_1_ensembl_ids

['ENSG00000177989',
 'ENSG00000105519',
 'ENSG00000087086',
 'ENSG00000187608',
 'ENSG00000188010',
 'ENSG00000080824',
 'ENSG00000197956',
 'ENSG00000126432',
 'ENSG00000161011',
 'ENSG00000147400',
 'ENSG00000178980',
 'ENSG00000130066',
 'ENSG00000106211',
 'ENSG00000120306',
 'ENSG00000272196',
 'ENSG00000167996',
 'ENSG00000186973',
 'ENSG00000230989',
 'ENSG00000105258',
 'ENSG00000006327',
 'ENSG00000204387',
 'ENSG00000151632',
 'ENSG00000160213',
 'ENSG00000160472',
 'ENSG00000170315',
 'ENSG00000166595',
 'ENSG00000088986',
 'ENSG00000090273',
 'ENSG00000160345',
 'ENSG00000115415',
 'ENSG00000104904',
 'ENSG00000172183',
 'ENSG00000100380',
 'ENSG00000115963',
 'ENSG00000185043',
 'ENSG00000146425',
 'ENSG00000196072',
 'ENSG00000226976',
 'ENSG00000125868',
 'ENSG00000196141',
 'ENSG00000096384',
 'ENSG00000151929',
 'ENSG00000115541',
 'ENSG00000247077',
 'ENSG00000180921',
 'ENSG00000117228',
 'ENSG00000197170',
 'ENSG00000125534',
 'ENSG00000205155',
 'ENSG00000166592',


In [60]:
cell_2_ensembl_ids = [inverse_pkl[i] for i in ids_2]

In [62]:
inter = (set(cell_1_ensembl_ids).intersection(set(cell_2_ensembl_ids)))

In [63]:
len(inter)

859

In [46]:
raw_genes = adata.to_df().head().iloc[0]

In [52]:
raw_genes = {e_id.split(".")[0]: v for e_id, v in raw_genes.to_dict().items()}

In [55]:
expr_lvls = [raw_genes[i] for i in cell_1_ensembl_ids]

In [11]:
fp = "/u/scratch/s/schhina/geneformer_embed/output_prefix_live.csv"
df = pd.read_csv(fp, header=0, index_col=[0])
ce_fp = "/u/scratch/s/schhina/geneformer_embed/output_prefix_all.csv"
df2 = pd.read_csv(ce_fp, header=[0], index_col=[0])

In [12]:
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,246,247,248,249,250,251,252,253,254,255
ENSG00000000003,-0.323353,1.814615,0.145999,-0.71589,3.115681,-0.486772,0.705424,-0.08626,0.236463,-0.526529,...,0.432412,-0.047683,-0.053445,-0.470638,-0.074031,2.894128,0.25929,-1.216643,1.18334,-2.722787
ENSG00000000419,0.072225,2.895637,0.392951,0.155172,-0.710113,0.30601,0.609706,-0.01828,1.60332,-0.713971,...,0.256384,0.203404,-0.271855,-0.048602,-0.551704,1.119836,-0.567077,1.355696,0.918637,-0.822258
ENSG00000000457,-0.714346,0.606344,-0.518253,0.581587,1.248615,-0.360164,0.29711,0.159605,0.776134,0.098848,...,-0.320856,1.157132,0.192773,-0.338647,0.293212,-0.953268,-0.686376,-0.813667,0.225207,-0.730619
ENSG00000000460,-1.265316,-0.569788,-1.335897,0.150945,-0.758533,-0.100282,0.52721,-0.103953,0.020362,-1.286357,...,0.52565,0.241151,-0.317504,0.187679,-0.23571,0.216157,-0.954598,0.54529,1.326163,-0.821527
ENSG00000000938,-0.367541,0.061623,-0.461874,-0.189678,0.261861,-0.295954,0.689981,-0.319234,-0.285984,-0.150739,...,-0.398672,-0.106047,0.410535,-1.01705,0.438342,3.214784,-0.714355,-0.640027,0.236771,-1.210339


In [13]:
df2.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,247,248,249,250,251,252,253,254,255,stimulation
0,0.131183,1.254158,0.03485,-0.538065,-0.763984,0.700918,0.595461,-0.55062,1.217002,-1.047136,...,-0.257504,-1.002568,0.501713,0.150006,-1.015681,-0.579915,-0.326472,-0.147749,-0.497381,act
1,0.271865,1.078847,-0.047733,-0.610135,-0.896876,0.740286,0.404384,-0.694898,1.527199,-1.226817,...,-0.306675,-1.090134,0.716873,0.235993,-1.39313,-0.672959,-0.280913,-0.431451,-0.275925,act
2,0.217396,1.046113,0.066718,-0.547242,-0.757857,0.844193,0.539904,-0.607289,1.420038,-1.333719,...,-0.309625,-0.951596,0.608716,0.201681,-1.196268,-0.605729,-0.414629,-0.340751,-0.292053,act
3,0.253963,0.979613,0.034751,-0.437916,-0.921687,0.84559,0.59659,-0.482216,1.327709,-1.254548,...,-0.358469,-0.952724,0.65953,0.13448,-1.170608,-0.455934,-0.350335,-0.269506,-0.539805,act
4,0.23613,0.964443,0.157049,-0.393103,-0.900009,0.73925,0.655262,-0.672821,1.219045,-1.131613,...,-0.138827,-0.716785,0.400466,0.126782,-0.915382,-0.38073,-0.28643,0.030887,-0.71408,act


In [79]:
type(df.loc)

pandas.core.indexing._LocIndexer

In [87]:
total = 0
for i, e_id in enumerate(cell_1_ensembl_ids):
    if e_id not in df.index: 
#         print(f"skipped: {e_id}")
        continue
    if i == 0:
        cell_1_embed =  df.loc[e_id]
    else:
        cell_1_embed += df.loc[e_id]
    total += 1

In [89]:
cell_1_embed/total

0     -0.103242
1      0.866334
2     -0.061240
3     -0.346104
4     -0.121889
         ...   
251    0.012666
252   -0.272956
253   -0.458098
254    0.372274
255   -1.163650
Name: ENSG00000177989, Length: 256, dtype: float64

In [73]:
df.loc['ENSG00000000003'] + df.loc['ENSG00000000003']

0     -0.646705
1      3.629230
2      0.291998
3     -1.431779
4      6.231362
         ...   
251    5.788257
252    0.518579
253   -2.433286
254    2.366681
255   -5.445575
Name: ENSG00000000003, Length: 256, dtype: float64

In [6]:
from geneformer import perturber_utils as pu

In [7]:
model_type = "CellClassifier"
num_classes = 2
model_directory = "/u/scratch/s/schhina/geneformer_output/240509125208/240509_geneformer_cellClassifier_cm_classifier_test/ksplit1"

model = pu.load_model(model_type, num_classes, model_directory, mode="eval")

In [8]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12

In [22]:
input_id_tensor = torch.load("/u/scratch/s/schhina/temp_files/input_ids.pt")

In [23]:
atten_mask_tensor = torch.load("/u/scratch/s/schhina/temp_files/atten_mask.pt")

In [31]:
labels_tensor = torch.load("/u/scratch/s/schhina/temp_files/labels.pt")

AttributeError: 'list' object has no attribute 'grad'

In [26]:
atten_mask_tensor.grad

In [16]:
atten_mask_tensor[0]

tensor([1, 1, 1,  ..., 1, 1, 1], device='cuda:0')

In [None]:
output = model(input_ids=input_id_tensor[:2].to("cuda"), attention_mask=atten_mask_tensor[:2],)

OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 0 has a total capacity of 10.91 GiB of which 118.06 MiB is free. Including non-PyTorch memory, this process has 794.00 MiB memory in use. Process 29778 has 10.02 GiB memory in use. Of the allocated memory 555.15 MiB is allocated by PyTorch, and 76.85 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [11]:
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)

padded_batch = preprocess_classifier_batch(
            batch_evalset, max_evalset_len, label_name
        )
padded_batch.set_format(type="torch")

input_data_batch = padded_batch["input_ids"]
attn_msk_batch = padded_batch["attention_mask"]
label_batch = padded_batch[label_name]

outputs = model(input_ids=input_data_batch.to("cuda"), attention_mask=attn_msk_batch.to("cuda"), labels=label_batch.to("cuda"))

TypeError: unhashable type: 'slice'