In [None]:
import os
import csv
import jax
import haiku as hk
import numpy as np
import pandas as pd
import jax.numpy as jnp
import matplotlib.pyplot as plt

import multiprocessing as mp

from datetime import datetime, timedelta

from nucleotide_transformer.pretrained import get_pretrained_model

# device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device

from datasets import Dataset, DatasetDict

print(jax.devices())

In [None]:
try:
    import nucleotide_transformer
except:
    !pip install numpy==1.23.5
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [None]:
comp = {'A':1, 'C':2, 'G':3, 'T':4}

#@title Select a model
#@markdown ---
model_name = '50M_multi_species_v2'
model_name = '500M_human_ref'
#@markdown ---

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=32,
    # If the progress bar gets stuck at the start of the model wieghts download,
    # you can set verbose=False to download without the progress bar.
    verbose=False
)
forward_fn = hk.transform(forward_fn)

In [None]:
def append_data(final_df, sub_df, sub_embedding_df):
    
    sub_df=sub_df.drop(columns=['sequence'])
    sub_df = sub_df.reset_index(drop=True)
    sub_embedding_df = sub_embedding_df.reset_index(drop=True)
    
    sub_final_df = pd.concat([sub_embedding_df, sub_df],  axis=1, ignore_index=True)
    final_df = pd.concat([final_df, sub_final_df],  axis=0, ignore_index=True) 
    
    return final_df

In [None]:
def get_tokens(df):
    sequences = []
    # refs=[]
    for index, row in df.iterrows():      
        subsequence = row['sequence']
        if 'N' in subsequence:
            print("The character 'N' is present in the string.")
    
        sequences.append(subsequence)
        # refs.append(ref)
    
    print (len(sequences))
    # sequences

    try:
        tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
        tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
        tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)   

    except Exception as e:
        print(f"exception caught: {e}"+str(row['CHROM'])+'-'+str(row['START'])+'-'+str(row['SIZE']))
        tokens=None
        
    return tokens

In [None]:
def get_embeddings(tokens):

    # Initialize random key
    random_key = jax.random.PRNGKey(0)
    
    # Infer
    outs = forward_fn.apply(parameters, random_key, tokens)    
    # print(outs.keys())
    # print(outs["embeddings_20"].shape)
    # outs["embeddings_20"]
    
    # my_embedding=outs["embeddings_20"][:,16,:]
    my_embedding=outs["embeddings_20"][:,0,:]
    
    my_embedding.shape
    column_names = [f'{i}' for i in range(0, my_embedding.shape[1])]
    embedding_df = pd.DataFrame(my_embedding, columns=column_names)
    return embedding_df

### load dna sequence data file

In [None]:
import pandas as pd

# pathogenecity_type='noncoding'
pathogenecity_type='missense'

df=pd.read_csv('dna_segment_'+pathogenecity_type+'.csv')
# df

In [None]:
%%time

sub_df = pd.DataFrame()    
final_df = pd.DataFrame()
segment=2000


# csv_Filename = './homo_sapiens_nt_embedding.csv'
# if os.path.exists(csv_Filename):
#     os.remove(csv_Filename)

max_length= 186

now = datetime.now()
formatted_time = now.strftime("%y-%m-%d-%H-%M-%S")
csv_filename = './pathogenecity_nt_'+pathogenecity_type+'_'+formatted_time+'.csv'


cnt=0
for index, row in df.iterrows():
    cnt+=1
    sub_df = sub_df.drop(sub_df.index)
    
    if (cnt % segment==0):
        sub_df = df.iloc[cnt-segment:cnt]
        sub_tokens = get_tokens(sub_df)
        sub_embedding_df = get_embeddings(sub_tokens)       

        final_df = append_data(final_df, sub_df, sub_embedding_df)
        
        sub_df = sub_df.reset_index(drop=True)
        print(f"complete batch...... {cnt}")


print(f"last index...... {(cnt)}")
sub_df = df.iloc[cnt-(cnt % segment):cnt]
sub_tokens = get_tokens(sub_df)
sub_embedding_df = get_embeddings(sub_tokens)        
final_df = append_data(final_df, sub_df, sub_embedding_df)

final_df.to_csv(csv_filename, sep=',', index=False,  header=True, na_rep='NaN')

### Load CSV File

In [None]:
import pandas as pd
# csv_filename='./pathogenecity_nt_missense.csv'
def load_embedding_file(csv_filename):

    df=pd.read_csv(csv_filename)
    
    column_names = [f'{i}' for i in range(0, df.shape[1]-1)]
    column_names.extend(['y'])
    
    df.columns = column_names
    return df

df = load_embedding_file(csv_filename)
df.head(5)