In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, BertConfig, logging
import numpy as np
from tqdm import tqdm
from biom import load_table, Table

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
table = load_table("data/input/merged_biom_table.biom")
table.ids(axis="observation")

array(['AAAAAGAGAGATGAGATTGAGGCTGGGAAAAGTTACTGTAGCCGACGTTTTGGCGGCGCAACCTGTGACGACAAATCTGCTCAAATTTATGCGCGCTTCGATAAAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGAC',
       'AAAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGACGCAGAAGTTAACACTTTCGGATATTTCTGATGAGTCGAAAAATTATCTTGATAAAGCAGGAATTACTACTGCTTGTTTACGAATTAAATCGAAGTGGACTGC',
       'AAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGACGCAGAAGTTAACACTTTCGGATATTTCTGATGAGTCGAAAAATTATCTTGATAAAGCAGGAATTACTACTGCTTGTTTACGAATTAAATCGAAGTGGACTGCT',
       ...,
       'TTTTCTCATTTTCCGCCAGCAGTCCACTTCGATTTAATTCGTAAACAAGCAGTAGTAATTCCTGCTTTATCAAGATAATTTTTCGACTCATCAGAAATATCCGAAAGTGTTAACTTCTGCGTCATGGAAGCGATAAAACTCTGCAGGTTG',
       'TTTTGGCGGCGCAACCTGTGACGACAAATCTGCTCAAATTTATGCGCGCTTCGATAAAAATGATTGGCGTATCCAACCTGCAGAGTTTTATCGCTTCCATGACGCAGAAGTTAACACTTTCGGATATTTCTGATGAGTCGAAAAATTATC',
       'TTTTTCGACTCATCAGAAATATCCGAAAGTGTTAACTTCTGCGTCATGGAAGCGATAAAACTCTGCAGGTTGGATACGCCAATCATTTTTATCGAAGCGCGCATAAATTTGAGCAGATTTGTCGTCACAGGTTGCGCCGCCAAAACGTCG'],
      dtype='<U150')

In [11]:
config = BertConfig.from_pretrained('https://raw.githubusercontent.com/jerryji1993/DNABERT/master/src/transformers/dnabert-config/bert-config-5/config.json', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_5", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNA_bert_5", trust_remote_code=True, config=config).to(device)

def calc_embedding_mean(asvs):
    '''
    input: asv
    returns: [B, A, E]
    '''
    inputs = [tokenizer(asv, return_tensors = 'pt')["input_ids"].to(device) for asv in asvs]
    hidden_states = [model(input).last_hidden_state for input in inputs] # shape: [B x A, N, E]
    embedding_mean = [torch.mean(byte_pair, dim=1) for byte_pair in hidden_states] # embedding with mean pooling
    return torch.concat(embedding_mean, dim=0).cpu().detach().numpy()

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Some weights of the model checkpoint at zhihan1996/DNA_bert_5 were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
embeddings = []
for asv in tqdm(table.ids(axis="observation")):
    embeddings.append(calc_embedding_mean([asv]))
embeddings = np.array(embeddings).reshape((61974, 768))

100%|██████████| 61974/61974 [11:25<00:00, 90.46it/s]


In [13]:
# run if shape (61974, 1, 768)
embeddings = np.squeeze(embeddings, axis=1)

In [14]:
embeddings.shape # should be (61974, 768)

(61974, 768)

In [15]:
embeddings

array([[ 0.09152081, -0.17413566,  1.4005626 , ...,  0.2815851 ,
         0.24107593,  0.25032005],
       [ 0.09152081, -0.17413566,  1.4005626 , ...,  0.2815851 ,
         0.24107593,  0.25032005],
       [ 0.09152081, -0.17413566,  1.4005626 , ...,  0.2815851 ,
         0.24107593,  0.25032005],
       ...,
       [ 0.09152081, -0.17413566,  1.4005626 , ...,  0.2815851 ,
         0.24107593,  0.25032005],
       [ 0.09152081, -0.17413566,  1.4005626 , ...,  0.2815851 ,
         0.24107593,  0.25032005],
       [ 0.09152081, -0.17413566,  1.4005626 , ...,  0.2815851 ,
         0.24107593,  0.25032005]], dtype=float32)

In [16]:
np.save("asv_embeddings.txt", embeddings)