In [1]:
import os
import numpy as np
import lmdb
import pickle
from pathlib import Path

from copy import deepcopy


import dgl
from dgl import heterograph
from dgl import DGLGraph

import torch
from torch.utils.data import Dataset
import pytorch_lightning as pl

from bondnet.data.reaction_network import ReactionNetworkLMDB
from bondnet.data.utils import construct_rxn_graph_empty, create_rxn_graph
from bondnet.model.training_utils import load_model_lightning
from bondnet.test_utils import get_defaults
from bondnet.data.dataset import ReactionNetworkLMDBDataset
from bondnet.data.dataloader import DataLoaderReactionNetworkLMDB, collate_parallel_lmdb
from bondnet.data.lmdb import LmdbMoleculeDataset, LmdbReactionDataset, TransformMol

In [2]:
env = lmdb.open(
            str("/home/santiagovargas/dev/bondnet/bondnet/scripts/helpers/test_rapter/reaction.lmdb"),
            subdir=False,
            readonly=False,
            lock=False,
            readahead=True,
            meminit=False,
            max_readers=1,
    )

length_entry = env.begin().get("length".encode("ascii"))
            

In [3]:
length_entry = env.begin().get("length".encode("ascii"))
num_entries = pickle.loads(length_entry)
num_entries



with env.begin() as txn:
   myList = [ key for key, _ in txn.cursor() ]
   print(len(myList))



1197


In [4]:
config = {
    "src": "/home/santiagovargas/dev/bondnet/bondnet/scripts/helpers/test_rapter/molecule.lmdb"
}
mol = LmdbMoleculeDataset(config=config, transform=TransformMol)

config = {
    "src": "/home/santiagovargas/dev/bondnet/bondnet/scripts/helpers/test_rapter/reaction.lmdb"
}
reaction = LmdbReactionDataset(config=config)


In [5]:
rxn_ntwk = ReactionNetworkLMDB(mol, reaction)

In [6]:
rxn_ntwk.reactions.feature_size

{'atom': 20, 'bond': 8, 'global': 3}

In [7]:
rxn_ntwk.reactions.feature_name["bond"]

['metal bond',
 'ring inclusion',
 'ring size_3',
 'ring size_4',
 'ring size_5',
 'ring size_6',
 'ring size_7',
 'bond_length']

In [8]:
#print features
[print(rxn_ntwk.molecules.__getitem__(i)["molecule_graph"].ndata["feat"]["global"][0][-5:]) for i in range(10)]

tensor([-0.3089, -0.3642, -0.2198])
tensor([-1.1864, -1.1951, -1.3521])
tensor([-0.4551, -0.5027, -0.5119])
tensor([-1.0402, -1.0566, -1.0600])
tensor([-0.0163, -0.0872, -0.3484])
tensor([-0.0163, -0.0872, -0.3484])
tensor([-1.3327, -1.3336, -1.4630])
tensor([-1.1864, -1.1951, -1.0236])
tensor([-0.8939, -0.9181, -0.8785])
tensor([-0.3089, -0.2257,  0.1093])


[None, None, None, None, None, None, None, None, None, None]

In [9]:
rxn_ntwk.molecules.__getitem__(0)["molecule_graph"]

Graph(num_nodes={'atom': 9, 'bond': 8, 'global': 1},
      num_edges={('atom', 'a2a', 'atom'): 9, ('atom', 'a2b', 'bond'): 16, ('atom', 'a2g', 'global'): 9, ('bond', 'b2a', 'atom'): 16, ('bond', 'b2b', 'bond'): 8, ('bond', 'b2g', 'global'): 8, ('global', 'g2a', 'atom'): 9, ('global', 'g2b', 'bond'): 8, ('global', 'g2g', 'global'): 1},
      metagraph=[('atom', 'atom', 'a2a'), ('atom', 'bond', 'a2b'), ('atom', 'global', 'a2g'), ('bond', 'atom', 'b2a'), ('bond', 'bond', 'b2b'), ('bond', 'global', 'b2g'), ('global', 'atom', 'g2a'), ('global', 'bond', 'g2b'), ('global', 'global', 'g2g')])

In [10]:
[print(rxn_ntwk.molecules.__getitem__(i)["molecule_wrapper"].charge) for i in range(100)]

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [11]:

test = rxn_ntwk.subselect_reactions([4])[0][0]
test["reaction_molecule_info"]["reactants"]["reactants"] = [0]
test[0] = 1
test_copy = deepcopy(test)
test_copy[0] = 1
print(test["reaction_molecule_info"]["reactants"]["reactants"])

[0]


In [12]:
rxn_ntwk.reactions[0]["reaction_molecule_info"]

{'reactants': {'init_reactants': [0, 1], 'has_bonds': [True, True]},
 'products': {'init_products': [2, 3], 'has_bonds': [True, True]},
 'mappings': {'bond_map': [[{0: 0, 1: 10, 2: 3, 3: 9, 4: 7, 5: 8, 6: 11, 7: 4},
    {0: 6, 1: 1}],
   [{0: 5, 1: 3, 2: 9, 3: 7, 4: 11, 5: 8, 6: 4}, {0: 2, 1: 6, 2: 1}]],
  'atom_map': [[{0: 0, 1: 1, 2: 5, 3: 6, 4: 7, 5: 8, 6: 9, 7: 10, 8: 11},
    {0: 2, 1: 3, 2: 4}],
   [{0: 0, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, 6: 10, 7: 11},
    {0: 1, 1: 2, 2: 3, 3: 4}]],
  'total_bonds': [[0, 1],
   [2, 4],
   [1, 2],
   [5, 8],
   [8, 10],
   [0, 6],
   [2, 3],
   [6, 7],
   [8, 9],
   [5, 6],
   [1, 6],
   [8, 11]],
  'total_atoms': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
  'num_bonds_total': 12,
  'num_atoms_total': 12}}

In [13]:
rxn_ntwk.reactions[0]

{'reaction_index': 156621542715663,
 'reaction_molecule_info': {'reactants': {'init_reactants': [0, 1],
   'has_bonds': [True, True]},
  'products': {'init_products': [2, 3], 'has_bonds': [True, True]},
  'mappings': {'bond_map': [[{0: 0,
      1: 10,
      2: 3,
      3: 9,
      4: 7,
      5: 8,
      6: 11,
      7: 4},
     {0: 6, 1: 1}],
    [{0: 5, 1: 3, 2: 9, 3: 7, 4: 11, 5: 8, 6: 4}, {0: 2, 1: 6, 2: 1}]],
   'atom_map': [[{0: 0, 1: 1, 2: 5, 3: 6, 4: 7, 5: 8, 6: 9, 7: 10, 8: 11},
     {0: 2, 1: 3, 2: 4}],
    [{0: 0, 1: 5, 2: 6, 3: 7, 4: 8, 5: 9, 6: 10, 7: 11},
     {0: 1, 1: 2, 2: 3, 3: 4}]],
   'total_bonds': [[0, 1],
    [2, 4],
    [1, 2],
    [5, 8],
    [8, 10],
    [0, 6],
    [2, 3],
    [6, 7],
    [8, 9],
    [5, 6],
    [1, 6],
    [8, 11]],
   'total_atoms': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
   'num_bonds_total': 12,
   'num_atoms_total': 12}},
 'label': tensor([0.2825]),
 'reverse_label': tensor([3.3280]),
 'extra_info': [],
 'mappings': {'bond_map': [[{0: 0,

In [14]:
dataset = ReactionNetworkLMDBDataset(rxn_ntwk)

In [15]:
dataloader = DataLoaderReactionNetworkLMDB(
    dataset, batch_size=100, shuffle=True, collate_fn=collate_parallel_lmdb
)

In [16]:
sample = next(iter(dataloader))
print((sample[1]["reaction"][0]["reaction_molecule_info"]["reactants"].keys()))

dict_keys(['init_reactants', 'has_bonds', 'reactants'])


In [17]:


config = get_defaults()

config = {
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.1,
        "test_size": 0.1,
        "batch_size": 4,
        "num_workers": 1,
    },
}

dataset_loc = "../../../tests/data/testdata/barrier_100.json"
config = {
    "dataset": {
        "data_dir": dataset_loc,
        "target_var": "ts",
    },
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.2,
        "test_size": 0.2,
        "batch_size": 4,
        "num_workers": 1,
    },
}
config_model = get_defaults()
# update config with model settings
for key, value in config_model["model"].items():
    config["model"][key] = value
for key, value in config_model["model"].items():
    config["model"][key] = value
    
#from bondnet.data.datamodule import BondNetLightningDataModule
#dm = BondNetLightningDataModule(config)
# feat_size, feat_name = dm.prepare_data()
# config["model"]["in_feats"] = feat_size
# config["model"]["in_feats"] = feat_size
# config = get_defaults()
#config["model"]["in_feats"] = dataset.feature_info["feature_size"]
#reaction = dataset.reaction_network.reactions[0]
config["model"]["in_feats"] = reaction.feature_size
model = load_model_lightning(config["model"], load_dir="./test_lmdb/")

NB: using GatedGCNConv
NB: using Set2SetThenCat
:::NO INITIALIZER USED:::


In [18]:

device = "cuda" if torch.cuda.is_available() else "cpu"
nodes = ["atom", "bond", "global"]
for it, (batched_graph, label) in enumerate(dataloader):
    feats = {nt: batched_graph.nodes[nt].data["feat"] for nt in nodes}
    target = label["value"].view(-1).to(device)
    norm_atom = None
    norm_bond = None
    stdev = torch.tensor([1.0])
    # print(feats.keys())
    # if device is not None:
    # feats = {k: v.to(device) for k, v in feats.items()}
    # target = target.to(device)
    # norm_atom = norm_atom.to(device)
    # norm_bond = norm_bond.to(device)
    # stdev = stdev.to(device)

    #print(label["reaction"][0]["reaction_molecule_info"]["mappings"].keys())
    #print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_bonds_total"])
    #print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_atoms_total"])
    
    reactions = label["reaction"]
    for nt, ft in feats.items():
        batched_graph.nodes[nt].data.update({"ft": ft})

    graphs = dgl.unbatch(batched_graph)
    #reaction[0]
    #print(reaction[0]["reaction_graph"])
    #print(reaction[0]["reaction_feature"])

    #print(reactions[0]["mappings"])
    #print(reactions[0]["reaction_graph"])
    #print(reactions[0]["reaction_feature"])

    #print(reactions[0]["reaction_feature"]["global"].shape)
    #print(reactions[0]["reaction_feature"]["bond"].shape)
    #print(reactions[0]["reaction_feature"]["atom"].shape)
    #print(reactions[0]["mappings"])
    #print(reactions[0]["reaction_feature"])
    #print(reactions[0]["reaction_feature"])
    model(
        graph=batched_graph,
        feats=feats,
        reactions=reactions,
        norm_atom=norm_atom,
        norm_bond=norm_bond,
        reverse=False,
    )
""" 
    for rxn in reactions:
        reactants = [
            graphs[i] for i in rxn["reaction_molecule_info"]["reactants"]["reactants"]
        ]
        products = [
            graphs[i] for i in rxn["reaction_molecule_info"]["products"]["products"]
        ]
        # print(rxn["reaction_molecule_info"]["products"]["products"])
        # print(len(products))

        
        mappings = rxn["mappings"]
        #print(mappings)
        #has_bonds = rxn["has_bonds"]
        
        #reactant_atom_map = rxn["reaction_molecule_info"]["reactants"]["atom_map"]
        #product_atom_map = rxn["reaction_molecule_info"]["products"]["atom_map"]
        #reactant_bond_map = rxn["reaction_molecule_info"]["reactants"]["bond_map"]
        #product_bond_map = rxn["reaction_molecule_info"]["products"]["bond_map"]
        

        #mappings = {"bond_map": None, "atom_map": None}
        #has_bonds = False
        #print("mappings {} hasbonds {}".format(mappings["bond_map"], has_bonds))
        g, fts = create_rxn_graph(
            reactants=reactants,
            products=products,
            mappings=mappings,
            device=device,
            has_bonds=None,
            reverse=False,
            reactant_only=False,
            empty_graph_fts=None,
        )
"""   

' \n    for rxn in reactions:\n        reactants = [\n            graphs[i] for i in rxn["reaction_molecule_info"]["reactants"]["reactants"]\n        ]\n        products = [\n            graphs[i] for i in rxn["reaction_molecule_info"]["products"]["products"]\n        ]\n        # print(rxn["reaction_molecule_info"]["products"]["products"])\n        # print(len(products))\n\n        \n        mappings = rxn["mappings"]\n        #print(mappings)\n        #has_bonds = rxn["has_bonds"]\n        \n        #reactant_atom_map = rxn["reaction_molecule_info"]["reactants"]["atom_map"]\n        #product_atom_map = rxn["reaction_molecule_info"]["products"]["atom_map"]\n        #reactant_bond_map = rxn["reaction_molecule_info"]["reactants"]["bond_map"]\n        #product_bond_map = rxn["reaction_molecule_info"]["products"]["bond_map"]\n        \n\n        #mappings = {"bond_map": None, "atom_map": None}\n        #has_bonds = False\n        #print("mappings {} hasbonds {}".format(mappings["bond_map"

In [None]:
reactions[0].keys()

dict_keys(['reaction_index', 'reaction_graph', 'reaction_feature', 'reaction_molecule_info', 'label', 'reverse_label', 'extra_info', 'mappings'])

In [19]:
project_name = "test_multi_gpu"


trainer = pl.Trainer(
    max_epochs=2,
    accelerator="gpu",
    devices=[0],
    accumulate_grad_batches=5,
    enable_progress_bar=True,
    gradient_clip_val=1.0,
    enable_checkpointing=True,
    precision=32,
)

trainer.fit(model, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/santiagovargas/anaconda3/envs/bondnet_new/lib/python3.11/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name            | Type              | Params
-------------------------------------------------------
0  | embedding       | UnifySize         | 124   
1  | gated_layers    | ModuleList        | 1.6 K 
2  | readout_layer   | Set2SetThenCat    | 2.6 K 
3  | fc_layers       | ModuleList        | 3.3 K 
4  | loss            | MeanSquaredError  | 0     
5  | train_r2        | R2Score           | 0     
6  | train_torch_l1  | MeanAbsoluteError | 0     
7  | train_torch_mse | MeanSquaredError  | 0     
8  | val_r2          | R2Score           | 0     
9  | val_torch_l1    | MeanAbsoluteError | 0     
10 | val_torch_mse   | MeanSquaredError  | 0     
11 | test_r2         | R2Score           | 0     
12 | test_torch_l1   | MeanAbsoluteError | 0     
13 | test_torch_mse  | MeanSquaredError  | 0     
-------------------------------------------------------
7.6 K     Trainable params
0         Non-trainable params
7.6 K     Total params
0.030     To

Epoch 0: 100%|██████████| 12/12 [00:08<00:00,  1.34it/s, v_num=14, train_loss=1.020]

MisconfigurationException: ReduceLROnPlateau conditioned on metric val_loss which is not available. Available metrics are: ['train_loss', 'train_r2', 'train_l1', 'train_mse']. Condition can be set using `monitor` key in lr scheduler dict

In [19]:
from bondnet.model.training_utils import get_grapher
from bondnet.data.dataset import ReactionNetworkDatasetGraphs
from bondnet.data.dataloader import DataLoaderReactionNetworkParallel, collate_parallel

In [20]:

dataset_loc = "/home/santiagovargas/dev/bondnet/bondnet/dataset/rapter_new_parse/qtaim/test_rapter_filtered_species.pkl"


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

extra_keys = {"bond": ["bond_length"]}
precision = "32"

if precision == "16" or precision == "32":
    precision = int(precision)

extra_keys = {}

dataset = ReactionNetworkDatasetGraphs(
    grapher=get_grapher(extra_keys),
    file=dataset_loc,
    target="ts",
    classifier=False,
    classif_categories=3,
    filter_species=[5, 5],
    filter_outliers=False,
    filter_sparse_rxns=False,
    debug=False,
    extra_keys={"bond":["bond_length"]},
    extra_info={}
)

fg_list None
reading file from: /home/santiagovargas/dev/bondnet/bondnet/dataset/rapter_new_parse/qtaim/test_rapter_filtered_species.pkl
rxn raw len: 1189
Program finished in 9.034226115094498 seconds
.............failures.............
reactions len: 1189
valid ind len: 1189
bond break fail count: 		0
default fail count: 		0
sdf map fail count: 		0
product bond fail count: 	0
about to group and organize
number of grouped reactions: 1189
---> generating grouped reactions


grouped reactions: 100%|██████████| 1189/1189 [00:24<00:00, 48.60it/s]


--> generating labels


labeled reactions: 100%|██████████| 1189/1189 [00:00<00:00, 16818.69it/s]


features: 3418
labels: 1189
molecules: 3418
constructing graphs & features....


mol graphs: 100%|██████████| 3418/3418 [00:07<00:00, 486.48it/s]


number of graphs valid: 3418
number of graphs: 3418


In [21]:
dataset.feature_size

{'atom': 20, 'bond': 7, 'global': 3}

In [22]:
dataloader_normal = DataLoaderReactionNetworkParallel(
    dataset, batch_size=100, shuffle=True, collate_fn=collate_parallel
)

In [23]:
next(iter(dataloader_normal))

(Graph(num_nodes={'atom': 3252, 'bond': 3113, 'global': 280},
       num_edges={('atom', 'a2a', 'atom'): 3252, ('atom', 'a2b', 'bond'): 6223, ('atom', 'a2g', 'global'): 3252, ('bond', 'b2a', 'atom'): 6223, ('bond', 'b2b', 'bond'): 3113, ('bond', 'b2g', 'global'): 3113, ('global', 'g2a', 'atom'): 3252, ('global', 'g2b', 'bond'): 3113, ('global', 'g2g', 'global'): 280},
       metagraph=[('atom', 'atom', 'a2a'), ('atom', 'bond', 'a2b'), ('atom', 'global', 'a2g'), ('bond', 'atom', 'b2a'), ('bond', 'bond', 'b2b'), ('bond', 'global', 'b2g'), ('global', 'atom', 'g2a'), ('global', 'bond', 'g2b'), ('global', 'global', 'g2g')]),
 {'value': tensor([[-0.4440],
          [ 0.5022],
          [-0.1840],
          [-0.3610],
          [-0.1371],
          [-0.1253],
          [ 0.3932],
          [-0.3252],
          [ 0.2052],
          [ 1.0891],
          [ 0.5040],
          [ 0.0088],
          [ 5.6143],
          [ 0.1299],
          [ 1.4528],
          [ 0.9241],
          [-0.1249],
      

In [24]:
config = get_defaults()

config = {
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.1,
        "test_size": 0.1,
        "batch_size": 4,
        "num_workers": 1,
    },
}

dataset_loc = "../../../tests/data/testdata/barrier_100.json"
config = {
    "dataset": {
        "data_dir": dataset_loc,
        "target_var": "ts",
    },
    "model": {
        "extra_features": [],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.2,
        "test_size": 0.2,
        "batch_size": 4,
        "num_workers": 1,
    },
}
config_model = get_defaults()
# update config with model settings
for key, value in config_model["model"].items():
    config["model"][key] = value
for key, value in config_model["model"].items():
    config["model"][key] = value
    
#from bondnet.data.datamodule import BondNetLightningDataModule
#dm = BondNetLightningDataModule(config)
# feat_size, feat_name = dm.prepare_data()
# config["model"]["in_feats"] = feat_size
# config["model"]["in_feats"] = feat_size
# config = get_defaults()
#config["model"]["in_feats"] = dataset.feature_info["feature_size"]
#reaction = dataset.reaction_network.reactions[0]
config["model"]["in_feats"] = dataset.feature_size
model = load_model_lightning(config["model"], load_dir="./test_lmdb/")

NB: using GatedGCNConv
NB: using Set2SetThenCat
:::NO INITIALIZER USED:::


In [27]:
device = "cuda" if torch.cuda.is_available() else "cpu"
nodes = ["atom", "bond", "global"]
for it, (batched_graph, label) in enumerate(dataloader_normal):
    feats = {nt: batched_graph.nodes[nt].data["feat"] for nt in nodes}
    target = label["value"].view(-1).to(device)
    norm_atom = None
    norm_bond = None
    stdev = torch.tensor([1.0])
    # print(feats.keys())
    # if device is not None:
    # feats = {k: v.to(device) for k, v in feats.items()}
    # target = target.to(device)
    # norm_atom = norm_atom.to(device)
    # norm_bond = norm_bond.to(device)
    # stdev = stdev.to(device)

    #print(label["reaction"][0]["reaction_molecule_info"]["mappings"].keys())
    #print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_bonds_total"])
    #print(label["reaction"][0]["reaction_molecule_info"]["mappings"]["num_atoms_total"])
    
    reactions = label["reaction"]
    for nt, ft in feats.items():
        batched_graph.nodes[nt].data.update({"ft": ft})

    graphs = dgl.unbatch(batched_graph)
    #reaction[0]
    #print(reaction[0]["reaction_graph"])
    #print(reaction[0]["reaction_feature"])

    #print(reactions[0]["mappings"])
    #print(reactions[0]["reaction_graph"])
    #print(reactions[0]["reaction_feature"])

    #print(reactions[0]["reaction_feature"]["global"].shape)
    #print(reactions[0]["reaction_feature"]["bond"].shape)
    #print(reactions[0]["reaction_feature"]["atom"].shape)
    #print(reactions[0]["mappings"])
    #print(reactions[0]["reaction_feature"])
    #print(reactions[0]["reaction_feature"])
    model(
        graph=batched_graph,
        feats=feats,
        reactions=reactions,
        norm_atom=norm_atom,
        norm_bond=norm_bond,
        reverse=False,
    )

torch.Size([21, 10]) torch.Size([18, 10]) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18] 2 atom
torch.Size([21, 10]) torch.Size([3, 10]) [0, 1, 2] [19, 12, 20] 2 atom
torch.Size([24, 10]) torch.Size([20, 10]) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] [18, 14, 11, 6, 13, 5, 19, 20, 21, 3, 10, 1, 8, 9, 12, 16, 23, 7, 0, 4] 2 bond
torch.Size([24, 10]) torch.Size([2, 10]) [0, 1] [2, 15] 2 bond
torch.Size([21, 10]) torch.Size([18, 10]) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] 2 atom
torch.Size([21, 10]) torch.Size([3, 10]) [0, 1, 2] [18, 19, 20] 2 atom
torch.Size([23, 10]) torch.Size([19, 10]) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18] [15, 7, 8, 10, 20, 17, 18, 19, 14, 0, 1, 5, 12, 3, 22, 21, 13, 9, 16] 2 bond
torch.Size([23, 10]) torch.Size([2, 10]) [0, 1] [4, 11] 2 bond
torc

In [26]:
reactions[0]

<bondnet.data.reaction_network.ReactionInNetwork at 0x7f002bd2cc90>