## Load libraries and model


Firt step: Load libraries


In [None]:
import torch
from gnnepcsaft.train.utils import build_train_dataset, build_test_dataset
from gnnepcsaft.train.models import GNNePCSAFTL
from torch_geometric.loader import DataLoader
from rdkit import Chem
from gnnepcsaft.demo.utils import (
    plotdata,
    plotparams,
)

torch.cuda.is_available()

Second step: Load data


In [None]:
es = build_train_dataset("gnnepcsaft", "esper")
tml_val, tml_train = build_test_dataset("gnnepcsaft", es)
es_loader = DataLoader(es, batch_size=512, shuffle=True)
tml_loader = DataLoader(tml_train, batch_size=len(tml_train), shuffle=False)

Third Step: Define and load models

- You can get one checkpoint from [Hugging Face](https://huggingface.co/wildsonbbl/gnnepcsaft).


In [None]:
model = GNNePCSAFTL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/pna_msigmae_1.0-epoch=19999-mape_den.ckpt",
    "cpu",
).eval()
model_assoc = GNNePCSAFTL.load_from_checkpoint(
    "./gnnepcsaft/train/checkpoints/gatv2_assoc_1.0-epoch=16874-mape_den.ckpt",
    "cpu",
).eval()

## plotting


In [None]:
inchi = input("InChI: ")  # InChI=1S/C2H6O/c1-2-3/h3H,2H2,1H3

molecule_name = "UNKNOWN"
plotdata(
    inchi,
    molecule_name,
    [model_assoc.model, model.model],
    model.model,
)

In [None]:
mol = Chem.MolFromInchi(inchi, removeHs=False, sanitize=True)
n_pyrimidine = len(mol.GetSubstructMatches(Chem.MolFromSmiles("O")))
print(Chem.MolToSmiles(mol, isomericSmiles=True))

mol

In [None]:
smiles = ["C" * i + "(=O)OCC" for i in range(1, 50)]
plotparams(smiles, [model.model], r"$C_n(=O)OCC$")

In [None]:
smiles = ["C" * i for i in range(1, 50)]
plotparams(smiles, [model.model], r"$C_n$")

In [None]:
smiles = ["C" * i + "O" for i in range(1, 50)]
plotparams(smiles, [model.model], r"$C_nO$")

## ONNX save


In [None]:
from gnnepcsaft.demo.utils import save_exported_program

example_input = es[0].x, es[0].edge_index, es[0].edge_attr, es[0].batch  # type: ignore
exp_msigmae = save_exported_program(model.model, example_input, "../test_msigmae.onnx")
exp_assoc = save_exported_program(
    model_assoc.model, example_input, "../test_assoc.onnx"
)