In [1]:
%load_ext ipycache

  from IPython.utils.traitlets import Unicode


This document demonstrates the making, training, saving, loading, and usage of a sklearn-compliant CGCNN model.

In [2]:
import os
import sys
import numpy as np
import cgcnn

## Load the dataset as mongo docs

In [3]:
import random
import pickle
import tqdm
import multiprocess as mp

#Load a selection of documents
docs = pickle.load(open('/global/homes/z/zulissi/CO_docs.pkl','rb'))

random.seed(42)
random.shuffle(docs)

## Currently we add connectivity change as another metric of reconstruction

In [4]:
import mongo
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.analysis.structure_analyzer import VoronoiConnectivity


In [5]:
%%cache CO_docs_connectivity.pkl docs

def doc_to_connectivity_array(doc):
    #pymatgen-style connectivity discarding atoms w/ tags=1 (adsorbates)
    
    #Remove the adsorbate
    atoms = mongo.make_atoms_from_doc(doc)
    atoms = atoms[atoms.get_tags()==0]
    
    #turn to crystal, get the connectivity matrix
    crystal = AseAtomsAdaptor.get_structure(atoms)
    VC = VoronoiConnectivity(crystal)

    #Find the max connection to each other atom (regardless of which image)
    connectivity_array = np.max(VC.connectivity_array,2)
  
    return connectivity_array

def max_connectivity_change(doc):
    
    #Get the connectivity of the initial and final image
    array_final = doc_to_connectivity_array(doc)
    array_initial = doc_to_connectivity_array(doc['initial_configuration'])

    #Return the maximum change in the connectivity array 
    return np.max(np.abs(array_final-array_initial))

#Add the connectivity change score to the documents
with mp.Pool(16) as pool:
    scores = list(tqdm.tqdm(pool.imap(max_connectivity_change,docs,chunksize=40)))
    
for doc,score in zip(docs,scores):
    doc['movement_data']['max_connectivity_change']=score

[Saved variables 'docs' to file '/global/u2/z/zulissi/software/cgcnn_sklearn/CO_docs_connectivity.pkl'.]


20833it [18:10, 19.11it/s]


## Get the size of the features from the data transformer, to be used in setting up the net model

In [6]:
%%cache SDT_list_distance_relaxed.pkl SDT_list_distance_relaxed

from torch.utils.data import Dataset, DataLoader
import mongo
from cgcnn.data import StructureData, ListDataset, StructureDataTransformer
import numpy as np
import tqdm
from sklearn.preprocessing import StandardScaler


SDT = StructureDataTransformer(atom_init_loc='atom_init.json',
                              max_num_nbr=12,
                              step=0.2,
                              radius=1,
                              use_tag=False,
                              use_fixed_info=False,
                              use_distance=True)


import multiprocess as mp
from sklearn.model_selection import ShuffleSplit

SDT_out = SDT.transform(docs)

with mp.Pool(16) as pool:
    SDT_list_distance_relaxed = list(tqdm.tqdm(pool.imap(lambda x: SDT_out[x],range(len(SDT_out)),chunksize=40),total=len(SDT_out)))
      

[Saved variables 'SDT_list_distance_relaxed' to file '/global/u2/z/zulissi/software/cgcnn_sklearn/SDT_list_distance_relaxed.pkl'.]


100%|##########| 20833/20833 [33:47<00:00, 10.28it/s]


In [7]:
%%cache SDT_list_distance_unrelaxed.pkl SDT_list_distance_unrelaxed

from torch.utils.data import Dataset, DataLoader
import mongo
from cgcnn.data import StructureData, ListDataset, StructureDataTransformer
import numpy as np
import tqdm
from sklearn.preprocessing import StandardScaler

SDT = StructureDataTransformer(atom_init_loc='atom_init.json',
                              max_num_nbr=12,
                              step=0.2,
                              radius=1,
                              use_tag=True,
                              use_fixed_info=False,
                              use_distance=True)


import multiprocess as mp
from sklearn.model_selection import ShuffleSplit

SDT_out = SDT.transform(docs)

with mp.Pool(16) as pool:
    SDT_list_distance_unrelaxed = list(tqdm.tqdm(pool.imap(lambda x: SDT_out[x],range(len(SDT_out)),chunksize=40),total=len(SDT_out)))


[Saved variables 'SDT_list_distance_unrelaxed' to file '/global/u2/z/zulissi/software/cgcnn_sklearn/SDT_list_distance_unrelaxed.pkl'.]


100%|##########| 20833/20833 [35:59<00:00,  9.65it/s]
