In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from typing import Optional, Tuple
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from torch.utils.data import TensorDataset
from scipy.spatial.distance import pdist, squareform
from skbio.stats.ordination import pcoa    

from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error
from torchmetrics.regression import MeanSquaredError, MeanAbsolutePercentageError

import pandas as pd
import skbio
from skbio import TreeNode
from io import StringIO
from ete3 import Tree
from skbio import diversity 
from skbio.diversity.beta import unweighted_unifrac, weighted_unifrac
from scipy.spatial.distance import braycurtis
import warnings
from collections import Counter
from torch.utils.data import TensorDataset
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from io import StringIO
from ete3 import Tree
import phylodm
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray import train, tune
from ray.tune import ResultGrid
from scipy.spatial import procrustes
from torch.distributions.normal import Normal


#https://medium.com/@hunter-j-phillips/the-embedding-layer-27d9c980d124

In [3]:
%cd ..

/Users/zkarwowska/TomaszLab Dropbox/Zuzanna Karwowsk/My Mac (zkarwowska’s MacBook Pro)/Desktop/microbiome_gpt


In [4]:
taxonomy = pd.read_csv('inputs/taxonomy.csv', index_col= [0], low_memory=False).fillna(0).sort_index()

metadata = pd.read_csv('inputs/metadata.csv', index_col= [0], low_memory=False).sort_index()
metadata = metadata[metadata['sample_id'].isin(taxonomy.index)]

metadata['SICK'] = np.where(metadata.disease == 'healthy', 0, 1)

In [5]:
## Filter low abundance bacteria

def filter_prevalence(df, treshold = 0.5):
    '''features as columns'''
    df_binary = df.copy()
    df_binary[df_binary>0]=1
    df_binary_sum = df_binary.sum(axis=0)
    
    keep_features = df_binary_sum[df_binary_sum > df.shape[0]*treshold].index
    filtered_df = df[keep_features]
    
    return filtered_df

def filter_on_abundance(df, abundance_treshold = 1e-2):
    '''features as columns'''
    df_relab = df.div(df.sum(axis=1), axis=0)
    df_relab_mean = df_relab.mean()

    keep_features = df_relab_mean[df_relab_mean > abundance_treshold].index
    filtered_df = df[keep_features]
    
    return filtered_df

filtered_taxonomy = filter_on_abundance(filter_prevalence(taxonomy))

In [6]:
filtered_taxonomy = filtered_taxonomy.sample(50)
filtered_taxonomy.shape

(50, 30)

In [47]:
embedding_dim = 8
num_bacteria = filtered_taxonomy.shape[1]
batch_size = filtered_taxonomy.shape[0]

bacteria_tensor = torch.tensor(filtered_taxonomy.values).float()

# Krok 1: Zmieniamy kształt tensora na (B, #Bac, 1)
bacteria_tensor = bacteria_tensor.unsqueeze(-1)
print('bacteria_tensor', bacteria_tensor.shape)

# Krok 2: Tworzenie "bacteria encoding" jako `(Bact, D)`
bacteria_embedding = nn.Embedding(num_bacteria, embedding_dim)
bacteria_indices = torch.arange(num_bacteria)
bacteria_encoded = bacteria_embedding(bacteria_indices)
print('bacteria_encoded', bacteria_encoded.shape)

# Krok 3
linear_layer = nn.Linear(1, embedding_dim)
# Przekształcenie tensora (B, #Bac, 1) do (B, #Bac, D)
bacteria_transformed = linear_layer(bacteria_tensor)
print('bacteria_transformed', bacteria_transformed.shape)


# Krok 4: Dodanie nowego wymiaru na początku, aby uzyskać `(1, #Bac, D)`
bacteria_encoded_expanded = bacteria_encoded.unsqueeze(0)
print('bacteria_encoded_expanded', bacteria_encoded_expanded.shape)

# Dodaj krok 4 do sekwencji z 3
bacteria_encoded_broadcasted = bacteria_encoded_expanded.expand(bacteria_transformed.shape)
result = bacteria_encoded_broadcasted + bacteria_transformed
print('result', result.shape)

# Krok 5: wrzuć do transformera
encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

# Transponowanie result do kształtu (sequence_length, batch_size, embedding_dim)
# Ponieważ Transformer oczekuje danych w formacie (S, N, E)
result_transposed = result.permute(1, 0, 2)
# Przejście przez 3 warstwy transformera
output = transformer_encoder(result_transposed)
print('output', output.shape)

# Uśrednij
mean_vector = output.mean(dim=0)
print('mean_vector', mean_vector.shape)

# Mean and Logvar
mean_layer = nn.Linear(embedding_dim, embedding_dim)
mu_layer = nn.Linear(embedding_dim, embedding_dim)

# Obliczanie mean i mu
mean = mean_layer(mean_vector)
logvar = mu_layer(mean_vector)
print('mean', mean.shape)

# Reparametrize
def reparameterize(mean, log_var):
    batch, dim = mean.shape
    epsilon = Normal(0, 1).sample((batch, dim)).to(mean.device)
    return mean + torch.exp(0.5 * log_var) * epsilon

z = reparameterize(mean, logvar)
print('z', z.shape)

#Weź sekwencję z latentu i przy pomocy Linear przekształć ją do wektora D-wymiarowego
linear_transform = nn.Linear(embedding_dim, embedding_dim)
D_vector = linear_transform(z)
print('D_vector', D_vector.shape)

# Powtórz ten wektor #Bac razy dostając sekwencję wymiaru (B, #Bac, D)
D_vector_expanded = D_vector.unsqueeze(1)
D_vector_repeated = D_vector_expanded.repeat(1, num_bacteria, 1)
print('D_vector_repeated', D_vector_repeated.shape)

#Dodaj “bacteria encoding”
result = D_vector_repeated + bacteria_encoded
print('result', result.shape)

# Użycie 3 warstw transformera (decoder)
decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=4)

transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)

memory = torch.randn(batch_size, num_bacteria, embedding_dim)
memory = memory.permute(1, 0, 2)
result_with_encoding = result.permute(1, 0, 2)

print(memory.shape, result.shape)

output = transformer_decoder(result_with_encoding, memory)
print("output", output.shape)

output = output.permute(1, 0, 2)
print("output", output.shape)

linear_transform = nn.Linear(embedding_dim, 1)
output_transformed = linear_transform(output)

print(output_transformed.shape)

true_values = bacteria_tensor

criterion = nn.MSELoss()
optimizer = optim.Adam(linear_transform.parameters(), lr=0.001)
loss = criterion(output_transformed, true_values)

# Backpropagation i optymalizacja
optimizer.zero_grad()
loss.backward()
optimizer.step()

loss.item()

bacteria_tensor torch.Size([10, 30, 1])
bacteria_encoded torch.Size([30, 8])
bacteria_transformed torch.Size([10, 30, 8])
bacteria_encoded_expanded torch.Size([1, 30, 8])
result torch.Size([10, 30, 8])
output torch.Size([30, 10, 8])
mean_vector torch.Size([10, 8])
mean torch.Size([10, 8])
z torch.Size([10, 8])
D_vector torch.Size([10, 8])
D_vector_repeated torch.Size([10, 30, 8])
result torch.Size([10, 30, 8])
torch.Size([30, 10, 8]) torch.Size([10, 30, 8])
output torch.Size([30, 10, 8])
output torch.Size([10, 30, 8])
torch.Size([10, 30, 1])


2017444102144.0

In [48]:
true_values.shape, output_transformed.shape

(torch.Size([10, 30, 1]), torch.Size([10, 30, 1]))

In [59]:
ytrue = true_values.reshape((true_values.shape[0], true_values.shape[1]))
ypred = output_transformed.reshape((output_transformed.shape[0], output_transformed.shape[1]))

In [65]:
ytrue[0]

tensor([ 229247., 2239335., 4142136.,  319497., 3837933.,  268772.,  137676.,
        1711264.,  215320.,  660408., 1624201., 1260792., 1965534.,  814927.,
         216326., 1865244.,  474209.,    9577.,   63730., 3128068.,  117583.,
         554078.,       0.,  148992.,  737692.,   45610.,  577982.,  230069.,
         351512.,  789446.])

In [66]:
ypred[0]

tensor([-1.0118, -0.4562, -0.1482, -0.2330, -0.9547, -0.8142, -0.6284, -0.8120,
        -0.2323,  0.2082, -0.5217, -0.5674, -0.1278,  0.4950, -0.1186, -1.3166,
        -0.0483, -0.8770, -0.3322,  0.3068, -0.0199,  0.6789, -0.2523, -0.8710,
        -0.3243, -0.2478, -0.2636,  0.4446, -0.8622, -0.4481],
       grad_fn=<SelectBackward0>)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

class BacteriaModel(nn.Module):
    def __init__(self, num_bacteria, embedding_dim):
        super(BacteriaModel, self).__init__()
        self.embedding_dim = embedding_dim
        
        # Bacteria encoding
        self.bacteria_embedding = nn.Embedding(num_bacteria, embedding_dim)
        
        # Linear layer to transform (B, #Bac, 1) to (B, #Bac, D)
        self.linear_layer = nn.Linear(1, embedding_dim)
        
        # Transformer encoder
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
        
        # Mean and Logvar layers
        self.mean_layer = nn.Linear(embedding_dim, embedding_dim)
        self.mu_layer = nn.Linear(embedding_dim, embedding_dim)
        
        # Linear transformation for latent vector
        self.linear_transform = nn.Linear(embedding_dim, embedding_dim)
        
        # Transformer decoder
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=4)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=3)
        
        # Final linear layer to transform (B, #Bac, D) to (B, #Bac, 1)
        self.output_transform = nn.Linear(embedding_dim, 1)
    
    def reparameterize(self, mean, log_var):
        batch, dim = mean.shape
        epsilon = Normal(0, 1).sample((batch, dim)).to(mean.device)
        return mean + torch.exp(0.5 * log_var) * epsilon
    
    def forward(self, bacteria_tensor):
        batch_size = bacteria_tensor.size(0)
        num_bacteria = bacteria_tensor.size(1)
        
        # Step 1: Transform (B, #Bac, 1) to (B, #Bac, D)
        bacteria_transformed = self.linear_layer(bacteria_tensor)
        
        # Step 2: Create bacteria encoding
        bacteria_indices = torch.arange(num_bacteria).to(bacteria_tensor.device)
        bacteria_encoded = self.bacteria_embedding(bacteria_indices)
        
        # Expand bacteria encoding to match batch size and add to bacteria_transformed
        bacteria_encoded_expanded = bacteria_encoded.unsqueeze(0).expand(batch_size, -1, -1)
        result = bacteria_transformed + bacteria_encoded_expanded
        
        # Step 3: Pass through Transformer encoder
        result_transposed = result.permute(1, 0, 2)  # (S, N, E)
        output = self.transformer_encoder(result_transposed)
        
        # Step 4: Average over sequence (mean along the sequence dimension)
        mean_vector = output.mean(dim=0)
        
        # Step 5: Compute mean and logvar
        mean = self.mean_layer(mean_vector)
        logvar = self.mu_layer(mean_vector)
        
        # Step 6: Reparameterize to get latent vector z
        z = self.reparameterize(mean, logvar)
        
        # Step 7: Transform z to (B, D)
        D_vector = self.linear_transform(z)
        
        # Step 8: Repeat D_vector for each bacterium to match (B, #Bac, D)
        D_vector_expanded = D_vector.unsqueeze(1).repeat(1, num_bacteria, 1)
        
        # Add bacteria encoding again
        result_with_encoding = D_vector_expanded + bacteria_encoded_expanded
        
        # Step 9: Pass through Transformer decoder
        memory = torch.randn(batch_size, num_bacteria, self.embedding_dim).to(bacteria_tensor.device)
        memory = memory.permute(1, 0, 2)  # (S, N, E)
        result_with_encoding = result_with_encoding.permute(1, 0, 2)  # (S, N, E)
        output = self.transformer_decoder(result_with_encoding, memory)
        
        # Step 10: Transform output to (B, #Bac, 1)
        output = output.permute(1, 0, 2)  # Back to (B, #Bac, D)
        output_transformed = self.output_transform(output)
        
        return output_transformed, mean, logvar

In [None]:
# Model instantiation
filtered_taxonomy = filtered_taxonomy.div(filtered_taxonomy.sum(axis=1), axis=0)
embedding_dim = 8
num_bacteria = filtered_taxonomy.shape[1]  # replace with filtered_taxonomy.shape[1]
batch_size = filtered_taxonomy.shape[0]  # replace with filtered_taxonomy.shape[0]

model = BacteriaModel(num_bacteria, embedding_dim)
bacteria_tensor = torch.tensor(filtered_taxonomy.values).float()
bacteria_tensor = bacteria_tensor.unsqueeze(-1)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    optimizer.zero_grad()
    
    output_transformed, mean, logvar = model(bacteria_tensor)
    
    loss = criterion(output_transformed, bacteria_tensor)
    
    loss.backward()
    optimizer.step()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')


In [86]:
ytrue = bacteria_tensor.reshape((bacteria_tensor.shape[0], bacteria_tensor.shape[1]))
ypred = output_transformed.reshape((output_transformed.shape[0], output_transformed.shape[1]))

In [87]:
ytrue[0]

tensor([0.0080, 0.0779, 0.1441, 0.0111, 0.1336, 0.0094, 0.0048, 0.0595, 0.0075,
        0.0230, 0.0565, 0.0439, 0.0684, 0.0284, 0.0075, 0.0649, 0.0165, 0.0003,
        0.0022, 0.1089, 0.0041, 0.0193, 0.0000, 0.0052, 0.0257, 0.0016, 0.0201,
        0.0080, 0.0122, 0.0275])

In [88]:
ypred[0]

tensor([ 0.2368,  0.3632,  0.3669,  0.1195,  0.6309,  0.3628,  0.2577, -0.1238,
         0.5537,  0.2589,  0.3329,  0.3159,  0.6331,  0.4046,  0.3246, -0.0170,
         0.2125,  0.2085,  0.2566,  0.5124,  0.5993,  0.1942,  0.3124,  0.5302,
         0.2581,  0.2395,  0.4017,  0.4363,  0.4361,  0.4310],
       grad_fn=<SelectBackward0>)