In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from nucleotide_transformer.pretrained import get_pretrained_model

In [2]:
model_name = '500M_human_ref'

In [3]:
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=32,
    # If the progress bar gets stuck at the start of the model wieghts download, 
    # you can set verbose=False to download without the progress bar.
    verbose=False
)
forward_fn = hk.transform(forward_fn)

In [4]:
token_dict = tokenizer.vocabulary

In [5]:
listA = []
listC = []
listG = []
listT = []

lists = {'A': listA, 'C': listC, 'G': listG, 'T': listT}

for i in range(4107):
    if len(token_dict[i]) < 5:
        continue  

    char = token_dict[i][3]
    if char in lists:
        lists[char].append(i)


In [6]:
import numpy as np
import pandas as pd
from Bio import SeqIO
import os

#file_path = '/blue/xiaofan/chenyuanhan/private/gnomad.v4.1.proximity.txt' 
#file_path = '/blue/xiaofan/chenyuanhan/private/gnomad.v4.1.intergenic.txt' 
file_path = '/blue/xiaofan/chenyuanhan/private/gnomad.v4.1.exon.txt' 
    
data = pd.read_csv(file_path, sep='\t',header=None)
data.columns = ['CHROM', 'POS', 'REF', 'ALT','INFO']
data['ROWID'] = data.index

filtered_df = data[data['REF'].str.len() == 1]
filtered_df.loc[:, 'CHROM'] = filtered_df['CHROM'].str.replace('chr', '', regex=False)

os.chdir('/blue/xiaofan/chenyuanhan/data')

sequences = list(SeqIO.parse("genome.hg38rg.fa", "fasta"))

def extract_sequence_segment(seq_id, start, end, sequences):
    for seq_record in sequences:
        if seq_record.id == seq_id:
            segment = str(seq_record.seq[start:end])
            return segment
    return None

for i in range(1,len(filtered_df)+1):
    seq_id = str(filtered_df.loc[i, 'CHROM'])  
    start = int(filtered_df.loc[i, 'POS'])-64  
    end = start + 128
    
    segment = extract_sequence_segment(seq_id, start, end, sequences)
    
    filtered_df.loc[i, 'SEQUENCE'] = segment
    
dataframe = filtered_df.reset_index(drop=True)
dataframe

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df.loc[i, 'SEQUENCE'] = segment


Unnamed: 0,CHROM,POS,REF,ALT,INFO,ROWID,SEQUENCE
0,1,962377,A,G,missense_variant,1,GACCCTCCCCAGATCTCAGGTCTGAGGACCCCCACTCCCAGGTTCT...
1,1,973540,A,G,missense_variant,2,TTTCTCCCACCTCTGCCCTGCAGCTGCACAGGCTGAGCCTGGAGAG...
2,1,1040653,T,C,splice_polypyrimidine_tract_variant&intron_var...,3,GGCAAGGTCTCTCAGGCTTGTGGACGTGGGTACGGGCGTCTCGGCA...
3,1,1063066,C,G,splice_polypyrimidine_tract_variant&intron_var...,4,ACCGTGGAAACAATGAGGGAGGTTTGTGTGGGGCCAGATTCCTCCT...
4,1,1091518,A,G,missense_variant,5,GTGAGGGGGGCACCTACCGTGTTCTCCATGGACTTGCTGGCGACTC...
...,...,...,...,...,...,...,...
9995,22,50546776,A,C,missense_variant,9996,GGCCAGGCACTGCCCTCCTGGGCAGGAGCGAAGCAGGGGGGATGTC...
9996,22,50547003,G,C,missense_variant,9997,CGGATGGACGCTGGCTCGGGAGACAGAGCCCGCCGCCCCCGGAAAC...
9997,22,50573625,A,G,missense_variant,9998,TGAGGTGGGGAGGGGTCGTCCAGGATCCTCTGGAACTGCATCTCCA...
9998,22,50595180,C,T,splice_polypyrimidine_tract_variant&intron_var...,9999,TAACAATGTTGTAAATGCCATGATTTTGGATGTCCTGTACCATCAG...


In [7]:
database = pd.DataFrame(columns=[f"token_{i}" for i in range(32)] + ["REF"] + ["ROWID"])

for i in range(len(dataframe)):
    ref = str(dataframe.loc[i, 'REF'])
    ID = dataframe.loc[i, 'ROWID']
    sequences = str(dataframe.loc[i,'SEQUENCE'])
    
    tokens_ids = [b[1] for b in tokenizer.batch_tokenize([sequences])]
    #print(len(tokens_ids)) 
    
    database.loc[i] = tokens_ids[0] + [ref] + [ID] 

In [8]:
database

Unnamed: 0,token_0,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,...,token_24,token_25,token_26,token_27,token_28,token_29,token_30,token_31,REF,ROWID
0,3,3245,2727,412,3947,974,2701,2707,1443,163,...,1,1,1,1,1,1,1,1,A,1
1,3,1386,2605,2542,1938,1932,4000,3747,828,3751,...,1,1,1,1,1,1,1,1,A,2
2,3,3975,3485,2301,1920,2947,1219,2922,3982,2514,...,1,1,1,1,1,1,1,1,T,3
3,3,699,3084,467,3321,1915,4076,3166,1693,4070,...,1,1,1,1,1,1,1,1,C,4
4,3,3539,4092,2638,3546,1675,3227,2559,622,720,...,1,1,1,1,1,1,1,1,A,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,3,4007,3627,2718,2044,3903,231,4096,1903,2734,...,1,1,1,1,1,1,1,1,A,9996
9996,3,3019,3261,3999,3894,830,2994,2739,43,3247,...,1,1,1,1,1,1,1,1,G,9997
9997,3,1857,4087,4063,1683,426,1990,1930,1680,1696,...,1,1,1,1,1,1,1,1,A,9998
9998,3,1060,1889,34,2165,1408,1901,3372,1601,3269,...,1,1,1,1,1,1,1,1,C,9999


In [9]:
mask_token_id = tokenizer.mask_token_id
database['token_11'] = mask_token_id

In [10]:
token_columns = [f"token_{i}" for i in range(32)]
tokens_df = database[token_columns]

tokens = jnp.asarray(tokens_df.to_numpy(), dtype=jnp.int32)

2024-11-05 16:39:34.515649: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.77). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [11]:
%%time
random_key = jax.random.PRNGKey(0)

outs = forward_fn.apply(parameters, random_key, tokens)

Before attention blocks: [[[-0.26703775 -0.3988906   0.9669867  ... -1.4870129   0.34861523
   -0.82291   ]
  [-0.49522611  1.0295012  -2.1168013  ... -1.032915    0.42739213
   -0.14097512]
  [ 0.33072022  2.6063714   1.4318283  ...  0.67101604 -0.36057836
   -1.3567348 ]
  ...
  [ 1.7735257  -0.22743307  1.6646428  ...  1.5046847  -1.9633977
   -0.22115314]
  [ 1.7735257  -0.22743307  1.6646428  ...  1.5046847  -1.9633977
   -0.22115314]
  [ 1.7735257  -0.22743307  1.6646428  ...  1.5046847  -1.9633977
   -0.22115314]]

 [[-0.26703775 -0.3988906   0.9669867  ... -1.4870129   0.34861523
   -0.82291   ]
  [ 0.34997982  0.58328557 -1.8483988  ... -0.7058597   0.2408331
   -1.5861366 ]
  [-0.22434735 -0.056517    0.6739899  ...  1.6259072   0.57565045
   -2.2825513 ]
  ...
  [ 1.7735257  -0.22743307  1.6646428  ...  1.5046847  -1.9633977
   -0.22115314]
  [ 1.7735257  -0.22743307  1.6646428  ...  1.5046847  -1.9633977
   -0.22115314]
  [ 1.7735257  -0.22743307  1.6646428  ...  1.5046847 

In [12]:
logits = outs["logits"]

In [13]:
import torch
import torch.nn.functional as F

probabilities = jax.nn.softmax(logits, axis=-1)

In [14]:
lists_jax = {
    'A': jnp.array(lists['A']),
    'C': jnp.array(lists['C']),
    'G': jnp.array(lists['G']),
    'T': jnp.array(lists['T'])
}

pro_list = []

In [15]:
for k in range(len(database)):
    char = database.loc[k, 'REF']
    indices = lists_jax[char]  

    probs = probabilities[k, 11, indices]
    pro = jnp.sum(probs)  

    pro_list.append(pro)

pro_array = jnp.array(pro_list)

In [16]:
log_likelihood = jnp.log(pro_array)

nll = -log_likelihood.mean()

ppl = jnp.exp(nll)
print(f"Perplexity (PPL): {ppl}")

Perplexity (PPL): 4.455536365509033


In [17]:
#proximity:4.1595635414123535
#intergenic:4.128282070159912
#exon:4.455536365509033