In [0]:
# from torch.utils.tensorboard import SummaryWriter
# log_dir = "runs/test"
# writer = SummaryWriter(log_dir = log_dir)

In [0]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split


from xgboost import XGBRegressor
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


## Load data

In [0]:
df_age = pd.read_csv("../data/age.csv", header=None, index_col=0, sep='\t')
y = df_age.to_numpy().reshape(-1, 1).flatten()

y.shape

In [0]:
y_class = y//10
y_class[y_class==9] = 8 

In [0]:
sns.histplot(y_class)

In [0]:
train_idx, test_idx = train_test_split(range(len(y)), test_size=0.2, stratify=y_class, random_state=42)  # split the data once so that index keeps the same for different types of X
test_idx[:10]

[4436, 2292, 4448, 4903, 2378, 842, 2625, 3097, 4898, 1911]

In [0]:
X = pd.read_csv("../data/processed_log_drop08_scaled.csv", header=0, index_col=0, sep='\t').loc[df_age.index, :].to_numpy()
X_train, X_test, y_train, y_test = X[train_idx], X[test_idx], y[train_idx], y[test_idx]

# Convert to tensors and move to device
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).to(device)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

input_dim = X.shape[1]
print(X_train.max())
print(input_dim)

In [0]:
# get the proportion of zeros in X
p_zero = (X_train == 0).mean()
print(p_zero)
print(p_zero/(1-p_zero))
print(1/(1-p_zero))
print(1/p_zero)

## Models

In [0]:
from models_3_heads import *

# todo  add noise (e.g., zero out random values) to the input during training and train the network to reconstruct the original data. This encourages the model to learn robust features despite the sparse noise.

## Loss Functions

In [0]:
# # todo track reconstruction error (e.g., MSE, BCE) only on the non-zero entries
# reconstruction_loss = nn.BCELoss(reduction='none')(decoded, inputs)
# non_zero_mask = inputs > 0  # Mask to focus only on non-zero entries
# loss = reconstruction_loss * non_zero_mask

# POS_WEIGHT = torch.tensor([p_zero/(1-p_zero)]).to(device)  # theoretical weight 15.44
# POS_WEIGHT = torch.tensor([1000]).to(device)  # try an extreme value
# POS_WEIGHT = torch.tensor([(1-p_zero)/p_zero]).to(device)  # in case I inversed it
POS_WEIGHT = torch.tensor([1]).to(device)  # no weight
print(POS_WEIGHT)

WEIGHTS = torch.tensor([p_zero, 1-p_zero]).to(device)

def get_losses(y_true, X_true, latent, presence, reconstructed):
    # loss_l1 = torch.mean(torch.abs(latent)) # regularization term
    # print('loss_l1', loss_l1) # check
    mask_non_zero = (X_true != 0).float()
    loss_non_zero = nn.MSELoss(reduction='none')(reconstructed, X_true) * mask_non_zero
    # loss_presence = nn.BCELoss(weight=WEIGHTS)(presence, mask_non_zero).mean()
    loss_presence = nn.BCEWithLogitsLoss(pos_weight=POS_WEIGHT, reduction='none')(presence, X_true).mean()  # use mask as presence absence matrix  # todo use focal instead of bce for balanced data?
    return 0, loss_non_zero.mean(), loss_presence


## Training Function

In [0]:
# Training function

def train_model(model, model_name, train_loader, test_loader, optimizer, reg_loss_function, lambda_ae, lambda_reg, alpha_l1 = 0, num_epochs=50, patience=10):
    min_val_loss = float('inf')
    best_model = None  # for early stopping
    early_stopping_counter = 0
    lst_train_loss_non_zero = []
    lst_train_loss_presence = []
    lst_train_loss_reg = []
    lst_train_r2 = []
    
    lst_val_loss_non_zero = []
    lst_val_loss_presence = []
    lst_val_loss_reg = []
    lst_val_r2 = []

    for epoch in range(num_epochs):
        model.train()
        train_loss_non_zero = 0.0
        train_loss_presence = 0.0
        train_loss_reg = 0.0
        train_loss_combined = 0.0
        train_r2 = 0.0
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            if model_name.endswith('VAE'):
                latent, presence, reconstructed, regression_output, mu, logvar = model(X_batch)
                loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            else:
                latent, presence, reconstructed, regression_output = model(X_batch)
                loss_kl = 0.0
           
            loss_l1, loss_non_zero, loss_presence = get_losses(y_batch, X_batch, latent, presence, reconstructed)
            # loss_ae = 0.5 * loss_non_zero + 0.5 * loss_presence + loss_kl + alpha_l1 * loss_l1 # Reconstruction loss
            loss_ae = ALPHA * loss_presence + BETA * loss_non_zero +  + loss_kl
            loss_reg = reg_loss_function(regression_output, y_batch)
            combined_loss = lambda_ae * loss_ae  + lambda_reg * loss_reg
            combined_loss.backward()
            optimizer.step()
            
            train_loss_non_zero += loss_non_zero.item()
            train_loss_presence += loss_presence.item()
            train_loss_reg += loss_reg.item()
            train_loss_combined += combined_loss.item()
            train_r2 += r2_score(y_batch.cpu().detach().numpy(), regression_output.cpu().detach().numpy())
        train_loss_combined = train_loss_combined/len(train_loader)


        lst_train_loss_non_zero.append(train_loss_non_zero/len(train_loader))
        lst_train_loss_presence.append(train_loss_presence/len(train_loader))
        lst_train_loss_reg.append(train_loss_reg/len(train_loader))
        lst_train_r2.append(train_r2/len(train_loader))


        # Validation loss
        model.eval()
        val_loss_non_zero = 0.0
        val_loss_presence = 0.0
        val_loss_reg = 0.0
        val_combined_loss = 0.0
        val_r2 = 0.0
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                if model_name.endswith('VAE'):
                    latent, presence, reconstructed, regression_output, mu, logvar = model(X_batch)
                    loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 
                else:                
                    latent, presence, reconstructed, regression_output = model(X_batch)
                    loss_kl = 0.0

           
                loss_l1, loss_non_zero, loss_presence = get_losses(y_batch, X_batch, latent, presence, reconstructed)
                loss_ae = ALPHA * loss_presence + BETA * loss_non_zero +  + loss_kl
                loss_reg = reg_loss_function(regression_output, y_batch)
                combined_loss = lambda_ae * loss_ae  + lambda_reg * loss_reg

                val_loss_non_zero += loss_non_zero.item()
                val_loss_presence += loss_presence.item()
                val_loss_reg += loss_reg.item()
                val_combined_loss += combined_loss.item()
                val_r2 += r2_score(y_batch.cpu().detach().numpy(), regression_output.cpu().detach().numpy())
            val_combined_loss = val_combined_loss/len(test_loader)


            lst_val_loss_non_zero.append(val_loss_non_zero/len(test_loader))
            lst_val_loss_presence.append(val_loss_presence/len(test_loader))
            lst_val_loss_reg.append(val_loss_reg/len(test_loader))
            lst_val_r2.append(val_r2/len(test_loader))

            # writer.add_scalar(f'Train/Loss_combined_{model_name}', train_loss_combined, epoch)
            # writer.add_scalar(f'Train/Loss_presence_{model_name}', train_loss_presence/len(train_loader), epoch)
            # writer.add_scalar(f'Train/Loss_non_zero_{model_name}', train_loss_non_zero/len(train_loader), epoch)

            # writer.add_scalar(f'Val/Loss_combined_{model_name}', val_combined_loss, epoch)
            # writer.add_scalar(f'Val/Loss_presence_{model_name}', val_loss_presence/len(test_loader), epoch)
            # writer.add_scalar(f'Val/Loss_non_zero_{model_name}', val_loss_non_zero/len(test_loader), epoch)

        # writer.close()

        if epoch % 5 == 0:
            print(f'{model_name} Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss_combined:.4f}, Val Loss: {val_combined_loss:.4f}')

        # # Early stopping
        # if val_loss < min_val_loss:
        #     min_val_loss = val_loss
        #     best_model = model.state_dict()
        #     early_stopping_counter = 0
        # else:
        #     early_stopping_counter += 1
        #     if early_stopping_counter >= patience:
        #         print("Early stopping")
        #         break

    return lst_train_loss_non_zero, lst_train_loss_presence, lst_train_loss_reg, lst_train_r2, lst_val_loss_non_zero, lst_val_loss_presence, lst_val_loss_reg, lst_val_r2

# Start

In [0]:
# parameters to define before each experiment

x_type = 'log'  # 'abundance' or 'presence' or 'log'  # deprecated
latent_dim = 20 # García-Jiménez et al. 2021 used latent_dim 10 to represent 717 taxa https://academic.oup.com/bioinformatics/article/37/10/1444/5988714

num_epochs= 50
patience = num_epochs//4

ALPHA = 0 # decoder1 - presence
BETA = 1  # decoder2 - nonzero values

# ALPHA = 1 - p_zero # balanced by zero probability
# BETA = p_zero  

lambda_ae = 1
lambda_reg = 1 - lambda_ae

ALPHA, BETA

In [0]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from sklearn.metrics import f1_score
import time

models = [
    ShallowVAE(input_dim, latent_dim),
    DeepVAE(input_dim, latent_dim),
    ShallowAutoencoder(input_dim, latent_dim),
    DeepAutoencoder(input_dim, latent_dim),

]

model_names = [
    "ShallowVAE",
    "DeepVAE",
    "ShallowAutoencoder",
    "DeepAutoencoder",

]


dct_history = dict()
dct_y_pred = dict()

plt.figure(figsize=(18, 15))
i = 1
plt.subplot(5, 3, i)
sns.histplot(X_test.flatten(), bins=100)
plt.xlim(-0.35, 1.5)
plt.ylim(0, 1.35e6)
plt.title('Test Set - Original Distribution')
i+=1
plt.subplot(5, 3, i )
plt.xlim(0, 1.5)
plt.ylim(0, 12000)
sns.histplot(X_test.flatten()[X_test.flatten()>0])
plt.title('Test Set - Original Distribution > 0')
i+=1
plt.subplot(5, 3, i)
sns.histplot(X_test.flatten()>0, bins=100)
plt.title('Test Set - Presence Distribution')
i+=1


for model, model_name in zip(models, model_names):
    t0 = time.time()
    print(f"Training {model_name}")

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # training
    model_history = train_model(model, model_name, train_loader, test_loader, optimizer=optimizer,reg_loss_function=nn.MSELoss(), num_epochs=num_epochs, patience=patience, lambda_ae=lambda_ae, lambda_reg=lambda_reg)

    # # Save the best model
    # torch.save(best_model, f"model/{model_name}_best_model.pth")

    ## Save latent representations
    # model.load_state_dict(best_model)
    t1 = time.time()
    print(f'\t\tmodel trained, time used = {t1 - t0} seconds')
    print('model evaluating')
    model.eval()
    with torch.no_grad():
        results = model(X_test_tensor)
        X_train_latent = model(X_train_tensor)[0].cpu().detach().numpy()
    
    t2 = time.time()
    print(f'\t\tmodel applied on test set, time used = {t2 - t1} seconds')

    plt.subplot(5, 3, i)
    sns.histplot(results[2].cpu().numpy().flatten())
    plt.xlim(-0.35, 1.5)
    # plt.ylim(0, 1.35e6)
    plt.title(f'{model_name} - Reconstructed Distribution')
    i+=1
    plt.subplot(5, 3, i )
    plt.xlim(0, 1.5)
    plt.ylim(0, 12000)
    sns.histplot(results[2].cpu().numpy().flatten()[results[2].cpu().numpy().flatten()>0])
    plt.title(f'{model_name} - Reconstructed Distribution > 0')
    i+=1
    plt.subplot(5, 3, i )
    # plt.xlim(0, 1.5)
    # plt.ylim(0, 12000)
    sns.histplot(nn.Sigmoid()(results[1]).cpu().numpy().flatten()>0.5, bins=100)
    plt.title(f'{model_name} - Presence Distribution')
    plt.text(0.5, 10000, f'F1 score = {f1_score(X_test.flatten()>0, nn.Sigmoid()(results[1]).cpu().numpy().flatten()>0.5):.2f}', fontsize=14)
    i+=1

    lst_train_loss_non_zero, lst_train_loss_presence, lst_train_loss_reg, lst_train_r2, lst_val_loss_non_zero, lst_val_loss_presence, lst_val_loss_reg, lst_val_r2 = model_history
    
    # Log losses
    dct_history[model_name] = {
            "train_loss_non_zero": np.array(lst_train_loss_non_zero),
            "train_loss_presence": np.array(lst_train_loss_presence),
            "train_loss_reg": np.array(lst_train_loss_reg),
            "train_r2": np.array(lst_train_r2),
            "val_loss_non_zero": np.array(lst_val_loss_non_zero),
            "val_loss_presence": np.array(lst_val_loss_presence),
            "val_loss_reg": np.array(lst_val_loss_reg),
            "val_r2": np.array(lst_val_r2)
        }
    
    # log predicted values
    dct_y_pred[model_name] = results[3].cpu().detach().numpy()

    # Train XGBoost model on latent features
    print("prediction using embedding by", model_name)
    
    # Train XGBoost model on latent features
    xgb_model = XGBRegressor(n_estimators=200, learning_rate=0.05, max_depth=4)
    xgb_model.fit(X_train_latent, y_train)  # Use the latent features as input for regression
    
    X_val_latent = results[0].cpu().detach().numpy()
    # Predict on the validation set
    dct_y_pred[model_name + '_xgb'] =  xgb_model.predict(X_val_latent)
    
    t3 = time.time()
    print(f'\t\tXGB trained and tested, time used = {t3 - t2} seconds')

plt.tight_layout();

In [0]:
f1_score(X_test.flatten()>0, nn.Sigmoid()(results[1]).cpu().numpy().flatten()>0.5)

In [0]:
nn.Sigmoid()(results[1])

In [0]:
sns.histplot(nn.Sigmoid()(results[1]).cpu().numpy().flatten(), bins=100)

In [0]:
POS_WEIGHT

In [0]:
import gc
gc.collect()

In [0]:
import gc
import torch
for i in range(3):
    print(gc.collect())
    torch.cuda.empty_cache()
    print(gc.collect())

In [0]:
plt.figure(figsize=(15, 10))
for i, (model_name, history) in enumerate(dct_history.items()):
    plt.subplot(4, 3, 3 * i + 1)
    plt.plot(history['train_loss_non_zero'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_loss_non_zero'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Non-Zero Reconstruction Loss')
    plt.text
    plt.xlabel('Epochs')
    plt.ylabel('Loss')

    plt.subplot(4, 3, 3 * i + 2)
    plt.plot(history['train_loss_presence'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_loss_presence'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(model_name)
    plt.xlabel('Epochs')
    plt.title(f'{model_name} - Presence Loss')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.grid(True)


    plt.subplot(4, 3, 3 * i + 3)
    plt.plot(history['train_loss_reg'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_loss_reg'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(model_name)
    plt.xlabel('Epochs')
    plt.title(f'{model_name} - Regression Loss')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.grid(True)
plt.tight_layout()

In [0]:
plt.figure(figsize=(10, 12))


for i, (model_name, y_pred) in enumerate(dct_y_pred.items()):
    
    y_pred = y_pred.squeeze()

    plt.subplot(4, 2, i + 1)

    plt.scatter(y_test, y_pred, alpha=0.5, s=6)
    plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)

    plt.xlabel('True Values')
    plt.ylabel('Predicted Values')
    plt.title(f'{model_name} lambdaAE={lambda_ae}, lambdaReg={lambda_reg}')


    r2 = r2_score(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    plt.text(0.05, 0.95, f'R^2: {r2:.2f}\nMSE: {mse:.2f}', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top')

    plt.tight_layout()

In [0]:
log_dir = "runs/test"

In [0]:
# !tensorboard --logdir log_dir

In [0]:
# !tensorboard --logdir $log_dir

In [0]:
# !load_ext tensorboard
# !tensorboard --logdir $log_dir -- port 6006