In [None]:
import os
import csv
import torch
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from transformers import AutoModel #, AutoModelForMaskedLM

import warnings
warnings.filterwarnings('ignore')


%run hyena_utility.py
%run preprocess_utility.py

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

### Load Human Chrom Sequences from .fa File

In [None]:
fasta_file = "../genome.hg38rg.fa"
chrom_sequences = read_fasta(fasta_file)

def get_subsequence(chrom_name, start_pos, length):
    
    if chrom_name in chrom_sequences:
        sequence = chrom_sequences[chrom_name]
        subsequence = sequence[start_pos:start_pos + length]
        return subsequence
    else:
        raise ValueError(f"Chromosome '{chrom_name}' not found in the FASTA file.")

In [None]:
def Subsequence2Embedding(subsequence):
    tok_seq = tokenizer(subsequence)
    tok_seq = tok_seq["input_ids"]  # grab ids

    # place on device, convert to tensor
    tok_seq = torch.LongTensor(tok_seq).unsqueeze(0)  # unsqueeze for batch dim
    tok_seq = tok_seq.to(device)

    with torch.inference_mode():
        embeddings = model(tok_seq)

    # cls_embedding = embeddings.last_hidden_state[:, 0, :]
    # cls_embedding = embeddings[:, 0, :]
    
    mean_embeddings = embeddings.mean(dim=1) # Mean across the sequence length dimension
    mean_embeddings = mean_embeddings.squeeze(0)  # This will change the shape to [256]

    
    # print(embeddings.shape)  # embeddings here!
    # return cls_embedding
    return mean_embeddings

# max_length

### Main Process

In [None]:
pretrained_model_name = 'hyenadna-small-32k-seqlen'
pretrained_model_name = 'hyenadna-medium-160k-seqlen'
# pretrained_model_name = 'hyenadna-medium-450k-seqlen'
# pretrained_model_name = 'hyenadna-large-1m-seqlen'
model, tokenizer, max_length =  get_model_tokenizer_maxlen(pretrained_model_name)
model.to(device)
model.eval()

In [None]:
datafile='methylation'

In [None]:
data_filename = '../../datasets/task05-methylation/GSM6637962_CpG_coverage20_GRCh38.bed.gz'     
df = preprocess_datafile(data_filename)
df

In [None]:
%%time

csv_Filename =datafile + '_hyena_embedding.csv'

if os.path.exists(csv_Filename):
    os.remove(csv_Filename)


rows=[]
for index, row in df.iterrows():      
    chrom=row['CHROM']
    pos_start=row['START']

    if pos_start<=1:
        pos_start=1
    y=row['y']
    length = row['SIZE'] # max_length
    
    subsequence = get_subsequence(chrom, pos_start, length)
    if 'N' in subsequence:
        print("The character 'N' is present in the string.")
        
    embedding = Subsequence2Embedding(subsequence)
    # print(embedding.shape)

    # feature=np.array(embedding_df.iloc[64])
    rows.append(np.append(embedding.cpu().numpy(),  [y])) # chrom,  length,  comp[ref],comp[alt],

    if index > 0 and (index % 5000) == 0:
        append_rows_to_csv(csv_Filename, rows)
        rows=[]
        print (f"index = {index} completed")
        
append_rows_to_csv(csv_Filename, rows)

print(f"Create File: "+csv_Filename)

### Load CSV File

In [None]:
df = load_embedding_file(csv_Filename)
df