# Evaluate models' generative performance

## Setup

In [None]:
%run ./utils/imports.py

import utils.utils as utils
from models import VanillaVAE, GMMVAE, SensitivityModelVanillaVAE, SensitivityModelGMMVAE, modules

## Load the data

In [None]:
# Data path
dataset_dir = "path/to/files"

# Sensitivity table
sensitivity_table = pd.read_csv(os.path.join(dataset_dir, "sensitivity_table.csv"))

# Cell lines biological data
cell_lines_biological_data = pd.read_csv(os.path.join(dataset_dir, "cell_lines_biological_data.csv"))

# Drugs SMILES vector representations
drugs_mol2vec_reprs = pd.read_csv(os.path.join(dataset_dir, "drugs_Mol2Vec_reprs.csv"))

# Drugs inhibition profiles
NO_TRUE_CLUSTER_LABELS = 3
drugs_inhib_profiles= pd.read_csv("path/to/file")


# print(drugs_inhib_profiles.shape)

# Create mappers from IDs to indexes
cell_line_ID_to_index_mapper = utils.get_ID_to_idx_mapper(cell_lines_biological_data, id_col="cell_line_id")
drugs_ID_to_smiles_rep_index_mapper = utils.get_ID_to_idx_mapper(drugs_mol2vec_reprs, id_col="PubChem CID")
drugs_ID_to_inhib_profiles_index_mapper = utils.get_ID_to_idx_mapper(drugs_inhib_profiles, id_col="PubChem CID")

# # Create main dataset and dataloaders
full_dataset = utils.DatasetThreeTables(sensitivity_table, 
                                  cell_lines_biological_data.values[:, 1:], 
                                  drugs_mol2vec_reprs.values[:, 1:], 
                                  drugs_inhib_profiles.values[:, 1:],
                                  cell_line_ID_to_index_mapper, drugs_ID_to_smiles_rep_index_mapper, drugs_ID_to_inhib_profiles_index_mapper,
                                  drug_ID_name="PubChem CID", cell_line_ID_name="COSMIC_ID", guiding_data_class_name="guiding_data_class",
                                  sensitivity_metric="LN_IC50", drug_ID_index=1, cell_line_ID_index=3, sensitivity_metric_index=4)

# Split into train/val/test
NUM_VAL_CELL_LINES = 100
NUM_TEST_CELL_LINES = 100
SPLIT_SEED = 11   # 1st run
dataset_train, dataset_val, dataset_test = full_dataset.train_val_test_split(NUM_VAL_CELL_LINES, NUM_TEST_CELL_LINES, seed=SPLIT_SEED)

## Load the model to assess

In [None]:
# Load model to assess
runs_dir = "final_runs/GMM_VAE__IP__no_comps=3__fixed_comp_std=1.0"
exp = "run_1_split_seed_11"
model_dir = os.path.join(runs_dir, exp, "version_0")
with open(os.path.join(model_dir, "whole_model_config.json"), "r") as f:
    whole_model_config = json.load(f)
with open(os.path.join(model_dir, "sensitivity_prediction_network_config.json"), "r") as f:
    sensitivity_prediction_network_config = json.load(f)

In [None]:
%%time
# Load GMMVAE sensitivity model from config dict
var_transformation = lambda x: torch.exp(x) ** 0.5

# Establish drug model
drug_gmm_vae = GMMVAE(whole_model_config["drug_encoder_layers"], whole_model_config["drug_input_decoder_layers"], 
                      whole_model_config["drug_guiding_decoder_layers"], 
                      whole_model_config["no_gmm_components"],
                 var_transformation=var_transformation, 
                      learning_rate=whole_model_config["learning_rate"],
                 loss_function_weights=whole_model_config["vae_loss_function_weights"], 
                      batch_norm=False, optimizer="adam",
                 encoder_dropout_rate=0, decoders_dropout_rate=0)

# Establish cell line model
cell_line_aen = modules.AutoencoderConfigurable(whole_model_config["cell_line_encoder_layers"], whole_model_config["cell_line_decoder_layers"])

# Establish sensitivity prediction network
sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurable(sensitivity_prediction_network_config)

# Establish checpoint directory
sorted_checkpoints = sorted(os.listdir(os.path.join(runs_dir, exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
checkpoint_path = os.path.join(runs_dir, exp, "version_0", "checkpoints", sorted_checkpoints[0])

model = SensitivityModelGMMVAE.load_from_checkpoint(checkpoint_path,
                                                    drug_model=drug_gmm_vae, 
                                                    cell_line_model=cell_line_aen, 
                                                    sensitivity_prediction_network=sensitivity_prediction_network,
                                                    learning_rate=whole_model_config["learning_rate"],
                                                    aen_reconstruction_loss_weight=whole_model_config["aen_reconstruction_weight"],
                                                    sensitivity_prediction_weight=whole_model_config["sensitivity_prediction_weight"])

MODEL_ALIAS = "gmm_vae_fixed_std_run_1"
assert MODEL_ALIAS[-1] == model_dir.split("\\")[-2].split("_")[1]
assert MODEL_ALIAS.split("_")[2] in model_dir

model.eval()
drug_model = model.drug_model
drug_model.eval()

## Latent spaces

In [None]:
# Cosmetic params for latent space visualization
figsize = (4, 4)
marker_size = 50
title_fontsize = 12
ax_label_fontsize = 12
alpha = 0.9

palette={0: sns.color_palette("tab10")[0], 
         1: sns.color_palette("tab10")[1], 
         2: sns.color_palette("tab10")[2]}

## Visualize latent spaces - Vanilla VAE

In [None]:
# Load Vanilla model
vanilla_model_dir = r"final_runs\Vanilla_VAE__IP__dr=0.5_0.5__gam=0.1"
vanilla_exp = "run_1_split_seed_11"
sorted_checkpoints = sorted(os.listdir(os.path.join(vanilla_model_dir, vanilla_exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
checkpoint_path = os.path.join(vanilla_model_dir, vanilla_exp, "version_0", "checkpoints", sorted_checkpoints[0])

# Setup the model's hyperparams
DRUG_INPUT_DIM = 300
DRUG_GUIDING_DIM = 294
CELL_LINE_INPUT_DIM = 241

DRUG_LATENT_DIM = 10
CELL_LINE_LATENT_DIM = 10

DRUG_ENCODER_LAYERS = (DRUG_INPUT_DIM, 128, 64, DRUG_LATENT_DIM)
DRUG_INPUT_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_INPUT_DIM)
DRUG_GUIDING_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_GUIDING_DIM)


CELL_LINE_ENCODER_LAYERS = (CELL_LINE_INPUT_DIM, 128, 64, CELL_LINE_LATENT_DIM)
CELL_LINE_DECODER_LAYERS = (CELL_LINE_LATENT_DIM, 64, 128, CELL_LINE_INPUT_DIM)

# Establish sensitivity prediction network
vanilla_sensitivity_prediction_network_config = {"layers": (DRUG_LATENT_DIM + CELL_LINE_LATENT_DIM, 512, 256, 128, 1),
                                        "learning_rate": 0.001,
                                        "l2_term": 0,
                                        "dropout_rate1": 0.5,
                                        "dropout_rate2": 0.5}
# Establish drug model
drug_vanilla_vae = VanillaVAE(DRUG_ENCODER_LAYERS, DRUG_INPUT_DECODER_LAYERS, DRUG_GUIDING_DECODER_LAYERS,
                 var_transformation=lambda x: torch.exp(x) ** 0.5, learning_rate=0.0005,
                 loss_function_weights=(1., 1., 1., 1., 0.0), batch_norm=False, optimizer="adam",
                 encoder_dropout_rate=0, decoders_dropout_rate=0, clip_guiding_rec=True
                  )

# Establish cell line model
vanilla_cell_line_aen = modules.AutoencoderConfigurable(CELL_LINE_ENCODER_LAYERS, CELL_LINE_DECODER_LAYERS)

# Forward network
vanilla_sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurableDropout(vanilla_sensitivity_prediction_network_config)

vanilla_model = SensitivityModelVanillaVAE.load_from_checkpoint(checkpoint_path,
                                    drug_model=drug_vanilla_vae, 
                                    cell_line_model=vanilla_cell_line_aen, 
                                   sensitivity_prediction_network=vanilla_sensitivity_prediction_network, 
                                   vae_training_num_epochs=100,
                                   vae_training_step_rate=1000,
                                   vae_dataloader=None,
                                   learning_rate=0.0005)
vanilla_model.eval()

In [None]:
# Establish w.r.t which latent dimensions plot the data
axes = "PCA"
# Establish where to save the plot if needed
save_dir = None

figsize = (4, 4)

# Establish points to plot
drugs_to_plot = list(drugs_inhib_profiles["PubChem CID"])
data = drugs_inhib_profiles[["PubChem CID", "guiding_cluster_label"]].merge(drugs_mol2vec_reprs, on="PubChem CID", how="inner")

X_to_plot = torch.Tensor(data.iloc[:, 2:].values)
true_classes = data["guiding_cluster_label"]

# Extract drug encoder
encoder = vanilla_model.drug_model.encoder
encoder.eval()

# Perform encoding
z_means, z_stds = encoder(X_to_plot)
z_means = z_means.detach().numpy()

plt.figure(figsize=figsize)

if true_classes is not None:
    colors = true_classes
    cmap = "tab10"
else:
    colors = None
    cmap = None

if axes == "PCA":
    pca = PCA(n_components=2)
    standard_scaler = StandardScaler()
    X_scaled = standard_scaler.fit_transform(z_means)
    X_projected = pca.fit_transform(X_scaled)
    xlabel = "PCA 1"
    ylabel = "PCA 2"
    print(f"Explained variance: {pca.explained_variance_ratio_[0]:.{2}f} + {pca.explained_variance_ratio_[1]:.{2}f}")

    # Plot
    ax = sns.scatterplot(x=X_projected[:, 0], y=X_projected[:, 1], hue=colors, s=marker_size, alpha=alpha, palette=palette, legend=False)
else:
    # Plot
    ax = sns.scatterplot(x=z_means[:, axes[0]], y=z_means[:, axes[1]], hue=colors, s=marker_size, alpha=alpha,
                palette=cmap)
    xlabel = f"Latent dim {axes[0]}"
    ylabel = f"Latent dim {axes[1]}"
    
plt.title("Vanilla VAE", fontsize=title_fontsize)

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)
  
plt.tight_layout()
plt.show()

## Visualize latent spaces - GMM

In [None]:
### Create a figure

# Establish w.r.t which latent dimensions plot the data
axes = "PCA"
# Establish where to save the plot if needed
save_dir = None

# Establish points to plot
drugs_to_plot = list(drugs_inhib_profiles["PubChem CID"])
data = drugs_inhib_profiles[["PubChem CID", "guiding_cluster_label"]].merge(drugs_mol2vec_reprs, on="PubChem CID", how="inner")

X_to_plot = torch.Tensor(data.iloc[:, 2:].values)
true_classes = data["guiding_cluster_label"]

# Extract drug encoder
encoder = model.drug_model.encoder
encoder.eval()

# Perform encoding
z_means, z_stds = encoder(X_to_plot)
z_means = z_means.detach().numpy()

plt.figure(figsize=figsize)

if true_classes is not None:
    colors = true_classes
    cmap = "tab10"
else:
    colors = None
    cmap = None

if axes == "PCA":
    pca = PCA(n_components=2)
    standard_scaler = StandardScaler()
    X_scaled = standard_scaler.fit_transform(z_means)
    X_projected = pca.fit_transform(X_scaled)
    xlabel = "PCA 1"
    ylabel = "PCA 2"
    print(f"Explained variance: {pca.explained_variance_ratio_[0]:.{2}f} + {pca.explained_variance_ratio_[1]:.{2}f}")

    # Plot
    ax = sns.scatterplot(x=X_projected[:, 0], y=X_projected[:, 1], hue=colors, s=marker_size, alpha=alpha, palette=palette, legend=False)
else:
    # Plot
    ax = sns.scatterplot(x=z_means[:, axes[0]], y=z_means[:, axes[1]], hue=colors, s=marker_size, alpha=alpha,
                palette=cmap)
    xlabel = f"Latent dim {axes[0]}"
    ylabel = f"Latent dim {axes[1]}"
    
plt.title("GMM VAE", fontsize=title_fontsize)

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)
  
plt.tight_layout()
plt.show()

## Generative performance

In [None]:
# Cosmetic params for generated data
figsize = (2.5, 2.5)
marker_size = 50
title_fontsize = 12
ax_label_fontsize = 12
alpha = 0.9

palette={0: sns.color_palette("tab10")[0], 
         1: sns.color_palette("tab10")[1], 
         2: sns.color_palette("tab10")[2]}

### Visualize guiding true data

In [None]:
# Establish points to plot
drugs_to_plot = list(drugs_inhib_profiles["PubChem CID"])
data = drugs_inhib_profiles.copy()

save_dir = None

X_to_plot = data.iloc[1:, :-1].values
true_classes = data["guiding_cluster_label"]

pca = PCA(n_components=2)
standard_scaler = StandardScaler()
X_scaled = standard_scaler.fit_transform(data)
X_projected = pca.fit_transform(X_scaled)
xlabel = "PCA 1"
ylabel = "PCA 2"
print(f"Explained variance: {pca.explained_variance_ratio_[0]:.{2}f} + {pca.explained_variance_ratio_[1]:.{2}f}")

# Plot
plt.figure(figsize=figsize)

ax = sns.scatterplot(x=X_projected[:, 0], y=X_projected[:, 1], hue=colors, s=marker_size, alpha=alpha, palette=palette, legend=False)


plt.title("True IP", fontsize=title_fontsize)

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir, format="pdf")
  
plt.tight_layout()
plt.show()

### Generate samples and visualize - Vanilla

In [None]:
# Load Vanilla model
vanilla_model_dir = r"final_runs\Vanilla_VAE__IP__dr=0.5_0.5__gam=0.1"
vanilla_exp = "run_1_split_seed_11"
sorted_checkpoints = sorted(os.listdir(os.path.join(vanilla_model_dir, vanilla_exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
checkpoint_path = os.path.join(vanilla_model_dir, vanilla_exp, "version_0", "checkpoints", sorted_checkpoints[0])

# Setup the model's hyperparams
DRUG_INPUT_DIM = 300
DRUG_GUIDING_DIM = 294
CELL_LINE_INPUT_DIM = 241

DRUG_LATENT_DIM = 10
CELL_LINE_LATENT_DIM = 10
CLIP_GUIDING_REC = True

DRUG_ENCODER_LAYERS = (DRUG_INPUT_DIM, 128, 64, DRUG_LATENT_DIM)
DRUG_INPUT_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_INPUT_DIM)
DRUG_GUIDING_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_GUIDING_DIM)


CELL_LINE_ENCODER_LAYERS = (CELL_LINE_INPUT_DIM, 128, 64, CELL_LINE_LATENT_DIM)
CELL_LINE_DECODER_LAYERS = (CELL_LINE_LATENT_DIM, 64, 128, CELL_LINE_INPUT_DIM)

# Establish sensitivity prediction network
vanilla_sensitivity_prediction_network_config = {"layers": (DRUG_LATENT_DIM + CELL_LINE_LATENT_DIM, 512, 256, 128, 1),
                                        "learning_rate": 0.001,
                                        "l2_term": 0,
                                        "dropout_rate1": 0.5,
                                        "dropout_rate2": 0.5}
# Establish drug model
drug_vanilla_vae = VanillaVAE(DRUG_ENCODER_LAYERS, DRUG_INPUT_DECODER_LAYERS, DRUG_GUIDING_DECODER_LAYERS,
                 var_transformation=lambda x: torch.exp(x) ** 0.5, learning_rate=0.0005,
                 loss_function_weights=(1., 1., 1., 1., 0.0), batch_norm=False, optimizer="adam",
                 encoder_dropout_rate=0, decoders_dropout_rate=0, clip_guiding_rec=True
                  )

# Establish cell line model
vanilla_cell_line_aen = modules.AutoencoderConfigurable(CELL_LINE_ENCODER_LAYERS, CELL_LINE_DECODER_LAYERS)

# Forward network
vanilla_sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurableDropout(vanilla_sensitivity_prediction_network_config)

vanilla_model = SensitivityModelVanillaVAE.load_from_checkpoint(checkpoint_path,
                                    drug_model=drug_vanilla_vae, 
                                    cell_line_model=vanilla_cell_line_aen, 
                                   sensitivity_prediction_network=vanilla_sensitivity_prediction_network, 
                                   vae_training_num_epochs=100,
                                   vae_training_step_rate=1000,
                                   vae_dataloader=None,
                                   learning_rate=0.0005)
vanilla_model.eval()

def generate_from_vanilla(drug_model, n_samples=100, model_as_std=True):
    """Generate samples from latent and put them into decoders."""
    drug_model.eval()
    # Get means of appropriate Gaussian from the mixture in latent
    means = torch.zeros(drug_model.latent_dim)
    # Get stds
    stds = torch.ones(drug_model.latent_dim)

    # Create Multivariate Gauss to sample from
    if model_as_std:
        rv = td.MultivariateNormal(means, torch.diag(stds ** 2))
    else:
        rv = td.MultivariateNormal(means, torch.diag(stds))

    # Sample
    cluster_z_sample = rv.sample([n_samples])

    # Get input reconstructions
    input_reconstructions = drug_model.input_decoder(cluster_z_sample)
    # Get guiding reconstructions
    guiding_reconstructions = drug_model.guiding_decoder(cluster_z_sample)
    
    return input_reconstructions, guiding_reconstructions

In [None]:
# Generate samples and visualize
n_samples = 900

vanilla_drug_model = vanilla_model.drug_model
print(type(vanilla_drug_model))
vanilla_drug_model.eval()

# Get samples from single cluster
input_reconstructions, guiding_reconstructions = generate_from_vanilla(vanilla_drug_model, n_samples=n_samples)
# Convert to numpy
input_reconstructions, guiding_reconstructions = input_reconstructions.detach().numpy(), guiding_reconstructions.detach().numpy()

In [None]:
# Visualize generated samples
save_dir = None

# Visualize guiding with PCA
scaler = StandardScaler()
pca = PCA(n_components=2)
data_scaled = scaler.fit_transform(guiding_reconstructions)
data_transformed = pca.fit_transform(data_scaled)
print(pca.explained_variance_ratio_)

# Plot
plt.figure(figsize=figsize)
sns.scatterplot(data_transformed[:, 0], data_transformed[:, 1], c=sns.color_palette("tab10", 10)[7])

plt.title("Vanilla VAE", fontsize=title_fontsize)

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)
    
plt.show()

### Generate samples - GMM

In [None]:
# Get samples from multiple clusters
cluster_numbers = [0, 1, 2]
input_samples_all_clusters = {}
guiding_samples_all_clusters = {}

n_samples = 300
assert type(model) == SensitivityModelGMMVAE
drug_model = model.drug_model
drug_model.eval()

GENERATED_GUIDING_SEED = 22
for cluster_number in cluster_numbers:
    # Get samples from single cluster
    input_reconstructions, guiding_reconstructions = utils.generate_from_cluster(drug_model, cluster_number, n_samples=n_samples,
                                                                                seed=GENERATED_GUIDING_SEED, model_as_std=False)
    # Convert to numpy
    input_reconstructions, guiding_reconstructions = input_reconstructions.detach().numpy(), guiding_reconstructions.detach().numpy()
    
    input_samples_all_clusters[cluster_number] = input_reconstructions
    guiding_samples_all_clusters[cluster_number] = guiding_reconstructions

### Visualize generated samples

In [None]:
# Establish where to save the plot if needed
save_dir = None

all_generated_data_inputs = np.stack([input_samples_all_clusters[cl_num] for cl_num in input_samples_all_clusters]).reshape(-1, 300)
input_cluster_labels = []
for cluster_number in input_samples_all_clusters:
    input_cluster_labels = input_cluster_labels + [cluster_number] * input_samples_all_clusters[cluster_number].shape[0]
    
all_generated_data_guiding = np.stack([guiding_samples_all_clusters[cl_num] for cl_num in guiding_samples_all_clusters]).reshape(-1, 294)
guiding_cluster_labels = []
for cluster_number in guiding_samples_all_clusters:
    guiding_cluster_labels = guiding_cluster_labels + [cluster_number] * guiding_samples_all_clusters[cluster_number].shape[0]
print(all_generated_data_inputs.shape, all_generated_data_guiding.shape)

# Visualize guiding with PCA
scaler = StandardScaler()
pca = PCA(n_components=2)
data_scaled = scaler.fit_transform(all_generated_data_guiding)
data_transformed = pca.fit_transform(data_scaled)
print(pca.explained_variance_ratio_)

# Plot
plt.figure(figsize=figsize)
sns.scatterplot(data_transformed[:, 0], data_transformed[:, 1], hue=guiding_cluster_labels, palette=palette, legend=False)

plt.title("GMM VAE constrained", fontsize=title_fontsize)

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)
    
plt.show()

### Cluster-feature tables

#### Generate data to plot

In [None]:
# Create cluster-mean features and cluster-stds tables for true guiding data
true_inhib_profiles = pd.DataFrame(data=dataset_train.drugs_inhib_profiles)
true_cluster_mean_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))
true_cluster_std_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))

for cluster_number in true_inhib_profiles.iloc[:, -1].unique():
    true_cluster_mean_features_table_guiding[int(cluster_number)] = true_inhib_profiles[true_inhib_profiles[true_inhib_profiles.columns[-1]] == cluster_number].values[:, :-1].mean(axis=0)
    true_cluster_std_features_table_guiding[int(cluster_number)] = true_inhib_profiles[true_inhib_profiles[true_inhib_profiles.columns[-1]] == cluster_number].values[:, :-1].std(axis=0)

# Create cluster-mean features and cluster-stds tables for generated guiding data
cluster_mean_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))
cluster_std_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))

for cluster_number in guiding_samples_all_clusters:
    cluster_mean_features_table_guiding[cluster_number] = guiding_samples_all_clusters[cluster_number].mean(axis=0)
    cluster_std_features_table_guiding[cluster_number] = guiding_samples_all_clusters[cluster_number].std(axis=0)
    
# Create cluster-mean features and cluster-stds tables for generated input data
cluster_mean_features_table_input = np.zeros(shape=(len(cluster_numbers), 294))
cluster_std_features_table_input = np.zeros(shape=(len(cluster_numbers), 294))

for cluster_number in guiding_samples_all_clusters:
    cluster_mean_features_table_input[cluster_number] = guiding_samples_all_clusters[cluster_number].mean(axis=0)
    cluster_std_features_table_input[cluster_number] = guiding_samples_all_clusters[cluster_number].std(axis=0)
    
true_cluster_mean_features_table_guiding_df = pd.DataFrame(data=true_cluster_mean_features_table_guiding, columns=drugs_inhib_profiles.columns[1:-1])
cluster_mean_features_table_guiding_df = pd.DataFrame(data=cluster_mean_features_table_guiding, columns=drugs_inhib_profiles.columns[1:-1])

# Sort in the order
true_cluster_mean_features_table_guiding_df_T = true_cluster_mean_features_table_guiding_df.transpose()
for cluster_num in [2, 0, 1]:
    true_cluster_mean_features_table_guiding_df_T = true_cluster_mean_features_table_guiding_df_T.sort_values(cluster_num, ascending=False)
ordered_feats = true_cluster_mean_features_table_guiding_df_T.index

#### Plot cluster means

In [None]:
vmin = true_cluster_mean_features_table_guiding_df.min().min()
vmax = true_cluster_mean_features_table_guiding_df.max().max()

In [None]:
# Cosmetic params
title_fontsize = 10
ax_label_fontsize = 8
alpha = 0.99

# Cosmetic params
title_fontsize = 10
ax_label_fontsize = 8
alpha = 0.99

figsize = (3.5, 1.)
palette = "viridis"

##### Plot for true data

In [None]:
# Feature means
# Establish where to save the plot if needed
save_dir = None

plt.figure(figsize=figsize)
ax1 = sns.heatmap(data=true_cluster_mean_features_table_guiding_df[ordered_feats], xticklabels=False, vmin=vmin, vmax=vmax, cmap=palette, cbar=False)
ax1.set_title("True IP cluster means", fontsize=title_fontsize)
ax1.set_ylabel("Cluster", fontsize=ax_label_fontsize)

plt.tight_layout()

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.show()

##### Plot for model

In [None]:
# Establish where to save the plot if needed
save_dir = None

plt.figure(figsize=figsize)
ax2 = sns.heatmap(data=cluster_mean_features_table_guiding_df[ordered_feats], xticklabels=False, vmin=vmin, vmax=vmax, cmap=palette, cbar=False)


ax2.set_title("Generated guiding data feature means", fontsize=title_fontsize)
ax2.set_ylabel("Cluster", fontsize=ax_label_fontsize)
# ax2.set_xlabel("Feature", fontsize=ax_label_fontsize)
# ax2.tick_params(axis='x', labelsize=6)

plt.tight_layout()

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.show()

#### Plot cluster stds

In [None]:
vmin = true_cluster_std_features_table_guiding.min().min()
vmax = true_cluster_std_features_table_guiding.max().max()

In [None]:
# Cosmetic params
title_fontsize = 10
ax_label_fontsize = 8
alpha = 0.99

# Cosmetic params
title_fontsize = 10
ax_label_fontsize = 8
alpha = 0.99

figsize = (3.5, 1.)
palette = "rocket"

##### Plot for true data

In [None]:
# Feature means
# Establish where to save the plot if needed
save_dir = None

plt.figure(figsize=figsize)
ax1 = sns.heatmap(data=pd.DataFrame(data=true_cluster_std_features_table_guiding, columns=drugs_inhib_profiles.columns[1:-1])[ordered_feats], 
                  xticklabels=False, vmin=vmin, vmax=vmax, cmap=palette, cbar=False)

ax1.set_title("True IP cluster stds", fontsize=title_fontsize)
ax1.set_ylabel("Cluster", fontsize=ax_label_fontsize)

plt.tight_layout()

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.show()

##### Plot for model

In [None]:
# Establish where to save the plot if needed
save_dir = None

plt.figure(figsize=figsize)
ax2 = sns.heatmap(data=pd.DataFrame(data=cluster_std_features_table_guiding, columns=drugs_inhib_profiles.columns[1:-1])[ordered_feats], 
                  xticklabels=False, vmin=vmin, vmax=vmax, cmap=palette, cbar=False)


ax2.set_title("Generated guiding data feature stds", fontsize=title_fontsize)
ax2.set_ylabel("Cluster", fontsize=ax_label_fontsize)
# ax2.set_xlabel("Feature", fontsize=ax_label_fontsize)
# ax2.tick_params(axis='x', labelsize=6)

plt.tight_layout()

if save_dir is not None:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.show()

## Numerical computations of generative performance

In [None]:
model_filepaths = {"vanilla_vae": "final_runs/Vanilla_VAE__IP__dr=0.5_0.5__gam=0.1",
                  "gmm_vae_constrained": "final_runs/GMM_VAE__IP__no_comps=3__fixed_comp_std=1.0",
                  "gmm_vae_unconstrained": "final_runs/GMM_VAE__IP__no_comps=3__trained_comp_std"}

run_names = ["run_1_split_seed_11",
            "run_2_split_seed_13",
            "run_3_split_seed_26",
            "run_4_split_seed_76",
            "run_5_split_seed_92"]

generative_results_dict = {"latent_clustering": {},
                          "generated_ip_clustering": {},
                          "feature_means_rmse": {},
                          "feature_stds_rmse": {},
                          "feature_means_corr": {},
                          "feature_stds_corr": {}}

### Latent space clustering

#### Vanilla

In [None]:
vanilla_latent_silhs = []
for vanilla_exp in run_names:
    # Load Vanilla model
    vanilla_model_dir = r"final_runs\Vanilla_VAE__IP__dr=0.5_0.5__gam=0.1"
    sorted_checkpoints = sorted(os.listdir(os.path.join(vanilla_model_dir, vanilla_exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
    checkpoint_path = os.path.join(vanilla_model_dir, vanilla_exp, "version_0", "checkpoints", sorted_checkpoints[0])

    # Setup the model's hyperparams
    DRUG_INPUT_DIM = 300
    DRUG_GUIDING_DIM = 294
    CELL_LINE_INPUT_DIM = 241

    DRUG_LATENT_DIM = 10
    CELL_LINE_LATENT_DIM = 10
    CLIP_GUIDING_REC = True

    DRUG_ENCODER_LAYERS = (DRUG_INPUT_DIM, 128, 64, DRUG_LATENT_DIM)
    DRUG_INPUT_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_INPUT_DIM)
    DRUG_GUIDING_DECODER_LAYERS = (DRUG_LATENT_DIM, 64, 128, DRUG_GUIDING_DIM)


    CELL_LINE_ENCODER_LAYERS = (CELL_LINE_INPUT_DIM, 128, 64, CELL_LINE_LATENT_DIM)
    CELL_LINE_DECODER_LAYERS = (CELL_LINE_LATENT_DIM, 64, 128, CELL_LINE_INPUT_DIM)

    # Establish sensitivity prediction network
    vanilla_sensitivity_prediction_network_config = {"layers": (DRUG_LATENT_DIM + CELL_LINE_LATENT_DIM, 512, 256, 128, 1),
                                            "learning_rate": 0.001,
                                            "l2_term": 0,
                                            "dropout_rate1": 0.5,
                                            "dropout_rate2": 0.5}
    # Establish drug model
    drug_vanilla_vae = VanillaVAE(DRUG_ENCODER_LAYERS, DRUG_INPUT_DECODER_LAYERS, DRUG_GUIDING_DECODER_LAYERS,
                     var_transformation=lambda x: torch.exp(x) ** 0.5, learning_rate=0.0005,
                     loss_function_weights=(1., 1., 1., 1., 0.0), batch_norm=False, optimizer="adam",
                     encoder_dropout_rate=0, decoders_dropout_rate=0, clip_guiding_rec=True
                      )

    # Establish cell line model
    vanilla_cell_line_aen = modules.AutoencoderConfigurable(CELL_LINE_ENCODER_LAYERS, CELL_LINE_DECODER_LAYERS)

    # Forward network
    vanilla_sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurableDropout(vanilla_sensitivity_prediction_network_config)

    vanilla_model = SensitivityModelVanillaVAE.load_from_checkpoint(checkpoint_path,
                                        drug_model=drug_vanilla_vae, 
                                        cell_line_model=vanilla_cell_line_aen, 
                                       sensitivity_prediction_network=vanilla_sensitivity_prediction_network, 
                                       vae_training_num_epochs=100,
                                       vae_training_step_rate=1000,
                                       vae_dataloader=None,
                                       learning_rate=0.0005)
    vanilla_model.eval()
    

    # Establish points to plot
    drugs_to_plot = list(drugs_inhib_profiles["PubChem CID"])
    data = drugs_inhib_profiles[["PubChem CID", "guiding_cluster_label"]].merge(drugs_mol2vec_reprs, on="PubChem CID", how="inner")

    X_to_plot = torch.Tensor(data.iloc[:, 2:].values)
    true_classes = data["guiding_cluster_label"]

    # Extract drug encoder
    encoder = vanilla_model.drug_model.encoder
    encoder.eval()

    # Perform encoding
    z_means, z_stds = encoder(X_to_plot)
    z_means = z_means.detach().numpy()
    
    silh_score = metrics.silhouette_score(z_means, true_classes)
    
    vanilla_latent_silhs.append(silh_score)
    
generative_results_dict["latent_clustering"]["vanilla_vae"] = vanilla_latent_silhs

#### GMM constrained

In [None]:
gmm_constrained_latent_silhs = []
for exp in run_names:
    runs_dir = model_filepaths["gmm_vae_constrained"]

    model_dir = os.path.join(runs_dir, exp, "version_0")

    with open(os.path.join(model_dir, "whole_model_config.json"), "r") as f:
        whole_model_config = json.load(f)

    with open(os.path.join(model_dir, "sensitivity_prediction_network_config.json"), "r") as f:
        sensitivity_prediction_network_config = json.load(f)

    # Load GMMVAE sensitivity model from config dict
    var_transformation = lambda x: torch.exp(x) ** 0.5

    # Establish drug model
    drug_gmm_vae = GMMVAE(whole_model_config["drug_encoder_layers"], whole_model_config["drug_input_decoder_layers"], 
                          whole_model_config["drug_guiding_decoder_layers"], 
                          whole_model_config["no_gmm_components"],
                     var_transformation=var_transformation, 
                          learning_rate=whole_model_config["learning_rate"],
                     loss_function_weights=whole_model_config["vae_loss_function_weights"], 
                          batch_norm=False, optimizer="adam",
                     encoder_dropout_rate=0, decoders_dropout_rate=0)

    # Establish cell line model
    cell_line_aen = modules.AutoencoderConfigurable(whole_model_config["cell_line_encoder_layers"], whole_model_config["cell_line_decoder_layers"])

    # Establish sensitivity prediction network
    sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurable(sensitivity_prediction_network_config)

    # Establish checpoint directory
    sorted_checkpoints = sorted(os.listdir(os.path.join(runs_dir, exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
    checkpoint_path = os.path.join(runs_dir, exp, "version_0", "checkpoints", sorted_checkpoints[0])

    model = SensitivityModelGMMVAE.load_from_checkpoint(checkpoint_path,
                                                        drug_model=drug_gmm_vae, 
                                                        cell_line_model=cell_line_aen, 
                                                        sensitivity_prediction_network=sensitivity_prediction_network,
                                                        learning_rate=whole_model_config["learning_rate"],
                                                        aen_reconstruction_loss_weight=whole_model_config["aen_reconstruction_weight"],
                                                        sensitivity_prediction_weight=whole_model_config["sensitivity_prediction_weight"])

    model.eval()
    drug_model = model.drug_model
    drug_model.eval()    
    
    # Establish points to plot
    drugs_to_plot = list(drugs_inhib_profiles["PubChem CID"])
    data = drugs_inhib_profiles[["PubChem CID", "guiding_cluster_label"]].merge(drugs_mol2vec_reprs, on="PubChem CID", how="inner")

    X_to_plot = torch.Tensor(data.iloc[:, 2:].values)
    true_classes = data["guiding_cluster_label"]

    # Extract drug encoder
    encoder = drug_model.encoder
    encoder.eval()

    # Perform encoding
    z_means, z_stds = encoder(X_to_plot)
    z_means = z_means.detach().numpy()
    
    silh_score = metrics.silhouette_score(z_means, true_classes)
    gmm_constrained_latent_silhs.append(silh_score)
    
generative_results_dict["latent_clustering"]["gmm_vae_constrained"] = gmm_constrained_latent_silhs

#### GMM unconstrained

In [None]:
gmm_unconstrained_latent_silhs = []

for exp in run_names:
    runs_dir = model_filepaths["gmm_vae_unconstrained"]

    model_dir = os.path.join(runs_dir, exp, "version_0")

    with open(os.path.join(model_dir, "whole_model_config.json"), "r") as f:
        whole_model_config = json.load(f)

    with open(os.path.join(model_dir, "sensitivity_prediction_network_config.json"), "r") as f:
        sensitivity_prediction_network_config = json.load(f)

    # Load GMMVAE sensitivity model from config dict
    var_transformation = lambda x: torch.exp(x) ** 0.5

    # Establish drug model
    drug_gmm_vae = GMMVAE(whole_model_config["drug_encoder_layers"], whole_model_config["drug_input_decoder_layers"], 
                          whole_model_config["drug_guiding_decoder_layers"], 
                          whole_model_config["no_gmm_components"],
                     var_transformation=var_transformation, 
                          learning_rate=whole_model_config["learning_rate"],
                     loss_function_weights=whole_model_config["vae_loss_function_weights"], 
                          batch_norm=False, optimizer="adam",
                     encoder_dropout_rate=0, decoders_dropout_rate=0)

    # Establish cell line model
    cell_line_aen = modules.AutoencoderConfigurable(whole_model_config["cell_line_encoder_layers"], whole_model_config["cell_line_decoder_layers"])

    # Establish sensitivity prediction network
    sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurable(sensitivity_prediction_network_config)

    # Establish checpoint directory
    sorted_checkpoints = sorted(os.listdir(os.path.join(runs_dir, exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
    checkpoint_path = os.path.join(runs_dir, exp, "version_0", "checkpoints", sorted_checkpoints[0])

    model = SensitivityModelGMMVAE.load_from_checkpoint(checkpoint_path,
                                                        drug_model=drug_gmm_vae, 
                                                        cell_line_model=cell_line_aen, 
                                                        sensitivity_prediction_network=sensitivity_prediction_network,
                                                        learning_rate=whole_model_config["learning_rate"],
                                                        aen_reconstruction_loss_weight=whole_model_config["aen_reconstruction_weight"],
                                                        sensitivity_prediction_weight=whole_model_config["sensitivity_prediction_weight"])

    model.eval()
    drug_model = model.drug_model
    drug_model.eval()

    # Establish points to plot
    drugs_to_plot = list(drugs_inhib_profiles["PubChem CID"])
    data = drugs_inhib_profiles[["PubChem CID", "guiding_cluster_label"]].merge(drugs_mol2vec_reprs, on="PubChem CID", how="inner")

    X_to_plot = torch.Tensor(data.iloc[:, 2:].values)
    true_classes = data["guiding_cluster_label"]

    # Extract drug encoder
    encoder = model.drug_model.encoder
    encoder.eval()

    # Perform encoding
    z_means, z_stds = encoder(X_to_plot)
    z_means = z_means.detach().numpy()
    
    silh_score = metrics.silhouette_score(z_means, true_classes)
    gmm_unconstrained_latent_silhs.append(silh_score)
    
generative_results_dict["latent_clustering"]["gmm_vae_unconstrained"] = gmm_unconstrained_latent_silhs

### Generated IP clustering and feature maps

In [None]:
# Create cluster-mean features and cluster-stds tables for true guiding data
cluster_numbers = [0, 1, 2]

true_inhib_profiles = pd.DataFrame(data=dataset_train.drugs_inhib_profiles)
true_cluster_mean_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))
true_cluster_std_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))

for cluster_number in true_inhib_profiles.iloc[:, -1].unique():
    true_cluster_mean_features_table_guiding[int(cluster_number)] = true_inhib_profiles[true_inhib_profiles[true_inhib_profiles.columns[-1]] == cluster_number].values[:, :-1].mean(axis=0)
    true_cluster_std_features_table_guiding[int(cluster_number)] = true_inhib_profiles[true_inhib_profiles[true_inhib_profiles.columns[-1]] == cluster_number].values[:, :-1].std(axis=0)
    
true_cluster_mean_features_table_guiding_df = pd.DataFrame(data=true_cluster_mean_features_table_guiding, columns=drugs_inhib_profiles.columns[1:-1])

#### GMM constrained

In [None]:
gmm_constrained_generated_ip_silhs = []
gmm_constrained_feature_means_rmses = []
gmm_constrained_feature_means_corrs = []
gmm_constrained_feature_stds_rmses = []
gmm_constrained_feature_stds_corrs = []

for exp in run_names:
    runs_dir = model_filepaths["gmm_vae_constrained"]

    model_dir = os.path.join(runs_dir, exp, "version_0")

    with open(os.path.join(model_dir, "whole_model_config.json"), "r") as f:
        whole_model_config = json.load(f)

    with open(os.path.join(model_dir, "sensitivity_prediction_network_config.json"), "r") as f:
        sensitivity_prediction_network_config = json.load(f)

    # Load GMMVAE sensitivity model from config dict
    var_transformation = lambda x: torch.exp(x) ** 0.5

    # Establish drug model
    drug_gmm_vae = GMMVAE(whole_model_config["drug_encoder_layers"], whole_model_config["drug_input_decoder_layers"], 
                          whole_model_config["drug_guiding_decoder_layers"], 
                          whole_model_config["no_gmm_components"],
                     var_transformation=var_transformation, 
                          learning_rate=whole_model_config["learning_rate"],
                     loss_function_weights=whole_model_config["vae_loss_function_weights"], 
                          batch_norm=False, optimizer="adam",
                     encoder_dropout_rate=0, decoders_dropout_rate=0)

    # Establish cell line model
    cell_line_aen = modules.AutoencoderConfigurable(whole_model_config["cell_line_encoder_layers"], whole_model_config["cell_line_decoder_layers"])

    # Establish sensitivity prediction network
    sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurable(sensitivity_prediction_network_config)

    # Establish checpoint directory
    sorted_checkpoints = sorted(os.listdir(os.path.join(runs_dir, exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
    checkpoint_path = os.path.join(runs_dir, exp, "version_0", "checkpoints", sorted_checkpoints[0])

    model = SensitivityModelGMMVAE.load_from_checkpoint(checkpoint_path,
                                                        drug_model=drug_gmm_vae, 
                                                        cell_line_model=cell_line_aen, 
                                                        sensitivity_prediction_network=sensitivity_prediction_network,
                                                        learning_rate=whole_model_config["learning_rate"],
                                                        aen_reconstruction_loss_weight=whole_model_config["aen_reconstruction_weight"],
                                                        sensitivity_prediction_weight=whole_model_config["sensitivity_prediction_weight"])

    model.eval()
    drug_model = model.drug_model
    drug_model.eval()
    
    # Get samples from multiple clusters
    cluster_numbers = [0, 1, 2]
    input_samples_all_clusters = {}
    guiding_samples_all_clusters = {}

    n_samples = 300
    assert type(model) == SensitivityModelGMMVAE
    drug_model = model.drug_model
    drug_model.eval()

    GENERATED_GUIDING_SEED = 22
    for cluster_number in cluster_numbers:
        # Get samples from single cluster
        input_reconstructions, guiding_reconstructions = utils.generate_from_cluster(drug_model, cluster_number, n_samples=n_samples,
                                                                                    seed=GENERATED_GUIDING_SEED)
        # Convert to numpy
        input_reconstructions, guiding_reconstructions = input_reconstructions.detach().numpy(), guiding_reconstructions.detach().numpy()

        input_samples_all_clusters[cluster_number] = input_reconstructions
        guiding_samples_all_clusters[cluster_number] = guiding_reconstructions
    
    all_generated_data_guiding = np.stack([guiding_samples_all_clusters[cl_num] for cl_num in guiding_samples_all_clusters]).reshape(-1, 294)
    guiding_cluster_labels = []
    for cluster_number in guiding_samples_all_clusters:
        guiding_cluster_labels = guiding_cluster_labels + [cluster_number] * guiding_samples_all_clusters[cluster_number].shape[0]
    silh_score = metrics.silhouette_score(all_generated_data_guiding, guiding_cluster_labels)
    gmm_constrained_generated_ip_silhs.append(silh_score)
    
    ##### Guiding features tables  ##########################################################################################
    
    # Create cluster-mean features and cluster-stds tables for generated guiding data
    cluster_mean_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))
    cluster_std_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))

    for cluster_number in guiding_samples_all_clusters:
        cluster_mean_features_table_guiding[cluster_number] = guiding_samples_all_clusters[cluster_number].mean(axis=0)
        cluster_std_features_table_guiding[cluster_number] = guiding_samples_all_clusters[cluster_number].std(axis=0)
        
    # Evaluate per-cluster similarity
    pearson_rs_mean = []
    rmses_mean = []
    
    pearson_rs_std = []
    rmses_std = []
    
    for cluster_number in range(true_cluster_mean_features_table_guiding.shape[0]):
        true_mean_vector = true_cluster_mean_features_table_guiding[cluster_number]
        true_std_vector = true_cluster_std_features_table_guiding[cluster_number]

        generated_mean_vector = cluster_mean_features_table_guiding[cluster_number]
        generated_std_vector = cluster_std_features_table_guiding[cluster_number]

        # Compute desired similarity metrics - mean vectors
        pearson_rs_mean.append(pearsonr(true_mean_vector, generated_mean_vector)[0])
        rmses_mean.append((((true_mean_vector - generated_mean_vector) ** 2).sum() / len(generated_mean_vector)) ** 0.5)


        # Compute desired similarity metrics - std vectors
        pearson_rs_std.append(pearsonr(true_std_vector, generated_std_vector)[0])
        rmses_std.append((((true_std_vector - generated_std_vector) ** 2).sum() / len(generated_std_vector)) ** 0.5)
    
    gmm_constrained_feature_means_rmses.append(np.mean(rmses_mean))
    gmm_constrained_feature_means_corrs.append(np.mean(pearson_rs_mean))
    
    gmm_constrained_feature_stds_rmses.append(np.mean(rmses_std))
    gmm_constrained_feature_stds_corrs.append(np.mean(pearson_rs_std))
    
generative_results_dict["generated_ip_clustering"]["gmm_vae_constrained"] = gmm_constrained_generated_ip_silhs

generative_results_dict["feature_means_rmse"]["gmm_vae_constrained"] = gmm_constrained_feature_means_rmses
generative_results_dict["feature_means_corr"]["gmm_vae_constrained"] = gmm_constrained_feature_means_corrs

generative_results_dict["feature_stds_rmse"]["gmm_vae_constrained"] = gmm_constrained_feature_stds_rmses
generative_results_dict["feature_stds_corr"]["gmm_vae_constrained"] = gmm_constrained_feature_stds_corrs

#### GMM unconstrained

In [None]:
gmm_unconstrained_generated_ip_silhs = []
gmm_unconstrained_feature_means_rmses = []
gmm_unconstrained_feature_means_corrs = []
gmm_unconstrained_feature_stds_rmses = []
gmm_unconstrained_feature_stds_corrs = []

for exp in run_names:
    runs_dir = model_filepaths["gmm_vae_unconstrained"]

    model_dir = os.path.join(runs_dir, exp, "version_0")

    with open(os.path.join(model_dir, "whole_model_config.json"), "r") as f:
        whole_model_config = json.load(f)

    with open(os.path.join(model_dir, "sensitivity_prediction_network_config.json"), "r") as f:
        sensitivity_prediction_network_config = json.load(f)

    # Load GMMVAE sensitivity model from config dict
    var_transformation = lambda x: torch.exp(x) ** 0.5

    # Establish drug model
    drug_gmm_vae = GMMVAE(whole_model_config["drug_encoder_layers"], whole_model_config["drug_input_decoder_layers"], 
                          whole_model_config["drug_guiding_decoder_layers"], 
                          whole_model_config["no_gmm_components"],
                     var_transformation=var_transformation, 
                          learning_rate=whole_model_config["learning_rate"],
                     loss_function_weights=whole_model_config["vae_loss_function_weights"], 
                          batch_norm=False, optimizer="adam",
                     encoder_dropout_rate=0, decoders_dropout_rate=0)

    # Establish cell line model
    cell_line_aen = modules.AutoencoderConfigurable(whole_model_config["cell_line_encoder_layers"], whole_model_config["cell_line_decoder_layers"])

    # Establish sensitivity prediction network
    sensitivity_prediction_network = modules.FeedForwardThreeLayersConfigurable(sensitivity_prediction_network_config)

    # Establish checpoint directory
    sorted_checkpoints = sorted(os.listdir(os.path.join(runs_dir, exp, "version_0", "checkpoints")), reverse=True, key=lambda x: int(x.split("=")[-1][:-5]))
    checkpoint_path = os.path.join(runs_dir, exp, "version_0", "checkpoints", sorted_checkpoints[0])

    model = SensitivityModelGMMVAE.load_from_checkpoint(checkpoint_path,
                                                        drug_model=drug_gmm_vae, 
                                                        cell_line_model=cell_line_aen, 
                                                        sensitivity_prediction_network=sensitivity_prediction_network,
                                                        learning_rate=whole_model_config["learning_rate"],
                                                        aen_reconstruction_loss_weight=whole_model_config["aen_reconstruction_weight"],
                                                        sensitivity_prediction_weight=whole_model_config["sensitivity_prediction_weight"])

    model.eval()
    drug_model = model.drug_model
    drug_model.eval()
    
     # Get samples from multiple clusters
    cluster_numbers = [0, 1, 2]
    input_samples_all_clusters = {}
    guiding_samples_all_clusters = {}

    n_samples = 300
    assert type(model) == SensitivityModelGMMVAE
    drug_model = model.drug_model
    drug_model.eval()

    GENERATED_GUIDING_SEED = 22
    for cluster_number in cluster_numbers:
        # Get samples from single cluster
        input_reconstructions, guiding_reconstructions = utils.generate_from_cluster(drug_model, cluster_number, n_samples=n_samples,
                                                                                    seed=GENERATED_GUIDING_SEED)
        # Convert to numpy
        input_reconstructions, guiding_reconstructions = input_reconstructions.detach().numpy(), guiding_reconstructions.detach().numpy()

        input_samples_all_clusters[cluster_number] = input_reconstructions
        guiding_samples_all_clusters[cluster_number] = guiding_reconstructions
        
    all_generated_data_guiding = np.stack([guiding_samples_all_clusters[cl_num] for cl_num in guiding_samples_all_clusters]).reshape(-1, 294)
    guiding_cluster_labels = []
    for cluster_number in guiding_samples_all_clusters:
        guiding_cluster_labels = guiding_cluster_labels + [cluster_number] * guiding_samples_all_clusters[cluster_number].shape[0]
    silh_score = metrics.silhouette_score(all_generated_data_guiding, guiding_cluster_labels)
    gmm_unconstrained_generated_ip_silhs.append(silh_score)
    
    ##### Guiding features tables  ##########################################################################################
    
    # Create cluster-mean features and cluster-stds tables for generated guiding data
    cluster_mean_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))
    cluster_std_features_table_guiding = np.zeros(shape=(len(cluster_numbers), 294))

    for cluster_number in guiding_samples_all_clusters:
        cluster_mean_features_table_guiding[cluster_number] = guiding_samples_all_clusters[cluster_number].mean(axis=0)
        cluster_std_features_table_guiding[cluster_number] = guiding_samples_all_clusters[cluster_number].std(axis=0)
        
    # Evaluate per-cluster similarity
    pearson_rs_mean = []
    rmses_mean = []
    
    pearson_rs_std = []
    rmses_std = []
    
    for cluster_number in range(true_cluster_mean_features_table_guiding.shape[0]):
        true_mean_vector = true_cluster_mean_features_table_guiding[cluster_number]
        true_std_vector = true_cluster_std_features_table_guiding[cluster_number]

        generated_mean_vector = cluster_mean_features_table_guiding[cluster_number]
        generated_std_vector = cluster_std_features_table_guiding[cluster_number]

        # Compute desired similarity metrics - mean vectors
        pearson_rs_mean.append(pearsonr(true_mean_vector, generated_mean_vector)[0])
        rmses_mean.append((((true_mean_vector - generated_mean_vector) ** 2).sum() / len(generated_mean_vector)) ** 0.5)


        # Compute desired similarity metrics - std vectors
        pearson_rs_std.append(pearsonr(true_std_vector, generated_std_vector)[0])
        rmses_std.append((((true_std_vector - generated_std_vector) ** 2).sum() / len(generated_std_vector)) ** 0.5)
    
    gmm_unconstrained_feature_means_rmses.append(np.mean(rmses_mean))
    gmm_unconstrained_feature_means_corrs.append(np.mean(pearson_rs_mean))
    
    gmm_unconstrained_feature_stds_rmses.append(np.mean(rmses_std))
    gmm_unconstrained_feature_stds_corrs.append(np.mean(pearson_rs_std))
    
   
generative_results_dict["generated_ip_clustering"]["gmm_vae_unconstrained"] = gmm_unconstrained_generated_ip_silhs

generative_results_dict["feature_means_rmse"]["gmm_vae_unconstrained"] = gmm_unconstrained_feature_means_rmses
generative_results_dict["feature_means_corr"]["gmm_vae_unconstrained"] = gmm_unconstrained_feature_means_corrs

generative_results_dict["feature_stds_rmse"]["gmm_vae_unconstrained"] = gmm_unconstrained_feature_stds_rmses
generative_results_dict["feature_stds_corr"]["gmm_vae_unconstrained"] = gmm_unconstrained_feature_stds_corrs

In [None]:
# Save numerical results
# with open("generative_results_dict.pkl", "wb") as f:
#     pickle.dump(generative_results_dict, f)

## Plot numerical evaluations of generative performance

In [None]:
# Load numerical results
with open("generative_results_dict.pkl", "rb") as f:
    generative_results_dict = pickle.load(f)

In [None]:
# Some general cosmetic params
figsize = (2.5, 2.5)
marker_size = 50
title_fontsize = 11
ax_label_fontsize = 10
alpha = 0.9
xticklabels_fontsize = 9
yticklabelsfontsize = 6

palette={0: sns.color_palette("tab10")[0], 
         1: sns.color_palette("tab10")[1], 
         2: sns.color_palette("tab10")[2]}

palette_dict = {"vanilla_vae": sns.color_palette("tab10")[7],
               "gmm_vae_constrained": sns.color_palette("tab10")[4],
               "gmm_vae_unconstrained": sns.color_palette("tab10")[6]}

default_order = ["vanilla_vae", "gmm_vae_constrained", "gmm_vae_unconstrained"]

### Latent space clustering

In [None]:
df = pd.DataFrame(columns=["Silhouette score", "run", "model"])
models = []
runs = []
silh_scores = []
for model in generative_results_dict["latent_clustering"]:
    silh_scores = silh_scores + generative_results_dict["latent_clustering"][model]
    runs = runs + [1, 2, 3, 4, 5]
    models = models + [model] * 5
    
df["Silhouette score"] = silh_scores
df["run"] = runs
df["model"] = models

In [None]:
save_dir = None

plt.figure(figsize=(1.8, 1.5))
palette = [palette_dict[default_order[0]], palette_dict[default_order[1]], palette_dict[default_order[2]]]
ax = sns.barplot(data=df, x="model", y="Silhouette score", palette=palette,
                order=default_order, ci="sd")

ax.set_xticklabels(["Vanilla VAE", "GMM VAE constrained", "GMM VAE unconstrained"],
                  rotation=45, ha="right", fontsize=xticklabels_fontsize)
ax.set_xticklabels([])
ax.set_ylabel("Silhouette score", fontsize=8)
ax.set_ylabel("")
ax.set_xlabel("")
#ax.set_yticklabels(ax.get_yticklabels())
plt.yticks(fontsize=8)

if save_dir:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.tight_layout()
plt.show()

### Generated IP clustering

In [None]:
# Create df
df = pd.DataFrame(columns=["Silhouette score", "run", "model"])
models = []
runs = []
silh_scores = []
for model in generative_results_dict["generated_ip_clustering"]:
    silh_scores = silh_scores + generative_results_dict["generated_ip_clustering"][model]
    runs = runs + [1, 2, 3, 4, 5]
    models = models + [model] * 5
    
df["Silhouette score"] = silh_scores
df["run"] = runs
df["model"] = models

In [None]:
save_dir = None

plt.figure(figsize=(1.5, 1.5))
palette = [palette_dict[default_order[1]], palette_dict[default_order[2]]]
ax = sns.barplot(data=df, x="model", y="Silhouette score", palette=palette,
                order=default_order[1:], ci="sd")

ax.set_ylabel("Silhouette score", fontsize=8)
ax.set_ylabel("")
ax.set_xlabel("")
#ax.set_yticklabels(ax.get_yticklabels())
plt.xticks(rotation=90)
ax.set_xticklabels([])
plt.yticks(fontsize=8)

if save_dir:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.tight_layout()
plt.show()

### Feature maps - RMSE

In [None]:
# Create df
df = pd.DataFrame(columns=["RMSE", "run", "model"])
models = []
runs = []
rmses = []
modes = []
for model in generative_results_dict["feature_means_rmse"]:
    rmses = rmses + generative_results_dict["feature_means_rmse"][model]
    runs = runs + [1, 2, 3, 4, 5]
    models = models + [model] * 5
    modes = modes + ["mean"] * 5
    
for model in generative_results_dict["feature_stds_rmse"]:
    rmses = rmses + generative_results_dict["feature_stds_rmse"][model]
    runs = runs + [1, 2, 3, 4, 5]
    models = models + [model] * 5
    modes = modes + ["STD"] * 5
    
df["RMSE"] = rmses
df["run"] = runs
df["model"] = models
df["mode"] = modes

In [None]:
save_dir = None

plt.figure(figsize=(1.8, 1.5))
palette = [palette_dict[default_order[1]], palette_dict[default_order[2]]]
palette = [sns.color_palette()[i] for i in range(2, 6)]

palette = [sns.color_palette("viridis")[-3], sns.color_palette("rocket")[-3]]
#palette = [sns.color_palette("Greens")[-4], sns.color_palette("Oranges")[-4]]


ax = sns.barplot(data=df, x="model", y="RMSE", palette=palette, 
                order=default_order[1:], ci="sd", hue="mode")

ax.set_ylabel("Silhouette score", fontsize=8)
ax.set_ylabel("")
ax.set_xlabel("")
plt.xticks(rotation=90)
ax.set_xticklabels([])
plt.yticks(fontsize=8)

ax.legend_.remove()

if save_dir:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.tight_layout()
plt.show()

### Feature maps - corr

In [None]:
# Create df
df = pd.DataFrame(columns=["Pearson", "run", "model"])
models = []
runs = []
corrs = []
modes = []
for model in generative_results_dict["feature_means_corr"]:
    corrs = corrs + generative_results_dict["feature_means_corr"][model]
    runs = runs + [1, 2, 3, 4, 5]
    models = models + [model] * 5
    modes = modes + ["mean"] * 5
    
for model in generative_results_dict["feature_stds_corr"]:
    corrs = corrs + generative_results_dict["feature_stds_corr"][model]
    runs = runs + [1, 2, 3, 4, 5]
    models = models + [model] * 5
    modes = modes + ["STD"] * 5
    
df["Pearson"] = corrs
df["run"] = runs
df["model"] = models
df["mode"] = modes

In [None]:
save_dir = None

plt.figure(figsize=(1.8, 1.5))
palette = [palette_dict[default_order[1]], palette_dict[default_order[2]]]
palette = [sns.color_palette()[i] for i in range(2, 6)]

palette = [sns.color_palette("viridis")[-3], sns.color_palette("rocket")[-3]]


ax = sns.barplot(data=df, x="model", y="Pearson", palette=palette, 
                order=default_order[1:], ci="sd", hue="mode")

ax.set_ylabel("Silhouette score", fontsize=8)
ax.set_ylabel("")
ax.set_xlabel("")
plt.xticks(rotation=90)
ax.set_xticklabels([])
plt.yticks(fontsize=8)

ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))


if save_dir:
    plt.tight_layout()
    plt.savefig(save_dir)

plt.show()