In [1]:
# !pip install --upgrade jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# !pip install jax jaxlib==0.1.87+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# using Kernel PyTorch-2.2.0

In [5]:
import os
import csv
import jax
import haiku as hk
import numpy as np
import pandas as pd
import pandas as pd
import jax.numpy as jnp
import matplotlib.pyplot as plt
import multiprocessing as mp
from datasets import Dataset, DatasetDict
from nucleotide_transformer.pretrained import get_pretrained_model

%run preprocess_utility.py

print(jax.devices())

# device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device

[CudaDevice(id=0)]


In [6]:
try:
    import nucleotide_transformer
except:
    !pip install numpy==1.23.5
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [7]:
comp = {'A':1, 'C':2, 'G':3, 'T':4}

#@title Select a model
#@markdown ---
model_name = '500M_human_ref'
#@markdown ---

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=32,
    # If the progress bar gets stuck at the start of the model wieghts download,
    # you can set verbose=False to download without the progress bar.
    verbose=False
)
forward_fn = hk.transform(forward_fn)

In [8]:
fasta_file = "../genome.hg38rg.fa"
chrom_sequences = read_fasta(fasta_file)

In [9]:
def get_subsequence(sequences, chrom_name, start_pos, length):
    
    if chrom_name in sequences:
        sequence = sequences[chrom_name]
        subsequence = sequence[start_pos:start_pos + length]
        return subsequence
    else:
        raise ValueError(f"Chromosome '{chrom_name}' not found in the FASTA file.")

In [11]:
def append_data(final_df, sub_df, sub_embedding_df):
    
    sub_df = sub_df.reset_index(drop=True)
    sub_embedding_df = sub_embedding_df.reset_index(drop=True)
    
    sub_final_df = pd.concat([sub_embedding_df, sub_df],  axis=1, ignore_index=True)
    final_df = pd.concat([final_df, sub_final_df],  axis=0, ignore_index=True) 
    
    return final_df

In [12]:
def get_tokens(df):
    sequences = []
    for index, row in df.iterrows():      
        chrom=row['CHROM']
        pos_start=row['START']
        rowid=row['ROWID']
        y=row['y']
        if row['SIZE'] % 6 == 0:
            length = row['SIZE']
        else:
            length = 6 * round(row['SIZE'] / 6)
    
        subsequence = get_subsequence(chrom_sequences, chrom, pos_start, length)
        if 'N' in subsequence:
            print("The character 'N' is present in the string.")
    
        sequences.append(subsequence)

    try:
        tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
        tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
        tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)   

    except Exception as e:
        print(f"exception caught: {e}"+str(row['CHROM'])+'-'+str(row['START'])+'-'+str(row['SIZE']))
        tokens=None
        
    return tokens

In [13]:
def get_embeddings(tokens):

    # Initialize random key
    random_key = jax.random.PRNGKey(0)
    
    # Infer
    outs = forward_fn.apply(parameters, random_key, tokens)    
    # print(outs.keys())    

    my_embedding=outs["embeddings_20"][:,0,:]
    
    my_embedding.shape
    column_names = [f'{i}' for i in range(0, my_embedding.shape[1])]
    embedding_df = pd.DataFrame(my_embedding, columns=column_names)
    return embedding_df

### Process datafile 

In [14]:
max_length= 186

datafile_path = '../../datasets/task03-genomic-regions/Homo_sapiens.GRCh38.109.txt.gz'  
df = preprocess_home_sapiens_datafile(datafile_path)

df.loc[df['SIZE'] > max_length, 'END'] = df['START'] + max_length
df.loc[df['SIZE'] > max_length, 'SIZE'] = max_length

df.drop(columns=['END','TYPE','CLUSTER'], inplace=True)

df

Unnamed: 0,CHROM,START,SIZE,ROWID,y
0,7,116953541,150,0,0
1,12,54241755,150,1,0
2,20,38033948,150,2,0
3,15,82659560,150,3,0
4,9,131579497,150,4,0
...,...,...,...,...,...
33713,9,124692442,109,33713,6
33714,9,128244721,109,33714,6
33715,X,45746157,109,33715,6
33716,X,45747015,109,33716,6


In [16]:
%%time

sub_df = pd.DataFrame()    
final_df = pd.DataFrame()
segment=2000


csv_Filename = './homo_sapiens_nt_embedding.csv'
if os.path.exists(csv_Filename):
    os.remove(csv_Filename)


cnt=0
for index, row in df.iterrows():
    cnt+=1
    sub_df = sub_df.drop(sub_df.index)
    
    if (cnt % segment==0):
        sub_df = df.iloc[cnt-segment:cnt]
        sub_tokens = get_tokens(sub_df)
        sub_embedding_df = get_embeddings(sub_tokens)       
        final_df = append_data(final_df, sub_df, sub_embedding_df)
        
        sub_df = sub_df.reset_index(drop=True)
        print(f"complete batch...... {cnt}")


print(f"last index...... {(cnt)}")
sub_df = df.iloc[cnt-(cnt % segment):cnt]
sub_tokens = get_tokens(sub_df)
sub_embedding_df = get_embeddings(sub_tokens)        
final_df = append_data(final_df, sub_df, sub_embedding_df)


column_names = [f'{i}' for i in range(0, final_df.shape[1]-5)]
column_names.extend(['CHROM', 'START', 'SIZE', 'ROWID',  'y'])
final_df.columns = column_names
final_df = final_df.drop(columns=['CHROM','START','SIZE'])

final_df.to_csv(csv_Filename, sep=',', index=False,  header=True, na_rep='NaN')

final_df

Before attention blocks: [[[-0.2133187  -0.37896746  0.9894879  ... -1.4485674   0.29604977
   -0.8132925 ]
  [-0.12643832  0.8992672  -2.0135226  ... -0.77297044  0.0352236
    0.16593635]
  [-0.15411776  0.65062594  0.03419471 ...  0.10655487 -0.7579897
   -1.9504462 ]
  ...
  [ 1.7484777  -0.23501392  1.5967557  ...  1.4864607  -1.9107313
   -0.23773763]
  [ 1.7484777  -0.23501392  1.5967557  ...  1.4864607  -1.9107313
   -0.23773763]
  [ 1.7484777  -0.23501392  1.5967557  ...  1.4864607  -1.9107313
   -0.23773763]]

 [[-0.2133187  -0.37896746  0.9894879  ... -1.4485674   0.29604977
   -0.8132925 ]
  [-2.7441926   0.48109138 -0.6968985  ...  0.55498445 -0.17344053
   -0.36260185]
  [-1.3409749   0.749942    1.5137826  ...  0.5966591   0.6111201
   -0.8521586 ]
  ...
  [ 1.7484777  -0.23501392  1.5967557  ...  1.4864607  -1.9107313
   -0.23773763]
  [ 1.7484777  -0.23501392  1.5967557  ...  1.4864607  -1.9107313
   -0.23773763]
  [ 1.7484777  -0.23501392  1.5967557  ...  1.4864607  -

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1272,1273,1274,1275,1276,1277,1278,1279,ROWID,y
0,-0.324386,10.222736,0.472543,8.665915,-9.160305,3.684612,-9.674951,9.519714,1.577347,39.271915,...,-8.166821,-4.011395,6.786868,-5.835679,-18.031162,-12.672304,-4.555680,0.773324,0,0
1,0.215155,13.518826,3.466785,9.049223,-9.875353,-0.560636,-9.800148,9.256715,1.785120,38.618298,...,-5.242308,-1.541120,4.458237,-3.372420,-20.867760,-16.494043,-8.260738,3.139253,1,0
2,-2.372187,13.168259,2.926776,11.876443,-8.724695,-0.121530,-5.909375,7.047293,3.658315,40.302120,...,-5.853244,-3.642528,6.743409,-3.631339,-20.687992,-14.483265,-3.643872,5.543061,2,0
3,-1.159037,17.319054,1.178315,5.559775,-7.308437,3.267867,-8.044428,12.884748,2.865854,38.994896,...,-6.229579,2.583466,3.808581,-1.381266,-22.413837,-11.401751,-0.634754,1.241418,3,0
4,-0.237370,14.949961,2.041708,7.604409,-8.435667,-2.641743,-9.003962,9.609196,2.431862,44.436890,...,-5.326154,-0.373121,6.259539,-1.734646,-21.528700,-13.961947,-7.830993,-0.421597,4,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
33503,-1.043825,16.685047,4.881373,9.886226,-9.486728,3.509766,-5.895805,13.012412,4.322705,41.282146,...,-6.551616,1.186344,2.800662,-3.488761,-20.997541,-11.911755,-5.680223,-0.126435,33713,6
33504,0.734016,14.266375,4.807714,11.782293,-8.877528,7.116413,-4.754360,9.617087,5.737248,42.193779,...,-0.346061,3.357471,3.357189,-1.302580,-23.649208,-10.654892,-1.484709,1.113938,33714,6
33505,0.622062,15.830442,2.136884,10.288782,-7.188518,1.998456,-7.231447,10.671149,1.026647,42.652550,...,-5.106775,-2.922949,5.629794,-5.241525,-22.151175,-12.723240,-3.804622,-0.569391,33715,6
33506,3.573058,14.700028,0.162089,8.332699,-9.245561,5.044686,-6.244898,13.499894,2.904738,41.875347,...,-1.514194,2.692098,5.354314,-2.944794,-20.697811,-5.762572,-2.856724,1.189772,33716,6


### Load CSV File

In [26]:
def load_embedding_file(csv_filename):

    df=pd.read_csv(csv_filename)
    return df

df = load_embedding_file('./homo_sapiens_nt_embedding.csv')
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1272,1273,1274,1275,1276,1277,1278,1279,ROWID,y
0,-0.324386,10.222736,0.472543,8.665916,-9.160305,3.684612,-9.674951,9.519714,1.577347,39.271915,...,-8.166821,-4.011395,6.786868,-5.835679,-18.031162,-12.672304,-4.555680,0.773324,0,0
1,0.215155,13.518826,3.466785,9.049223,-9.875353,-0.560636,-9.800148,9.256715,1.785120,38.618298,...,-5.242308,-1.541120,4.458237,-3.372420,-20.867760,-16.494043,-8.260738,3.139253,1,0
2,-2.372187,13.168259,2.926776,11.876443,-8.724695,-0.121530,-5.909375,7.047293,3.658315,40.302120,...,-5.853244,-3.642528,6.743409,-3.631339,-20.687992,-14.483265,-3.643872,5.543061,2,0
3,-1.159037,17.319054,1.178315,5.559775,-7.308437,3.267867,-8.044428,12.884747,2.865854,38.994896,...,-6.229579,2.583466,3.808581,-1.381266,-22.413837,-11.401751,-0.634754,1.241418,3,0
4,-0.237370,14.949961,2.041708,7.604409,-8.435667,-2.641743,-9.003962,9.609196,2.431862,44.436890,...,-5.326154,-0.373121,6.259539,-1.734646,-21.528700,-13.961947,-7.830993,-0.421597,4,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
34786,0.592047,12.480164,-2.207586,8.692018,-5.893743,7.420892,-12.051699,13.012270,3.701627,38.131466,...,-2.668526,-5.377493,3.924320,-3.992713,-20.384758,-14.261722,-1.868958,-1.269481,34832,6
34787,-1.070907,10.931201,3.266546,8.628684,-5.165432,6.097012,-6.288345,11.052012,5.249962,39.650800,...,-3.123629,0.197064,4.605883,-1.002314,-21.945362,-11.409150,-4.581974,2.272008,34833,6
34788,-0.759760,9.573089,7.653216,-0.243169,-8.327825,6.982180,-5.615465,10.963967,4.588292,42.244380,...,-6.250230,-1.333322,2.300076,-2.416507,-21.926449,-10.241255,-3.423506,1.949625,34834,6
34789,-0.543251,11.539721,3.129790,9.273889,-4.417173,5.724512,-6.320546,11.377898,4.870200,39.562275,...,-3.272682,0.535543,4.866612,-0.382214,-21.396807,-11.046916,-5.265911,2.063335,34835,6
