# data loading

In [0]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import KFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split


from xgboost import XGBRegressor
from sklearn.ensemble import RandomForestRegressor

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


In [0]:
df_age = pd.read_csv("../data/age.csv", header=None, index_col=0, sep='\t')
y = df_age.to_numpy().reshape(-1, 1).flatten()

y.shape

In [0]:
# stratify by age group
y_class = y//10
y_class[y_class==9] = 8 

In [0]:
sns.histplot(y_class)

In [0]:

train_idx, test_idx = train_test_split(range(len(y)), test_size=0.2, stratify=y_class, random_state=42)  # split the data once so that index keeps the same for different types of X
test_idx[:10]

[4436, 2292, 4448, 4903, 2378, 842, 2625, 3097, 4898, 1911]

# model and training functions

In [0]:
# data type specific parameters
def get_params(X_type):
    if X_type == 'abundance':
        scale_data = True
        ae_loss_function = nn.MSELoss(reduction='sum')   # calculate loss for non-zero values rather than averaging over all values, which could otherwise be dominated by the many zeros in the data.

    elif X_type == 'log':
        scale_data = True
        ae_loss_function = nn.MSELoss(reduction='none')  # todo sum

    elif X_type == 'presence':
        scale_data = False  # no need to scale as the data is already between 0 and 1
        ae_loss_function = nn.BCELoss(reduction='sum')  # use binary loss                    # todo try focal ?

    else:
        raise ValueError("Invalid character for data type")


    return scale_data, ae_loss_function


In [0]:
def get_data(x_type, y, scale_data):
    if x_type == 'log':
        X = pd.read_csv("../data/processed_log_abundance.csv", header=0, index_col=0, sep='\t').loc[df_age.index, :].to_numpy()
    elif x_type == 'abundance':
        X = pd.read_csv("../data/processed_abundance.csv", sep='\t', header=0, index_col=0).loc[df_age.index, :].to_numpy()
    elif x_type == 'presence':
        X = pd.read_csv("../data/processed_abundance.csv", sep='\t', header=0, index_col=0).loc[df_age.index, :].to_numpy()
        X = (X > 0).astype(int)
    else:
        raise ValueError("Invalid character for data type")
        
    X_train, X_test, y_train, y_test = X[train_idx], X[test_idx], y[train_idx], y[test_idx]
    if scale_data:
        # scale data so that they are between 0 and 1
        # scaler = MinMaxScaler()
        # X_train = scaler.fit_transform(X_train)
        # X_test = scaler.transform(X_test)
        if x_type == 'log':
            # X_train = (X_train + 6)/8
            # X_test = (X_test + 6)/8
            pass
        elif x_type == 'abundance':
            scaler = MinMaxScaler()
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)  # todo check if this is the right scaling
    return X_train, X_test, y_train, y_test


def get_dataloader(X_train_tensor, X_test_tensor, y_train_tensor, y_test_tensor, batch_size=64):
    # Create TensorDatasets
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


In [0]:
# Define models

# todo  add noise (e.g., zero out random values) to the input during training and train the network to reconstruct the original data. This encourages the model to learn robust features despite the sparse noise.


# # todo track reconstruction error (e.g., MSE, BCE) only on the non-zero entries
# reconstruction_loss = nn.BCELoss(reduction='none')(decoded, inputs)
# non_zero_mask = inputs > 0  # Mask to focus only on non-zero entries
# loss = reconstruction_loss * non_zero_mask

class ShallowAutoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(ShallowAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.LeakyReLU()
        )
        self.decoder1 = nn.Sequential(
            nn.Linear(latent_dim, input_dim),
        )
        self.decoder2 = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.Sigmoid()  # presence
        )
        self.regression_head = nn.Sequential(
            nn.Linear(latent_dim, 1),
            nn.ReLU()
        )
            
    def forward(self, x):
        encoded = self.encoder(x)
        decoded1 = self.decoder1(encoded)
        decoded2 = self.decoder2(encoded)
        regression_output = self.regression_head(encoded)
        return encoded, decoded1, decoded2, regression_output

class DeepAutoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(DeepAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim//4),
            nn.LeakyReLU(),
            nn.Linear(input_dim//4, latent_dim),
            nn.LeakyReLU()
        )
        self.decoder1 = nn.Sequential(
            nn.Linear(latent_dim, input_dim//4),
            nn.LeakyReLU(),
            nn.Linear(input_dim//4, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim),
        )
        self.decoder2 = nn.Sequential(
            nn.Linear(latent_dim, input_dim//4),
            nn.LeakyReLU(),
            nn.Linear(input_dim//4, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim),
            nn.Sigmoid()
        )
        self.regression_head = nn.Sequential(
            nn.Linear(latent_dim, 1),
            nn.ReLU()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded1 = self.decoder1(encoded)
        decoded2 = self.decoder2(encoded)
        regression_output = self.regression_head(encoded)
        return encoded, decoded1, decoded2, regression_output
    

class ShallowVAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(ShallowVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, input_dim//2),
            nn.LeakyReLU(),
        )
        self.mu = nn.Linear( input_dim//2, latent_dim)
        self.logvar = nn.Linear( input_dim//2, latent_dim)
        self.decoder1 = nn.Sequential(
            nn.Linear(latent_dim, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim),
        )
        self.decoder2 = nn.Sequential(
            nn.Linear(latent_dim, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim),
        )
        self.regression_head = nn.Sequential(
            nn.Linear(latent_dim, 1),
            nn.ReLU()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.mu(h), self.logvar(h)
        encoded = self.reparameterize(mu, logvar)
        decoded1 = self.decoder1(encoded)
        decoded2 = self.decoder2(encoded)
        regression_output = self.regression_head(encoded)
        return encoded, decoded1, decoded2, regression_output, mu, logvar




class DeepVAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(DeepVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim//4),
            nn.LeakyReLU(),

        )
        self.mu = nn.Linear( input_dim//4, latent_dim)
        self.logvar = nn.Linear( input_dim//4, latent_dim)
        self.decoder1 = nn.Sequential(
            nn.Linear(latent_dim, input_dim//4),
            nn.LeakyReLU(),
            nn.Linear(input_dim//4, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim),
        )
        self.decoder2 = nn.Sequential(
            nn.Linear(latent_dim, input_dim//4),
            nn.LeakyReLU(),
            nn.Linear(input_dim//4, input_dim//2),
            nn.LeakyReLU(),
            nn.Linear(input_dim//2, input_dim),
            nn.Sigmoid()
        )

        self.regression_head = nn.Sequential(
            nn.Linear(latent_dim, 1),
            nn.ReLU()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.mu(h), self.logvar(h)
        encoded = self.reparameterize(mu, logvar)
        decoded1 = self.decoder1(encoded)
        decoded2 = self.decoder2(encoded)
        regression_output = self.regression_head(encoded)
        return encoded, decoded1, decoded2, regression_output, mu, logvar




In [0]:
from torch.nn.functional import binary_cross_entropy, mse_loss
def get_losses(y_true, X_true, regression_output, latent, presence, reconstructed, ae_loss_function, reg_loss_function):
    loss_l1 = torch.mean(torch.abs(latent)) # regularization term
    mask_non_zero = X_true != 0
    X_true_masked = X_true[mask_non_zero]
    reconstructed_masked = reconstructed[mask_non_zero]
    loss_non_zero = mse_loss(reconstructed_masked, X_true_masked)
    loss_presence = binary_cross_entropy(presence, X_true)
    return loss_l1, loss_non_zero, loss_presence


    

In [0]:
# Training function

def train_model(model, model_name, train_loader, test_loader, ae_loss_function, optimizer, reg_loss_function, lambda_ae, lambda_reg, alpha_l1 = 0, num_epochs=50, patience=10):
    min_val_loss = float('inf')
    best_model = None  # for early stopping
    early_stopping_counter = 0
    lst_train_loss_non_zero = []
    lst_train_loss_presence = []
    lst_train_loss_reg = []
    lst_train_r2 = []
    
    lst_val_loss_non_zero = []
    lst_val_loss_presence = []
    lst_val_loss_reg = []
    lst_val_r2 = []

    for epoch in range(num_epochs):
        model.train()
        train_loss_non_zero = 0.0
        train_loss_presence = 0.0
        train_loss_reg = 0.0
        combined_loss = 0.0
        train_r2 = 0.0
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            if model_name.endswith('VAE'):
                latent, reconstructed, presence, regression_output, mu, logvar = model(X_batch)
                loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            else:
                latent, reconstructed,presence, regression_output = model(X_batch)
                loss_kl = 0.0
           
            loss_l1, loss_non_zero, loss_presence = get_losses(y_batch, X_batch, regression_output, latent, presence, reconstructed, ae_loss_function, reg_loss_function)
            print("!!!!!!!!!!!!!check loss !!!!!!!!!!!!!!!!!!")
            print('loss_kl', loss_kl)
            print('loss_presence',loss_presence)
            print('loss_non_zero',loss_non_zero)
            print('loss_l1',loss_l1)
            # loss_ae = 0.5 * loss_non_zero + 0.5 * loss_presence + loss_kl + alpha_l1 * loss_l1 # Reconstruction loss
            loss_ae = 0.5 * loss_non_zero + 0.5 * loss_presence + loss_kl
            loss_reg = reg_loss_function(regression_output, y_batch)
            combined_loss = lambda_ae * loss_ae  + lambda_reg * loss_reg
            combined_loss.backward()
            optimizer.step()
            
            train_loss_non_zero += loss_non_zero.item()
            train_loss_presence += loss_presence.item()
            train_loss_reg += loss_reg.item()
            combined_loss += combined_loss.item()
            train_r2 += r2_score(y_batch.cpu().detach().numpy(), regression_output.cpu().detach().numpy())
        train_loss = combined_loss/len(train_loader)


        lst_train_loss_non_zero.append(train_loss_non_zero/len(train_loader))
        lst_train_loss_presence.append(train_loss_presence/len(train_loader))
        lst_train_loss_reg.append(train_loss_reg/len(train_loader))
        lst_train_r2.append(train_r2/len(train_loader))


        # Validation loss
        model.eval()
        val_loss_non_zero = 0.0
        val_loss_presence = 0.0
        val_loss_reg = 0.0
        val_r2 = 0.0
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                if model_name.endswith('VAE'):
                    latent, reconstructed,presence, regression_output, mu, logvar = model(X_batch)
                    loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 
                else:                
                    latent, reconstructed, presence, regression_output = model(X_batch)
                    loss_kl = 0.0

           
                loss_l1, loss_non_zero, loss_presence = get_losses(y_batch, X_batch, regression_output, latent, presence, reconstructed, ae_loss_function, reg_loss_function)
                loss_ae = 0.5 * loss_non_zero + 0.5 * loss_presence + loss_kl + alpha_l1 * loss_l1 # Reconstruction loss
                loss_reg = reg_loss_function(regression_output, y_batch)
                combined_loss = lambda_ae * loss_ae  + lambda_reg * loss_reg

                val_loss_non_zero += loss_non_zero.item()
                val_loss_presence += loss_presence.item()
                val_loss_reg += loss_reg.item()
                combined_loss += combined_loss.item()
                val_r2 += r2_score(y_batch.cpu().detach().numpy(), regression_output.cpu().detach().numpy())
            val_loss = combined_loss/len(test_loader)


            lst_val_loss_non_zero.append(val_loss_non_zero/len(test_loader))
            lst_val_loss_presence.append(val_loss_presence/len(test_loader))
            lst_val_loss_reg.append(val_loss_reg/len(test_loader))
            lst_val_r2.append(val_r2/len(test_loader))



        if epoch % 5 == 0:
            print(f'{model_name} Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

        # # Early stopping
        # if val_loss < min_val_loss:
        #     min_val_loss = val_loss
        #     best_model = model.state_dict()
        #     early_stopping_counter = 0
        # else:
        #     early_stopping_counter += 1
        #     if early_stopping_counter >= patience:
        #         print("Early stopping")
        #         break

    return lst_train_loss_non_zero, lst_train_loss_presence, lst_train_loss_reg, lst_train_r2, lst_val_loss_non_zero, lst_val_loss_presence, lst_val_loss_reg, lst_val_r2

# start

In [0]:
# parameters to define before each experiment

x_type = 'log'  # 'abundance' or 'presence' or 'log'
latent_dim = 20  # García-Jiménez et al. 2021 used latent_dim 10 to represent 717 taxa https://academic.oup.com/bioinformatics/article/37/10/1444/5988714
lambda_ae = 1
lambda_reg = 1 - lambda_ae
alpha_l1 = 0.0
num_epochs= 200
patience = num_epochs/4

In [0]:
scale_data, ae_loss_function = get_params(x_type)
X_train, X_test, y_train, y_test = get_data(x_type, y, scale_data)
# Convert to tensors and move to device
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).to(device)
train_loader, test_loader = get_dataloader(X_train_tensor, X_test_tensor, y_train_tensor, y_test_tensor, batch_size=64)
input_dim = X_train.shape[1]

In [0]:
y_test_tensor.shape

In [0]:

plt.subplot(121)
sns.histplot(X_test.flatten())
plt.title('Test Set - Original Distribution')

plt.subplot(122)
sns.histplot(X_test.flatten()[X_test.flatten()>0])
plt.title('Test Set - Original Distribution > 0')


In [0]:
len(train_loader)

In [0]:
model = ShallowVAE(input_dim, latent_dim)

min_val_loss = float('inf')
best_model = None  # for early stopping
early_stopping_counter = 0
lst_train_loss_non_zero = []
lst_train_loss_presence = []
lst_train_loss_reg = []
lst_train_r2 = []

lst_val_loss_non_zero = []
lst_val_loss_presence = []
lst_val_loss_reg = []
lst_val_r2 = []

for epoch in range(num_epochs):
    model.train()
    train_loss_non_zero = 0.0
    train_loss_presence = 0.0
    train_loss_reg = 0.0
    combined_loss = 0.0
    train_r2 = 0.0
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        if model_name.endswith('VAE'):
            latent, reconstructed, presence, regression_output, mu, logvar = model(X_batch)
            loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        else:
            latent, reconstructed,presence, regression_output = model(X_batch)
            loss_kl = 0.0
        
        loss_l1, loss_non_zero, loss_presence = get_losses(y_batch, X_batch, regression_output, latent, presence, reconstructed, ae_loss_function, reg_loss_function)
        print("!!!!!!!!!!!!!check loss !!!!!!!!!!!!!!!!!!")
        print('loss_kl', loss_kl)
        print('loss_presence',loss_presence)
        print('loss_non_zero',loss_non_zero)
        print('loss_l1',loss_l1)

In [0]:



models = [

    ShallowVAE(input_dim, latent_dim),
    DeepVAE(input_dim, latent_dim),
    ShallowAutoencoder(input_dim, latent_dim),
    DeepAutoencoder(input_dim, latent_dim),

]

model_names = [

    "ShallowVAE",
    "DeepVAE",
    "ShallowAutoencoder",
    "DeepAutoencoder",
]


dct_history = dict()
dct_y_pred = dict()

plt.figure(figsize=(18, 15))
i = 1
plt.subplot(5, 2, i)
sns.histplot(X_test.flatten())
plt.title('Test Set - Original Distribution')
i+=1
plt.subplot(5, 2, i )
sns.histplot(X_test.flatten()[X_test.flatten()>0])
plt.title('Test Set - Original Distribution > 0')
i+=1

for model, model_name in zip(models, model_names):
    
    print(f"Training {model_name}")

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # training
    model_history = train_model(model, model_name, train_loader, test_loader, ae_loss_function, optimizer=optimizer,reg_loss_function=nn.MSELoss(), num_epochs=num_epochs, patience=patience, lambda_ae=lambda_ae, lambda_reg=lambda_reg)

    # # Save the best model
    # torch.save(best_model, f"model/{model_name}_best_model.pth")

    ## Save latent representations
    # model.load_state_dict(best_model)
    print('model evaluating')
    model.eval()
    with torch.no_grad():
        results = model(X_test_tensor)
        X_train_latent = model(X_train_tensor)[0].cpu().detach().numpy()
        X_val_latent = results[0].cpu().detach().numpy()

    plt.subplot(5, 2, i)
    sns.histplot(results[1].cpu().numpy().flatten())
    plt.title(f'{model_name} - Reconstructed Distribution')
    i+=1
    plt.subplot(5, 2, i )
    sns.histplot(results[1].cpu().numpy().flatten()[results[1].cpu().numpy().flatten()>0])
    plt.title(f'{model_name} - Reconstructed Distribution > 0')
    i+=1

    lst_train_loss_non_zero, lst_train_loss_presence, lst_train_loss_reg, lst_train_r2, lst_val_loss_non_zero, lst_val_loss_presence, lst_val_loss_reg, lst_val_r2 = model_history
    
    # Log losses
    dct_history[model_name] = {
            "train_loss_non_zero": np.array(lst_train_loss_non_zero),
            "train_loss_presence": np.array(lst_train_loss_presence),
            "train_loss_reg": np.array(lst_train_loss_reg),
            "train_r2": np.array(lst_train_r2),
            "val_loss_non_zero": np.array(lst_val_loss_non_zero),
            "val_loss_presence": np.array(lst_val_loss_presence),
            "val_loss_reg": np.array(lst_val_loss_reg),
            "val_r2": np.array(lst_val_r2)
        }
    
    # log predicted values
    dct_y_pred[model_name] = results[3].cpu().detach().numpy()

    # Train XGBoost model on latent features
    print("prediction using embedding by", model_name)

    # Train XGBoost model on latent features
    xgb_model = XGBRegressor(n_estimators=500, learning_rate=0.05, max_depth=6)
    xgb_model.fit(X_train_latent, y_train)  # Use the latent features as input for regression

    # Predict on the validation set
    dct_y_pred[model_name + '_xgb'] =  xgb_model.predict(X_val_latent)

plt.tight_layout();

In [0]:
loss_l1.cpu()

In [0]:
plt.figure(figsize=(14, 10))
for i, (model_name, history) in enumerate(dct_history.items()):
    plt.subplot(4, 2, 2 * i + 1)
    plt.plot(history['train_loss_ae'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_loss_ae'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Reconstruction Loss')
    plt.text
    plt.xlabel('Epochs')
    plt.ylabel('Reconstruction Loss')

    plt.subplot(4, 2, 2 * i + 2)
    plt.plot(history['train_loss_reg'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_loss_reg'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(model_name)
    plt.xlabel('Epochs')
    plt.title(f'{model_name} - Regression Loss')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.grid(True)
    plt.tight_layout()

In [0]:
plt.figure(figsize=(10, 12))


for i, (model_name, y_pred) in enumerate(dct_y_pred.items()):
    
    y_pred = y_pred.squeeze()

    plt.subplot(4, 2, i + 1)

    plt.scatter(y_test, y_pred, alpha=0.5, s=6)
    plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)

    plt.xlabel('True Values')
    plt.ylabel('Predicted Values')
    plt.title(f'{model_name} lambdaAE={lambda_ae}, lambdaReg={lambda_reg}')


    # mask = (y_pred <= 100) & (y_pred >= 0)
    # r2 = 1 - (np.sum((y_test[mask] - y_pred[mask]) ** 2) / np.sum((y_test[mask] - np.mean(y_test[mask])) ** 2))
    r2 = r2_score(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    plt.text(0.05, 0.95, f'R^2: {r2:.2f}\nMSE: {mse:.2f}', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top')

    plt.tight_layout()