In [1]:
import pickle
import os 
import numpy as np 
import dgl
from gaspy_utils import make_atoms_from_doc
from pymatgen.io.ase import AseAtomsAdaptor
import json
from pymatgen.core.structure import Structure
from pymatgen.analysis.structure_analyzer import VoronoiConnectivity
from ase.constraints import FixAtoms
from sklearn.preprocessing import OneHotEncoder
import copy
import torch

In [2]:
with open('gaspy_docs/docs.pkl','rb') as infile:
    gasdb = pickle.load(infile)

In [3]:
adsorbates = [doc['adsorbate'] for doc in gasdb]

In [4]:
adsorbates = list(set(adsorbates))

In [5]:
print(f"We have {len(adsorbates)} unique adsorbates")

We have 5 unique adsorbates


In [6]:
from collections import defaultdict
from dgl import backend as F

In [7]:
class AtomInitializer(object):
    """
    Base class for intializing the vector representation for atoms.

    !!! Use one AtomInitializer per dataset !!!
    """
    def __init__(self, atom_types):
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {idx: atom_type for atom_type, idx in
                            self._embedding.items()}

    def state_dict(self):
        return self._embedding

    def decode(self, idx):
        if not hasattr(self, '_decodedict'):
            self._decodedict = {idx: atom_type for atom_type, idx in
                                self._embedding.items()}
        return self._decodedict[idx]


class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Parameters
    ----------

    elem_embedding_file: str
        The path to the .json file
    """
    def __init__(self, elem_embedding_file):
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        elem_embedding = {int(key): value for key, value
                          in elem_embedding.items()}
        atom_types = set(elem_embedding.keys())
        super(AtomCustomJSONInitializer, self).__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)

In [8]:
def distance_to_adsorbate_feature(atoms, VC, max_dist = 6):    
    # This function looks at an atoms object and attempts to find
    # the minimum distance from each atom to one of the adsorbate 
    # atoms (marked with tag==1)
    conn = copy.deepcopy(VC.connectivity_array)
    conn = np.max(conn,2)

    for i in range(len(conn)):
        conn[i]=conn[i]/np.max(conn[i])

    #get a binary connectivity matrix
    conn=(conn>0.3)*1
    
    #Everything is connected to itself, so add a matrix with zero on the diagonal 
    # and a large number on the off-diagonal
    ident_connection = np.eye(len(conn))
    ident_connection[ident_connection==0]=max_dist+1
    ident_connection[ident_connection==1]=0

    #For each distance, add an array of atoms that can be connected at that distance
    arrays = [ident_connection]
    for i in range(1,max_dist):
        arrays.append((np.linalg.matrix_power(conn,i)>=1)*i+(np.linalg.matrix_power(conn,i)==0)*(max_dist+1))

    #Find the minimum distance from each atom to every other atom (over possible distances)
    arrays=np.min(arrays,0)

    # Find the minimum distance from one of the adsorbate atoms to the other atoms
    min_distance_to_adsorbate = np.min(arrays[atoms.get_tags()==1],0).reshape((-1,1))
    
    #Make sure all of the one hot distance vectors are encoded to the same length. 
    # Encode, return
    min_distance_to_adsorbate[min_distance_to_adsorbate>=max_dist]=max_dist-1
    OHE = OneHotEncoder(categories=[range(max_dist)]).fit(min_distance_to_adsorbate)
    return OHE.transform(min_distance_to_adsorbate).toarray()

In [9]:
def crystal_atom_featurizer(atoms):
    """
    takes ASE.atoms object
    return num_atoms and
    atom featurizer dict with tags and fixed locations using ASE.constraints
    """
    crystal = AseAtomsAdaptor.get_structure(atoms)
    VC = VoronoiConnectivity(crystal)
    atom_feats_dict = defaultdict(list)
    num_atoms = atoms.get_global_number_of_atoms()
    atomic_numbers = atoms.get_atomic_numbers()
    tags = atoms.get_tags()
    fix_loc, = np.where([type(constraint)==FixAtoms for constraint in atoms.constraints])
    fix_atoms_indices = set(atoms.constraints[fix_loc[0]].get_indices())
    fixed_atoms = [i in fix_atoms_indices for i in range(len(atoms))]
    for i in range(num_atoms):
        atom_feats = list(ari.get_atom_fea(atomic_numbers[i])) #get init feats from json and convert to list
        atom_feats.append(tags[i])
        atom_feats.append(fixed_atoms[i])
        atom_feats_dict['n_feat'].append(F.tensor(np.array(atom_feats).astype(np.float32))) #make it into tensor float32
    atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'],dim=0)#finally all together 
    distance_to_adsorbate_feats = distance_to_adsorbate_feature(atoms,VC)
    atom_feats_dict['n_feat'] = F.cat((atom_feats_dict['n_feat'],F.tensor(distance_to_adsorbate_feats.astype(np.float32))),dim=1)# np.hstack
    return num_atoms, atom_feats_dict

In [10]:
ari = AtomCustomJSONInitializer('../atom_init.json')

In [11]:
class GaussianDistance(object):
    """
    Expands the distance by Gaussian basis.

    Unit: angstrom
    """
    def __init__(self, dmin, dmax, step, var=None):
        """
        Parameters
        ----------

        dmin: float
          Minimum interatomic distance
        dmax: float
          Maximum interatomic distance
        step: float
          Step size for the Gaussian filter
        """
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax+step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """
        Apply Gaussian distance filter to a numpy distance array

        Parameters
        ----------

        distance: np.array shape n-d array
          A distance matrix of any shape

        Returns
        -------
        expanded_distance: shape (n+1)-d array
          Expanded distance matrix with the last dimension of length
          len(self.filter)
        """
        return np.exp(-(distances[..., np.newaxis] - self.filter)**2 /
                      self.var**2)

In [12]:
gdf = GaussianDistance(0,8,0.2)

In [13]:
def crystal_bond_featurizer(atoms,atoms_init_config, train_geometry,gdf, max_neighbors):
    """
    takes ASE.atoms object of final and initial_config and returns bond 
    features upto max_neighbors using gaussian distance object based on train_geometry
    returns bond_feats_dict with e_feat and gdf_feat 
    along with src_list, dst_list and total_bonds
    """
    crystal = AseAtomsAdaptor.get_structure(atoms)
    VC = VoronoiConnectivity(crystal)
    conn = copy.deepcopy(VC.connectivity_array)
    atoms_initial_config = copy.deepcopy(atoms_init_config)
    crystal_initial_config = AseAtomsAdaptor.get_structure(atoms_initial_config)
    VC_initial_config = VoronoiConnectivity(crystal_initial_config)
    conn_initial_config = copy.deepcopy(VC_initial_config.connectivity_array)
    all_nbrs = []          
    # Loop over central atom
    for ii in range(0, conn.shape[0]):
        curnbr = []

        #Loop over neighbor atoms
        for jj in range(0, conn.shape[1]):

            #Loop over each possible PBC image for the chosen image
            for kk in range(0,conn.shape[2]):
                # Only add as a neighbor if the atom is not the currently selected center one and there is connectivity
                # to that image
                if jj is not kk and conn[ii][jj][kk] != 0:

                    #Add the neighbor strength depending on train_geometry base
                    if train_geometry =='initial':
                        curnbr.append([ii, conn_initial_config[ii][jj][kk]/np.max(conn_initial_config[ii]), jj])
                    elif train_geometry =='final':
                        curnbr.append([ii, conn[ii][jj][kk]/np.max(conn[ii]), jj])
                    elif train_geometry == 'final-adsorbate':
                        #In order for this to work, each adsorbate atom should be set to tag==1 in the atoms object
                        if (atoms.get_tags()[ii]==1 or atoms.get_tags()[jj]==1):
                            if conn[ii][jj][kk]/np.max(conn[ii])>0.3:
                                curnbr.append([ii, 1.0, jj])
                            else:
                                curnbr.append([ii, 0.0, jj])
                        else:
                            curnbr.append([ii, conn_initial_config[ii][jj][kk]/np.max(conn_initial_config[ii]), jj])

                    else:
                        curnbr.append([ii, conn[ii][jj][kk]/np.max(conn[ii]), jj])

                else:
                    curnbr.append([ii, 0.0, jj])
        all_nbrs.append(np.array(curnbr))
    all_nbrs = np.array(all_nbrs)
    total_bonds = all_nbrs.shape[1]
    all_nbrs = [sorted(nbrs, key=lambda x: x[1],reverse=True) for nbrs in all_nbrs]
    nbr_fea_idx = np.array([list(map(lambda x: x[2],
                            nbr[:max_neighbors])) for nbr in all_nbrs])
    nbr_fea = np.array([list(map(lambda x: x[1], nbr[:max_neighbors]))
                    for nbr in all_nbrs])
    gdf_nbr = gdf.expand(nbr_fea)
    bond_feats_dict = defaultdict(list)
    src_list = []
    dst_list = []
    for i in range(len(nbr_fea_idx)):
        for j in nbr_fea_idx[i]:
            if not i == j:
                src_list.extend([int(i),int(j)])
                bond_feats_dict['e_feat'].append(np.array(nbr_fea[int(i)][list(nbr_fea_idx[i]).index(j)]))
                bond_feats_dict['gdf_feat'].append(np.array(gdf_nbr[int(i)][list(nbr_fea_idx[i]).index(j)]))
                dst_list.extend([int(j),int(i)])
                bond_feats_dict['e_feat'].append(np.zeros(1))
                bond_feats_dict['gdf_feat'].append(np.zeros(gdf_nbr.shape[-1]))
    bond_feats_dict['e_feat'] = F.tensor(np.array(bond_feats_dict['e_feat']).astype(np.float32))
    bond_feats_dict['gdf_feat'] = F.tensor(np.array(bond_feats_dict['gdf_feat']).astype(np.float32))
    return src_list, dst_list, total_bonds, bond_feats_dict

In [14]:
num_atoms, atoms_feats_dict = crystal_atom_featurizer(make_atoms_from_doc(gasdb[0]))

In [15]:
src_list, dst_list, total_bonds, bond_feats_dict = crystal_bond_featurizer(make_atoms_from_doc(gasdb[0]),
                                                      make_atoms_from_doc(gasdb[0]['initial_configuration']),
                                                      'final-adsorbate',
                                                      gdf,
                                                      12)

In [16]:
def make_dgl_graph(doc, atom_featurizer, bond_featurizer,gdf,train_geometry, max_neighbors):
    g = dgl.DGLGraph()
    atoms = make_atoms_from_doc(doc)
    init_atoms = make_atoms_from_doc(doc['initial_configuration'])
    num_atoms, atoms_feats_dict = atom_featurizer(atoms)
    src_list, dst_list, total_bonds, bond_feats_dict = bond_featurizer(atoms,
                                                      init_atoms,
                                                      train_geometry,
                                                      gdf,
                                                      max_neighbors)
    g.add_nodes(num_atoms)
    g.add_edges(src_list,dst_list)
    g.ndata.update(atoms_feats_dict)
    g.edata.update(bond_feats_dict)
    g.adsorbate = doc['adsorbate']
    g.mpid = doc['mpid']
    g.miller = doc['miller']
    g.comp = list(set(atoms.get_chemical_symbols()))
    g.target = doc['energy']
    return g

In [17]:
# Targets = adsorption energies
targets = np.array([doc['energy'] for doc in gasdb]).reshape(-1, 1)

In [18]:
len(targets)

47279

In [19]:
import tqdm
import multiprocessing as mp 
from functools import partial

In [20]:
make_final_adsorbate_graphs = partial(make_dgl_graph,atom_featurizer=crystal_atom_featurizer, 
                             bond_featurizer=crystal_bond_featurizer,
                             gdf=gdf,
                             train_geometry='final-adsorbate', 
                             max_neighbors=12)

In [21]:
from tqdm.contrib.concurrent import process_map

In [22]:
test_graphs = [make_final_adsorbate_graphs(doc) for doc in gasdb[:10]]

In [26]:
r1 = r = process_map(make_final_adsorbate_graphs, gasdb[:1000],max_workers=23)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [None]:
pool = mp.Pool(processes=23)

In [None]:
graphs = pool.map(make_final_adsorbate_graphs, gasdb)

In [None]:
pool.close()
pool.join()

graphs = []
with mp.Pool(24) as pool:
    iterator = pool.map(make_final_adsorbate_graphs,gasdb)
    _graphs = list(tqdm(iterator, total=len(gasdb),
                      desc='Transforming docs in a chunk'))
    graphs.extend(_graphs)
    iterator.start()
    iterator.join()

In [None]:
print(f"Number of converted graphs  {len(graphs)}")

In [29]:
with open('/scratch/westgroup/mpnn/gasdb_dgl_graphs/init_gasdb_dgl_graphs.pkl','wb') as outfile:
    pickle.dump(r1,outfile)

In [None]:
import dgl.function as fn
from dgl.nn.pytorch.utils import Identity

In [None]:
class CG(torch.nn.Module):
    def __init__(self,
                in_feats,
                h_feats,
                out_feats,):
        super(CG, self).__init__()
        self.in_feats = in_feats
        self.h_feats = h_feats
        self.out_feats = out_feats
        self.lin = torch.nn.Linear(2*in_feats+41,h_feats)
        self.sigmoid = torch.nn.Sigmoid()
        self.softplus = torch.nn.LeakyReLU()

    def get_msg(self, edges):
        z = torch.cat([edges.src['n_feat'], edges.dst['n_feat'],edges.data['gdf_feat']], -1)
        z = self.lin(z)
        sig_z = self.sigmoid(z)
        softplus_z = self.softplus(z)
        return {'z':sig_z*softplus_z}
    def forward(self, graph):
        #graph.apply_edges(self.get_msg)
        graph.update_all(message_func=self.get_msg,
                     reduce_func=fn.sum('z', 'm'))
        return graph

In [None]:
cg =CG(100,64,64)

In [None]:
z = cg.forward(g)

In [None]:
z.ndata['m'].shape