In [2]:
import pandas as pd
import numpy as np
import torch
import warnings
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("PoetschLab/GROVER")
model = AutoModel.from_pretrained("PoetschLab/GROVER").to(device)
print("Model successfully loaded.")

# get data
data_df = pd.read_csv('data/fecal_data.csv')
data = data_df.drop(columns=['sample']).to_numpy()
print("Data successfully loaded.")

Using device: cuda
Model successfully loaded.
Data successfully loaded.


In [18]:
def calc_embedding_mean(seq):

    inputs = tokenizer(
    seq,
    return_tensors = 'pt',
    max_length=512,
    truncation=True
    )["input_ids"].to(device)

    output = model(inputs)
    hidden_states = output[0]

    # embedding with mean pooling
    embedding_mean = torch.mean(hidden_states[0], dim=0)

    return embedding_mean

In [19]:
embeddings_list = []
for sample in tqdm(data):
    sample_embeddings = []
    for seq in tqdm(sample):
        sample_embeddings.append(calc_embedding_mean(seq).detach().cpu().numpy())
    embeddings_list.append(sample_embeddings)
embeddings = torch.tensor(embeddings_list)
print(f"Embeddings successfully loaded.")

100%|██████████| 562/562 [00:04<00:00, 114.29it/s]
100%|██████████| 562/562 [00:04<00:00, 115.87it/s]
100%|██████████| 562/562 [00:05<00:00, 107.76it/s]
100%|██████████| 562/562 [00:05<00:00, 102.47it/s]
100%|██████████| 562/562 [00:05<00:00, 105.96it/s]
100%|██████████| 562/562 [00:05<00:00, 107.34it/s]
100%|██████████| 562/562 [00:04<00:00, 112.47it/s]
100%|██████████| 562/562 [00:05<00:00, 102.97it/s]
100%|██████████| 562/562 [00:05<00:00, 106.37it/s]
100%|██████████| 562/562 [00:05<00:00, 110.97it/s]
100%|██████████| 562/562 [00:04<00:00, 114.17it/s]
100%|██████████| 562/562 [00:05<00:00, 102.92it/s]
100%|██████████| 562/562 [00:05<00:00, 101.05it/s]
100%|██████████| 562/562 [00:05<00:00, 103.58it/s]
100%|██████████| 562/562 [00:04<00:00, 114.74it/s]
100%|██████████| 562/562 [00:05<00:00, 99.49it/s] 
100%|██████████| 562/562 [00:05<00:00, 108.86it/s]
100%|██████████| 562/562 [00:05<00:00, 108.60it/s]
100%|██████████| 562/562 [00:05<00:00, 110.41it/s]
100%|██████████| 562/562 [00:05

Embeddings successfully loaded.


In [20]:
print(embeddings.shape)

torch.Size([60, 562, 768])


In [22]:
print(embeddings)

tensor([[[-0.3025,  0.6759, -0.5520,  ..., -0.3827,  0.6943,  0.7613],
         [-0.4812,  0.6448, -0.9720,  ..., -0.2833,  0.7118,  0.4289],
         [-0.3506,  0.8393, -0.8018,  ..., -0.2215,  0.6070,  0.3619],
         ...,
         [-0.8737,  0.0957,  0.0357,  ..., -0.3207, -0.3994, -0.1517],
         [-0.8737,  0.0957,  0.0357,  ..., -0.3207, -0.3994, -0.1517],
         [-0.8737,  0.0957,  0.0357,  ..., -0.3207, -0.3994, -0.1517]],

        [[-0.3506,  0.8393, -0.8018,  ..., -0.2215,  0.6070,  0.3619],
         [-0.2762,  0.5872, -0.7504,  ..., -0.1395,  0.6457,  0.8284],
         [-0.3451,  0.5209, -0.8811,  ..., -0.3598,  0.7210,  1.0067],
         ...,
         [-0.8737,  0.0957,  0.0357,  ..., -0.3207, -0.3994, -0.1517],
         [-0.8737,  0.0957,  0.0357,  ..., -0.3207, -0.3994, -0.1517],
         [-0.8737,  0.0957,  0.0357,  ..., -0.3207, -0.3994, -0.1517]],

        [[-0.4812,  0.6448, -0.9720,  ..., -0.2833,  0.7118,  0.4289],
         [-0.3506,  0.8393, -0.8018,  ..., -0