# Transformer for OTU tables
> Let's use NeuroSEED embeddings for classification, for instance

Written partially by ChatGPT: https://chat.openai.com/share/3c087924-3c9f-4d42-993f-69657a4afbfd

In [8]:
# Get torch to use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [1]:
import torch
from torch import nn

class OTUTransformerClassifier(nn.Module):
    def __init__(self, nhead, nhid, nlayers, dropout=0.5, embedding_dim=128):
        super(OTUTransformerClassifier, self).__init__()
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=embedding_dim, nhead=nhead, dim_feedforward=nhid, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers, num_layers=nlayers
        )

        # Binary classification output layer
        self.fc = nn.Linear(embedding_dim, 1)

    def forward(self, X):
        # Transformer expects input of shape [seq_len, batch_size, embedding_dim].
        transformer_output = self.transformer_encoder(X.permute(1, 0, 2))

        # Use the last output in the sequence for each item in the batch for classification.
        transformer_output = transformer_output[-1, :, :]

        output = torch.sigmoid(self.fc(transformer_output))

        return output

In [2]:
# How to get embeddings working?

# nn.Embedding.from_pretrained() expects a tensor of shape [num_embeddings, embedding_dim] as input.

import pandas as pd

embeddings_path = "/home/phil/mixture_embeddings/data/processed/otu_embeddings/yatsunenko/cnn_hyperbolic_128_otu_embeddings.csv"
embeddings_df = pd.read_csv(embeddings_path, index_col=0)

embed = nn.Embedding.from_pretrained(torch.tensor(embeddings_df.values), freeze=True)
embed

Embedding(15783, 128)

In [3]:
def embed_by_name(name, embeddings=embeddings_df):
    try:
        return embed(torch.tensor(embeddings.index.get_loc(name)))
    except KeyError:
        return embed(torch.tensor(embeddings.index.get_loc(int(name))))

embed_by_name(179499)

tensor([ 1.3231e-03,  2.8251e-03, -6.2945e-03,  5.3259e-03, -2.5479e-03,
         3.6997e-03,  2.0449e-03, -6.1082e-04, -2.6830e-03, -1.7612e-03,
         1.8971e-04,  4.4580e-03,  9.2340e-04, -1.8812e-03, -2.9178e-03,
         1.4681e-02, -4.3525e-03, -5.4849e-03, -3.7471e-03,  4.0354e-03,
         6.3417e-03, -1.8050e-03, -2.4388e-03, -4.5531e-03, -2.9844e-04,
         2.0589e-03,  7.0418e-04,  4.9733e-04, -1.4950e-03, -9.9802e-04,
        -5.5532e-03, -2.6107e-03,  2.3284e-03, -3.1904e-03,  2.5297e-03,
        -2.7868e-03, -5.4514e-03, -1.7141e-02, -3.2478e-04,  3.7553e-03,
        -3.1140e-03, -1.1644e-03, -1.3271e-03, -7.3676e-03,  5.2530e-03,
         1.8106e-03, -2.7933e-04,  2.8870e-03, -1.6990e-03,  1.1634e-03,
         2.0539e-03,  5.3673e-03, -6.8478e-04, -6.0461e-03,  2.3066e-03,
         8.5648e-04,  1.4091e-03,  1.8466e-03,  1.5592e-03, -5.6731e-03,
        -7.9304e-04, -7.1623e-04,  1.0376e-02, -8.2456e-04, -1.3137e-03,
         1.7177e-03, -2.8253e-04,  5.4399e-04,  3.9

In [4]:
otu_table_path = "/home/phil/mixture_embeddings/data/interim/knight/yatsunenko_data.csv"
otu_table = pd.read_csv(otu_table_path, index_col=0)

otu_table.head()

Unnamed: 0_level_0,298804,179499,114400,1099710,1078207,1016598,1010876,1000592,1000113,100198,...,998383,997777,99793,998719,998524,999784,998905,998869,99960,99981
Sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Amz4adltF.418711,9,0,267,0,6,1,0,1,0,0,...,0,1,0,0,0,0,0,0,0,0
USygt25.F.418747,0,7,0,0,146,0,13,0,0,8,...,0,0,0,0,0,0,0,0,0,0
USygt27.M.418861,3,87,1757,0,77,0,26,0,0,2,...,0,0,0,0,0,0,0,0,0,0
Amz5eldF.418421,43,905,252,0,126,317,28,0,0,0,...,0,47,0,0,0,0,0,0,0,0
USygt52.M.418736,0,25,0,2,32,0,18,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
def embed_otu_vector(otu_table_row):
    # # Drop zero indices
    otu_table_row = otu_table_row[otu_table_row != 0]
    names = otu_table_row.index

    # Normalize otu_vector, return as tensor
    count_vector = torch.tensor(otu_table_row, dtype=torch.float32)
    count_vector = count_vector.flatten()
    count_vector /= count_vector.sum()

    # Get embeddings
    embeddings = [embed_by_name(name) for name in names]
    embeddings = torch.stack(embeddings)

    return embeddings, count_vector

otu_embeddings, count_vector = embed_otu_vector(otu_table.iloc[0, :])

In [6]:
# # Embed all rows of otu_table
# otu_tensor = torch.tensor(otu_table.values).float() # (n_samples, n_features)
# otu_tensor /= otu_tensor.sum(dim=1, keepdim=True) # Normalize by total counts in each sample
# embed_tensor = torch.stack([embed_by_name(name) for name in otu_table.columns]).float() # (n_features, embedding_dim)
# otu_tensor = otu_tensor.unsqueeze(-1) # (n_samples, n_features, 1)
# X = otu_tensor * embed_tensor # (n_samples, n_features, embedding_dim)
# # X = torch.einsum("ij,jk->ijk", [otu_tensor, embed_tensor]) # (n_samples, n_features, embedding_dim)

# The above made tensors that were too big, with sparse values. Instead, we can use variable-length sequences:
X = []
for sample in otu_table.index:
    sample_embeddings, sample_counts = embed_otu_vector(otu_table.loc[sample, :])
    X.append(sample_embeddings * sample_counts.unsqueeze(-1)) # Need to unsqueeze to make Hadamard product work
# Pad
X = nn.utils.rnn.pad_sequence(X, batch_first=True).float()

# Get labels
y = pd.read_csv("/home/phil/mixture_embeddings/data/interim/knight/yatsunenko_metadata.csv", index_col=0)["Sex"]
y = y.loc[otu_table.index]
labeler = {"female": 0, "male": 1, "unknown": -1}
y = torch.tensor(y.map(labeler).values).float()

# Drop samples without labels
X = X[y != -1]
y = y[y != -1]

In [13]:
# Assume otu_embeddings is a 2D tensor of shape [num_otus, embedding_dim].
from tqdm.notebook import tqdm
from torch.utils.data import TensorDataset, DataLoader

model = OTUTransformerClassifier(nhead=4, nhid=64, nlayers=3, embedding_dim=128).float()

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

# Get everything on the GPU
model = model.to("cuda")
criterion = criterion.to("cuda")
X = X.to("cuda")
y = y.to("cuda")

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
dataset = TensorDataset(X, y)

# Example training loop
for epoch in range(10):
    losses = []
    pbar = tqdm(dataloader)
    for X_batch, y_batch in pbar:  # Assume dataloader yields count_matrix and labels
        optimizer.zero_grad()
        output = model(X_batch)
        loss = criterion(output.flatten(), y_batch)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        pbar.set_description(f"Epoch {epoch} loss: {sum(losses) / len(losses):.6f}")
        pbar.refresh()
    print(f"Epoch {epoch} loss: {sum(losses) / len(losses)}")

AttributeError: 'TensorDataset' object has no attribute 'to'

In [None]:
# Evaluate model - predict in batches

outputs = []
for X_batch, y_batch in tqdm(dataloader):
    output = model(X_batch)
    outputs.append(output)

outputs = torch.cat(outputs).flatten().numpy()
outputs

  0%|          | 0/238 [00:00<?, ?it/s]

: 

: 