Prior knowledge
1. diversity : increase in young adults, plateaus arond 40, dicrease in older adults
2. older adults decrease in beneficial bacteria and increase in harmful bacteria




De La Cuesta-Zuluaga, J., Kelley, S., Chen, Y., Escobar, J., Mueller, N., Ley, R., McDonald, D., Huang, S., Swafford, A., Knight, R., & Thackray, V. (2019). Age- and Sex-Dependent Patterns of Gut Microbial Diversity in Human Adults. mSystems, 4. https://doi.org/10.1128/mSystems.00261-19.

In [0]:
# %load_ext tensorboard
# from torch.utils.tensorboard import SummaryWriter


In [0]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split


from xgboost import XGBRegressor
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


## Load data

In [0]:
df_age = pd.read_csv("../data/age.csv", header=None, index_col=0, sep='\t')
df_age = df_age[df_age[1]>=18]  # todo remove after test 
# y = df_age.to_numpy().reshape(-1, 1).flatten()

# y.shape

In [0]:
# X = pd.read_csv("../data/processed_log_drop08_scaled.csv", header=0, index_col=0, sep='\t').loc[df_age.index, :].to_numpy()
df_genus = pd.read_csv("../data/processed_genus_log_drop08_scaled.csv", header=0, index_col=0, sep='\t').loc[df_age.index, :]
X = df_genus.to_numpy()

In [0]:
diversity_genus = df_genus.astype(bool).sum(axis=1)

In [0]:
sns.boxplot(x=df_age[1], y=diversity_genus)

In [0]:
diversity_genus

In [0]:
sample_to_drop = diversity_genus[diversity_genus<10].index
df_genus = df_genus.drop(sample_to_drop)
df_age = df_age.drop(sample_to_drop)
X = df_genus.to_numpy()
y = df_age.to_numpy().reshape(-1, 1).flatten()
print(X.shape, y.shape)

In [0]:
y_class = y//10
y_class[y_class==9] = 8 

In [0]:
train_idx, test_idx = train_test_split(range(len(y)), test_size=0.2, stratify=y_class, random_state=42)  # split the data once so that index keeps the same for different types of X
test_idx[:10]



[3070, 4428, 2088, 328, 675, 563, 555, 3960, 3624, 2427]

In [0]:

X_train, X_test, y_train, y_test = X[train_idx], X[test_idx], y[train_idx], y[test_idx]

# Convert to tensors and move to device
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).to(device)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

input_dim = X.shape[1]
print(X_train.max())
print(input_dim)

In [0]:
# get the proportion of zeros in X
p_zero = (X_train == 0).mean()
print(p_zero)
print(p_zero/(1-p_zero))
print(1/(1-p_zero))
print(1/p_zero)

In [0]:
nrow = 1
ncol = 2
i = 1
plt.figure(figsize=(15, 5))
plt.subplot(nrow, ncol, i)
sns.histplot(X_test.flatten(), bins=10)
plt.xlim(-0.05, 1.05)
plt.ylim(0, 5e5)
plt.title('Test Set - Original Distribution')
i+=1

plt.subplot(nrow, ncol, i )
plt.ylim(0, 10000)
sns.histplot(X_test.flatten(), bins=100)
plt.title('Test Set - Original NonZero Values Distribution')
i+=2

## Models

In [0]:
from models_attention_with_regressor_head import *

# todo  add noise (e.g., zero out random values) to the input during training and train the network to reconstruct the original data. This encourages the model to learn robust features despite the sparse noise.

## Loss Functions

In [0]:
from loss_functions import *

## Training Function

In [0]:
# Training function

def train_model(model, model_name, train_loader, test_loader, num_epochs_1=100, num_epochs_2 = 50, patience=10, gamma=0.1, mask_x_true=False):
    lst_train_nonzero_loss = []
    lst_train_presence_loss = []
    lst_train_pred_loss = []
    lst_train_loss = []
    lst_train_f1 = []
    lst_val_nonzero_loss = []
    lst_val_presence_loss = []
    lst_val_pred_loss = []
    lst_val_loss = []
    lst_val_f1 = []
    ######################################
    #### training for presence by BCE ####
    ######################################

    min_val_loss = float('inf')
    best_model = None 
    early_stopping_counter = 0
    optimizer1 = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(num_epochs_1):
        model.train()
        train_loss = 0.0
        train_pred_loss = 0.0
        train_f1 = 0.0
        val_loss = 0.0
        val_pred_loss = 0.0
        val_f1 = 0.0
        for X_batch, y_batch in train_loader:
            optimizer1.zero_grad()
            latent, decoded1, decoded2, predicted = model(X_batch)
            loss_presence = nn.BCEWithLogitsLoss(reduction='mean')(decoded1, (X_batch !=0).float())
            loss_pred = nn.MSELoss()(predicted, y_batch)
            loss_presence.backward()
            optimizer1.step()
            train_loss += loss_presence.item() * len(X_batch)
            train_pred_loss += loss_pred.item()
            train_f1 += get_f1_score(decoded1, X_batch)
        train_loss = train_loss/len(train_dataset)
        train_pred_loss = train_pred_loss/len(train_loader)
        lst_train_loss.append(train_loss)
        lst_train_presence_loss.append(train_loss)
        lst_train_f1.append(train_f1/len(train_loader))
        # lst_train_pred_loss.append(train_pred_loss)

        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                latent, decoded1, decoded2, predicted = model(X_batch)
                loss_presence = nn.BCEWithLogitsLoss(reduction='mean')(decoded1, (X_batch !=0).float())
                loss_pred = nn.MSELoss()(predicted, y_batch)
                val_loss += loss_presence.item() * len(X_batch)
                val_pred_loss += loss_pred.item()
                val_f1 += get_f1_score(decoded1, X_batch)
            val_loss = val_loss/len(test_dataset)
            val_pred_loss = val_pred_loss/len(test_loader)
            lst_val_loss.append(val_loss)
            lst_val_presence_loss.append(val_loss)
            lst_val_f1.append(val_f1/ len(test_loader))
            # lst_val_pred_loss.append(val_pred_loss)
            

        if epoch % 5 == 0:
            print(f'{model_name} Presence Epoch {epoch+1}/{num_epochs_1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        # Early stopping
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            best_model = model.state_dict()
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print("Early stopping at epoch", epoch+1)
                break


    ###################################################
    #### training for nonzero values by masked MSE ####
    ###################################################

    model.load_state_dict(best_model)

    # freeze encoder parameters
    last_layer_id = [name.split('.')[0] for name, _ in model.named_parameters()][-1]
    for name, param in model.encoder.named_parameters():
        if not last_layer_id in name:  # freeze all but last layer
        # if '0' in name:  # freeze first layer
            print(f"freeze {name} of encoder")
            param.requires_grad = False


    min_val_loss = float('inf')
    early_stopping_counter = 0
    optimizer2 = optim.Adam(model.parameters(), lr=0.0001)
    for epoch in range(num_epochs_2):
        model.train()
        train_loss = 0.0
        train_nonzero_loss = 0.0
        train_presence_loss = 0.0
        train_pred_loss = 0.0
        train_f1 = 0.0
        val_loss = 0.0
        val_nonzero_loss = 0.0
        val_presence_loss = 0.0
        val_pred_loss = 0.0
        val_f1 = 0.0
        alpha = max(0, 1 - epoch / 20)
        # alpha = 0.5 * (1 + np.cos(np.pi * epoch/40))
        for X_batch, y_batch in train_loader:
            optimizer2.zero_grad()
            latent, decoded1, decoded2, predicted = model(X_batch)            
            loss_presence = nn.BCEWithLogitsLoss(reduction='mean')(decoded1, (X_batch !=0).float())
            # mask = (X_batch !=0).float()
            # mask = torch.sigmoid(decoded1)
            mask = torch.sigmoid(decoded1).detach()
            # mask = torch.sigmoid(decoded2)
            loss_non_zero = softly_masked_mse_loss(torch.sigmoid(decoded2), X_batch, mask, mask_x_true=mask_x_true)
            loss_pred = nn.MSELoss()(predicted, y_batch)
            loss = (alpha * loss_presence + (1 - alpha) * loss_non_zero ) *(1 - gamma) + gamma * loss_pred
            loss.backward()
            optimizer2.step()
            train_loss += loss.item() * len(X_batch)
            train_nonzero_loss += (loss_non_zero * torch.sum(mask)).item()
            train_presence_loss += loss_presence.item() * len(X_batch)
            train_pred_loss += loss_pred.item()
            train_f1 += get_f1_score(decoded1, X_batch)
        train_loss = train_loss/len(train_dataset)
        lst_train_loss.append(train_loss)
        lst_train_presence_loss.append(train_presence_loss/len(train_dataset))
        lst_train_nonzero_loss.append(train_nonzero_loss/(X_train != 0).sum())
        lst_train_pred_loss.append(train_pred_loss/len(train_loader))
        lst_train_f1.append(train_f1/len(train_loader))

        model.eval()
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                latent, decoded1, decoded2, predicted = model(X_batch)
                loss_presence = nn.BCEWithLogitsLoss(reduction='mean')(decoded1, (X_batch !=0).float())
                # mask = (X_batch !=0).float()
                # mask = torch.sigmoid(decoded1)
                mask = torch.sigmoid(decoded1).detach()
                # mask = torch.sigmoid(decoded2)
                loss_non_zero = softly_masked_mse_loss(torch.sigmoid(decoded2), X_batch, mask,  mask_x_true=mask_x_true)
                loss_pred = nn.MSELoss()(predicted, y_batch)
                loss = (alpha * loss_presence + (1 - alpha) * loss_non_zero) *(1 - gamma) + gamma * loss_pred
                val_loss += loss.item() * len(X_batch)
                val_nonzero_loss += (loss_non_zero * torch.sum(mask)).item()
                val_presence_loss += loss_presence.item() * len(X_batch)
                val_pred_loss += loss_pred.item()
                val_f1 += get_f1_score(decoded1, X_batch)
            val_loss = val_loss/len(test_dataset)
            lst_val_loss.append(val_loss)
            lst_val_presence_loss.append(val_presence_loss/len(test_dataset))
            lst_val_nonzero_loss.append(val_nonzero_loss/(X_test != 0).sum())
            lst_val_pred_loss.append(val_pred_loss/len(test_loader))
            lst_val_f1.append(val_f1/len(test_loader))

        if epoch % 10 == 0:
            print(f'{model_name} NonZero Epoch {epoch+1}/{num_epochs_2}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        # Early stopping
        if epoch > 21:
            if val_loss < min_val_loss:
                min_val_loss = val_loss
                best_model = model.state_dict()
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1
                if early_stopping_counter >= patience *2:
                    print("Early stopping")
                    break
    return lst_train_loss, lst_val_loss, lst_train_presence_loss, lst_val_presence_loss, lst_train_nonzero_loss, lst_val_nonzero_loss, lst_train_pred_loss, lst_val_pred_loss, lst_train_f1, lst_val_f1, best_model

# Start

In [0]:
# parameters to define before each experiment

latent_dim = 20 # García-Jiménez et al. 2021 used latent_dim 10 to represent 717 taxa https://academic.oup.com/bioinformatics/article/37/10/1444/5988714

num_epochs_1 = 50
num_epochs_2 = 200
patience = 10
gamma = 0  # weight for regressor loss, between 0 and 1
dropout = 0.1

mask_x_true=False

In [0]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from sklearn.metrics import f1_score
import time

models = [
    AttentionAEmidconcat(input_dim, latent_dim, dropout=dropout), 
    # AttentionAEend(input_dim, latent_dim, dropout=dropout),
    # AttentionAEmid(input_dim, latent_dim, dropout=dropout),
    # AttentionAEbegin(input_dim, latent_dim, dropout=dropout),
    # DeepShallowAutoencoder(input_dim, latent_dim),
    # DeepShallowAutoencoder(input_dim, latent_dim),
    # # ShallowVAE(input_dim, latent_dim),
    # # DeepVAE(input_dim, latent_dim),
    # ShallowAutoencoder(input_dim, latent_dim),
    # DeepAutoencoder(input_dim, latent_dim),


]

model_names = [
    "AttentionAEmidconcat",
    # "AttentionAEend",
    # "AttentionAEmid",
    # "AttentionAEbegin",
    # "DeepShallowAE",

]

lst_xgb_models = []

# log_dirs = [f"run/{model_name}" for model_name in model_names]
# dct_writer = dict()


dct_history = dict()
dct_y_pred = dict()
dct_latent_vectors = dict()
dct_metrics = dict()
dct_presence = dict()
dct_nonzero = dict()


for model, model_name in zip(models, model_names):
    # log_dir = f"run/{model_name}"
    # dct_writer[model_name] = SummaryWriter(log_dir=log_dir)

    t0 = time.time()
    print(f"Training {model_name}")

    model.to(device)
    # dct_writer[model_name].add_graph(model, X_train_tensor)
    # dct_writer[model_name].close()
    

    # training
    lst_train_loss, lst_val_loss, lst_train_presence_loss, lst_val_presence_loss, lst_train_nonzero_loss, lst_val_nonzero_loss, lst_train_pred_loss, lst_val_pred_loss, lst_train_f1, lst_val_f1, best_model = train_model(model, model_name, train_loader, test_loader, num_epochs_1=num_epochs_1, num_epochs_2=num_epochs_2, patience=patience, gamma=gamma, mask_x_true=mask_x_true)

    torch.save(best_model, f"model/{model_name}_best_model.pth")

    model.load_state_dict(best_model)
    t1 = time.time()
    print(f'\t\t ------------ model trained, time used = {t1 - t0} seconds ------------')


    print('\nmodel evaluating')
    model.eval()
    with torch.no_grad():
        model_output = model(X_test_tensor)
        X_test_latent = model_output[0].cpu().detach().numpy() 
        X_test_presence = torch.sigmoid(model_output[1]).cpu().detach().numpy() 
        X_test_nonzero = torch.sigmoid(model_output[2]).cpu().detach().numpy() 
        X_train_latent = model(X_train_tensor)[0].cpu().detach().numpy()
    
    t2 = time.time()
    print(f'\t\t ------------ model applied on test set, time used = {t2 - t1} seconds ------------')
    # save model outputs
    dct_presence[model_name] = X_test_presence.flatten()
    dct_nonzero[model_name] = X_test_nonzero.flatten()
    dct_latent_vectors[model_name] = X_test_latent

    # get and save metrics on test set
    f1_test = f1_score(X_test.flatten()>0, X_test_presence.flatten()>0.5)
    mse_test = ((X_test - X_test_presence)**2).mean()
    masked_mse_test = ((X_test!=0) * (X_test_nonzero - X_test)**2).sum()/(X_test!=0).sum()
    soft_masked_mse_test = ((X_test_presence * X_test_nonzero -  X_test_presence * X_test)**2).sum()/X_test_presence.sum()
    dct_metrics[model_name] = {'f1_test': f1_test, 'mse_test': mse_test, 'masked_mse_test': masked_mse_test, 'soft_masked_mse_test': soft_masked_mse_test}
    
    # Log losses
    dct_history[model_name] = {
            "train_loss": np.array(lst_train_loss),
            "val_loss": np.array(lst_val_loss),
            "train_presence_loss": np.array(lst_train_presence_loss),
            "val_presence_loss": np.array(lst_val_presence_loss),
            "train_nonzero_loss": np.array(lst_train_nonzero_loss),
            "val_nonzero_loss": np.array(lst_val_nonzero_loss),
            'train_pred_loss': np.array(lst_train_pred_loss),
            "val_pred_loss": np.array(lst_val_pred_loss),
            "train_f1_score": np.array(lst_train_f1),
            "val_f1_score": np.array(lst_val_f1),
        }
    
    # Train XGBoost model on latent features
    print("\nprediction using embedding by", model_name)
    xgb_model = XGBRegressor(n_estimators=200, learning_rate=0.05, max_depth=4)
    xgb_model.fit(X_train_latent, y_train)  # Use the latent features as input for regression
    lst_xgb_models.append(xgb_model)

    # Predict on the validation set
    dct_y_pred[model_name] =  xgb_model.predict(X_test_latent)
        
    t3 = time.time()
    print(f'\t\t ------------ XGB trained and tested, time used = {t3 - t2} seconds ------------')



# Results

In [0]:
ncol = 4
nrow = len(model_names) 
if gamma != 0:
    ncol += 1
plt.figure(figsize=(4 * ncol, 3 * nrow))
for i, (model_name, history) in enumerate(dct_history.items()):
    plt.subplot(nrow, ncol,  ncol*i + 1)
    plt.plot(history['train_loss'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_loss'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Loss')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')

    plt.subplot(nrow, ncol,  ncol*i + 2)
    plt.plot(history['train_f1_score'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_f1_score'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - F1 score')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')


    plt.subplot(nrow, ncol,  ncol*i + 3)
    plt.plot(history['train_presence_loss'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_presence_loss'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Presence Loss')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')

    plt.subplot(nrow, ncol,  ncol*i + 4)
    plt.plot(history['train_nonzero_loss'], '-', label=f'Train', color='blue', alpha=0.5)
    plt.plot(history['val_nonzero_loss'], '--', label=f'Validation', color='red', alpha=0.5)
    plt.title(f'{model_name} - Non-zero Loss')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')

    if gamma != 0:
        plt.subplot(nrow,  ncol,  ncol*i + 5)
        plt.plot(history['train_pred_loss'], '-', label=f'Train', color='blue', alpha=0.5)
        plt.plot(history['val_pred_loss'], '--', label=f'Validation', color='red', alpha=0.5)
        plt.title(f'{model_name} - Prediction Loss')
        plt.legend()
        plt.xlabel('Epochs')
        plt.ylabel('Loss')


plt.tight_layout()

In [0]:
ncol = 3
nrow = len(model_names) 

plt.figure(figsize=(5 * ncol, 3 * nrow))
i = 1

# plt.subplot(nrow, ncol, i)
# sns.histplot(X_test.flatten(), bins=10)
# plt.xlim(-0.05, 1.05)
# plt.ylim(0, 5e5)
# plt.title('Test Set - Original Distribution')
# i+=1

# plt.subplot(nrow, ncol, i )
# plt.ylim(0, 10000)
# sns.histplot(X_test.flatten(), bins=100)
# plt.title('Test Set - Original NonZero Values Distribution')
# i+=2

# plt.subplot(nrow, ncol, i )
# plt.xlim(0, 1)
# # plt.ylim(0, 10000)
# sns.histplot(X_test.flatten(), bins=100)
# plt.title('Test Set - Original NonZero Values Distribution')
# i+=1



for model_name in model_names:
    X_test_presence = dct_presence[model_name]
    X_test_nonzero = dct_nonzero[model_name]
    f1 = dct_metrics[model_name]['f1_test']
    mse_test = dct_metrics[model_name]['mse_test']
    masked_mse_test = dct_metrics[model_name]['masked_mse_test']
    soft_masked_mse_test = dct_metrics[model_name]['soft_masked_mse_test']

    # plot the metrics
    plt.subplot(nrow, ncol, i )
    sns.histplot(X_test_presence *  X_test_nonzero , bins=10)
    plt.xlim(-0.05, 1.05)
    plt.ylim(0, 5e5)
    plt.title(f'{model_name} - Reconstructed Presence Distribution')
    plt.text(0.5, 300000, f'F1 score = {f1:.2f}', fontsize=14)
    i+=1

    plt.subplot(nrow, ncol, i )
    sns.histplot(X_test_presence *  X_test_nonzero , bins=100)
    plt.xlim(-0.05, 1.05)
    plt.ylim(0, 10000)
    plt.text(0.1, 6000, f'soft masked MSE = {soft_masked_mse_test:.3f}', fontsize=14)
    plt.title(f'{model_name} - Reconstructed Values Distribution')
    i+=1

    plt.subplot(nrow, ncol, i )
    sns.histplot(X_test_nonzero , bins=100)
    plt.xlim(-0.05, 1.05)
    plt.ylim(0, 10000)
    plt.text(0.1, 6000, f'masked MSE = {masked_mse_test:.3f}', fontsize=14)
    plt.title(f'{model_name} - Reconstructed non zero Values Distribution')
    i+=1


    # plt.subplot(nrow, ncol, i )
    # sns.histplot(X_test_presence *  X_test_nonzero , bins=100)
    # plt.xlim(0, 1)
    # # plt.ylim(0, 16000)
    # plt.title(f'{model_name} - Reconstructed NonZero values Distribution')
    # plt.text(0.01, 10000, f'masked MSE = {masked_mse_test:.3f}', fontsize=14)
    # i+=1

plt.tight_layout();

In [0]:
# for log_dir in log_dirs:
#     %tensorboard --logdir $log_dir

In [0]:
import gc
gc.collect()

In [0]:
nrow = len(model_names)
plt.figure(figsize=(5, 4 * nrow))


for i, (model_name, y_pred) in enumerate(dct_y_pred.items()):
    
    y_pred = y_pred.squeeze()

    plt.subplot(nrow, 1, i + 1)

    plt.scatter(y_test, y_pred, alpha=0.5, s=6)
    plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)

    plt.xlabel('True Values')
    plt.ylabel('Predicted Values')
    plt.title(model_name)


    r2 = r2_score(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    plt.text(0.05, 0.95, f'R^2: {r2:.2f}\nMSE: {mse:.2f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')

    plt.tight_layout()

In [0]:
X_test_presence = dct_presence[model_names[0]]
X_test_nonzero = dct_nonzero[model_names[0]] 

# X_test_presence = X_test_nonzero  # when encoder 1 was not changed during training stage 2

In [0]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# Generate confusion matrix
y_true = (X_test != 0).astype(int).flatten()
y_pred = (X_test_presence >= 0.5).astype(int).flatten()
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['False', 'True'], yticklabels=['False', 'True'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix of X_test != 0 and X_test_presence')
plt.show()

In [0]:
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

# Calculate ROC curve
fpr, tpr, _ = roc_curve(y_true, y_pred)
roc_auc = auc(fpr, tpr)
# Calculate Precision-Recall curve
precision, recall, _ = precision_recall_curve(y_true, y_pred)
auprc = average_precision_score(y_true, y_pred)

# Plot ROC and PRC curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))

# ROC curve
ax1.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
ax1.set_xlim([0.0, 1.0])
ax1.set_ylim([0.0, 1.05])
ax1.set_xlabel('False Positive Rate')
ax1.set_ylabel('True Positive Rate')
ax1.set_title('ROC Curve')
ax1.legend(loc="lower right")

# PRC curve
ax2.plot(recall, precision, color='blue', lw=2, label=f'PRC curve (area = {auprc:.2f})')
ax2.set_xlim([0.0, 1.0])
ax2.set_ylim([0.0, 1.05])
ax2.set_xlabel('Recall')
ax2.set_ylabel('Precision')
ax2.set_title('Precision-Recall Curve')
ax2.legend(loc="lower left")

plt.tight_layout()
plt.show()


In [0]:
from sklearn.metrics import mean_squared_error

# Calculate MSE
mse = mean_squared_error(X_test.flatten(), X_test_nonzero.flatten())
masked_mse = mean_squared_error(X_test.flatten()[X_test.flatten() > 0], X_test_nonzero.flatten()[X_test.flatten() > 0])
softly_masked_mse = sum((X_test.flatten()*X_test_presence.flatten() - X_test_nonzero.flatten()*X_test_presence.flatten())**2)/sum(X_test_presence)


plt.figure(figsize=(15, 4 * len(model_names)))

plt.subplot(131)
plt.scatter(X_test, X_test_nonzero, alpha=0.1, s=6)
plt.plot([X_test.min(), X_test.max()], [X_test.min(), X_test.max()], 'r--', lw=2)
plt.xlabel('X_test')
plt.ylabel('X_test_nonzero')
plt.title('Scatterplot of X_test_nonzero vs X_test')
plt.text(0.05, 0.95, f'MSE: {mse:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')

plt.subplot(132)
plt.scatter(X_test.flatten()[X_test.flatten() > 0], X_test_nonzero.flatten()[X_test.flatten() > 0], alpha=0.1, s=6)
plt.plot([X_test.min(), X_test.max()], [X_test.min(), X_test.max()], 'r--', lw=2)
plt.xlabel('X_test')
plt.ylabel('X_test_nonzero')
plt.title('Scatterplot of X_test_nonzero vs X_test (Non-zero)')
plt.text(0.05, 0.95, f'Masked MSE: {masked_mse:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')


plt.subplot(133)
plt.scatter(X_test.flatten(), X_test_nonzero.flatten()*X_test_presence.flatten(), alpha=0.1, s=6)
plt.plot([X_test.min(), X_test.max()], [X_test.min(), X_test.max()], 'r--', lw=2)
plt.xlabel('X_test')
plt.ylabel('X_test_nonzero')
plt.title('Scatterplot of X_test_nonzero vs X_test (Non-zero)')
plt.text(0.05, 0.95, f'Softly Masked MSE: {softly_masked_mse:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')


plt.tight_layout()
plt.show()



In [0]:
(X_test != 0).sum().round(0), (X_test_presence).sum().round(0), X_test.sum().round(0)

In [0]:
# for model_name, model in zip(model_names, lst_models):
#     best_model = torch.load( f"model/{model_name}_best_model.pth")
#     model.load_state_dict(best_model)
#     model.eval()


# Interpretation

In [0]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
# import umap

def plot_latent_space(latent, method='tsne', colorby = None, cbar_label=None):
    if method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42)
    elif method == 'pca':
        reducer = PCA(n_components=2)
    elif method == 'umap':
        reducer = umap.UMAP(n_components=2, random_state=42)
    else:
        raise ValueError("Invalid method")

    reduced = reducer.fit_transform(latent)
    
    scatter = plt.scatter(reduced[:, 0], reduced[:, 1], 
                          c=colorby, cmap='viridis',
                          s=10, alpha=0.6, edgecolors='w', linewidths=0.5)
    plt.colorbar(scatter, label=cbar_label)
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')

In [0]:
non_zero_counts = (X_test != 0).sum(axis=1)

plt.figure(figsize=(15, 4 * len(model_names)))
i = 1
for model_name, latent_vectors in dct_latent_vectors.items():
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='pca', colorby=non_zero_counts, cbar_label='Number of non-zero Features')
    plt.title(f'{model_name} - PCA')
    i += 1
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='tsne', colorby=non_zero_counts, cbar_label='Number of non-zero Features')
    plt.title(f'{model_name} - TSNE')
    i+=1
plt.tight_layout()

In [0]:
plt.figure(figsize=(15, 4*len(model_names)))
i = 1
for model_name, latent_vectors in dct_latent_vectors.items():
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='pca', colorby=y_test, cbar_label='Age')
    plt.title(f'{model_name} - PCA')
    i += 1
    plt.subplot(len(dct_latent_vectors), 2, i)
    plot_latent_space(latent_vectors, method='tsne', colorby=y_test, cbar_label='Age')
    plt.title(f'{model_name} - TSNE')
    i+=1
plt.tight_layout()

In [0]:
df_meta = pd.read_csv("../data/metadata.txt", sep='\t', header=0, index_col=0)
df_meta = df_meta.iloc[test_idx, :]


In [0]:
df_meta['gender']

In [0]:
# explainer = shap.KernelExplainer(model, torch.tensor(X_test, dtype=torch.float32))
# shap_values = explainer.shap_values(X_test)

# # Plotting SHAP values
# shap.summary_plot(shap_values, X_test)

In [0]:
# for model_name, model in zip(model_names, models):
#   explainer = shap.KernelExplainer(model.encoder.predict, input_data)
#   shap_values = explainer.shap_values(input_data)

# # Plotting SHAP values
# shap.summary_plot(shap_values, input_data)

In [0]:
# !tensorboard --logdir log_dir

In [0]:
# !tensorboard --logdir $log_dir

In [0]:
# !load_ext tensorboard
# !tensorboard --logdir $log_dir -- port 6006