In [1]:
import wandb, argparse, torch, json
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    EarlyStopping,
    ModelCheckpoint,
)

from bondnet.data.dataset import ReactionNetworkDatasetGraphs
from bondnet.data.dataloader import DataLoaderReactionNetworkParallel
from bondnet.data.dataset import train_validation_test_split
from bondnet.utils import seed_torch
from bondnet.data.utils import find_rings
from bondnet.model.training_utils import (
    get_grapher,
    LogParameters,
    load_model_lightning,
)


import torch

seed_torch()
torch.set_float32_matmul_precision("high")  # might have to disable on older GPUs

import torch.multiprocessing

torch.multiprocessing.set_start_method("spawn")  # good solution !!!!

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
config = {
    "augment": True,
    "batch_size": 4,
    "debug": False,
    "classifier": False,
    "classif_categories": 3,
    "cat_weights": [1.0, 1.0, 1.0],
    "embedding_size": 24,
    "epochs": 100,
    "extra_features": ["bond_length"],
    "extra_info": [],
    "filter_species": [3, 5],
    "fc_activation": "ReLU",
    "fc_batch_norm": True,
    "fc_dropout": 0.2,
    "fc_hidden_size_1": 256,
    "fc_hidden_size_shape": "flat",
    "fc_num_layers": 1,
    "gated_activation": "ReLU",
    "gated_batch_norm": False,
    "gated_dropout": 0.1,
    "gated_graph_norm": False,
    "gated_hidden_size_1": 512,
    "gated_hidden_size_shape": "flat",
    "gated_num_fc_layers": 1,
    "gated_num_layers": 2,
    "gated_residual": True,
    "learning_rate": 0.003,
    "precision": 32,
    "loss": "mse",
    "num_lstm_iters": 3,
    "num_lstm_layers": 1,
    "on_gpu": True,
    "restore": False,
    "target_var": "ts",
    "target_var_transfer": "diff",
    "weight_decay": 0.0,
    "max_epochs": 100,
    "max_epochs_transfer": 100,
    "transfer": False,
    "filter_outliers": True,
}

dataset_loc = "../../../tests/data/testdata/barrier_100.json"

on_gpu = config["on_gpu"]
extra_keys = config["extra_features"]
debug = config["debug"]
precision = config["precision"]

if precision == "16" or precision == "32":
    precision = int(precision)

if on_gpu:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

extra_keys = config["extra_features"]

In [4]:
dataset = ReactionNetworkDatasetGraphs(
    grapher=get_grapher(extra_keys),
    file=dataset_loc,
    target=config["target_var"],
    classifier=config["classifier"],
    classif_categories=config["classif_categories"],
    filter_species=config["filter_species"],
    filter_outliers=config["filter_outliers"],
    filter_sparse_rxns=False,
    debug=debug,
    extra_keys=extra_keys,
    extra_info=config["extra_info"],
)

reading file from: ../../../tests/data/testdata/barrier_100.json
rxn raw len: 100
Program finished in 16.616420175880194 seconds
.............failures.............
reactions len: 89
valid ind len: 89
bond break fail count: 		0
default fail count: 		11
sdf map fail count: 		0
product bond fail count: 	0
about to group and organize
number of grouped reactions: 89
features: 214
labels: 89
molecules: 214
constructing graphs & features....
number of graphs valid: 214
number of graphs: 214


In [10]:
dataset.labels

[{'value': tensor([-0.3627]),
  'value_rev': tensor([-0.4074]),
  'id': ['10527514179105274'],
  'environment': None,
  'atom_map': [[{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}],
   [{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}]],
  'bond_map': [[{0: 0, 1: 3, 2: 1, 3: 4, 4: 2}],
   [{0: 0, 1: 1, 2: 3, 3: 4, 4: 2}]],
  'total_bonds': [[0, 1], [1, 2], [4, 5], [1, 4], [3, 4]],
  'total_atoms': [0, 1, 2, 3, 4, 5],
  'reaction_type': [],
  'extra_info': {},
  'scaler_mean': tensor(0.4038),
  'scaler_stdev': tensor(0.5013)},
 {'value': tensor([-0.5762]),
  'value_rev': tensor([-0.5762]),
  'id': ['680555887768054'],
  'environment': None,
  'atom_map': [[{0: 0,
     1: 1,
     2: 2,
     3: 3,
     4: 4,
     5: 5,
     6: 6,
     7: 7,
     8: 8,
     9: 9,
     10: 10,
     11: 11,
     12: 12}],
   [{0: 0,
     1: 1,
     2: 2,
     3: 3,
     4: 4,
     5: 5,
     6: 6,
     7: 7,
     8: 8,
     9: 9,
     10: 10,
     11: 11,
     12: 12}]],
  'bond_map': [[{0: 10,
     1: 0,
     2: 11,
     3: 8,


In [19]:
print(dataset.graphs)

None


In [15]:
# List of Molecules
dgl_graphs = []
pmg_objects = []
molecule_ind_list = []
charge_set = set()
ring_size_set = set()
element_set = set()


def clean(input):
    return "".join([i for i in input if not i.isdigit()])


def clean_op(input):
    return "".join([i for i in input if i.isdigit()])


for ind, molecule_in_rxn_network in enumerate(
    dataset.reaction_network.molecule_wrapper
):
    #    pmg mol retrieval
    pmg_objects.append(molecule_in_rxn_network.pymatgen_mol)
    #    molecule index in rxn network
    # this would just be given by the index in HiPRGen anyways
    molecule_ind_list.append(ind)

    formula = molecule_in_rxn_network.pymatgen_mol.composition.formula.split()
    elements = [clean(x) for x in formula]
    atom_num = np.sum(np.array([int(clean_op(x)) for x in formula]))
    element_set.update(elements)

    charge = molecule_in_rxn_network.pymatgen_mol.charge
    charge_set.add(charge)
    bond_list = [[i[0], i[1]] for i in molecule_in_rxn_network.mol_graph.graph.edges]
    cycles = find_rings(atom_num, bond_list, edges=False)
    ring_len_list = [len(i) for i in cycles]
    ring_size_set.update(ring_len_list)


for ind, molecule_in_rxn_network in enumerate(dataset.reaction_network.molecules):
    dgl_graphs.append(molecule_in_rxn_network)

<class 'dgl.heterograph.DGLGraph'>


In [17]:
print("Number of dgl graphs in dataset: ", len(dgl_graphs))
print("Number of pmg objects in dataset: ", len(pmg_objects))
print("Number of molecule indices in dataset: ", len(molecule_ind_list))
print("Set of charges in dataset: ", (charge_set))
print("Set of ring sizes in dataset: ", (ring_size_set))
print("Set of elements in dataset: ", (element_set))

Number of dgl graphs in dataset:  214
Number of pmg objects in dataset:  214
Number of molecule indices in dataset:  214
Set of charges in dataset:  {0}
Set of ring sizes in dataset:  {3, 4, 5, 6}
Set of elements in dataset:  {'P', 'H', 'S', 'Cl', 'N', 'O', 'C', 'F'}


In [49]:
batched_graph = dgl.batch(dgl_graphs)
feats = batched_graph.ndata["feat"]
for nt, ft in feats.items():
    batched_graph.nodes[nt].data.update({"ft": ft})
graphs = dgl.unbatch(batched_graph)

# Reactions


In [53]:
from bondnet.data.utils import create_rxn_graph


extra_info = []
reaction_molecule_info = []
label_list = []
reverse_list = []
reaction_molecule_info_list = []
empty_reaction_graphs = []
empty_reaction_fts = []

for ind, rxn in enumerate(dataset.reaction_network.reactions):
    extra_info.append(reaction_in_rxn_network.extra_info)
    label_list.append(dataset.labels[ind]["value"])
    reverse_list.append(dataset.labels[ind]["value_rev"])

    mappings = {
        "bond_map": rxn.bond_mapping,
        "atom_map": rxn.atom_mapping,
        "total_bonds": rxn.total_bonds,
        "total_atoms": rxn.total_atoms,
        "num_bonds_total": rxn.num_bonds_total,
        "num_atoms_total": rxn.num_atoms_total,
    }

    molecule_info_temp = {
        "reactants": {
            "molecule_index": rxn.reactants,
            "atom_map": rxn.atom_mapping[0],
            "bond_map": rxn.bond_mapping[0],
        },
        "products": {
            "molecule_index": rxn.products,
            "atom_map": rxn.atom_mapping[1],
            "bond_map": rxn.bond_mapping[1],
        },
    }

    reaction_molecule_info.append(molecule_info_temp)

    #### taken from create_rxn_graph in bondnet
    reactants = [graphs[i] for i in rxn.reactants]
    products = [graphs[i] for i in rxn.products]

    has_bonds = {
        "reactants": [True if len(mp) > 0 else False for mp in rxn.bond_mapping[0]],
        "products": [True if len(mp) > 0 else False for mp in rxn.bond_mapping[1]],
    }
    if len(has_bonds["reactants"]) != len(reactants) or len(
        has_bonds["products"]
    ) != len(products):
        print("unequal mapping & graph len")

    empty_graph, empty_fts = create_rxn_graph(
        reactants=reactants,
        products=products,
        mappings=mappings,
        device=None,
        has_bonds=has_bonds,
        reverse=False,
    )

    empty_reaction_graphs.append(empty_graph)
    empty_reaction_fts.append(empty_fts)

In [54]:
print("Number of reaction molecule info: ", len(reaction_molecule_info))
print("Number of extra info: ", len(extra_info))
print("Number of labels: ", len(label_list))
print("Number of reverse labels: ", len(reverse_list))
print("Number of empty reaction graphs: ", len(empty_reaction_graphs))
print("Number of empty reaction fts: ", len(empty_reaction_fts))

Number of reaction molecule info:  89
Number of extra info:  89
Number of labels:  89
Number of reverse labels:  89
Number of empty reaction graphs:  89
Number of empty reaction fts:  89
