#### import all packages needed

In [4]:
import matplotlib.pyplot as plt
import multiprocess as mp
import numpy as np
import pandas as pd
import pickle
import random
import tqdm
import torch
import skorch.callbacks.base

import os
import sys
sys.path.insert(0, 'adamwr') # you will need to have adamW optimizer cloned locally
sys.path.insert(0, 'cgcnn/')
import cgcnn
import mongo

from cgcnn.data import collate_pool, MergeDataset, StructureDataTransformer
from cgcnn.model import CrystalGraphConvNet
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import ShuffleSplit, train_test_split 
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skorch.callbacks import Checkpoint, LoadInitState 
from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler
from skorch.dataset import CVSplit
from skorch import NeuralNetRegressor
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD
from cosine_scheduler import CosineLRWithRestarts
from adamw import AdamW

#Select which GPU to use if necessary
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


#### Load the cleavage energy docs and convert the structures into graph objects

In [5]:
docs = pickle.load(open('../cleavage_energy_dataset/intermetallics_cleavage_energy_data.pkl' ,'rb'))
random.seed(123)
random.shuffle(docs)

for doc in docs:
    doc["atoms"] = doc['thinnest_structure']['atoms']
    doc["results"] = doc['thinnest_structure']['results']
    doc["initial_configuration"] = doc['thinnest_structure']['initial_configuration']
    del doc["thinnest_structure"]

In [None]:
SDT = StructureDataTransformer(atom_init_loc='/home/zulissi/software/cgcnn_sklearn/atom_init.json',
                              max_num_nbr=12,
                               step=0.8,
                              radius=4,
                              use_voronoi=False,
                              use_tag=False,
                              use_fixed_info=False,
                              use_distance=False,
                              train_geometry = 'initial'
                              )

SDT_out = SDT.transform(docs)
structures = SDT_out[0]

#Settings necessary to build the model (since they are size of vectors as inputs)
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

SDT_out = SDT.transform(docs)
with mp.Pool(4) as pool:
    SDT_list = list(tqdm.tqdm(pool.imap(lambda x: SDT_out[x],range(len(SDT_out)),chunksize=40),total=len(SDT_out)))

#### Prepare prediction labels

In [7]:
target_list = np.array([[int(docs.index(doc)), np.log(doc['cleavage_energy'])] for doc in docs])
target_list = pd.DataFrame(target_list, columns = ['doc_index', 'cleavage_energy'])

#### Split data into 80:20 train:test

In [None]:
SDT_training, SDT_test, target_training, target_test = train_test_split(SDT_list, target_list, test_size=0.2, random_state=42)

#### Set up checkpoints

In [None]:
cuda = torch.cuda.is_available()
if cuda:
    device = torch.device("cuda")
else:
    device='cpu'

#Make a checkpoint to save parameters every time there is a new best for validation lost
cp = Checkpoint(monitor='valid_loss_best',fn_prefix='valid_best_')

#Callback to load the checkpoint with the best validation loss at the end of training
class train_end_load_best_valid_loss(skorch.callbacks.base.Callback):
    def on_train_end(self, net, X, y):
        net.load_params('valid_best_params.pt')
        
load_best_valid_loss = train_end_load_best_valid_loss()

#### Set up the model and train the model with training data

In [None]:
#further spilt the training data into train and validate set by 8:2 ratio to avoid overfitting
train_test_splitter = ShuffleSplit(test_size=0.2, random_state=42)
LR_schedule = LRScheduler(CosineLRWithRestarts, batch_size=87, epoch_size=len(SDT_training), restart_period=10, t_mult=1.2)

class MyNet(NeuralNetRegressor):
    def get_loss(self, y_pred, y_true, **kwargs):
        y_pred = y_pred[0] if isinstance(y_pred, tuple) else y_pred  # discard the 2nd output
        return super().get_loss(y_pred, y_true, **kwargs)

## below is the sigopt best assignment
net = MyNet(
    CrystalGraphConvNet,
    module__orig_atom_fea_len = orig_atom_fea_len,
    module__nbr_fea_len = nbr_fea_len,
    batch_size=87,  
    module__classification=False,
    lr=np.exp(-6.465085550816676),     
    max_epochs=300,
    module__atom_fea_len=43,
    module__h_fea_len=114,
    module__n_conv=8,
    module__n_h=3, 
    module__use_distance=False,
    module__cutoff=100,
    optimizer=AdamW,
    optimizer__weight_decay=1e-2,
    iterator_train__pin_memory=True,
    iterator_train__num_workers=0,
    iterator_train__collate_fn = collate_pool,
    iterator_train__shuffle=True, #VERY IMPORTANT
    iterator_valid__pin_memory=True,
    iterator_valid__num_workers=0,
    iterator_valid__collate_fn = collate_pool,
    iterator_valid__shuffle=False, #This should be False, which is the default
    device=device,
   criterion=torch.nn.L1Loss,
    dataset=MergeDataset,
    train_split = CVSplit(cv=train_test_splitter),
    callbacks=[cp, load_best_valid_loss, LR_schedule]
)

net.initialize()
net.fit(SDT_training,np.array(target_training[['cleavage_energy']]))

#### Make predictions and visualize the predictions with parity plot

In [None]:
training_data = {'doc_index': list(target_training['doc_index']),
                 'type': 'train', 
                 'actual_value':np.exp(target_train['cleavage_energy']),
                 'predicted_value':np.exp(net.predict(SDT_train).reshape(-1))}

test_data = {'doc_index': list(target_test['doc_index']),
             'type': 'test',
            'actual_value':np.exp(target_test['cleavage_energy']),
            'predicted_value':np.exp(net.predict(SDT_test).reshape(-1))}


df_training = pd.DataFrame(training_data)
df_test = pd.DataFrame(test_data)

In [None]:
f, ax = plt.subplots(figsize=(8,8))
ax.scatter(df_training['actual_value'], df_training['predicted_value'], color='yellowgreen', 
           marker='o', alpha=0.5, label='train: MAE=%0.4f eV/$\AA^2$, RMSE=%0.3f eV/$\AA^2$'\
            %(mean_absolute_error(df_training['actual_value'], df_training['predicted_value']), 
              np.sqrt(mean_squared_error(df_training['actual_value'], df_training['predicted_value']))))

ax.scatter(df_test['actual_value'], df_test['predicted_value'], color='cornflowerblue', 
           marker='o', alpha=0.5, label='test: MAE=%0.4f eV/$\AA^2$, RMSE=%0.3f eV/$\AA^2$'\
            %(mean_absolute_error(df_test['actual_value'], df_test['predicted_value']), 
              np.sqrt(mean_squared_error(df_test['actual_value'], df_test['predicted_value']))))

ax.plot([min(df_training['actual_value']), max(df_training['actual_value'])-0.25], 
        [min(df_training['actual_value']), max(df_training['actual_value'])-0.25], 'k--')

# format graph
ax.tick_params(labelsize=20)
ax.set_xlabel('DFT Energy (eV/$\AA^2$)', fontsize=20)
ax.set_ylabel('CGCNN predicted Energy (eV/$\AA^2$)', fontsize=20)
ax.set_xlim(0,0.35)
ax.set_ylim(0,0.35)
#ax.set_title('Multi-element ', fontsize=14) 
ax.legend(fontsize=15, loc='upper left')

plt.show()

#### Get atomic contributions trajectories

We picked the ones with reasonably accurate prediction as an example, but you can loop through the test data index and make trajectories of atomic contriution for all test data. 

In [None]:
visual_idx = np.where(np.array((abs(df_test['actual_value'] - df_test['predicted_value']))) < 0.00005)[0]

for idx in visual_idx:
    doc_idx = int(df_test.iloc[idx]['doc_index']) 
    out, atom_fea = net.forward([SDT_list[doc_idx]])
    contributions = atom_fea.cpu().data.numpy().reshape(-1)
    atoms = mongo.make_atoms_from_doc(docs[doc_idx])
    atoms.set_initial_charges(np.exp(contributions))
    atoms.write('./Traj/docs_%d.traj'%(doc_idx))