### This notebook creates the DNABERT-2 model's embeddings for Bacterial data

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers.models.bert.configuration_bert import BertConfig
from transformers import AutoTokenizer, AutoModel
import torch, gc, glob, json

from tqdm import tqdm
tqdm.pandas() 

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
def clean_gpu():
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
clean_gpu()

In [None]:
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=config)

In [None]:
model.to(device)

In [None]:
base_dir = '/sternadi/home/volume1/yuval/courses/NLP/NLP_Proj'
data_dir = f'{base_dir}/data'
data_files = glob.glob(f'{data_dir}/*.csv')

In [None]:
# Create a df containing DNA sequences from all genes
dna_df = pd.DataFrame()
for file in data_files:
    data = pd.read_csv(file)
    gene_name = file.split('/')[-1].split('.')[0]
    data = data.drop(columns='protein_seq')
    data['gene'] = gene_name
    dna_df = pd.concat([dna_df, data], ignore_index=True)

# Cleanup
dna_df.rename(columns={' organism': 'organism', ' strain': 'strain'}, inplace=True)
dna_df = dna_df[~dna_df['organism'].isna()]
dna_df = dna_df[~dna_df['strain'].isna()]

In [None]:
def get_embedding(dna_seq):
    with torch.no_grad(): 
        inputs = tokenizer(dna_seq, return_tensors = 'pt')["input_ids"].to(device)
        hidden_states = model(inputs)[0] # [1, sequence_length, 768]
        embedding_mean = torch.mean(hidden_states[0], dim=0)
    return embedding_mean

dna_df['embeddings_tensor'] = dna_df['dna_seq'].progress_apply(get_embedding)

In [None]:
dna_df['embeddings_np'] = dna_df['embeddings_tensor'].apply(lambda x: x.cpu().numpy() if isinstance(x, torch.Tensor) else x)
dna_df['embeddings_json'] = dna_df['embeddings_np'].apply(lambda x: json.dumps(x.tolist()))

In [None]:
dna_df.to_csv('dna_embeddings.csv')