In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, BertConfig, logging
import numpy as np
from tqdm import tqdm
from biom import load_table, Table

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
table = load_table("data/input/merged_biom_table.biom")
table.ids(axis="observation")

array(['AAAAAGAGAGATGAGATTGAGGCTGGGAAAAGTTACTGTAGCCGACGTTTTGGCGGCGCAACCTGTGACGACAAATCTGCTCAAATTTATGCGCGCTTCGATAAAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGAC',
       'AAAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGACGCAGAAGTTAACACTTTCGGATATTTCTGATGAGTCGAAAAATTATCTTGATAAAGCAGGAATTACTACTGCTTGTTTACGAATTAAATCGAAGTGGACTGC',
       'AAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGACGCAGAAGTTAACACTTTCGGATATTTCTGATGAGTCGAAAAATTATCTTGATAAAGCAGGAATTACTACTGCTTGTTTACGAATTAAATCGAAGTGGACTGCT',
       ...,
       'TTTTCTCATTTTCCGCCAGCAGTCCACTTCGATTTAATTCGTAAACAAGCAGTAGTAATTCCTGCTTTATCAAGATAATTTTTCGACTCATCAGAAATATCCGAAAGTGTTAACTTCTGCGTCATGGAAGCGATAAAACTCTGCAGGTTG',
       'TTTTGGCGGCGCAACCTGTGACGACAAATCTGCTCAAATTTATGCGCGCTTCGATAAAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGACGCAGAAGTTAACACTTTCGGATATTTCTGATGAGTCGAAAAATTATC',
       'TTTTTCGACTCATCAGAAATATCCGAAAGTGTTAACTTCTGCGTCATGGAAGCGATAAAACTCTGCAGGTTGGATACGCCAATCATTTTTATCGAAGCGCGCATAAATTTGAGCAGATTTGTCGTCACAGGTTGCGCCGCCAAAACGTCG'],
      dtype='<U150')

In [4]:
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=config).to(device)

def calc_embedding_mean(asv):
    '''
    input: asv
    returns: [B, A, E]
    '''
    inputs = [tokenizer(asv, return_tensors = 'pt')["input_ids"].to(device) for token in tokens]
    hidden_states = [model(input) for input in inputs] # shape: [B x A, N, E]
    embedding_mean = [torch.mean(byte_pair, dim=1) for byte_pair, class_token in hidden_states] # embedding with mean pooling
    return torch.concat(embedding_mean, dim=0).cpu().detach().numpy()

In [6]:
embeddings = []
for asv in tqdm(table.ids(axis="observation")):
    embeddings.append(calc_embedding_mean([asv]))
embeddings = np.array(embeddings).reshape((61974, 768))

100%|██████████| 61974/61974 [12:08<00:00, 85.08it/s]


In [26]:
# run if shape (61974, 1, 768)
embeddings = np.squeeze(embeddings, axis=1)

In [27]:
embeddings.shape # should be (61974, 768)

(61974, 768)

In [29]:
embeddings

array([[-0.08678621,  0.08558369,  0.09969272, ...,  0.04015661,
         0.05117398,  0.07039835],
       [-0.04607328, -0.0122479 , -0.00292744, ...,  0.0454605 ,
         0.14943917,  0.07452027],
       [-0.04895976, -0.03211617,  0.03563015, ...,  0.02124203,
         0.14967498,  0.0575194 ],
       ...,
       [-0.01771752, -0.0191946 ,  0.00114739, ...,  0.08808845,
         0.17157225,  0.0335262 ],
       [-0.0484022 ,  0.11651459,  0.04410456, ...,  0.06006808,
         0.05627226,  0.07800303],
       [-0.04873918,  0.14298637,  0.08974645, ...,  0.04729619,
         0.1111041 ,  0.07895355]], dtype=float32)

In [30]:
np.save("asv_embeddings.txt", embeddings)