In [1]:
# load all modules
import numpy as np
import pandas as pd

import pickle
import tensorflow as tf
import tensorflow_addons as tfa
import nfp
from nfp_extensions import RBFExpansion, CifPreprocessor

from pymatgen.core.structure import Structure
from tqdm import tqdm

In [2]:
# Initialize the preprocessor class.
preprocessor = CifPreprocessor(num_neighbors=12)
preprocessor.from_json('preprocessor_hybrid.json')

#load the hybrid model
model = tf.keras.models.load_model(
    'hybrid_model.hdf5',
    custom_objects={**nfp.custom_objects, **{'RBFExpansion': RBFExpansion}})

# load test set 
# demo to predict energies of five hypothetical crystals from "hypothetical_structure_energies.csv"
test = pd.read_csv('hypothetical_structure_energies.csv')
test = test.head(5)

# path to POSCARs
# unzip "relaxed_hypothetical_structures.tar.gz"
poscar_file = lambda x: 'relaxed_hypotheticals/POSCAR_{}'.format(x)
get_crystal = lambda x: Structure.from_file(poscar_file(x), primitive=True)

# construct features for test set
test_dataset = tf.data.Dataset.from_generator(
    lambda: (preprocessor.construct_feature_matrices(get_crystal(id), train=False)
             for id in tqdm(test.structure_id)),
    output_types=preprocessor.output_types,
    output_shapes=preprocessor.output_shapes)\
    .padded_batch(batch_size=32,
                  padded_shapes=preprocessor.padded_shapes(max_sites=256, max_bonds=2048),
                  padding_values=preprocessor.padding_values)

In [3]:
# predict energies 
predictions = model.predict(test_dataset)

# save predicted energies
test['predicted_energyperatom'] = predictions
test.to_csv('predicted_test.csv', index=False)

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  8.93it/s]


In [4]:
# predicted energies
test

Unnamed: 0,composition,structure_id,energyperatom,predicted_energyperatom
0,K1Zn1Bi1,KZnBi_NaBeSb_194,-2.36629,-2.363573
1,K1Zn1Bi1,KZnBi_NaBeAs_194,-2.366068,-2.36163
2,K1Zn1Bi1,KZnBi_LiBeSb_186,-2.365981,-2.362384
3,K1Zn1Bi1,KZnBi_EuPPt_164,-2.365732,-2.364271
4,K1Zn1Bi1,KZnBi_KZnSb_194_2,-2.358296,-2.358582
