In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, BertConfig, logging
import numpy as np
from tqdm import tqdm
from biom import load_table, Table

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
table = load_table("data/input/merged_biom_table.biom")
table.ids(axis="observation")

In [None]:
config = BertConfig.from_pretrained("PoetschLab/GROVER")
tokenizer = AutoTokenizer.from_pretrained("PoetschLab/GROVER")
model = AutoModel.from_pretrained("PoetschLab/GROVER", config=config).to(device)

def calc_embedding_mean(asvs):
    '''
    input: asv
    returns: [B, A, E]
    '''
    inputs = [tokenizer(asv, return_tensors = 'pt')["input_ids"].to(device) for asv in asvs]
    hidden_states = [np.mean(model(input)[0].cpu().detach().numpy(), axis=1) for input in inputs] # shape: [B x A, N, E]
    return np.vstack(hidden_states)

In [None]:
embeddings = calc_embedding_mean(table.ids(axis="observation"))
print(embeddings.shape)
np.save("asv_embeddings.npy", embeddings)
np.save("asv_embedding_ids.npy", table.ids(axis="observation"))