In [None]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
from ivae_scorer.utils import set_all_seeds
import seaborn as sns
import dotenv
from pathlib import Path
import pandas as pd

project_path = Path(dotenv.find_dotenv()).parent
results_path = project_path.joinpath("results")
data_path = project_path.joinpath("data")
figs_path = results_path.joinpath("figs")
tables_path = results_path.joinpath("tables")

set_all_seeds(seed=42)

tf.config.experimental.enable_op_determinism()

In [None]:
models = ["ivae_kegg", "ivae_reactome"]

In [None]:
# Metrics

metric_scores = [
    pd.read_pickle(results_path.joinpath(m, "scores_metrics.pkl")) for m in models
]
metric_scores = pd.concat(metric_scores, axis=0, ignore_index=True)
metric_scores.head()

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(context="paper", font_scale=2, style="ticks", rc=custom_params)
fac = 0.6

g = sns.catplot(
    data=metric_scores,
    kind="violin",
    col="metric",
    height=9 * fac,
    aspect=16 / 9 * fac,
    sharey=False,
    x="model",
    y="score",
    hue="split",
    split=False,
    cut=0,
    fill=False,
    density_norm="count",
    inner="quart",
    linewidth=2,
    legend_out=False,
)

sns.move_legend(
    g,
    "lower center",
    bbox_to_anchor=(0.5, 0.9),
    ncol=3,
    title=None,
    frameon=False,
)

In [None]:
# clustering

clustering_scores = [
    pd.read_pickle(results_path.joinpath(m, "scores_clustering.pkl")) for m in models
]
clustering_scores = pd.concat(clustering_scores, axis=0, ignore_index=True)
clustering_scores.head()

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(context="paper", font_scale=2, style="ticks", rc=custom_params)
fac = 0.6

g = sns.catplot(
    data=clustering_scores,
    kind="violin",
    col="model",
    height=9 * fac,
    aspect=16 / 9 * fac,
    sharey=True,
    x="layer",
    y="score",
    hue="split",
    split=False,
    cut=0,
    fill=False,
    density_norm="count",
    inner="quart",
    linewidth=2,
    legend_out=False,
)

sns.move_legend(
    g,
    "lower center",
    bbox_to_anchor=(0.5, 0.9),
    ncol=3,
    title=None,
    frameon=False,
)

In [None]:
# informed

informed_scores = [
    pd.read_pickle(results_path.joinpath(m, "scores_informed.pkl")) for m in models
]
informed_scores = pd.concat(informed_scores, axis=0, ignore_index=True)
informed_scores.head()

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(context="paper", font_scale=2, style="ticks", rc=custom_params)
fac = 0.6

g = sns.catplot(
    data=informed_scores,
    kind="violin",
    col="model",
    height=9 * fac,
    aspect=16 / 9 * fac,
    sharey=True,
    x="layer",
    y="score",
    hue="split",
    split=False,
    cut=0,
    fill=False,
    density_norm="count",
    inner="quart",
    linewidth=2,
    legend_out=False,
)

sns.move_legend(
    g,
    "lower center",
    bbox_to_anchor=(0.5, 0.9),
    ncol=3,
    title=None,
    frameon=False,
)