## Load libraries and model


- Firt step: Load libraries


In [None]:
import torch
from gnnepcsaft.train.utils import build_train_dataset, build_test_dataset
from gnnepcsaft.train.models import GNNePCSAFTL, HabitchNNL
from torch_geometric.loader import DataLoader
from rdkit import Chem
from gnnepcsaft.demo.utils import (
    plotdata,
    plotparams,
)
import xgboost as xgb
import joblib


torch.cuda.is_available()

- Second step: Load data


In [None]:
es = build_train_dataset("gnnepcsaft", "esper")
tml_val, tml_train = build_test_dataset("gnnepcsaft", es)
es_loader = DataLoader(es, batch_size=512, shuffle=True)
tml_loader = DataLoader(tml_train, batch_size=len(tml_train), shuffle=False)

- Third Step: Define and load models

  - You can get one checkpoint from [Hugging Face](https://huggingface.co/wildsonbbl/gnnepcsaft).


In [None]:
model = GNNePCSAFTL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/model-hlrn7lqv.ckpt",
    "cpu",
).eval()
model_assoc = GNNePCSAFTL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/assoc_model-j7isfrga.ckpt",
    "cpu",
).eval()

rf_msigmae = joblib.load("./gnnepcsaft/train/checkpoints/rf_model.joblib")
xgb_msigmae = xgb.Booster()
xgb_msigmae.load_model("./gnnepcsaft/train/checkpoints/xgb_model.json")
habitch_msigmae = HabitchNNL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/model-u62sbl40.ckpt", "cpu"
).eval()

## Plotting


- Plot to check performance on ThermoML Archive data


In [None]:
inchi = "InChI=1S/C2H6O/c1-2-3/h3H,2H2,1H3"

molecule_name = "UNKNOWN"
plotdata(
    inchi,
    molecule_name,
    [model_assoc.model, model.model, habitch_msigmae.model],
    model.model,
)

In [None]:
inchi = "InChI=1S/C8H15N2.C2F6NO4S2/c1-3-4-5-10-7-6-9(2)8-10;3-1(4,5)14(10,11)9-15(12,13)2(6,7)8/h6-8H,3-5H2,1-2H3;/q+1;-1"

molecule_name = "bmin-tf2n"
plotdata(
    inchi,
    molecule_name,
    [model_assoc.model, habitch_msigmae.model],
    model.model,
)

In [None]:
inchi = "InChI=1S/C8H15N2.BF4/c1-3-4-5-10-7-6-9(2)8-10;2-1(3,4)5/h6-8H,3-5H2,1-2H3;/q+1;-1"

molecule_name = "bmin-bf4"
plotdata(
    inchi,
    molecule_name,
    [model_assoc.model, habitch_msigmae.model],
    model.model,
)

In [None]:
inchi = "InChI=1S/C6H11N2.C2F6NO4S2/c1-3-8-5-4-7(2)6-8;3-1(4,5)14(10,11)9-15(12,13)2(6,7)8/h4-6H,3H2,1-2H3;/q+1;-1"

molecule_name = "emin-tf2n"
plotdata(
    inchi,
    molecule_name,
    [model_assoc.model, habitch_msigmae.model],
    model.model,
)

- Check mol structure


In [None]:
mol = Chem.MolFromInchi(inchi, removeHs=False, sanitize=True)
n_oh = len(mol.GetSubstructMatches(Chem.MolFromSmiles("O")))
print(Chem.MolToSmiles(mol, isomericSmiles=True))

mol

- Plots for parameters trend against chain length


In [None]:
n = 15
list_smiles = [
    ["C" * i + "C(=O)C" for i in range(1, n)],  # ketone
    ["C" * i + "C(=O)OCC" for i in range(1, n)],  # ester
    ["C" * i + "CC1=CC=CC=C1" for i in range(1, n)],  # benzene
    ["C" * i + "C[N+](=O)[O-] " for i in range(1, n)],  # 1-nitro alkane
]
list_xlabel = [
    r"$[C]_n(=O)C$",
    r"$[C]_n(=O)OCC$",
    r"$[C]_nC1=CC=CC=C1$",
    r"$[C]_nC[N+](=O)[O-]$",
]

fig, axs = plotparams(list_smiles, [model.model], list_xlabel)

In [None]:
n = 10
list_smiles = [
    ["C" * i + "C" for i in range(1, n)],  # alkanes
    ["C" * i + "CO" for i in range(1, n)],  # alcohols
    ["C" * i + "C(=O)O" for i in range(1, n)],  # acid
    ["C" * i + "OC" for i in range(1, n)],  # ether
]
list_xlabel = [
    r"$C_n$",
    r"$C_nO$",
    r"$C_nC(=O)O$",
    r"$C_nOC$",
]
fig, axs = plotparams(list_smiles, [model.model], list_xlabel=list_xlabel)

## ONNX save


- Save as onnx model


In [None]:
from gnnepcsaft.demo.utils import save_exported_program

example_input = es[0].x, es[0].edge_index, es[0].edge_attr, es[0].batch  # type: ignore
exp_msigmae = save_exported_program(model.model, example_input, "../msigmae_7.onnx")
exp_assoc = save_exported_program(model_assoc.model, example_input, "../assoc_8.onnx")