In [1]:
%load_ext ipycache

  from IPython.utils.traitlets import Unicode


In [2]:
import sys
sys.path.insert(0,'/home/zulissi/software/adamwr')

This document demonstrates the making, training, saving, loading, and usage of a sklearn-compliant CGCNN model.

In [3]:
import os
import sys
import numpy as np
import cgcnn
import pickle

#Select which GPU to use if necessary
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


## Load the dataset as mongo docs

In [4]:
docs_all_unfiltered = pickle.load(open('CO_docs_connectivity.pkl','rb'))['docs_all']
SDT_list_distance_relaxed= pickle.load(open('SDT_list_distance_relaxed.pkl','rb'))['SDT_list_distance_relaxed']
SDT_list_distance_unrelaxed= pickle.load(open('SDT_list_distance_unrelaxed.pkl','rb'))['SDT_list_distance_unrelaxed']

## Filter by energy

In [5]:
docs_all_unfiltered, SDT_list_distance_relaxed, SDT_list_distance_unrelaxed = zip(*[[doc,SDTR, SDTU] for doc,SDTR, SDTU in 
                                                                         zip(docs_all_unfiltered, SDT_list_distance_relaxed, SDT_list_distance_unrelaxed) 
                                        if -3<doc['energy']<1.0])

## Make the target list

In [6]:
import random
import pickle
from sklearn.preprocessing import StandardScaler

SS = StandardScaler()
SS.fit(np.array([doc['energy'] for doc in docs_all_unfiltered]).reshape(-1,1))
target_list = SS.transform(np.array([doc['energy'] for doc in docs_all_unfiltered]).reshape(-1,1))


## CGCNN model with skorch to make it sklearn compliant

In [7]:
from torch.optim import Adam, SGD
from sklearn.model_selection import ShuffleSplit
from skorch.callbacks import Checkpoint, LoadInitState #needs skorch 0.4.0, conda-forge version at 0.3.0 doesn't cut it
from cgcnn.data import collate_pool
from skorch import NeuralNetRegressor
from cgcnn.model import CrystalGraphConvNet
import torch
from cgcnn.data import MergeDataset
import skorch.callbacks.base


cuda = torch.cuda.is_available()
if cuda:
    device = torch.device("cuda")
else:
    device='cpu'

#Make a checkpoint to save parameters every time there is a new best for validation lost
cp = Checkpoint(monitor='valid_loss_best',fn_prefix='valid_best_')

#Callback to load the checkpoint with the best validation loss at the end of training
class train_end_load_best_valid_loss(skorch.callbacks.base.Callback):
    def on_train_end(self, net, X, y):
        net.load_params('valid_best_params.pt')
        
load_best_valid_loss = train_end_load_best_valid_loss()


In [8]:
import pandas as pd
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from skorch.dataset import CVSplit
from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler
from adamw import AdamW
from cosine_scheduler import CosineLRWithRestarts

def compare_relaxed_unrelaxed(targets, docs_all, SDT_list_relaxed, SDT_list_unrelaxed, filter_to_use, fname, test_split=0.05, valid_split=0.05):
    
    SDT_training_relaxed_train, SDT_training_relaxed_test, \
    SDT_training_unrelaxed_train, SDT_training_unrelaxed_test, \
    target_training, target_test = train_test_split([a for a,b in zip(SDT_list_distance_relaxed, filter_to_use) if b],
                                                    [a for a,b in zip(SDT_list_distance_unrelaxed, filter_to_use) if b],
                                                    target_list[filter_to_use], 
                                                    test_size=valid_split, 
                                                    random_state=42)

    #Set the size of the features from the SDT object
    targets = target_training
    SDT_list = SDT_training_relaxed_train
    structures = SDT_list[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]

    #Specify the internal train/test split
    train_test_splitter = ShuffleSplit(test_size=test_split, random_state=42)

    # warm restart scheduling from https://arxiv.org/pdf/1711.05101.pdf
    split = next(train_test_splitter.split(SDT_training_relaxed_train))
    LR_schedule = LRScheduler(CosineLRWithRestarts, batch_size=214, epoch_size=len(SDT_training_relaxed_train), restart_period=10, t_mult=1.2)

    net_relaxed = NeuralNetRegressor(
        CrystalGraphConvNet,
        module__orig_atom_fea_len = orig_atom_fea_len,
        module__nbr_fea_len = nbr_fea_len,
        batch_size=214,
        module__classification=False,
        lr=0.0056,
        max_epochs=100,
        module__atom_fea_len=46,
        module__h_fea_len=83,
        module__n_conv=8,
        module__n_h=4,
        optimizer__weight_decay=1e-5,
        optimizer=AdamW, # from https://arxiv.org/pdf/1711.05101.pdf
        iterator_train__pin_memory=True,
        iterator_train__num_workers=0,
        iterator_train__collate_fn = collate_pool,
        iterator_valid__pin_memory=True,
        iterator_valid__num_workers=0,
        iterator_valid__collate_fn = collate_pool,
        device=device,
        criterion=torch.nn.MSELoss,
        dataset=MergeDataset,
        train_split = CVSplit(cv=train_test_splitter),
        callbacks=[cp, load_best_valid_loss, LR_schedule]
    )

    #Fit the relaxed model
    net_relaxed.initialize()
    net_relaxed.fit(SDT_training_relaxed_train,target_training)

    #Get the train/test split used during the training
    train_indices, valid_indices = next(train_test_splitter.split(SDT_training_relaxed_train))

    #Use the fitted model to predict the values
    training_data = {'actual_value':SS.inverse_transform(np.array(target_training.reshape(-1))[train_indices]),
                     'predicted_value':SS.inverse_transform(net_relaxed.predict(SDT_training_relaxed_train)[train_indices].reshape(-1))}
    test_data ={'actual_value':SS.inverse_transform(np.array(target_test).reshape(-1)),
                'predicted_value':SS.inverse_transform(net_relaxed.predict(SDT_training_relaxed_test).reshape(-1))}
    validation_data = {'actual_value':SS.inverse_transform(np.array(target_training.reshape(-1))[valid_indices]),
                     'predicted_value':SS.inverse_transform(net_relaxed.predict(SDT_training_relaxed_train)[valid_indices].reshape(-1))}

    df_training = pd.DataFrame(training_data)
    df_validation = pd.DataFrame(validation_data)
    df_test = pd.DataFrame(test_data)

    #Plot the parity plot for the train/test/valid splits
    f, ax = plt.subplots(figsize=(8,8))
    ax.scatter(df_training['actual_value'], df_training['predicted_value'], color='orange', 
               marker='o', alpha=0.5, label='train\nMAE=%0.2f, RMSE=%0.2f, R$^2$=%0.2f'\
                %(mean_absolute_error(df_training['actual_value'], df_training['predicted_value']), 
                  np.sqrt(mean_squared_error(df_training['actual_value'], df_training['predicted_value'])),
                  r2_score(df_training['actual_value'], df_training['predicted_value'])))

    ax.scatter(df_validation['actual_value'], df_validation['predicted_value'], color='blue', 
               marker='o', alpha=0.5, label='valid\nMAE=%0.2f, RMSE=%0.2f, R$^2$=%0.2f'\
                %(mean_absolute_error(df_validation['actual_value'], df_validation['predicted_value']), 
                  np.sqrt(mean_squared_error(df_validation['actual_value'], df_validation['predicted_value'])),
                  r2_score(df_validation['actual_value'], df_validation['predicted_value'])))

    ax.scatter(df_test['actual_value'], df_test['predicted_value'], color='green', 
               marker='o', alpha=0.5, label='test\nMAE=%0.2f, RMSE=%0.2f, R$^2$=%0.2f'\
                %(mean_absolute_error(df_test['actual_value'], df_test['predicted_value']), 
                  np.sqrt(mean_squared_error(df_test['actual_value'], df_test['predicted_value'])),
                  r2_score(df_test['actual_value'], df_test['predicted_value'])))


    ax.plot([min(df_training['actual_value']), max(df_training['actual_value'])], 
            [min(df_training['actual_value']), max(df_training['actual_value'])], 'k--')

    # format graph
    ax.tick_params(labelsize=14)
    ax.set_xlabel('DFT E (eV)', fontsize=14)
    ax.set_ylabel('CGCNN predicted E (eV)', fontsize=14)
    ax.set_title('Relaxed', fontsize=14) 
    ax.legend(fontsize=12)
    plt.savefig('graphs/%s_parity_relaxed.pdf'%fname)
    plt.show()

    
    #Specify the feature sizes for the relaxed object
    targets = target_training
    SDT_list = SDT_training_unrelaxed_train
    structures = SDT_list[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]

    #Specify the internal train/test split strategy
    train_test_splitter = ShuffleSplit(test_size=test_split, random_state=42)

    # warm restart scheduling from https://arxiv.org/pdf/1711.05101.pdf
    split = next(train_test_splitter.split(SDT_training_unrelaxed_train))
    LR_schedule = LRScheduler(CosineLRWithRestarts, batch_size=214, epoch_size=len(SDT_training_unrelaxed_train), restart_period=10, t_mult=1.2)

    net_unrelaxed = NeuralNetRegressor(
        CrystalGraphConvNet,
        module__orig_atom_fea_len = orig_atom_fea_len,
        module__nbr_fea_len = nbr_fea_len,
        batch_size=214,
        module__classification=False,
        lr=0.0056,
        max_epochs=150, #should be 170 
        module__atom_fea_len=46,
        module__h_fea_len=83,
        module__n_conv=8,
        module__n_h=4,
        optimizer__weight_decay=1e-5,
        optimizer=AdamW, # from https://arxiv.org/pdf/1711.05101.pdf
        iterator_train__pin_memory=True,
        iterator_train__num_workers=0,
        iterator_train__collate_fn = collate_pool,
        iterator_valid__pin_memory=True,
        iterator_valid__num_workers=0,
        iterator_valid__collate_fn = collate_pool,
        device=device,
        criterion=torch.nn.L1Loss,
        dataset=MergeDataset,
        train_split = CVSplit(cv=train_test_splitter),
        callbacks=[cp, load_best_valid_loss, LR_schedule]
    )

    #Fit the unrelaxed cgcnn model
    net_unrelaxed.initialize()
    net_unrelaxed.fit(SDT_training_unrelaxed_train,target_training)

    train_indices, valid_indices = next(train_test_splitter.split(SDT_training_unrelaxed_train))

    training_data = {'actual_value':SS.inverse_transform(np.array(target_training.reshape(-1))[train_indices]),
                     'predicted_value':SS.inverse_transform(net_unrelaxed.predict(SDT_training_unrelaxed_train)[train_indices].reshape(-1))}
    test_data ={'actual_value':SS.inverse_transform(np.array(target_test).reshape(-1)),
                'predicted_value':SS.inverse_transform(net_unrelaxed.predict(SDT_training_unrelaxed_test).reshape(-1))}
    validation_data = {'actual_value':SS.inverse_transform(np.array(target_training.reshape(-1))[valid_indices]),
                     'predicted_value':SS.inverse_transform(net_unrelaxed.predict(SDT_training_unrelaxed_train)[valid_indices].reshape(-1))}

    df_training = pd.DataFrame(training_data)
    df_validation = pd.DataFrame(validation_data)
    df_test = pd.DataFrame(test_data)

    f, ax = plt.subplots(figsize=(8,8))
    ax.scatter(df_training['actual_value'], df_training['predicted_value'], color='orange', 
               marker='o', alpha=0.5, label='train\nMAE=%0.2f, RMSE=%0.2f, R$^2$=%0.2f'\
                %(mean_absolute_error(df_training['actual_value'], df_training['predicted_value']), 
                  np.sqrt(mean_squared_error(df_training['actual_value'], df_training['predicted_value'])),
                  r2_score(df_training['actual_value'], df_training['predicted_value'])))

    ax.scatter(df_validation['actual_value'], df_validation['predicted_value'], color='blue', 
               marker='o', alpha=0.5, label='valid\nMAE=%0.2f, RMSE=%0.2f, R$^2$=%0.2f'\
                %(mean_absolute_error(df_validation['actual_value'], df_validation['predicted_value']), 
                  np.sqrt(mean_squared_error(df_validation['actual_value'], df_validation['predicted_value'])),
                  r2_score(df_validation['actual_value'], df_validation['predicted_value'])))

    ax.scatter(df_test['actual_value'], df_test['predicted_value'], color='green', 
               marker='o', alpha=0.5, label='test\nMAE=%0.2f, RMSE=%0.2f, R$^2$=%0.2f'\
                %(mean_absolute_error(df_test['actual_value'], df_test['predicted_value']), 
                  np.sqrt(mean_squared_error(df_test['actual_value'], df_test['predicted_value'])),
                  r2_score(df_test['actual_value'], df_test['predicted_value'])))


    ax.plot([min(df_training['actual_value']), max(df_training['actual_value'])], 
            [min(df_training['actual_value']), max(df_training['actual_value'])], 'k--')

    # format graph
    ax.tick_params(labelsize=14)
    ax.set_xlabel('DFT E (eV)', fontsize=14)
    ax.set_ylabel('CGCNN predicted E (eV)', fontsize=14)
    ax.set_title('Unrelaxed', fontsize=14) 
    ax.legend(fontsize=12)

    plt.savefig('graphs/%s_parity_unrelaxed.pdf'%fname)
    plt.show()
    
    relaxed_energies = SS.inverse_transform(net_relaxed.predict(SDT_list_distance_relaxed).reshape((-1)))
    unrelaxed_energies = SS.inverse_transform(net_unrelaxed.predict(SDT_list_distance_unrelaxed).reshape((-1)))


    #Generate a dataframe (one per input document) with the relaxed/unrelaxed data and other metrics
    df = pd.DataFrame({'relaxed_energies':relaxed_energies.reshape((-1)),
                       'unrelaxed_energies':unrelaxed_energies.reshape((-1)),
                       'index':range(len(docs_all)),
                       'residual':np.abs(relaxed_energies-unrelaxed_energies).reshape((-1)),
                        'max_connectivity_change':[doc['movement_data']['max_connectivity_change'] for doc in docs_all],
                       'max_surface_movement':[doc['movement_data']['max_surface_movement'] for doc in docs_all],
                        'adsorbate_movement':[doc['movement_data']['max_adsorbate_movement'] for doc in docs_all],
                        'bare_slab_movement':[doc['movement_data']['max_bare_slab_movement'] for doc in docs_all],
                       'fmax':[doc['results']['fmax'] for doc in docs_all],
                       'energy':[doc['results']['energy'] for doc in docs_all],
                       'mongo_id':[doc['mongo_id'] for doc in docs_all]})
    
    return df

In [10]:
data_frame_first = compare_relaxed_unrelaxed(target_list, 
                                              docs_all_unfiltered, 
                                              SDT_list_distance_relaxed, 
                                              SDT_list_distance_unrelaxed, 
                                              [True]*len(target_list),
                                              'first_filter')

df_filter = data_frame_first['residual']<3*np.std(data_frame_first['relaxed_energies']-
                                                  data_frame_first['unrelaxed_energies'])


data_frame_second = compare_relaxed_unrelaxed(target_list, 
                                              docs_all_unfiltered, 
                                              SDT_list_distance_relaxed, 
                                              SDT_list_distance_unrelaxed, 
                                              df_filter.values,
                                              'second_filter')

df_filter = (data_frame_first['residual']<3*np.std(data_frame_first['relaxed_energies']-
                                                   data_frame_first['unrelaxed_energies']))& \
            (data_frame_second['residual']<3*np.std(data_frame_second['relaxed_energies']-
                                                    data_frame_second['unrelaxed_energies']))

    
#Fit the final model after two filter steps
data_frame_final = compare_relaxed_unrelaxed(target_list, 
                                              docs_all_unfiltered, 
                                              SDT_list_distance_relaxed, 
                                              SDT_list_distance_unrelaxed, 
                                              df_filter,
                                              'final_model',
                                             test_split=0.2,valid_split=0.1)


RuntimeError: CUDA error: out of memory