difference from the notebooks of age pred:
1. no pred 
2. use all the microbiota data instead of only those who filtered by age
3. vis with different meta data factors


In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score, f1_score
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from utils import *

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


can we cluster enterotypes by the latents?

In [0]:
!pip install torchinfo
from torchinfo import summary

import mlflow
import mlflow.pytorch
mlflow.pytorch.autolog()

## Load data

In [0]:
# df_genus = pd.read_csv("../data/processed_genus_log_drop08_scaled.csv", header=0, index_col=0, sep='\t')
df_genus = pd.read_csv("../data/genus_counts_log_scaled_reduced.csv", header=0, index_col=0, sep='\t')

print(df_genus.shape)
X = df_genus.to_numpy()
input_dim = X.shape[1]

In [0]:
train_idx, test_idx = train_test_split(range(X.shape[0]), test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_idx, test_size=0.3, random_state=42)

X_train = X[train_idx, :]
X_val = X[val_idx, :]
X_test = X[test_idx, :]

# Convert to tensors and move to device
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)

# Create TensorDatasets
train_dataset = TensorDataset(X_train_tensor, X_train_tensor)
val_dataset = TensorDataset(X_val_tensor, X_val_tensor)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

(X_train!=0).sum() / 0.7 / ((X_val!=0).sum() / 0.3)   # expect to be near 1

## Models

In [0]:
from models_attention_with_regressor_head import *

# todo  add noise (e.g., zero out random values) to the input during training and train the network to reconstruct the original data. This encourages the model to learn robust features despite the sparse noise.

## Loss Functions

In [0]:
from loss_functions import *

## Training Function

In [0]:


def train_model(model, model_name, train_loader, val_loader, num_epochs_1=100, num_epochs_2=50, 
                patience1=10, patience2=20, mask_x_true=True, mask_type=3, num_layers_to_freeze=1, filename='model'):
    
    def train_step(loader, optimizer=None, alpha=1.0):
        """Runs a training or validation step."""
        is_train = optimizer is not None
        model.train() if is_train else model.eval()
        total_loss, total_f1, total_nonzero_loss, total_presence_loss = 0.0, 0.0, 0.0, 0.0
        y_true = np.zeros(X_train.size if is_train else X_val.size, dtype=np.float32)
        y_pred = np.zeros(X_train.size if is_train else X_val.size, dtype=np.float32)
        idx = 0
        for X_batch, _ in loader:
            if is_train:
                optimizer.zero_grad()

            latent, decoded1, decoded2, _ = model(X_batch)  
            loss_presence = nn.BCEWithLogitsLoss()(decoded1, (X_batch != 0).float())

            mask = {
                0: (X_batch != 0).float(),
                1: torch.sigmoid(decoded1),
                2: torch.sigmoid(decoded2),
                3: torch.sigmoid(decoded1).detach()
            }.get(mask_type, torch.ones_like(X_batch))

            loss_nonzero = softly_masked_mse_loss(torch.sigmoid(decoded2), X_batch, mask, mask_x_true=mask_x_true)
            loss = alpha * loss_presence + (1 - alpha) * loss_nonzero

            if is_train:
                loss.backward()
                optimizer.step()

            total_loss += loss.item() 
            total_presence_loss += loss_presence.item() 
            total_nonzero_loss += (loss_nonzero * torch.sum(mask)).item()

            y_true[idx:idx + X_batch.numel()] = X_batch.cpu().flatten().numpy()
            y_pred[idx:idx + X_batch.numel()] = torch.sigmoid(decoded1).detach().cpu().flatten().numpy()
            idx += X_batch.numel()

        num_nonzero = (X_train!=0).sum() if is_train else (X_val != 0).sum()
        num_batch = len(loader)
        return {
            "loss": total_loss / num_batch,
            "presence_loss": total_presence_loss / num_batch,
            "nonzero_loss": total_nonzero_loss / num_nonzero,
            "f1": f1_score((y_true >= 0.5).astype(int), (y_pred >= 0.5).astype(int)), 
            "auroc": roc_auc_score((y_true >= 0.5).astype(int), y_pred), 
            "average_precision": average_precision_score((y_true >= 0.5).astype(int), y_pred)
        }

    with mlflow.start_run():
        params = {
            "model_name": model_name, "mask_x_true": mask_x_true, "mask_type": mask_type,
            "num_layers_to_freeze": num_layers_to_freeze, "latent_dim": latent_dim,
            "dropout": dropout, "num_epochs_1": num_epochs_1, "patience1": patience1,
            "num_epochs_2": num_epochs_2, "patience2": patience2, "ALPHA": ALPHA
        }
        mlflow.log_params(params)

        # Training Phase 1: Presence Prediction
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True)
        min_val_loss, best_model, early_stopping_counter = float('inf'), None, 0

        for epoch in range(num_epochs_1):
            train_metrics = train_step(train_loader, optimizer)
            val_metrics = train_step(val_loader)

            # Log metrics
            for key, value in train_metrics.items():
                mlflow.log_metric(f"train_{key}", value, step=epoch)
            for key, value in val_metrics.items():
                mlflow.log_metric(f"val_{key}", value, step=epoch)

            scheduler.step(val_metrics["loss"])
            mlflow.log_metric("learning_rate", optimizer.param_groups[0]['lr'], step=epoch)

            if epoch % 10 == 0:
                print(f"{model_name} Phase 1 Epoch {epoch+1}/{num_epochs_1}, Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")

            stop, min_val_loss, best_model, early_stopping_counter = early_stopping(epoch, val_metrics["loss"], min_val_loss, best_model, model, early_stopping_counter, patience1)
            if stop:
                break

        # Load best model & freeze encoder layers
        model.load_state_dict(best_model)
        for i, (name, param) in enumerate(model.encoder.named_parameters()):
            if i < num_layers_to_freeze:
                param.requires_grad = False
                print(f"Freezing {name}")

        # Training Phase 2: Nonzero Value Prediction
        optimizer = optim.Adam(model.parameters(), lr=0.0001)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True)
        min_val_loss, early_stopping_counter = float('inf'), 0

        for epoch in range(num_epochs_2):
            alpha = max(ALPHA, 1 - epoch / 20)
            train_metrics = train_step(train_loader, optimizer, alpha)
            val_metrics = train_step(val_loader, None, alpha)

            # Log metrics
            for key, value in train_metrics.items():
                mlflow.log_metric(f"train_{key}", value, step=epoch)
            for key, value in val_metrics.items():
                mlflow.log_metric(f"val_{key}", value, step=epoch)

            scheduler.step(val_metrics["loss"])
            mlflow.log_metric("learning_rate", optimizer.param_groups[0]['lr'], step=epoch)

            if epoch % 10 == 0:
                print(f"{model_name} Phase 2 Epoch {epoch+1}/{num_epochs_2}, Train Loss: {train_metrics['loss']:.4f}, Val Loss: {val_metrics['loss']:.4f}")

            if epoch > 20:
                stop, min_val_loss, best_model, early_stopping_counter = early_stopping(epoch, val_metrics["loss"], min_val_loss, best_model, model, early_stopping_counter, patience2)
                if stop:
                    break

        mlflow.pytorch.log_model(model, filename)

        return best_model


In [0]:
import gc
gc.collect()

# Start

In [0]:
# parameters to define before each experiment

latent_dim = 20 # García-Jiménez et al. 2021 used latent_dim 10 to represent 717 taxa https://academic.oup.com/bioinformatics/article/37/10/1444/5988714

num_epochs_1 = 50
num_epochs_2 = 200
patience1 = 10
patience2 = 10
dropout = 0.1

mask_x_true = False
mask_type = 3

num_layers_to_freeze = 1

ALPHA = 0.9


model_name = "AttentionAEmid"
model = AttentionAEmid(input_dim, latent_dim, dropout=dropout)
model.to(device)


In [0]:
summary(model, input_size=X_train_tensor.shape)

In [0]:

# for ALPHA in np.arrange(0, 1, 0.1):
filename = get_model_filename(model_name, mask_x_true, mask_type, num_layers_to_freeze, latent_dim, dropout, num_epochs_1, patience1, num_epochs_2, patience2, ALPHA)
print(filename)

t0 = time.time()




if os.path.exists(f"model/{filename}_best_model.pth"):
    print('model already trained, load model and latents')    
    best_model = torch.load(f"model/{filename}_best_model.pth")
    model.load_state_dict(best_model)
    latent = np.loadtxt(f"latents/latents_{filename}.txt")
else:
    print('model not trained, train model')
    best_model = train_model(model, model_name, train_loader, val_loader, num_epochs_1=num_epochs_1, num_epochs_2=num_epochs_2, patience1=patience1, patience2=patience2, mask_x_true=mask_x_true, num_layers_to_freeze=num_layers_to_freeze, filename=filename)

    # lst_train_loss, lst_val_loss, lst_train_presence_loss, lst_val_presence_loss, lst_train_nonzero_loss, lst_val_nonzero_loss, lst_train_f1, lst_val_f1, best_model = train_results

    # history = {
    #     "train_loss": np.array(lst_train_loss),
    #     "val_loss": np.array(lst_val_loss),
    #     "train_presence_loss": np.array(lst_train_presence_loss),
    #     "val_presence_loss": np.array(lst_val_presence_loss),
    #     "train_nonzero_loss": np.array(lst_train_nonzero_loss),
    #     "val_nonzero_loss": np.array(lst_val_nonzero_loss),
    #     "train_f1_score": np.array(lst_train_f1),
    #     "val_f1_score": np.array(lst_val_f1),
    # }

    # np.save(f"history/history_{filename}.npy", history)
    torch.save(best_model, f"model/{filename}_best_model.pth")
    model.load_state_dict(best_model)
    

t1 = time.time()
print(f'\t\t ------------ model processed, time used = {t1 - t0} seconds ------------')



## model computing graph

In [0]:
%sh

sudo apt-get install -y python3-dev graphviz libgraphviz-dev pkg-config

In [0]:
# from torchviz import make_dot
# output = model(X_test_tensor)
# make_dot(output, params=dict(model.named_parameters())).render(model_name, format="png")


In [0]:
# plot_training_history(history).show()

# Results

In [0]:
# Apply model on test set

print(filename)
model.eval()
with torch.no_grad():
    # X_train_latent = model(X_train_tensor)[0].cpu().detach().numpy()
    model_output = model(X_test_tensor)

In [0]:
# get latents, reconstructions and metrics

latent = model_output[0].cpu().detach().numpy()
np.savetxt(f"latents/latents_{filename}.txt", latent)

decoded1_sigmoid = torch.sigmoid(model_output[1]).cpu().detach().numpy()  # prob of presence
recon_presence = (decoded1_sigmoid >= 0.5).astype(int)
recon_nonzero = torch.sigmoid(model_output[2]).cpu().detach().numpy()
mask_options = {
    0: (X_test != 0).astype(float),
    1: decoded1_sigmoid,
    2: recon_nonzero,
    3: decoded1_sigmoid
}
softmask = mask_options.get(mask_type, ValueError("Invalid mask_type"))


predicted_mask = softmask.copy() # or recon_presence ????
predicted_mask[predicted_mask < 0.5] = 0
# reconstructed = recon_nonzero * predicted_mask # or recon_presence ????
reconstructed = recon_nonzero * recon_presence

f1_test = f1_score(X_test.flatten() > 0, recon_presence.flatten())
x_true_masked_mse_test = ((X_test != 0) * (recon_nonzero - X_test) ** 2).sum() / (X_test != 0).sum()
soft_masked_mse_test = ((softmask * recon_nonzero - X_test) ** 2).sum() / (X_test != 0).sum()
reconstructed_mse_test = ((X_test - reconstructed) ** 2).mean()

In [0]:
plot_reconstructed_distribution(softmask, recon_nonzero, f1_test, x_true_masked_mse_test, soft_masked_mse_test, reconstructed)

**Evaluation of Presence/Absence Reconstruction**

In [0]:
y_true = (X_test != 0).astype(int).flatten()
y_pred = recon_presence.flatten()

In [0]:
plot_confusion_matrix(y_true, y_pred)

In [0]:
plot_auc(y_true, y_pred).show()


**Evaluation of Nonzero Values Reconstruction**

In [0]:
scatter_plot_nonzero(X_test, recon_nonzero, softmask, soft_masked_mse_test, x_true_masked_mse_test)

In [0]:
r2 = r2_score(X_test.flatten(), reconstructed.flatten())
plt.figure(figsize=(4.5, 4 ))
plt.scatter(X_test.flatten(), reconstructed.flatten(), alpha=0.1, s=1)  # final reconstructed value vs X_true
plt.plot([X_test.min(), X_test.max()], [X_test.min(), X_test.max()], 'r--', lw=2)
plt.xlabel('X_test')
plt.ylabel('reconstructed')
plt.title('reconstructed vs X_test')
plt.text(0.08, 0.95, f'F1 Score = {f1_test:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')
plt.text(0.08, 0.85, f'R2 = {r2:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')
plt.text(0.08, 0.75, f'MSE = {reconstructed_mse_test:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top');



In [0]:
# Create a hexbin plot for the reconstructed vs X_test
plt.figure(figsize=(5, 4))

import numpy as np

# Filter out the values at (0, 0)
mask_origin = (X_test.flatten() != 0) | (reconstructed.flatten() != 0)
X_test_filtered = X_test.flatten()[mask_origin]
reconstructed_filtered = reconstructed.flatten()[mask_origin]

plt.hexbin(X_test_filtered, reconstructed_filtered, gridsize=50, cmap='Blues', mincnt=1, vmin=0, vmax=3000)
plt.colorbar(label='count in bin')
plt.plot([X_test.min(), X_test.max()], [X_test.min(), X_test.max()], 'r--', lw=2)
plt.xlabel('X_test')
plt.ylabel('reconstructed')
plt.title('Hexbin plot, removing (0, 0)')
plt.text(0.08, 0.95, f'F1 Score = {f1_test:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')
plt.text(0.08, 0.85, f'R2 = {r2:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')
plt.text(0.08, 0.75, f'MSE = {reconstructed_mse_test:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')
plt.show()

# Interpretation

**reduce to 2D for visualization**

In [0]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

def perform_dimensionality_reduction(X_test, latent, test_idx, df_genus):
    df_reduced = pd.DataFrame(index=df_genus.index[test_idx])

    # PCA and t-SNE on X_test
    pca_original = PCA(n_components=2)
    df_reduced[['pca_original_dim_1', 'pca_original_dim_2']] = pca_original.fit_transform(X_test)
    tsne_original = TSNE(n_components=2, random_state=42)
    df_reduced[['tsne_original_dim_1', 'tsne_original_dim_2']] = tsne_original.fit_transform(X_test)

    # PCA and t-SNE on latent
    pca_latent = PCA(n_components=2)
    df_reduced[['pca_latent_dim_1', 'pca_latent_dim_2']] = pca_latent.fit_transform(latent)
    tsne_latent = TSNE(n_components=2, random_state=42)
    df_reduced[['tsne_latent_dim_1', 'tsne_latent_dim_2']] = tsne_latent.fit_transform(latent)

    print("PCA Latent Explained Variance Ratio:", pca_latent.explained_variance_ratio_ * 100)
    print("PCA Original Explained Variance Ratio:", pca_original.explained_variance_ratio_ * 100)
    print("t-SNE Latent KL Divergence:", tsne_latent.kl_divergence_)
    print("t-SNE Original KL Divergence:", tsne_original.kl_divergence_)

    return df_reduced
# df_reduced = perform_dimensionality_reduction(X_test, latent, test_idx, df_genus)

**load samples' metadata**

In [0]:
df_meta = pd.read_csv("../data/metadata.txt", sep='\t', header=0, index_col=0)
col_factors = ['age', 'antibiotics_current_use', 'gender', 'country', 'non_westernized', 'sequencing_platform', 'disease', "study_condition"]
print(df_meta.shape)
df_meta = df_meta.loc[df_genus.index, :]
if df_meta.index.equals(df_genus.index):
    meta_test = df_meta.iloc[test_idx, :]
    meta_test = meta_test[col_factors]
else:
    raise ValueError("meta data and abundance data index mismatch")

meta_test['age'] = meta_test['age'].fillna(-1)
meta_test['age'] = meta_test['age'].astype(int)
# for col in col_factors:
#     if meta_test[col].dtype == 'object':
#         meta_test[col].fillna('missing', inplace=True)
#     else:
#         meta_test[col].fillna(-1, inplace=True)
meta_test['is_healthy'] = ['healthy'if _ == 'healthy'  else 'disease'for _ in meta_test['disease']]
meta_test['study_condition'] = ['control'if _ == 'control'  else 'case'for _ in meta_test['study_condition']]

print(meta_test.info())

In [0]:
df_non_zero_counts = pd.DataFrame((X_test != 0).sum(axis=1), columns=['non_zero_counts'], index = df_genus.index[test_idx])
col_factors = ['non_zero_counts', 'age', 'antibiotics_current_use', 'gender', 'country', 'non_westernized', 'sequencing_platform', 'is_healthy', "study_condition"]

df_results = pd.concat([df_reduced, df_non_zero_counts, meta_test], axis=1)
df_results

In [0]:
plot_latents(df_results, col_factors,  'pca', show_vectors=False)

In [0]:
plot_latents(df_results, col_factors, 'tsne', show_vectors=False)

In [0]:
# def plot_subplots(df_results, reduced_method='pca', show_vectors=False):
#     fig, axes = plt.subplots(len(col_factors), 2, figsize=(15, 5 * len(col_factors)))

#     for i, factor in enumerate(col_factors):
#         sns.scatterplot(ax=axes[i, 0], x=f'{reduced_method}_original_dim_1', y=f'{reduced_method}_original_dim_2', hue=factor, data=df_results, s=10, alpha=0.5, legend='brief')
#         axes[i, 0].set_title(f'{reduced_method.upper()} Original - Colored by {factor}')
#         axes[i, 0].set_xlabel(f'{reduced_method}_original_dim_1')
#         axes[i, 0].set_ylabel(f'{reduced_method}_original_dim_2')
#         if show_vectors: 
#             add_vectors(axes[i, 0], df_results, pca_original, f'{reduced_method}_original_dim_1', f'{reduced_method}_original_dim_2', df_genus.columns, top=5)
#         axes[i, 0].legend(loc='center left', bbox_to_anchor=(1, 0.5))

#         sns.scatterplot(ax=axes[i, 1], x=f'{reduced_method}_latent_dim_1', y=f'{reduced_method}_latent_dim_2', hue=factor, data=df_results, s=10, alpha=0.5, legend='brief')
#         axes[i, 1].set_title(f'{reduced_method.upper()} Latent - Colored by {factor}')
#         axes[i, 1].set_xlabel(f'{reduced_method}_latent_dim_1')
#         axes[i, 1].set_ylabel(f'{reduced_method}_latent_dim_2')
#         if show_vectors: 
#             add_vectors(axes[i, 1], df_results, pca_latent, f'{reduced_method}_latent_dim_1', f'{reduced_method}_latent_dim_2', range(latent.shape[1]), top=5)
#         axes[i, 1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
        

#     plt.tight_layout()
#     plt.show()

# plot_subplots(df_results, 'pca', show_vectors=True)



In [0]:
# plot_subplots(df_results, 'tsne')

# save table in catalog for age pred

In [0]:
meta_test.reset_index(drop=True, inplace=True)
meta_test

In [0]:
meta_test.reset_index()

In [0]:
# combine latent with metadata
df_encoded_meta = pd.concat([meta_test.reset_index(drop=True), pd.DataFrame(latent)], axis=1)
df_encoded_meta

In [0]:
%sql
USE CATALOG onesource_datascience_sbx

In [0]:
spark_df_encoded_meta = spark.createDataFrame(df_encoded_meta)

In [0]:
spark.sql("DROP TABLE IF EXISTS microbiota.encoded")
spark_df_encoded_meta.write.mode("overwrite").saveAsTable("microbiota.encoded")

In [0]:
df_encoded_meta_healthy = df_encoded_meta[df_encoded_meta['is_healthy'] == 1]
print(df_encoded_meta_healthy.shape)

In [0]:
spark_df_encoded_