In [None]:
import os.path as osp
import torch, jax
from torch_geometric.loader import DataLoader
from data.graphdataset import ThermoMLDataset, ramirez
from train.train import create_model

In [None]:
train_dataset = ramirez("./data/ramirez2022")
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
ra_data = {}
for graph in train_loader:
    for inchi, para in zip(graph.InChI, graph.para.view(-1, 3)):
        ra_data[inchi] = para

In [None]:

path = osp.join("data", "thermoml")
dataset = ThermoMLDataset(path)
loader = DataLoader(dataset, batch_size=1, shuffle=False)
device = torch.device("cpu")

In [None]:
from configs.default import get_config
config = get_config()

In [None]:
config.num_train_steps=1000000 
config.log_every_steps=100 
config.num_para=3 
config.checkpoint_every_steps=2000 
config.learning_rate=0.001
config.patience=10 
config.warmup_steps=500 
config.optimizer="adam" 
config.batch_size=128 
config.propagation_depth=3 
config.hidden_dim=64 
config.num_mlp_layers=2 
config.pre_layers=2 
config.post_layers=1

In [None]:
model_dtype = torch.float32

In [None]:
# Create and initialize the network.
model = create_model(config).to(device, model_dtype)

In [None]:
# Set up checkpointing of the model.
ckp_path = "./training/last_checkpoint.pth"
if osp.exists(ckp_path):
    checkpoint = torch.load(ckp_path)
    model.load_state_dict(checkpoint["model_state_dict"])


In [None]:
para = {}
for graph in loader:
    graph = graph.to(device)
    parameters = model(graph)
    parameters = parameters.tolist()[0]
    para[graph.InChI[0]] = (parameters, float("inf"))

In [None]:
len(para)

In [None]:
import pickle
with open("./data/thermoml/processed/para3.pkl", "wb") as file:
        # A new file will be created
        pickle.dump(para, file)

In [None]:
n_rho_vp = 0
n_rho = 0
n_vp = 0
ntrain = 0
size_vp = 0
size_rho = 0
for graph in loader:
    if torch.all(graph.rho == torch.zeros_like(graph.rho)):
        n_vp += 1
    elif torch.all(graph.vp == torch.zeros_like(graph.vp)):
        n_rho += 1
    else:
        n_rho_vp += 1
        if graph.InChI[0] not in ra_data:
            ntrain += 1 

In [None]:
n_rho_vp, n_rho, n_vp, ntrain