In [0]:
# %load_ext tensorboard
# from torch.utils.tensorboard import SummaryWriter


In [0]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split


from xgboost import XGBRegressor
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


## Load data

In [0]:
df_age = pd.read_csv("../data/age.csv", header=None, index_col=0, sep='\t')
df_age = df_age[df_age[1]>=18]  # todo remove after test 
y = df_age.to_numpy().reshape(-1, 1).flatten()

y.shape

In [0]:
y_class = y//10
y_class[y_class==9] = 8 

In [0]:
sns.histplot(y_class)

In [0]:
train_idx, test_idx = train_test_split(range(len(y)), test_size=0.2, stratify=y_class, random_state=42)  # split the data once so that index keeps the same for different types of X
test_idx[:10]

[4436, 2292, 4448, 4903, 2378, 842, 2625, 3097, 4898, 1911]

In [0]:
X = pd.read_csv("../data/processed_genus_log_drop08_scaled.csv", header=0, index_col=0, sep='\t').loc[df_age.index, :].to_numpy()
X_train, X_test, y_train, y_test = X[train_idx], X[test_idx], y[train_idx], y[test_idx]

# Convert to tensors and move to device
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).to(device)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

input_dim = X.shape[1]
print(X_train.max())
print(input_dim)

In [0]:
# get the proportion of zeros in X
p_zero = (X_train == 0).mean()
print(p_zero)
print(p_zero/(1-p_zero))
print(1/(1-p_zero))
print(1/p_zero)

In [0]:
xgb_model = XGBRegressor(n_estimators=200, learning_rate=0.05, max_depth=4)

X_train_zeros = X_train == 0
y_train = y_train
xgb_model.fit(X_train_zeros, y_train)

X_test_zeros = X_test == 0
y_pred = xgb_model.predict(X_test_zeros)


In [0]:
# Calculate R2 score and MSE
r2 = r2_score(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)

# Plot y_pred vs y_test
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.5)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title(f'Actual vs Predicted\nR2: {r2:.2f}, MSE: {mse:.2f}')
plt.show()

## Models

In [0]:
from models_vanilla import *
from models_attention import *

# todo  add noise (e.g., zero out random values) to the input during training and train the network to reconstruct the original data. This encourages the model to learn robust features despite the sparse noise.

In [0]:
# from sklearn.preprocessing import QuantileTransformer

# qt = QuantileTransformer(output_distribution='normal').fit(X_train[X_train!= 0].reshape(-1, 1))

# def postprocess(reconstructed):
#     reconstructed = reconstructed.cpu().numpy()
#     print('reconstructed cpu shape', reconstructed.shape)
#     mask = (reconstructed != 0)
#     print('mask shape', mask.shape)
#     print('masked reconstructed shape', reconstructed[mask].shape)
#     nonzero_transformed = qt.transform(reconstructed[mask].reshape(-1, 1)).reshape(-1)
#     print('nonzero_transformed shape', nonzero_transformed.shape)
#     reconstructed[mask] = nonzero_transformed
#     print("reconstructed shape", reconstructed.shape)
#     return torch.tensor(reconstructed,dtype=torch.float32).to(device)

In [0]:
def postprocess(x):  # todo finish the postprocess function later
    return x

## Loss Functions

In [0]:
def masked_mse_loss(reconstructed, X_true):
    mask = (X_true != 0).float()
    return torch.mean(mask * (nn.Sigmoid()(reconstructed) - X_true)**2)

In [0]:
def masked_kl_loss(mu, logvar, X_true):
    mask = (X_true !=0).float()
    kl_per_dim = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    non_zero_ratio = mask.mean(dim=1)
    kl_weight = 0.1 + 0.9 * (1-non_zero_ratio.unsqueeze(1))
    return torch.mean(kl_weight * kl_per_dim)

## Training Function

In [0]:
# Training function

def train_model(model, model_name, train_loader, test_loader, num_epochs_1=100, num_epochs_2 = 50, patience=10):
    lst_train_nonzero_loss = []
    lst_train_presence_loss = []
    lst_val_nonzero_loss = []
    lst_val_presence_loss = []
    lst_train_loss = []
    lst_val_loss = []

    ######################################
    #### training for presence by BCE ####
    ######################################

    min_val_loss = float('inf')
    best_model = None 
    early_stopping_counter = 0
    optimizer1 = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(num_epochs_1):
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        for X_batch, y_batch in train_loader:
            optimizer1.zero_grad()
            if model_name.endswith('VAE'):
                latent, reconstructed,  mu, logvar = model(X_batch)
            else:
                latent, reconstructed = model(X_batch)
            with torch.no_grad():
                reconstructed = postprocess(reconstructed)
            loss_presence = nn.BCEWithLogitsLoss()(reconstructed, (X_batch !=0).float())
            loss_presence.backward()
            optimizer1.step()            
            train_loss += loss_presence.item()
        train_loss = train_loss/len(train_loader)
        lst_train_loss.append(train_loss)
        lst_train_presence_loss.append(train_loss)

        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                if model_name.endswith('VAE'):
                    latent, reconstructed,  mu, logvar = model(X_batch)
                else:
                    latent, reconstructed = model(X_batch)
                reconstructed = postprocess(reconstructed)
                loss_presence = nn.BCEWithLogitsLoss()(reconstructed, (X_batch !=0).float())
                val_loss += loss_presence.item()                
            val_loss = val_loss/len(test_loader)
            lst_val_loss.append(val_loss)
            lst_val_presence_loss.append(val_loss)

        if epoch % 5 == 0:
            print(f'{model_name} Presence Epoch {epoch+1}/{num_epochs_1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        # Early stopping
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            best_model = model.state_dict()
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print("Early stopping at epoch", epoch+1)
                break


    ###################################################
    #### training for nonzero values by masked MSE ####
    ###################################################

    model.load_state_dict(best_model)

    # freeze encoder parameters
    # last_layer_id = [name.split('.')[0] for name, _ in model.named_parameters()][-1]
    # for name, param in model.encoder.named_parameters():
    #     if not last_layer_id in name:  # freeze all but last layer
    #     # if '0' in name:  # freeze first layer
    #         print(f"freeze {name} of encoder")
    #         param.requires_grad = False

    min_val_loss = float('inf')
    early_stopping_counter = 0
    optimizer2 = optim.Adam(model.parameters(), lr=0.0001)
    for epoch in range(num_epochs_2):
        model.train()
        train_loss = 0.0
        train_nonzero_loss = 0.0
        train_presence_loss = 0.0
        val_loss = 0.0
        val_nonzero_loss = 0.0
        val_presence_loss = 0.0
        # alpha = max(0, 1 - epoch / 20)
        alpha = 0.5 * (1 + np.cos(np.pi * epoch/num_epochs_2))
        for X_batch, y_batch in train_loader:
            optimizer2.zero_grad()
            
            if model_name.endswith('VAE'):
                latent, reconstructed,  mu, logvar = model(X_batch)
                loss_kl = masked_kl_loss(mu, logvar, X_batch)
            else:
                latent, reconstructed = model(X_batch)
                loss_kl = 0
            with torch.no_grad():
                reconstructed = postprocess(reconstructed)
            loss_presence = nn.BCEWithLogitsLoss()(reconstructed, (X_batch !=0).float())
            loss_non_zero = masked_mse_loss(reconstructed, X_batch) + loss_kl
            loss = alpha * loss_presence + (1 - alpha) * loss_non_zero
            loss.backward()
            optimizer2.step()
            train_loss += loss.item()
            train_nonzero_loss += loss_non_zero.item()
            train_presence_loss += loss_presence.item()

        train_loss = train_loss/len(train_loader)
        lst_train_loss.append(train_loss)
        lst_train_presence_loss.append(train_presence_loss/len(train_loader))
        lst_train_nonzero_loss.append(train_nonzero_loss/len(train_loader))

        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                if model_name.endswith('VAE'):
                    latent, reconstructed,  mu, logvar = model(X_batch)
                    loss_kl = masked_kl_loss(mu, logvar, X_batch)
                else:
                    latent, reconstructed = model(X_batch)
                    loss_kl = 0
                reconstructed = postprocess(reconstructed)
                loss_presence = nn.BCEWithLogitsLoss()(reconstructed, (X_batch !=0).float())
                loss_non_zero = masked_mse_loss(reconstructed, X_batch) + loss_kl
                loss = alpha * loss_presence + (1 - alpha) * loss_non_zero
                val_loss += loss.item()
                val_nonzero_loss += loss_non_zero.item()
                val_presence_loss += loss_presence.item()

            val_loss = val_loss/len(test_loader)
            lst_val_loss.append(val_loss)
            lst_val_presence_loss.append(val_presence_loss/len(test_loader))
            lst_val_nonzero_loss.append(val_nonzero_loss/len(test_loader))

        if epoch % 5 == 0:
            print(f'{model_name} NonZero Epoch {epoch+1}/{num_epochs_2}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        # Early stopping
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            best_model = model.state_dict()
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print("Early stopping")
                break
    return lst_train_loss, lst_val_loss, lst_train_presence_loss, lst_train_nonzero_loss, lst_val_presence_loss, lst_val_nonzero_loss, best_model

# Start

In [0]:
# parameters to define before each experiment

latent_dim = 20 # García-Jiménez et al. 2021 used latent_dim 10 to represent 717 taxa https://academic.oup.com/bioinformatics/article/37/10/1444/5988714

num_epochs_1 = 50
num_epochs_2 = 200
patience = 10


In [0]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from sklearn.metrics import f1_score
import time

models = [
    AttentionAEend(input_dim, latent_dim),
    AttentionAEmid(input_dim, latent_dim),
    AttentionAEbegin(input_dim, latent_dim),
    DeepShallowerAutoencoder(input_dim, latent_dim),
    # DeepShallowAutoencoder(input_dim, latent_dim),
    # # ShallowVAE(input_dim, latent_dim),
    # # DeepVAE(input_dim, latent_dim),
    # ShallowAutoencoder(input_dim, latent_dim),
    # DeepAutoencoder(input_dim, latent_dim),


]

model_names = [
    "AttentionAEend",
    "AttentionAEmid",
    "AttentionAEbegin",
    "DeepShallowerAE",
    # "DeepShallowAE",
    # # "ShallowVAE",
    # # "DeepVAE",
    # "ShallowAutoencoder",
    # "DeepAutoencoder",
]

# log_dirs = [f"run/{model_name}" for model_name in model_names]
# dct_writer = dict()


dct_history = dict()
dct_y_pred = dict()
dct_latent_vectors = dict()
dct_X_reconstructed = dict()

plt.figure(figsize=(18, 15))
i = 1
plt.subplot(5, 2, i)
sns.histplot(X_test.flatten(), bins=100)
plt.xlim(-0.35, 1.5)
plt.ylim(0, 1.35e6)
plt.title('Test Set - Original Distribution')
i+=1
plt.subplot(5, 2, i )
plt.xlim(0, 1.5)
plt.ylim(0, 12000)
sns.histplot(X_test.flatten()[X_test.flatten()>0], bins=100)
plt.title('Test Set - Original Distribution > 0')
i+=1


for model, model_name in zip(models, model_names):
    log_dir = f"run/{model_name}"
    # dct_writer[model_name] = SummaryWriter(log_dir=log_dir)

    t0 = time.time()
    print(f"Training {model_name}")

    model.to(device)
    # dct_writer[model_name].add_graph(model, X_train_tensor)
    # dct_writer[model_name].close()
    

    # training
    lst_train_loss, lst_val_loss, lst_train_presence_loss, lst_train_nonzero_loss, lst_val_presence_loss, lst_val_nonzero_loss, best_model = train_model(model, model_name, train_loader, test_loader, num_epochs_1=num_epochs_1, num_epochs_2=num_epochs_2, patience=patience)

    model.load_state_dict(best_model)
    t1 = time.time()
    print(f'\t\t ------------ model trained, time used = {t1 - t0} seconds ------------')
    print('model evaluating')
    model.eval()
    with torch.no_grad():
        model_output = model(X_test_tensor)
        X_test_latent, X_test_reconstructed = model_output[0], torch.sigmoid(model_output[1])
        X_train_latent = model(X_train_tensor)[0].cpu().detach().numpy()
    
    t2 = time.time()
    print(f'\t\t ------------ model applied on test set, time used = {t2 - t1} seconds ------------')

    plt.subplot(5, 2, i)
    sns.histplot(X_test_reconstructed.cpu().numpy().flatten(), bins=100)
    plt.xlim(-0.35, 1.5)
    # plt.ylim(0, 1.35e6)
    plt.title(f'{model_name} - Reconstructed Distribution')
    i+=1
    plt.subplot(5, 2, i )
    plt.xlim(0, 1.5)
    plt.ylim(0, 12000)
    sns.histplot(X_test_reconstructed.cpu().numpy().flatten()[X_test_reconstructed.cpu().numpy().flatten()>0], bins=100)
    plt.title(f'{model_name} - Reconstructed Distribution > 0')
    i+=1
    # plt.subplot(5, 3, i )
    # sns.histplot(nn.Sigmoid()(results[1]).cpu().numpy().flatten()>0.5, bins=100)
    # plt.title(f'{model_name} - Presence Distribution')
    # plt.text(0.5, 10000, f'F1 score = {f1_score(X_test.flatten()>0, nn.Sigmoid()(results[1]).cpu().numpy().flatten()>0.5):.2f}', fontsize=14)
    # i+=1

    
    # Log losses
    dct_history[model_name] = {
            "train_loss": np.array(lst_train_loss),
            "val_loss": np.array(lst_val_loss),
            "train_presence_loss": np.array(lst_train_presence_loss),
            "train_nonzero_loss": np.array(lst_train_nonzero_loss),
            "val_presence_loss": np.array(lst_val_presence_loss),
            "val_nonzero_loss": np.array(lst_val_nonzero_loss),
        }
    
    # Train XGBoost model on latent features
    print("prediction using embedding by", model_name)
    
    # Train XGBoost model on latent features
    xgb_model = XGBRegressor(n_estimators=200, learning_rate=0.05, max_depth=4)
    xgb_model.fit(X_train_latent, y_train)  # Use the latent features as input for regression

    # Predict on the validation set
    dct_y_pred[model_name] =  xgb_model.predict(X_test_latent.cpu().detach().numpy())
    
    dct_latent_vectors[model_name] = X_test_latent.cpu().detach().numpy()
    dct_X_reconstructed[model_name] = X_test_reconstructed.cpu().detach().numpy()

    
    t3 = time.time()
    print(f'\t\t ------------ XGB trained and tested, time used = {t3 - t2} seconds ------------')

plt.tight_layout();

In [0]:
# for log_dir in log_dirs:
#     %tensorboard --logdir $log_dir

In [0]:
import gc
gc.collect()

In [0]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
# import umap

def plot_latent_space(latent, method='tsne', colorby = None, cbar_label=None):
    if method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42)
    elif method == 'pca':
        reducer = PCA(n_components=2)
    elif method == 'umap':
        reducer = umap.UMAP(n_components=2, random_state=42)
    else:
        raise ValueError("Invalid method")

    reduced = reducer.fit_transform(latent)
    
    scatter = plt.scatter(reduced[:, 0], reduced[:, 1], 
                          c=colorby, cmap='viridis',
                          s=10, alpha=0.6, edgecolors='w', linewidths=0.5)
    plt.colorbar(scatter, label=cbar_label)
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')

In [0]:
plt.figure(figsize=(15, 15))
for i, (model_name, history) in enumerate(dct_history.items()):
    plt.subplot(4, 3,  3*i + 1)
    plt.plot(history['train_loss'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_loss'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Loss')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')

    plt.subplot(4, 3,  3*i + 2)
    plt.plot(history['train_presence_loss'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_presence_loss'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Presence Loss')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')

    plt.subplot(4, 3,  3*i + 3)
    plt.plot(history['train_nonzero_loss'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_nonzero_loss'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Non-zero Loss')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')



plt.tight_layout()

In [0]:
plt.figure(figsize=(5, 12))


for i, (model_name, y_pred) in enumerate(dct_y_pred.items()):
    
    y_pred = y_pred.squeeze()

    plt.subplot(4, 1, i + 1)

    plt.scatter(y_test, y_pred, alpha=0.5, s=6)
    plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)

    plt.xlabel('True Values')
    plt.ylabel('Predicted Values')
    plt.title(model_name)


    r2 = r2_score(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    plt.text(0.05, 0.95, f'R^2: {r2:.2f}\nMSE: {mse:.2f}', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top')

    plt.tight_layout()

In [0]:
non_zero_counts = (X_test != 0).sum(axis=1)

plt.figure(figsize=(15, 15))
i = 1
for model_name, latent_vectors in dct_latent_vectors.items():
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='pca', colorby=non_zero_counts, cbar_label='Number of non-zero Features')
    plt.title(f'{model_name} - PCA')
    i += 1
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='tsne', colorby=non_zero_counts, cbar_label='Number of non-zero Features')
    plt.title(f'{model_name} - TSNE')
    i+=1
plt.tight_layout()

In [0]:
plt.figure(figsize=(15, 15))
i = 1
for model_name, latent_vectors in dct_latent_vectors.items():
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='pca', colorby=y_test, cbar_label='Age')
    plt.title(f'{model_name} - PCA')
    i += 1
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='tsne', colorby=y_test, cbar_label='Age')
    plt.title(f'{model_name} - TSNE')
    i+=1
plt.tight_layout()

In [0]:
# !tensorboard --logdir log_dir

In [0]:
# !tensorboard --logdir $log_dir

In [0]:
# !load_ext tensorboard
# !tensorboard --logdir $log_dir -- port 6006