### Generate tokenizer and model

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load pre-trained Caduceus model and tokenizer
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)

# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

CaduceusForMaskedLM(
  (caduceus): Caduceus(
    (backbone): CaduceusMixerModel(
      (embeddings): CaduceusEmbeddings(
        (word_embeddings): Embedding(16, 256)
      )
      (layers): ModuleList(
        (0-15): 16 x Block(
          (mixer): BiMambaWrapper(
            (mamba_fwd): Mamba(
              (in_proj): Linear(in_features=256, out_features=1024, bias=False)
              (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
              (act): SiLU()
              (x_proj): Linear(in_features=512, out_features=48, bias=False)
              (dt_proj): Linear(in_features=16, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=256, bias=False)
            )
            (mamba_rev): Mamba(
              (in_proj): Linear(in_features=256, out_features=1024, bias=False)
              (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
              (act): SiLU()
          

In [3]:
from Bio import SeqIO

sequences = list(SeqIO.parse("../genome.hg38rg.fa", "fasta"))

def extract_sequence_segment(seq_id, start, end, sequences):
    for seq_record in sequences:
        if seq_record.id == seq_id:
            segment = str(seq_record.seq[start:end])
            return segment
    return None

In [4]:
%run preprocess_utility.py

datafile_path = '../../datasets/task03-genomic-regions/Homo_sapiens.GRCh38.109.txt.gz'  
dataframe = preprocess_home_sapiens_datafile(datafile_path)
dataframe

rows = []
for index, row in dataframe.iterrows():
    chrom=str(row['CHROM'])
    start=int(row['START'])
    end=int(row['END'])
    y=row['y']
    rowid=row['ROWID']
    segment = extract_sequence_segment(chrom, start, end, sequences)
    rows.append([segment, rowid, y])

columns=['sequence','ROWID','y']
df = pd.DataFrame(rows, columns=columns)
df_origin = df
df_origin

Unnamed: 0,sequence,ROWID,y
0,TGGCTGAAGCGGCCACGGGCTTTCTGGAGCAGCTCAAGTCCTGCAT...,0,0
1,TAGCTCTTTGCTGTTTCTTTCTCTTTGTTTTCCGCATCCTCAGGAT...,1,0
2,TGTCCTCCTTCTCTGAGTCGGCGCTGGAGAAGAAGCTCTCGGAGCT...,2,0
3,CACTGGGTCAGAGCCTGTATCACATCCTTTACCAGCATGGTGCCAA...,3,0
4,TAGGTCTTCTCTTCCCGGTCTGTTTTTCTCCTTGTTATGTTCCTGG...,4,0
...,...,...,...
33503,GAAGGGCTATCAGGCCAGCCTTCAGAGGACTCCAAGGAACATTCAA...,33713,6
33504,CGAGGGTCTAACCCAGCCCAGCCTAACCAATGTGCAGACTACTGTA...,33714,6
33505,AGAACATGTTTCCAGGTAGCCTGAAACCCAGCAGACAATGTAGCTG...,33715,6
33506,GCTAGAAGATGCCATCAGAGACCCAGTAGCCAGATGTAGCTGCTGA...,33716,6


### Mathod One: compute embeddings one by one. (works but slow, takes 10+ min)

In [10]:
# import dask.dataframe as dd
# from dask.diagnostics import ProgressBar
# import numpy as np

# # 设置并行参数和批处理大小
# num_parallel = 10

# # 创建一个DataFrame，包含DNA序列
# segments_df = pd.DataFrame(segments, columns=['sequence'])

# # 使用Dask对DataFrame进行并行化处理
# segments_ddf = dd.from_pandas(segments_df, npartitions=num_parallel)

# # 定义处理嵌入的函数
# def process_embedding(df):
#     embeddings = []
#     for dna in df['sequence']:
#         tokens = tokenizer(dna, return_tensors='pt', padding='max_length', max_length=512, truncation=True)
#         tokens = {key: val.to(device) for key, val in tokens.items()}

#         with torch.no_grad():
#             outputs = model(**tokens, output_hidden_states=True)
#             hidden_states = outputs.hidden_states
#             last_layer_embeddings = hidden_states[-1]  # 获取最后一层的嵌入
#             mean_embeddings = torch.mean(last_layer_embeddings, dim=1)  # 计算平均嵌入
#             mean_embeddings = mean_embeddings.view(mean_embeddings.shape[0], -1)
#             embeddings.append(mean_embeddings.cpu().numpy())

#     # 将嵌入结果转换为DataFrame
#     embeddings = np.vstack(embeddings)
#     return embeddings

# # 显示进度条并进行并行计算
# with ProgressBar():
#     ddf_embeddings = segments_ddf.map_partitions(process_embedding).compute()

# df_embeddings = pd.DataFrame(ddf_embeddings)
# df_embeddings

### Method Two: compute in batches (fast take 1+ min)

In [5]:
import dask.dataframe as dd
import pandas as pd
import jax.numpy as jnp
from dask.diagnostics import ProgressBar

# Vectorized tokenization function
def vectorized_tokenizer(subsequences):
    # Tokenize the batch of sequences
    tokens = tokenizer(subsequences, return_tensors='pt', padding=True, truncation=True, max_length=512)
    
    # Move tokens to GPU
    tokens = {key: val.to(device) for key, val in tokens.items()}
    return tokens

# Vectorized embedding function
def vectorized_embedding(tokens):
    # Forward pass to compute last layer embeddings for the batch
    with torch.no_grad():
        outputs = model(**tokens, output_hidden_states=True)  # Enable output of hidden states
        hidden_states = outputs.hidden_states  # Access all hidden states
        last_layer_embeddings = hidden_states[-1]  # Get the last layer embeddings (batch_size, seq_len, hidden_size)
    
    # Compute the mean of the last layer embeddings across the token (sequence) dimension for each sequence in the batch
    # Dimension 1 corresponds to the token/sequence length, so we compute the mean along this axis
    mean_embeddings = torch.mean(last_layer_embeddings, dim=1)  # (batch_size, hidden_size)
    
    # If needed, squeeze out any extra dimensions (though this shouldn't be necessary after mean calculation)
    mean_embeddings_squeezed = mean_embeddings.squeeze(dim=1)

    return mean_embeddings_squeezed

# Tokenization and embedding combined in a batch-wise function
def process_batch(subsequences):
    tokens = vectorized_tokenizer(subsequences)
    embeddings = vectorized_embedding(tokens)
    return embeddings


def apply_get_embeddings_dask(df):
    subsequences = df['sequence'].tolist() 
    embeddings = process_batch(subsequences)  # Process in a vectorized manner
    embeddings_cpu = embeddings.cpu().numpy()
    
    # df['embedding'] = list(embeddings_cpu)  # Assign embeddings back to the DataFrame
    df2 = pd.DataFrame(embeddings_cpu, columns=[f'{i+1}' for i in range(embeddings_cpu.shape[1])])
    df = pd.concat([df.reset_index(drop=True), df2.reset_index(drop=True)], axis=1)
    return df

In [6]:
%%time

import numpy as np
import pandas as pd
import dask.dataframe as dd
from datasets import load_dataset, load_from_disk

%run preprocess_utility.py

typename="homo_sapiens"


for chunkid in range(0,1):
    
    #================df -> df's chunks================
    # Define the number of rows per chunk
    chunk_size = 10000  
    num_parallel = 10

    # Calculate the number of chunks
    num_chunks = int(np.ceil(len(df) / chunk_size))  
    
    # Split the DataFrame into chunks using array_split
    chunks = np.array_split(df, num_chunks)
    
    # Initialize an empty list to store the processed chunks
    processed_chunks = []
    
    #================process each chunk with dask's ddf================
    # Iterate over each chunk
    for chunk in chunks:

        ddf = dd.from_pandas(chunk, npartitions=num_parallel) 
    
        num_embedding_columns = 256
        
        meta = chunk.copy()
        meta = meta.drop(columns=['embedding'], errors='ignore')  # Drop 'embedding' if it exists
        # Add new embedding columns to the metadata
        for i in range(num_embedding_columns):
            # Adjust type, float is common for embeddings
            meta[f'{i+1}'] = float  

    
        # Apply the function in parallel using Dask
        ddf = ddf.map_partitions(apply_get_embeddings_dask, meta=meta)
    
        # Compute the result with progress tracking
        with ProgressBar():
            processed_chunk = ddf.compute()
    
        # Append processed chunk to list
        processed_chunks.append(processed_chunk)

        
    # Concatenate all processed chunks into a final DataFrame
    final_df = pd.concat(processed_chunks, ignore_index=True)

    final_df = final_df.drop(columns=['sequence'])
    final_df = swapfirst2last(final_df)
    final_df = swapfirst2last(final_df)

    final_df.to_csv(typename+'_caduceus_embedding_'+str(chunkid)+'.csv', index=False)
    print(f"{typename}_caduceus_embedding_{chunkid}.csv is created.")

final_df



[########################################] | 100% Completed | 10.24 s




[########################################] | 100% Completed | 10.98 s




[########################################] | 100% Completed | 11.98 ss




[########################################] | 100% Completed | 8.79 sms
homo_sapiens_caduceus_embedding_0.csv is created.
CPU times: user 58.6 s, sys: 1.63 s, total: 1min
Wall time: 1min 4s


Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,249,250,251,252,253,254,255,256,ROWID,y
0,5.658448e-06,-0.000197,0.000309,0.002955,0.000686,0.000928,-0.000554,0.001058,-0.006210,-0.000159,...,-0.005134,-0.005686,0.000118,0.057623,-0.000550,-0.000044,0.000062,-0.000384,0,0
1,7.889087e-06,-0.000224,0.000382,0.076692,0.001234,0.000755,-0.000426,0.000442,-0.006937,-0.000082,...,-0.006732,-0.006180,0.000068,0.092767,-0.000501,-0.000120,0.000090,-0.000322,1,0
2,-2.406500e-06,-0.000200,0.000487,-0.029977,-0.000267,0.001213,-0.001109,0.001319,-0.005158,-0.000188,...,-0.003747,-0.007372,0.000092,0.052131,-0.000377,0.000010,-0.000015,-0.000267,2,0
3,-3.459091e-07,-0.000194,0.000425,0.005609,0.000216,0.001014,-0.000858,0.000856,-0.004863,-0.000138,...,-0.003818,-0.006700,0.000086,0.055448,-0.000468,-0.000029,0.000004,-0.000221,3,0
4,7.957618e-06,-0.000245,0.000295,0.051375,0.000720,0.000784,-0.000521,0.000439,-0.006848,-0.000051,...,-0.007073,-0.005436,0.000120,0.078436,-0.000539,-0.000112,0.000093,-0.000388,4,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
33503,8.310849e-06,-0.000209,0.000204,0.015854,0.000435,0.001157,-0.000379,0.001207,-0.006454,0.000065,...,-0.004690,-0.005454,0.000183,0.031822,-0.000642,-0.000147,0.000115,-0.000538,33713,6
33504,6.904099e-06,-0.000095,0.000177,0.010541,0.000384,0.000972,-0.000720,0.001301,-0.005753,0.000087,...,-0.004128,-0.005287,0.000184,0.022418,-0.000600,-0.000128,0.000067,-0.000504,33714,6
33505,6.175985e-06,-0.000112,0.000301,0.048167,0.000220,0.001166,-0.000662,0.001201,-0.005828,0.000056,...,-0.003900,-0.006533,0.000165,0.032318,-0.000622,-0.000130,0.000084,-0.000408,33715,6
33506,5.151878e-06,-0.000127,0.000263,0.022243,0.000250,0.001052,-0.000785,0.001441,-0.005710,0.000016,...,-0.003703,-0.006233,0.000197,0.026699,-0.000620,-0.000125,0.000103,-0.000441,33716,6
