In [1]:
import torch
from torch.utils.data import Dataset

import numpy as np
import lmdb
import pickle
from pathlib import Path
from torch.utils.data import random_split
import multiprocessing as mp
import os
import pickle5 as pickle
from tqdm import tqdm
import glob
from copy import deepcopy


import dgl
from dgl import heterograph
from dgl import DGLGraph
import pytorch_lightning as pl


from bondnet.data.reaction_network import ReactionNetworkLMDB
from bondnet.data.utils import construct_rxn_graph_empty, create_rxn_graph
from bondnet.model.training_utils import load_model_lightning
from bondnet.test_utils import get_defaults
from bondnet.data.dataset import ReactionNetworkLMDBDataset
from bondnet.data.dataloader import DataLoaderReactionNetworkLMDB, collate_parallel_lmdb

class LmdbBaseDataset(Dataset):

    """
    Dataset class to
    1. write Reaction networks objecs to lmdb
    2. load lmdb files
    """

    def __init__(self, config, transform=None):
        super(LmdbBaseDataset, self).__init__()

        self.config = config
        self.path = Path(self.config["src"])

        # Get metadata in case
        # self.metadata_path = self.path.parent / "metadata.npz"
        self.env = self.connect_db(self.path)

        # If "length" encoded as ascii is present, use that
        # If there are additional properties, there must be length.
        length_entry = self.env.begin().get("length".encode("ascii"))
        if length_entry is not None:
            num_entries = pickle.loads(length_entry)
        else:
            # Get the number of stores data from the number of entries
            # in the LMDB
            num_entries = self.env.stat()["entries"]

        self._keys = list(range(num_entries))
        self.num_samples = num_entries

        # Get portion of total dataset
        self.sharded = False
        if "shard" in self.config and "total_shards" in self.config:
            self.sharded = True
            self.indices = range(self.num_samples)
            # split all available indices into 'total_shards' bins
            self.shards = np.array_split(
                self.indices, self.config.get("total_shards", 1)
            )
            # limit each process to see a subset of data based off defined shard
            self.available_indices = self.shards[self.config.get("shard", 0)]
            self.num_samples = len(self.available_indices)

        # TODO
        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # if sharding, remap idx to appropriate idx of the sharded set
        if self.sharded:
            idx = self.available_indices[idx]

        #!CHECK, _keys should be less then total numbers of keys as there are more properties.
        datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii"))

        data_object = pickle.loads(datapoint_pickled)

        # TODO
        if self.transform is not None:
            data_object = self.transform(data_object)

        return data_object

    def connect_db(self, lmdb_path=None):
        env = lmdb.open(
            str(lmdb_path),
            subdir=False,
            readonly=False,
            lock=False,
            readahead=True,
            meminit=False,
            max_readers=1,
        )
        return env

    def close_db(self):
        if not self.path.is_file():
            for env in self.envs:
                env.close()
        else:
            self.env.close()

    def get_metadata(self, num_samples=100):
        pass

class LmdbMoleculeDataset(LmdbBaseDataset):
    def __init__(self, config, transform=None):
        super(LmdbMoleculeDataset, self).__init__(config=config, transform=transform)

    @property
    def charges(self):
        charges = self.env.begin().get("charges".encode("ascii"))
        return pickle.loads(charges)

    @property
    def ring_sizes(self):
        ring_sizes = self.env.begin().get("ring_sizes".encode("ascii"))
        return pickle.loads(ring_sizes)

    @property
    def elements(self):
        elements = self.env.begin().get("elements".encode("ascii"))
        return pickle.loads(elements)

    @property
    def feature_info(self):
        feature_info = self.env.begin().get("feature_info".encode("ascii"))
        return pickle.loads(feature_info)

class LmdbReactionDataset(LmdbBaseDataset):
    def __init__(self, config, transform=None):
        super(LmdbReactionDataset, self).__init__(config=config, transform=transform)

    @property
    def dtype(self):
        dtype = self.env.begin().get("dtype".encode("ascii"))
        return  pickle.loads(dtype)
            
    @property
    def feature_size(self):
        feature_size = self.env.begin().get("feature_size".encode("ascii"))
        return pickle.loads(feature_size)

    @property
    def feature_name(self):
        feature_name = self.env.begin().get("feature_name".encode("ascii"))
        return pickle.loads(feature_name)
    
    @property
    def mean(self):
        mean = self.env.begin().get("mean".encode("ascii"))
        return pickle.loads(mean)
    
    @property
    def std(self):
        std = self.env.begin().get("std".encode("ascii"))
        return pickle.loads(std)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = {
    "src": "/home/santiagovargas/dev/bondnet/bondnet/dataset/lmdb_dev/mol.lmdb"
}
mol = LmdbMoleculeDataset(config=config)

print(
    mol.charges, 
    mol.ring_sizes, 
    mol.elements, 
    mol.feature_info
)

config = {
    "src": "/home/santiagovargas/dev/bondnet/bondnet/dataset/lmdb_dev/reaction.lmdb"
}
reaction = LmdbReactionDataset(config=config)
print(len(reaction))

rxn_ntwk = ReactionNetworkLMDB(mol, reaction)

{0, 1, 2, -1, -2} {6} {'H', 'F', 'O', 'S', 'C'} {}
2340


In [3]:
reaction.feature_size

{'atom': 13, 'bond': 7, 'global': 8}

In [4]:

test = rxn_ntwk.subselect_reactions([4])[0][0]
test["reaction_molecule_info"]["reactants"]["reactants"] = [0]
test[0] = 1
test_copy = deepcopy(test)
test_copy[0] = 1
print(test["reaction_molecule_info"]["reactants"]["reactants"])

[0]


In [5]:
dataset = ReactionNetworkLMDBDataset(rxn_ntwk)

In [12]:
dataloader = DataLoaderReactionNetworkLMDB(
    dataset, batch_size=100, shuffle=True, collate_fn=collate_parallel_lmdb
)

In [7]:
sample = next(iter(dataloader))
print((sample[1]["reaction"][0]["reaction_molecule_info"]["reactants"].keys()))

dict_keys(['reactants', 'atom_map', 'bond_map', 'init_reactants'])


In [8]:


config = get_defaults()

config = {
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.1,
        "test_size": 0.1,
        "batch_size": 4,
        "num_workers": 1,
    },
}

dataset_loc = "../../../tests/data/testdata/barrier_100.json"
config = {
    "dataset": {
        "data_dir": dataset_loc,
        "target_var": "ts",
    },
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.2,
        "test_size": 0.2,
        "batch_size": 4,
        "num_workers": 1,
    },
}
config_model = get_defaults()
# update config with model settings
for key, value in config_model["model"].items():
    config["model"][key] = value
for key, value in config_model["model"].items():
    config["model"][key] = value
    
#from bondnet.data.datamodule import BondNetLightningDataModule
#dm = BondNetLightningDataModule(config)
# feat_size, feat_name = dm.prepare_data()
# config["model"]["in_feats"] = feat_size
# config["model"]["in_feats"] = feat_size
# config = get_defaults()
#config["model"]["in_feats"] = dataset.feature_info["feature_size"]

config["model"]["in_feats"] = reaction.feature_size
model = load_model_lightning(config["model"], load_dir="./test_lmdb/")

NB: using GatedGCNConv


In [10]:

device = "cuda" if torch.cuda.is_available() else "cpu"
nodes = ["atom", "bond", "global"]
for it, (batched_graph, label) in enumerate(dataloader):
    feats = {nt: batched_graph.nodes[nt].data["feat"] for nt in nodes}
    target = label["value"].view(-1).to(device)
    norm_atom = None
    norm_bond = None
    stdev = torch.tensor([1.0])
    # print(feats.keys())
    # if device is not None:
    # feats = {k: v.to(device) for k, v in feats.items()}
    # target = target.to(device)
    # norm_atom = norm_atom.to(device)
    # norm_bond = norm_bond.to(device)
    # stdev = stdev.to(device)

    # print(label["reaction"][0]["reaction_molecule_info"]["mappings"].keys())
    # print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_bonds_total"])
    # print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_atoms_total"])
    
    reactions = label["reaction"]
    for nt, ft in feats.items():
        batched_graph.nodes[nt].data.update({"ft": ft})

    graphs = dgl.unbatch(batched_graph)
    reaction[0]
    #print(reaction[0]["mappings"])
    model(
        graph=batched_graph,
        feats=feats,
        reactions=reactions,
        norm_atom=norm_atom,
        norm_bond=norm_bond,
        reverse=False,
    )

    for rxn in reactions:
        reactants = [
            graphs[i] for i in rxn["reaction_molecule_info"]["reactants"]["reactants"]
        ]
        products = [
            graphs[i] for i in rxn["reaction_molecule_info"]["products"]["products"]
        ]
        # print(rxn["reaction_molecule_info"]["products"]["products"])
        # print(len(products))

        
        mappings = rxn["mappings"]
        #print(mappings)
        has_bonds = rxn["has_bonds"]
        
        
        #reactant_atom_map = rxn["reaction_molecule_info"]["reactants"]["atom_map"]
        #product_atom_map = rxn["reaction_molecule_info"]["products"]["atom_map"]
        #reactant_bond_map = rxn["reaction_molecule_info"]["reactants"]["bond_map"]
        #product_bond_map = rxn["reaction_molecule_info"]["products"]["bond_map"]
        

        #mappings = {"bond_map": None, "atom_map": None}
        #has_bonds = False
        g, fts = create_rxn_graph(
            reactants=reactants,
            products=products,
            mappings=mappings,
            device=device,
            has_bonds=has_bonds,
            reverse=False,
            reactant_only=False,
            empty_graph_fts=None,
        )

KeyboardInterrupt: 

In [13]:
project_name = "test_multi_gpu"


trainer = pl.Trainer(
    max_epochs=2,
    accelerator="gpu",
    devices=[0],
    accumulate_grad_batches=5,
    enable_progress_bar=True,
    gradient_clip_val=1.0,
    enable_checkpointing=True,
    precision=32,
)

trainer.fit(model, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name            | Type              | Params
-------------------------------------------------------
0  | embedding       | UnifySize         | 672   
1  | gated_layers    | ModuleList        | 3.0 M 
2  | readout_layer   | Set2SetThenCat    | 6.3 M 
3  | fc_layers       | ModuleList        | 656 K 
4  | loss            | MeanSquaredError  | 0     
5  | train_r2        | R2Score           | 0     
6  | train_torch_l1  | MeanAbsoluteError

Epoch 0:  66%|██████▋   | 389/585 [00:38<00:19, 10.19it/s, loss=4.78, v_num=6]
Epoch 0: 100%|██████████| 24/24 [00:16<00:00,  1.44it/s, loss=3.11, v_num=7, train_loss=3.300]

MisconfigurationException: ReduceLROnPlateau conditioned on metric val_loss which is not available. Available metrics are: ['train_loss', 'train_r2', 'train_l1', 'train_mse']. Condition can be set using `monitor` key in lr scheduler dict