In [1]:
import torch
import pickle
from Bio import SeqIO
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

model_name = 'tattabio/gLM2_650M'
tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
pretrained_model = AutoModel.from_pretrained(model_name,trust_remote_code=True).cuda()

def embed_sequence(sequence):
    encodings = tokenizer([sequence], return_tensors='pt')
    with torch.no_grad():
        embedding = pretrained_model(encodings.input_ids.cuda(), output_hidden_states=True).last_hidden_state
    return embedding.cpu().numpy()

direction_dict = {'fwd':'+','rev':'-'}
embedding_dict = {}
for direction,symbol in direction_dict.items():
    seq_dict = {}
    for record in SeqIO.parse('contigs_coverage_95.fn', 'fasta'):
        seq_dict[record.id] = f'<{symbol}>{str(record.seq).lower()}'

    tmp_dict = {}
    for id, seq in tqdm(seq_dict.items()):
        tmp_dict[id] = embed_sequence(seq)

    with open(f'contigs_coverage_95_{direction}_embeddings_gLM2_650M.pkl', 'wb') as file:
        pickle.dump(tmp_dict, file)

    embedding_dict[direction] = tmp_dict
    del tmp_dict

100%|██████████| 405/405 [02:27<00:00,  2.75it/s]
100%|██████████| 405/405 [02:27<00:00,  2.75it/s]
