## Load libraries + define functions


In [None]:
from gnnepcsaft.data.molfamily import get_family_groups, complexity
from gnnepcsaft.data.graphdataset import ThermoMLDataset, Esper
from gnnepcsaft.data.graph import assoc_number
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
from matplotlib.figure import Figure
from matplotlib_venn import venn2
from typing import Any, Tuple
from tqdm import tqdm

LABEL_FS = 9
TICKS_FS = 10
TITLE_FS = 11
mpl.rcParams.update(
    {
        "font.size": 11,
        "axes.titlesize": TITLE_FS,
        "axes.labelsize": LABEL_FS,
        "xtick.labelsize": TICKS_FS,
        "ytick.labelsize": TICKS_FS,
    }
)

sns.set_theme(style="ticks")

test_dt = ThermoMLDataset("gnnepcsaft/data/thermoml")
train_dt = Esper("gnnepcsaft/data/esper2023")
gc_pcsaft = pl.read_csv("../refgc_pcsaft.csv")
winter = pl.read_csv("../ref_winter.csv")

In [None]:
def generate_family_group_distribution(
    df: pl.DataFrame, threshold: float = 21
) -> Tuple[Figure, Any]:
    label, count = (
        df.group_by(pl.col("famgroup"))
        .agg(pl.len())
        .sort("len", descending=True)
        .with_columns(
            pl.when(pl.col("len") > threshold)
            .then(pl.col("famgroup"))
            .otherwise(pl.lit("other"))
            .alias("famgroup")
        )
        .group_by("famgroup")
        .agg(pl.sum("len"))
        .sort("len", descending=True)
    )
    fig = plt.figure(figsize=(3.3, 3.3), dpi=600)
    pie = plt.pie(
        x=count,
        explode=[0.01] * len(count),
        labels=label.to_list(),
        colors=sns.color_palette("Set2"),
        rotatelabels=True,
        autopct="%1.1f%%",
        textprops={"fontsize": 8},
    )
    return fig, pie

In [None]:
def violin_plot_mw(df: pl.DataFrame):
    fig = plt.figure(figsize=(6.85, 6.85), dpi=600)
    sns.violinplot(
        data=df,
        hue="split",
        x="split",
        y="mw",
        inner=None,
        cut=0,
        palette="Set2",
    )
    # sns.stripplot(
    #     data=df,
    #     hue="split",
    #     x="split",
    #     y="mw",
    #     palette="dark:0.3",
    #     alpha=0.05,
    #     jitter=0.33,
    # )
    plt.ylabel(r"molecular weight (g mol$^{-1}$)")
    plt.xlabel("")
    plt.tight_layout()
    return fig

In [None]:
def violin_plot_comp(df: pl.DataFrame):
    fig = plt.figure(figsize=(6.85, 6.85), dpi=600)
    sns.violinplot(
        data=df,
        hue="split",
        x="split",
        y="comp",
        inner=None,
        cut=0,
        palette="Set2",
    )
    # sns.stripplot(
    #     data=df,
    #     hue="split",
    #     x="split",
    #     y="mw",
    #     palette="dark:0.3",
    #     alpha=0.05,
    #     jitter=0.33,
    # )
    plt.ylabel(r"Complexity")
    plt.xlabel("")
    plt.tight_layout()
    return fig

In [None]:
def data_venn_diagram(train_df: pl.DataFrame, test_df: pl.DataFrame):

    train_smiles = set(train_df.select(pl.col("inchis")).to_series().to_list())
    test_smiles = set(test_df.select(pl.col("inchis")).to_series().to_list())

    fig = plt.figure(figsize=(3.3, 3.3), dpi=600)
    v2 = venn2(
        subsets=(train_smiles, test_smiles),
        set_labels=("", ""),
        set_colors=("C0", "C1"),
        alpha=0.5,
    )
    for text in v2.set_labels:
        text.set_fontsize(LABEL_FS)
    for text in v2.subset_labels:
        if text:
            text.set_fontsize(TICKS_FS)
    return fig

## Get family groups data


In [None]:
train_families = {
    "inchis": [],
    "famgroup": [],
    "na": [],
    "nb": [],
    "mw": [],
    "split": [],
}
for graph in train_dt:
    inchi = graph.InChI
    famgroup = get_family_groups(inchi)
    na, nb = assoc_number(inchi)
    if len(famgroup) > 1:
        famgroup = ["polyfunctional"]
    train_families["inchis"].append(inchi)
    train_families["famgroup"].append(famgroup[0])
    train_families["na"].append(na)
    train_families["nb"].append(nb)
    train_families["mw"].append(graph.mw.item())
    train_families["split"].append("train")

df_train_families = pl.DataFrame(train_families)
df_train_families.write_csv("../esper_et_al_families.csv")

In [None]:
val_test_families = {
    "inchis": [],
    "famgroup": [],
    "na": [],
    "nb": [],
    "mw": [],
    "split": [],
}
for graph in test_dt:
    inchi = graph.InChI
    famgroup = get_family_groups(inchi)
    na, nb = assoc_number(inchi)
    if len(famgroup) > 1:
        famgroup = ["polyfunctional"]
    val_test_families["inchis"].append(inchi)
    val_test_families["famgroup"].append(famgroup[0])
    val_test_families["na"].append(na)
    val_test_families["nb"].append(nb)
    val_test_families["mw"].append(graph.mw.item())
    val_test_families["split"].append(
        "val" if inchi in train_families["inchis"] else "test"
    )

df_val_test_families = pl.DataFrame(val_test_families)
df_val_test_families.write_csv("../thermoml_families.csv")

In [None]:
df_val_test_families.filter(pl.col("mw") > 1000).sort("mw", descending=True)

## Data plots


In [None]:
fig = data_venn_diagram(df_train_families, df_val_test_families)
fig.savefig(
    "images/train_test_venn.png",
    dpi=600,
    format="png",
    bbox_inches="tight",
    transparent=False,
)

In [None]:
fig = violin_plot_mw(df_train_families.vstack(df_val_test_families))
fig.savefig(
    "images/mw_violin.png",
    dpi=600,
    format="png",
    bbox_inches="tight",
    transparent=False,
)

In [None]:
df = df_val_test_families.filter(pl.col("split") == "test")
fig, pie = generate_family_group_distribution(df)
fig.savefig(
    "images/test_pie.png", dpi=600, format="png", bbox_inches="tight", transparent=False
)

In [None]:
df = df_val_test_families.filter(pl.col("split") == "val")
fig, pie = generate_family_group_distribution(df)
fig.savefig(
    "images/val_pie.png", dpi=600, format="png", bbox_inches="tight", transparent=False
)

In [None]:
df = df_train_families
fig, pie = generate_family_group_distribution(df)
fig.savefig(
    "images/train_pie.png",
    dpi=600,
    format="png",
    bbox_inches="tight",
    transparent=False,
)

In [None]:
df = df_val_test_families.filter(
    pl.col("inchis").is_in(winter["inchis"].to_list()), pl.col("split") == "test"
)
fig, pie = generate_family_group_distribution(df, 9)
fig.savefig(
    "images/winter_pie.png",
    dpi=600,
    format="png",
    bbox_inches="tight",
    transparent=False,
)

In [None]:
df = df_val_test_families.filter(
    pl.col("inchis").is_in(gc_pcsaft["inchis"].to_list()), pl.col("split") == "test"
)
fig, pie = generate_family_group_distribution(df, 0)
fig.savefig(
    "images/gc_pie.png", dpi=600, format="png", bbox_inches="tight", transparent=False
)

In [None]:
df = (
    df_val_test_families.filter(
        pl.col("inchis").is_in(gc_pcsaft["inchis"].to_list()), pl.col("split") == "test"
    )
    .with_columns(pl.lit("GC PCSAFT subset").alias("split"))
    .vstack(
        df_val_test_families.filter(
            pl.col("inchis").is_in(winter["inchis"].to_list()),
            pl.col("split") == "test",
        ).with_columns(pl.lit("Winter et al. subset").alias("split"))
    )
)

fig = violin_plot_mw(df)
fig.savefig(
    "images/mw_violin_test_subsets.png",
    dpi=600,
    format="png",
    bbox_inches="tight",
    transparent=False,
)

## Complexity


In [None]:
complexity_bkp_df = pl.read_csv("../complexity.csv")

In [None]:
mol_complexity = {"inchis": [], "complexity": []}
for row in tqdm(
    df_train_families.vstack(df_val_test_families)
    .unique("inchis")
    .filter(~pl.col("inchis").is_in(complexity_bkp_df["inchis"].to_list()))
    .iter_rows(named=True)
):
    inchi = row["inchis"]
    comp = complexity(inchi)
    if isinstance(comp, float):
        mol_complexity["inchis"].append(inchi)
        mol_complexity["complexity"].append(comp)
    else:
        print(f"not float: {comp}")

In [None]:
complexity_df = pl.DataFrame(
    mol_complexity, schema={"inchis": str, "complexity": float}
)
df = complexity_bkp_df.vstack(complexity_df)
df.write_csv("../complexity.csv")

In [None]:
complexity_dict = {}
for row in df.iter_rows(named=True):
    complexity_dict[row["inchis"]] = row["complexity"]

train_val_test_with_comp = df_train_families.vstack(df_val_test_families).with_columns(
    pl.col("inchis")
    .replace_strict(complexity_dict, default=None, return_dtype=pl.Float32)
    .alias("comp")
)

In [None]:
fig = violin_plot_comp(
    train_val_test_with_comp.filter(
        pl.col("comp").is_finite() & ~pl.col("comp").is_null()
    )
)
fig.savefig(
    "images/comp_violin.png",
    dpi=600,
    format="png",
    bbox_inches="tight",
    transparent=False,
)

In [None]:
df = (
    train_val_test_with_comp.filter(
        pl.col("inchis").is_in(gc_pcsaft["inchis"].to_list()),
        pl.col("split") == "test",
    )
    .with_columns(pl.lit("GC PCSAFT subset").alias("split"))
    .vstack(
        train_val_test_with_comp.filter(
            pl.col("inchis").is_in(winter["inchis"].to_list()),
            pl.col("split") == "test",
        ).with_columns(pl.lit("Winter et al. subset").alias("split"))
    )
)

fig = violin_plot_comp(df)
fig.savefig(
    "images/comp_violin_test_subsets.png",
    dpi=600,
    format="png",
    bbox_inches="tight",
    transparent=False,
)