In [None]:
import os
import csv
import jax
import haiku as hk
import numpy as np
import pandas as pd

import jax.numpy as jnp
import matplotlib.pyplot as plt

import multiprocessing as mp

from nucleotide_transformer.pretrained import get_pretrained_model

from datasets import Dataset, DatasetDict

%run preprocess_utility.py

print(jax.devices())


In [None]:
try:
    import nucleotide_transformer
except:
    !pip install numpy==1.23.5
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [None]:
comp = {'A':1, 'C':2, 'G':3, 'T':4}

#@title Select a model
#@markdown ---
model_name = '500M_human_ref'
#@markdown ---

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=32,
    # If the progress bar gets stuck at the start of the model wieghts download,
    # you can set verbose=False to download without the progress bar.
    verbose=False
)
forward_fn = hk.transform(forward_fn)


In [None]:
fasta_file = "../genome.hg38rg.fa"
chrom_sequences = read_fasta(fasta_file)
# print(chrom_sequences.keys()) 

In [None]:
datafile='methylation'

### Load datafile file

In [None]:
data_filename = '../../datasets/task05-methylation/GSM6637962_CpG_coverage20_GRCh38.bed.gz'     
df = preprocess_datafile(data_filename)
df

## Step 1: Obtain sequence section for all data and attach to dataframe df

In [None]:
%%time

import multiprocessing as mp


def get_sequencesection(row):

    chrom=row['CHROM']
    pos_start=row['START']
    y=row['y']
    if row['SIZE'] % 6 == 0:
        length = row['SIZE']
    else:
        length = 6 * round(row['SIZE'] / 6)
    
    subsequence = chrom_sequences[str(chrom)][pos_start:pos_start + length]
    return subsequence


def parallelize_dataframe(df, func, num_partitions=100):
    # Split DataFrame into smaller chunks
    df_split = np.array_split(df, num_partitions)
    pool = mp.Pool(num_partitions)
    df = pd.concat(pool.map(func, df_split))
    pool.close()
    pool.join()
    return df

# Define a wrapper function to apply the get_sequencesection function
def apply_get_sequencesection(df):
    df['dna'] = df.apply(get_sequencesection, axis=1)
    return df

df_with_dna = df
    
df_with_dna = parallelize_dataframe(df_with_dna, apply_get_sequencesection)

df_with_dna= df_with_dna.drop(columns=['CHROM','START','SIZE'])
df_with_dna

In [None]:
num_parallel = 10
!export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=num_parallel

## THe below version complete noncoding 95760 in 1 min 29 second 

In [None]:
%%time

import dask.dataframe as dd
import pandas as pd
import jax.numpy as jnp
import numpy as np
from dask.diagnostics import ProgressBar

# Set number of parallel partitions
num_parallel = 10
batch_size = 10000  # Set batch size to 64

# Vectorized tokenization function
def vectorized_tokenizer(subsequences):
    tokens_ids = [b[1] for b in tokenizer.batch_tokenize(subsequences)]
    return jnp.asarray(tokens_ids, dtype=jnp.int32)

# Vectorized embedding function
def vectorized_embedding(tokens):
    random_key = jax.random.PRNGKey(0)
    outs = forward_fn.apply(parameters, random_key, tokens)
    return outs["embeddings_20"][:, 0, :]  # Return embeddings

# Tokenization and embedding combined in a batch-wise function
def truncate_sequences(sequences, max_length=32):
    """Truncate each sequence to the max_length."""
    return [seq[:max_length] for seq in sequences]

def process_batch(subsequences):
    truncated_sequences = truncate_sequences(subsequences, max_length=32)
    tokens = vectorized_tokenizer(truncated_sequences)  # Tokenize the subsequences
    embeddings = vectorized_embedding(tokens)  # Get embeddings
    return embeddings




# Function to apply on each Dask partition
def apply_get_tokens_dask(df):
    # Create an empty list to collect embeddings
    all_embeddings = []

    # Process subsequences in batches
    for start in range(0, len(df), batch_size):
        end = min(start + batch_size, len(df))
        subsequences = df['dna'].values[start:end]  # Get the current batch
        embeddings = process_batch(subsequences)  # Process the current batch
        all_embeddings.append(embeddings)

    # Concatenate all embeddings into a single array
    all_embeddings = jnp.concatenate(all_embeddings, axis=0)

    # Create a DataFrame from embeddings
    embedding_df = pd.DataFrame(all_embeddings, columns=[str(i) for i in range(1, 1281)])

    # Concatenate the original DataFrame with the embedding DataFrame
    df = pd.concat([df.reset_index(drop=True), embedding_df.reset_index(drop=True)], axis=1)

    return df

# Load the CSV file into a Pandas DataFrame
# df = pd.read_csv(datafile + '_with_dna.csv')
df = df_with_dna

# Convert the Pandas DataFrame to a Dask DataFrame
ddf = dd.from_pandas(df, npartitions=num_parallel)  # Adjust 'npartitions' based on resources

# Create metadata for Dask to understand the structure of the DataFrame
meta = df.copy()
for i in range(1, 1281):
    meta[str(i)] = np.ndarray  # Prepare for 1280 embedding columns

# Apply the function in parallel using Dask
ddf = ddf.map_partitions(apply_get_tokens_dask, meta=meta)

# Compute the result with progress tracking
with ProgressBar():
    df_top = ddf.compute()

df_top= df_top.drop(columns=['dna'])
df_top= swapfirst2last(df_top)

# Save the results to CSV
df_top.to_csv(datafile + '_nt_embedding.csv', index=False)

# Display the DataFrame
print(df_top.head())

In [None]:
df=pd.read_csv(datafile+'_nt_embedding.csv')
df