In [1]:
from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
import pickle
import lmdb
from torch.utils.data import random_split
import multiprocessing as mp
import os
import pickle
from tqdm import tqdm
import glob


class LmdbBaseDataset(Dataset):

    """
    Dataset class to
    1. write Reaction networks objecs to lmdb
    2. load lmdb files
    """

    def __init__(self, config, transform=None):
        super(LmdbBaseDataset, self).__init__()

        self.config = config
        self.path = Path(self.config["src"])

        # Get metadata in case
        # self.metadata_path = self.path.parent / "metadata.npz"
        self.env = self.connect_db(self.path)

        # If "length" encoded as ascii is present, use that
        # If there are additional properties, there must be length.
        length_entry = self.env.begin().get("length".encode("ascii"))
        if length_entry is not None:
            num_entries = pickle.loads(length_entry)
        else:
            # Get the number of stores data from the number of entries
            # in the LMDB
            num_entries = self.env.stat()["entries"]

        self._keys = list(range(num_entries))
        self.num_samples = num_entries

        # Get portion of total dataset
        self.sharded = False
        if "shard" in self.config and "total_shards" in self.config:
            self.sharded = True
            self.indices = range(self.num_samples)
            # split all available indices into 'total_shards' bins
            self.shards = np.array_split(
                self.indices, self.config.get("total_shards", 1)
            )
            # limit each process to see a subset of data based off defined shard
            self.available_indices = self.shards[self.config.get("shard", 0)]
            self.num_samples = len(self.available_indices)

        # TODO
        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # if sharding, remap idx to appropriate idx of the sharded set
        if self.sharded:
            idx = self.available_indices[idx]

        #!CHECK, _keys should be less then total numbers of keys as there are more properties.
        datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii"))

        data_object = pickle.loads(datapoint_pickled)

        # TODO
        if self.transform is not None:
            data_object = self.transform(data_object)

        return data_object

    def connect_db(self, lmdb_path=None):
        env = lmdb.open(
            str(lmdb_path),
            subdir=False,
            readonly=False,
            lock=False,
            readahead=True,
            meminit=False,
            max_readers=1,
        )
        return env

    def close_db(self):
        if not self.path.is_file():
            for env in self.envs:
                env.close()
        else:
            self.env.close()

    def get_metadata(self, num_samples=100):
        pass


class LmdbMoleculeDataset(LmdbBaseDataset):
    def __init__(self, config, transform=None):
        super(LmdbMoleculeDataset, self).__init__(config=config, transform=transform)

    @property
    def charges(self):
        charges = self.env.begin().get("charges".encode("ascii"))
        return pickle.loads(charges)

    @property
    def ring_sizes(self):
        ring_sizes = self.env.begin().get("ring_sizes".encode("ascii"))
        return pickle.loads(ring_sizes)

    @property
    def elements(self):
        elements = self.env.begin().get("elements".encode("ascii"))
        return pickle.loads(elements)

    @property
    def feature_info(self):
        feature_info = self.env.begin().get("feature_info".encode("ascii"))
        return pickle.loads(feature_info)


class LmdbReactionDataset(LmdbBaseDataset):
    def __init__(self, config, transform=None):
        super(LmdbReactionDataset, self).__init__(config=config, transform=transform)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = {
    "src": "/home/santiagovargas/dev/bondnet/bondnet/scripts/helpers/lmdb_big/molcule.lmdb"
}

In [3]:
mol = LmdbMoleculeDataset(config=config)
len(mol)

282

In [4]:
mol[1].keys()

dict_keys(['molecule_index', 'molecule_graph', 'molecule_wrapper'])

In [5]:
mol.charges, mol.ring_sizes, mol.elements, mol.feature_info

({0},
 {3, 4, 5, 6, 7},
 {'C', 'H', 'N', 'O'},
 {'feature_size': {'atom': 12, 'bond': 8, 'global': 8},
  'feature_scaler_mean': {'atom': tensor([1.9802, 0.4530, 0.2090, 0.0629, 0.0533, 0.0787, 0.0090, 0.0051, 0.3138,
           0.4485, 0.0892, 0.1485]),
   'bond': tensor([1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.7535]),
   'global': tensor([11.8440, 11.7270, 92.9330,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000])},
  'feature_scaler_std': {'atom': tensor([1.1909, 0.7687, 0.4066, 0.2427, 0.2246, 0.2693, 0.0943, 0.0712, 0.4640,
           0.4973, 0.2851, 0.3556]),
   'bond': tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9680]),
   'global': tensor([ 6.5702,  7.2032, 53.8251,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000])}})

In [6]:
config = {
    "src": "/home/santiagovargas/dev/bondnet/bondnet/scripts/helpers/lmdb_big/reaction.lmdb"
}

In [7]:
reaction = LmdbReactionDataset(config=config)

In [8]:
len(reaction)

94

In [9]:
reaction[0]["reaction_graph"]
reaction[0]["reaction_feature"]

{'global': tensor([[0., 0., 0., 0., 0., 0., 0., 0.]]),
 'atom': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [10]:
reaction[
    0
].keys()  # here key not euqal to reaction index as reaction data might be fail or shuffled.

dict_keys(['reaction_index', 'reaction_graph', 'reaction_feature', 'reaction_molecule_info', 'label', 'reverse_label', 'extra_info'])

In [11]:
from bondnet.data.reaction_network import ReactionNetworkLMDB

In [12]:
rxn_ntwk = ReactionNetworkLMDB(mol, reaction)
# reaction[0]
rxn_ntwk.subselect_reactions([4])[0][0]["reaction_molecule_info"]["reactants"][
    "reactants"
]

[0, 1]

In [13]:
rxn_ntwk.subselect_reactions([4])[0][0]["reaction_molecule_info"]["reactants"]

{'atom_map': [{0: 0,
   1: 1,
   2: 2,
   3: 3,
   4: 4,
   5: 5,
   6: 6,
   7: 7,
   8: 8,
   9: 9,
   10: 10,
   11: 11,
   12: 12,
   13: 13},
  {0: 16, 1: 14, 2: 15}],
 'bond_map': [{0: 17,
   1: 13,
   2: 5,
   3: 3,
   4: 2,
   5: 11,
   6: 6,
   7: 10,
   8: 14,
   9: 1,
   10: 0,
   11: 4,
   12: 16,
   13: 18},
  {0: 12, 1: 15}],
 'init_reactants': [177, 178],
 'reactants': [0, 1]}

In [14]:
rxn_ntwk.subselect_reactions([4])[0][0]["reaction_molecule_info"]["reactants"][
    "reactants"
] = [0]

In [15]:
rxn_ntwk.subselect_reactions([4])[0][0]["reaction_molecule_info"]["reactants"][
    "reactants"
]

[0, 1]

In [16]:
# check if dictionary is mutable
test = rxn_ntwk.subselect_reactions([4])[0][0]["reaction_molecule_info"]["reactants"]
test[0] = 1

In [17]:
test[0]

1

In [18]:
from copy import deepcopy

test = rxn_ntwk.subselect_reactions([4])[0][0]
test["reaction_molecule_info"]["reactants"]["reactants"] = [0]
test[0] = 1
test_copy = deepcopy(test)
test_copy[0] = 1
print(test["reaction_molecule_info"]["reactants"]["reactants"])

[0]


In [19]:
test["reaction_molecule_info"]["reactants"]["reactants"]

[0]

In [20]:
mol[0]

{'molecule_index': 0,
 'molecule_graph': Graph(num_nodes={'atom': 15, 'bond': 14, 'global': 1},
       num_edges={('atom', 'a2a', 'atom'): 15, ('atom', 'a2b', 'bond'): 28, ('atom', 'a2g', 'global'): 15, ('bond', 'b2a', 'atom'): 28, ('bond', 'b2b', 'bond'): 14, ('bond', 'b2g', 'global'): 14, ('global', 'g2a', 'atom'): 15, ('global', 'g2b', 'bond'): 14, ('global', 'g2g', 'global'): 1},
       metagraph=[('atom', 'atom', 'a2a'), ('atom', 'bond', 'a2b'), ('atom', 'global', 'a2g'), ('bond', 'atom', 'b2a'), ('bond', 'bond', 'b2b'), ('bond', 'global', 'b2g'), ('global', 'atom', 'g2a'), ('global', 'bond', 'g2b'), ('global', 'global', 'g2g')]),
 'molecule_wrapper': Molecule Summary
 Site: C (-0.1124, 1.6954, -0.6250)
 Site: N (-0.0176, 0.2572, -0.7638)
 Site: C (-0.8723, -0.5042, -0.2525)
 Site: C (-2.1147, -0.2901, 0.5344)
 Site: O (-0.6770, -1.8984, -0.5422)
 Site: C (-0.0621, -2.6138, 0.3801)
 Site: O (0.3844, -2.2258, 1.4259)
 Site: C (0.0309, -4.0320, -0.0299)
 Site: N (0.1239, -5.1496, -0

In [21]:
from bondnet.data.dataset import BaseDataset
from bondnet.data.dataset import ReactionNetworkLMDBDataset
from bondnet.data.dataloader import DataLoaderReactionNetworkLMDB, collate_parallel_lmdb


dataset = ReactionNetworkLMDBDataset(rxn_ntwk)

In [31]:
dataloader = DataLoaderReactionNetworkLMDB(
    dataset, batch_size=4, shuffle=True, collate_fn=collate_parallel_lmdb
)

In [32]:
sample = next(iter(dataloader))
print((sample[1]["reaction"][0]["reaction_molecule_info"]["reactants"].keys()))

dict_keys(['atom_map', 'bond_map', 'init_reactants', 'reactants'])


In [33]:
sample[1]["reaction"][0]["reaction_molecule_info"]["reactants"]["reactants"] = [0]
print(sample[1]["reaction"][0]["reaction_molecule_info"]["reactants"]["reactants"])

[0]


In [34]:
sample[1]["reaction"][0]["reaction_molecule_info"]["reactants"]["reactants"]

[0]

In [35]:
from bondnet.model.training_utils import load_model_lightning
from bondnet.test_utils import get_defaults

config = get_defaults()

config = {
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.1,
        "test_size": 0.1,
        "batch_size": 4,
        "num_workers": 1,
    },
}

dataset_loc = "../../../tests/data/testdata/barrier_100.json"
config = {
    "dataset": {
        "data_dir": dataset_loc,
        "target_var": "ts",
    },
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.2,
        "test_size": 0.2,
        "batch_size": 4,
        "num_workers": 1,
    },
}
config_model = get_defaults()
# update config with model settings
for key, value in config_model["model"].items():
    config["model"][key] = value
for key, value in config_model["model"].items():
    config["model"][key] = value
from bondnet.data.datamodule import BondNetLightningDataModule

dm = BondNetLightningDataModule(config)
# feat_size, feat_name = dm.prepare_data()
# config["model"]["in_feats"] = feat_size
# config["model"]["in_feats"] = feat_size

# config = get_defaults()
config["model"]["in_feats"] = dataset.feature_info["feature_size"]

model = load_model_lightning(config["model"], load_dir="./test_lmdb/")

NB: using GatedGCNConv


In [41]:
import pytorch_lightning as pl
import wandb
import os
from bondnet.data.utils import create_rxn_graph
import dgl
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
nodes = ["atom", "bond", "global"]
for it, (batched_graph, label) in enumerate(dataloader):
    # print(feats)
    # print(label)
    # print(label)
    feats = {nt: batched_graph.nodes[nt].data["feat"] for nt in nodes}
    target = label["value"].view(-1).to(device)
    # norm_atom = label["norm_atom"]
    norm_atom = None
    # norm_bond = label["norm_bond"]
    norm_bond = None
    stdev = torch.tensor([1.0])
    # print(feats.keys())
    # if device is not None:
    # feats = {k: v.to(device) for k, v in feats.items()}
    # target = target.to(device)
    # norm_atom = norm_atom.to(device)
    # norm_bond = norm_bond.to(device)
    # stdev = stdev.to(device)

    # print(label["reaction"][0]["reaction_molecule_info"]["mappings"].keys())
    # print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_bonds_total"])
    # print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_atoms_total"])
    reactions = label["reaction"]
    for nt, ft in feats.items():
        batched_graph.nodes[nt].data.update({"ft": ft})

    graphs = dgl.unbatch(batched_graph)
    model(
        graph=batched_graph,
        feats=feats,
        reactions=reactions,
        norm_atom=norm_atom,
        norm_bond=norm_bond,
        reverse=False,
    )

    for rxn in reactions:
        reactants = [
            graphs[i] for i in rxn["reaction_molecule_info"]["reactants"]["reactants"]
        ]
        products = [
            graphs[i] for i in rxn["reaction_molecule_info"]["products"]["products"]
        ]
        # print(rxn["reaction_molecule_info"]["products"]["products"])
        # print(len(products))

        mappings = rxn["reaction_molecule_info"]["mappings"]
        has_bonds = rxn["reaction_molecule_info"]["has_bonds"]

        g, fts = create_rxn_graph(
            reactants=reactants,
            products=products,
            mappings=mappings,
            device=device,
            has_bonds=has_bonds,
            reverse=False,
            reactant_only=False,
            empty_graph_fts=None,
        )

In [42]:
project_name = "test_multi_gpu"


trainer = pl.Trainer(
    max_epochs=2,
    accelerator="gpu",
    devices=[0],
    accumulate_grad_batches=5,
    enable_progress_bar=True,
    gradient_clip_val=1.0,
    enable_checkpointing=True,
    precision=32,
)

trainer.fit(model, dataloader)



Epoch 0: 100%|██████████| 24/24 [00:01<00:00, 12.76it/s, loss=2.1, v_num=0, train_loss=1.910, train_r2=-1.30, train_l1=1.050, train_mse=1.950]

MisconfigurationException: ReduceLROnPlateau conditioned on metric val_loss which is not available. Available metrics are: ['train_loss', 'train_r2', 'train_l1', 'train_mse']. Condition can be set using `monitor` key in lr scheduler dict