## Load libraries and model


- Firt step: Load libraries


In [None]:
import torch
from gnnepcsaft.train.utils import build_train_dataset, build_test_dataset
from gnnepcsaft.train.models import GNNePCSAFTL, HabitchNNL
from torch_geometric.loader import DataLoader
from rdkit import Chem
from gnnepcsaft.demo.utils import (
    _save_plot,
    plotdata,
    plotparams,
    predict_params_from_inchi,
    es_para,
)
from gnnepcsaft.demo.utils_binary import (
    binary_rho_plot,
    mape_rho,
)
import xgboost as xgb
import joblib
import polars as pl
from gnnepcsaft.data.rdkit_util import mw


torch.cuda.is_available()

- Second step: Load data


In [None]:
es = build_train_dataset("gnnepcsaft", "esper")
tml_val, tml_train = build_test_dataset("gnnepcsaft", es)
es_loader = DataLoader(es, batch_size=512, shuffle=True)
tml_loader = DataLoader(tml_train, batch_size=len(tml_train), shuffle=False)
rho_binary_tml = pl.read_parquet("./gnnepcsaft/data/thermoml/raw/rho_binary.parquet")

- Third Step: Define and load models
  - You can get one checkpoint from [Hugging Face](https://huggingface.co/wildsonbbl/gnnepcsaft).


In [None]:
pna_msigmae = GNNePCSAFTL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/model-hlrn7lqv.ckpt", "cuda", weights_only=False
).eval()
pna_assoc = GNNePCSAFTL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/assoc_model-j7isfrga.ckpt",
    "cuda",
    weights_only=False,
).eval()

rf_msigmae = joblib.load("./gnnepcsaft/train/checkpoints/rf_model.joblib")
xgb_msigmae = xgb.Booster()
xgb_msigmae.load_model("./gnnepcsaft/train/checkpoints/xgb_model.json")
habitch_msigmae = HabitchNNL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/model-u62sbl40.ckpt", "cuda", weights_only=False
).eval()

## Plotting


- Plot to check performance on ThermoML Archive data


### Pure


In [None]:
inchi = (
    "InChI=1S/C4H8O2.C3H9NO/c1-2-3-4(5)6;1-4-2-3-5/h2-3H2,1H3,(H,5,6);4-5H,2-3H2,1H3"
)

pna = predict_params_from_inchi(inchi, pna_assoc.model, pna_msigmae.model)
habitch = predict_params_from_inchi(inchi, pna_assoc.model, habitch_msigmae.model)
esper = es_para.get(inchi, None)
esper = esper and torch.hstack(esper).tolist()[0]

molecule_name = "Butanoic_acid_2-(methylamino)ethanol"
plotdata(
    inchi,
    molecule_name,
    list_params=[pna, habitch, esper] if esper else [pna, habitch],
)

In [None]:
inchi = "InChI=1S/C8H15N2.C2F6NO4S2/c1-3-4-5-10-7-6-9(2)8-10;3-1(4,5)14(10,11)9-15(12,13)2(6,7)8/h6-8H,3-5H2,1-2H3;/q+1;-1"

pna = predict_params_from_inchi(inchi, pna_assoc.model, pna_msigmae.model)
habitch = predict_params_from_inchi(inchi, pna_assoc.model, habitch_msigmae.model)
esper = es_para.get(inchi, None)
esper = esper and torch.hstack(esper).tolist()[0]


molecule_name = "bmin-tf2n"
plotdata(
    inchi,
    molecule_name,
    list_params=[pna, habitch, esper] if esper else [pna, habitch],
)

In [None]:
inchi = (
    "InChI=1S/C8H15N2.BF4/c1-3-4-5-10-7-6-9(2)8-10;2-1(3,4)5/h6-8H,3-5H2,1-2H3;/q+1;-1"
)

pna = predict_params_from_inchi(inchi, pna_assoc.model, pna_msigmae.model)
habitch = predict_params_from_inchi(inchi, pna_assoc.model, habitch_msigmae.model)
esper = es_para.get(inchi, None)
esper = esper and torch.hstack(esper).tolist()[0]

molecule_name = "bmin-bf4"
plotdata(
    inchi,
    molecule_name,
    list_params=[pna, habitch, esper] if esper else [pna, habitch],
)

In [None]:
inchi = "InChI=1S/C6H11N2.C2F6NO4S2/c1-3-8-5-4-7(2)6-8;3-1(4,5)14(10,11)9-15(12,13)2(6,7)8/h4-6H,3H2,1-2H3;/q+1;-1"

pna = predict_params_from_inchi(inchi, pna_assoc.model, pna_msigmae.model)
habitch = predict_params_from_inchi(inchi, pna_assoc.model, habitch_msigmae.model)
esper = es_para.get(inchi, None)
esper = esper and torch.hstack(esper).tolist()[0] + [mw(inchi)]

molecule_name = "emin-tf2n"
plotdata(
    inchi,
    molecule_name,
    list_params=[pna, habitch, esper] if esper else [pna, habitch],
)

- Check mol structure


In [None]:
mol = Chem.MolFromInchi(inchi, removeHs=False, sanitize=True)
n_oh = len(mol.GetSubstructMatches(Chem.MolFromSmiles("O")))
print(Chem.MolToSmiles(mol, isomericSmiles=True))

mol

- Plots for parameters trend against chain length


In [None]:
n = 15
list_smiles = [
    ["C" * i + "C(=O)C" for i in range(1, n)],  # ketone
    ["C" * i + "C(=O)OCC" for i in range(1, n)],  # ester
    ["C" * i + "CC1=CC=CC=C1" for i in range(1, n)],  # benzene
    ["C" * i + "C[N+](=O)[O-]" for i in range(1, n)],  # 1-nitro alkane
]
list_xlabel = [
    r"$[C]_n(=O)C$",
    r"$[C]_n(=O)OCC$",
    r"$[C]_nC1=CC=CC=C1$",
    r"$[C]_nC[N+](=O)[O-]$",
]

fig, axs = plotparams(list_smiles, [pna_msigmae.model], list_xlabel)

In [None]:
n = 10
list_smiles = [
    ["C" * i + "C" for i in range(1, n)],  # alkanes
    ["C" * i + "CO" for i in range(1, n)],  # alcohols
    ["C" * i + "C(=O)O" for i in range(1, n)],  # acid
    ["C" * i + "OC" for i in range(1, n)],  # ether
]
list_xlabel = [
    r"$C_n$",
    r"$C_nO$",
    r"$C_nC(=O)O$",
    r"$C_nOC$",
]
fig, axs = plotparams(list_smiles, [pna_msigmae.model], list_xlabel=list_xlabel)

### Binary


In [None]:
inchi1 = "InChI=1S/C5H14NO.ClH/c1-6(2,3)4-5-7;/h7H,4-5H2,1-3H3;1H/q+1;/p-1"
inchi2 = "InChI=1S/CH4N2O/c2-1(3)4/h(H4,2,3,4)"

mw1 = mw(inchi1)
mw2 = mw(inchi2)

rho_data = rho_binary_tml.filter(
    pl.col("inchi1") == inchi1, pl.col("inchi2") == inchi2
).filter(pl.col("P_kPa") == 100.0, pl.col("mole_fraction_c1") == 0.3333)

pna_inchi1 = predict_params_from_inchi(inchi1, pna_assoc.model, pna_msigmae.model)
habitch_inchi1 = predict_params_from_inchi(
    inchi1, pna_assoc.model, habitch_msigmae.model
)
pna_inchi2 = predict_params_from_inchi(inchi2, pna_assoc.model, pna_msigmae.model)
habitch_inchi2 = predict_params_from_inchi(
    inchi2, pna_assoc.model, habitch_msigmae.model
)

list_params = [[pna_inchi1, pna_inchi2], [habitch_inchi1, habitch_inchi2]]
binary_rho_plot(rho_data, mw1, mw2, list_params)
_save_plot("fig10.png")

In [None]:
rho_binary_tml.filter(pl.col("inchi1") == inchi1, pl.col("inchi2") == inchi2).group_by(
    pl.col("P_kPa"),
    pl.col("mole_fraction_c1"),
).len().sort("len", descending=True)

In [None]:
mape_rho(rho_data, mw1, mw2, list_params)

In [None]:
inchi1 = "InChI=1S/C5H14NO.ClH/c1-6(2,3)4-5-7;/h7H,4-5H2,1-3H3;1H/q+1;/p-1"
inchi2 = "InChI=1S/C3H4O4/c4-2(5)1-3(6)7/h1H2,(H,4,5)(H,6,7)"

mw1 = mw(inchi1)
mw2 = mw(inchi2)

rho_data = rho_binary_tml.filter(
    pl.col("inchi1") == inchi1, pl.col("inchi2") == inchi2
).filter(pl.col("P_kPa") == 101.0, pl.col("mole_fraction_c1") == 0.3333)

pna_inchi1 = predict_params_from_inchi(inchi1, pna_assoc.model, pna_msigmae.model)
habitch_inchi1 = predict_params_from_inchi(
    inchi1, pna_assoc.model, habitch_msigmae.model
)
pna_inchi2 = predict_params_from_inchi(inchi2, pna_assoc.model, pna_msigmae.model)
habitch_inchi2 = predict_params_from_inchi(
    inchi2, pna_assoc.model, habitch_msigmae.model
)

list_params = [[pna_inchi1, pna_inchi2], [habitch_inchi1, habitch_inchi2]]
binary_rho_plot(rho_data, mw1, mw2, list_params)
_save_plot("fig11.png")

In [None]:
rho_binary_tml.filter(pl.col("inchi1") == inchi1, pl.col("inchi2") == inchi2).group_by(
    pl.col("P_kPa"),
    pl.col("mole_fraction_c1"),
).len().sort("len", descending=True)

In [None]:
mape_rho(rho_data, mw1, mw2, list_params)

## ONNX save


- Save as onnx model


In [None]:
from gnnepcsaft.data.graphdataset import ThermoMLDataset

tml_all = ThermoMLDataset("./gnnepcsaft/data/thermoml").to("cuda")
tml_all_loader = DataLoader(tml_all, len(tml_all))
gh = next(iter(tml_all_loader))
example_input = gh.x, gh.edge_index, gh.edge_attr, None  # type: ignore

In [None]:
from gnnepcsaft.demo.utils import save_exported_program

exp_msigmae = save_exported_program(
    pna_msigmae.model, example_input, "../msigmae_7.onnx"
)

exp_assoc = save_exported_program(pna_assoc.model, example_input, "../assoc_8.onnx")

In [None]:
from gnnepcsaft.demo.utils import test_onnx


test_onnx(
    tml_all,  # type: ignore
    "../msigmae_7.onnx",
    "../assoc_8.onnx",
    pna_msigmae.model,
    pna_assoc.model,
)

## SI info


In [None]:
tml = pl.read_parquet("./gnnepcsaft/data/thermoml/raw/pure.parquet")
choline_mix = pl.read_csv("../choline_mix.csv")

In [None]:
tml_grouped = (
    tml.group_by("c1", "inchi1")
    .agg(
        pl.col("TK").min().alias("T_K_min"),
        pl.col("TK").max().alias("T_K_max"),
        (pl.col("tp") == 1).sum().alias("Density_n_points"),
        pl.col("m").filter(pl.col("tp") == 1).min().alias("Density_kg_m3_min"),
        pl.col("m").filter(pl.col("tp") == 1).max().alias("Density_kg_m3_max"),
        pl.col("PPa").filter(pl.col("tp") == 1).min().alias("Density_P_Pa_min"),
        pl.col("PPa").filter(pl.col("tp") == 1).max().alias("Density_P_Pa_max"),
        (pl.col("tp") == 3).sum().alias("VP_n_points"),
        pl.col("m").filter(pl.col("tp") == 3).min().alias("VP_Pa_min"),
        pl.col("m").filter(pl.col("tp") == 3).max().alias("VP_Pa_max"),
    )
    .rename({"c1": "chemical_name", "inchi1": "InChI"})
    .sort("chemical_name")
)

In [None]:
from ctypes import ArgumentError


pcsaft_params = [
    [
        *tml_grouped.columns,
        "m",
        "sigma",
        "mu_kb",
        "kappa_AB",
        "epsilon_kb_AB",
        "dipole_moment",
        "na",
        "nb",
        "molecular_weight_g_mol",
    ]
]
for row in tml_grouped.iter_rows(named=False):
    try:
        pcsaft_params.append(
            list(row)
            + predict_params_from_inchi(
                inchi=row[1],
                model_assoc=pna_assoc.model,
                model_msigmae=pna_msigmae.model,
                device="cuda",
            )
        )
    except (ArgumentError, ValueError) as e:
        print(e, *row[:2])
        pass

In [None]:
inchis = tml_grouped["InChI"].to_list()
for c, inchi in (
    choline_mix.select("c2", "inchi2")
    .rename({"c2": "c1", "inchi2": "inchi1"})
    .vstack(choline_mix.select("c1", "inchi1"))
    .unique()
    .sort("c1")
    .filter(~pl.col("inchi1").is_in(inchis))
    .iter_rows()
):
    mol_parameters = predict_params_from_inchi(
        inchi,
        model_assoc=pna_assoc.model,
        model_msigmae=pna_msigmae.model,
    )
    pcsaft_params.append(
        [c, inchi, None, None, 0, None, None, None, None, 0, None, None]
        + mol_parameters
    )

In [None]:
# save to csv
import csv

with open("../SI_thermoml_pcsaft_params.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerows(pcsaft_params)