This code is tested with the following packages:
- Python 3.8.13
- numpy 1.22.4
- pandas 1.4.2
- pymatgen 2022.5.26

In [1]:
from generate_graph import generate_graph
import numpy as np

Simple Usage:
- `cif_str`: path to the chemical structure file
- `graph_dict`: a dictionary that contain all information about the graph

In [2]:
cif_str = "example_input/DB0-m24_o19_o19_sra_repeat.cif"
graph_dict = generate_graph(cif_str)
graph_dict.keys()

dict_keys(['node_label', 'node_class', 'node_target', 'node_simple_feature', 'node_radial_feature', 'edges'])

Demo: Get the number of nodes in the graph

In [3]:
len(graph_dict["node_label"])

16

`node_label` are identifiers of each node

In [4]:
graph_dict["node_label"]

array(['V1', 'V2', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'O1', 'O2', 'O3',
       'O4', 'O5', 'O6', 'F1', 'F2'], dtype='<U2')

`node_class` denotes the chemical element of each atom/node

In [5]:
graph_dict["node_class"]

array(['V', 'V', 'C', 'C', 'C', 'C', 'C', 'C', 'O', 'O', 'O', 'O', 'O',
       'O', 'F', 'F'], dtype='<U1')

`node_class` denotes the chemical charge of each atom, this is used as training targets for the ML model

In [6]:
graph_dict["node_target"]

array([ 1.305764,  1.296616, -0.191794,  0.07363 ,  0.080816, -0.193092,
        0.798576,  0.763946, -0.492667, -0.521448, -0.426144, -0.548471,
       -0.516446, -0.543787, -0.44461 , -0.440887])

`node_simple_feature` are feature vectors of each atom, this is simple because they are obtained from existing tables based on the node's class (chemical element of the atom)

In [7]:
graph_dict["node_simple_feature"].shape

(16, 8)

In [8]:
graph_dict["node_simple_feature"][0]

array([2.300000e+01, 2.520000e+02, 5.094150e+01, 1.530000e+02,
       8.700000e+01, 1.273344e-01, 6.746187e+00, 2.070000e+02])

`node_radial_feature` are feature vectors of each atom, this requires some computation time. Can be used by itself or combined with `node_simple_feature`. This is also the training inputs for the CatBoost models.

In [9]:
graph_dict["node_radial_feature"].shape

(16, 168)

In [10]:
graph_dict["node_radial_feature"][0]

array([2.31489731e-11, 2.53536730e-04, 1.70688043e+00, 1.36777656e+01,
       4.21455441e+00, 5.73082764e-03, 1.29592598e-06, 1.09087646e-01,
       5.11141972e+00, 6.91124952e+00, 4.58970014e+00, 1.01094415e+01,
       1.71522358e+01, 1.68032838e+01, 6.79405139e+00, 1.06229668e+01,
       2.10153894e+01, 8.03861663e+00, 1.28016816e+01, 2.24650971e+01,
       2.30036737e+00, 4.65354155e+00, 3.76653856e+01, 1.37658070e+01,
       1.26709530e+01, 1.40612295e+01, 2.96946097e+01, 2.54025447e+01,
       1.91638908e+02, 3.72044665e+02, 7.08036042e+02, 9.92644131e+02,
       1.06992173e+03, 1.15387163e+03, 8.83906646e+02, 1.07107641e+03,
       2.23400623e+03, 1.07052286e+03, 8.78935185e+02, 1.16459558e+03,
       1.06511322e+03, 9.92446379e+02, 7.08116448e+02, 3.73286028e+02,
       2.57583371e+02, 3.45372733e+02, 4.20869380e+00, 1.74149615e+01,
       5.10230979e+00, 3.15665701e+00, 1.03829490e+01, 1.70372387e+01,
       1.76578250e+01, 1.06971801e+01, 8.86933775e+00, 6.16107148e+00,
      

`edges` are edges that connect nodes in the graph. Each pair of indices is in the same order of all arrays above.

In [11]:
graph_dict["edges"]

array([[ 0, 14],
       [ 0, 14],
       [ 0, 11],
       [ 0, 10],
       [ 0,  9],
       [ 0,  8],
       [ 1, 15],
       [ 1, 15],
       [ 1,  8],
       [ 1, 13],
       [ 1, 10],
       [ 1, 12],
       [ 2,  3],
       [ 2,  6],
       [ 3,  4],
       [ 4,  5],
       [ 5,  7],
       [ 6, 13],
       [ 6, 11],
       [ 7, 12],
       [ 7,  9]])

Saving the graph in one numpy file

In [12]:
np.savez("example_output/DB0-m24_o19_o19_sra_repeat.npz", **graph_dict)

Loading the saved graph 

In [13]:
graph_npz = np.load("example_output/DB0-m24_o19_o19_sra_repeat.npz")
graph_npz.files

['node_label',
 'node_class',
 'node_target',
 'node_simple_feature',
 'node_radial_feature',
 'edges']

Information can be accessed the same way as a python dictionary

In [14]:
graph_npz["node_label"]

array(['V1', 'V2', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'O1', 'O2', 'O3',
       'O4', 'O5', 'O6', 'F1', 'F2'], dtype='<U2')