In [1]:
import glob
import torch
import torch_geometric
import torch_scatter

import e3nn
from e3nn import rs, o3
from e3nn.point.data_helpers import DataPeriodicNeighbors
from e3nn.networks import GatedConvParityNetwork
from e3nn.kernel_mod import Kernel
from e3nn.point.message_passing import Convolution

import pymatgen
from pymatgen.core.structure import Structure

import time, os
import datetime
import pickle
from mendeleev import element
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import h5py
from class_evaluate_MPdata import ComprehensiveEvaluation, AtomEmbeddingAndSumLastLayer

torch.set_default_dtype(torch.float64)

device = 'cuda:0'

In [2]:
data = torch.load('models/200803-1018_len51max1000_fwin101ord3_trial_run_full_data.torch',map_location=device)

In [3]:
model = AtomEmbeddingAndSumLastLayer(data['state']['linear.weight'].shape[1], data['state']['linear.weight'].shape[0], GatedConvParityNetwork(**data['model_kwargs']))

In [4]:
model.load_state_dict(data['state'])
model.to(device)
model.eval()

AtomEmbeddingAndSumLastLayer(
  (linear): Linear(in_features=118, out_features=64, bias=True)
  (model): GatedConvParityNetwork(
    (layers): ModuleList(
      (0): ModuleList(
        (0): Convolution(
          (kernel): Kernel (64x0e -> 32x0e,32x0e,32x1o)
        )
        (1): GatedBlockParity (32x0e + 32x0e + 32x1o -> 32x0e,32x1o)
      )
      (1): ModuleList(
        (0): Convolution(
          (kernel): Kernel (32x0e,32x1o -> 32x0e,64x0e,32x1e,32x1o)
        )
        (1): GatedBlockParity (32x0e + 64x0e + 32x1e,32x1o -> 32x0e,32x1e,32x1o)
      )
      (2): Convolution(
        (kernel): Kernel (32x0e,32x1e,32x1o -> 51x0e)
      )
    )
  )
  (relu): ReLU()
)

In [5]:
with open('models/phdos_e3nn_len51max1000_fwin101ord3.pkl', 'rb') as f:
    data_dict = pickle.load(f)
phfre = data_dict['phfre']

In [6]:
with open('models/cif_unique_files.pkl', 'rb') as f: 
    ciflist_dict = pickle.load(f)

cif_name = ciflist_dict.get('cif_name')
cif_id = ciflist_dict.get('cif_id')
num_sites = ciflist_dict.get('num_sites')

cif_name_suc = [cif_name[i] for i in range(len(cif_name)) if num_sites[i] <= 13]
cif_id_suc = [cif_id[i] for i in range(len(cif_id)) if num_sites[i] <= 13]

In [7]:
len(cif_id_suc)

4348

In [16]:
T_lst = [273.15, 293.15]
h5_file = 'phdos_maxsites13_2020Aug27.h5'
skip_ids = [4319, 4334]
structures = []

In [17]:
for i in [x for x in range(0,len(cif_id_suc)) if x not in skip_ids]:
    material_id = cif_id_suc[i]
    chunk_evaluation = ComprehensiveEvaluation([cif_name_suc[i]], data['model_kwargs'], cif_path='../data/', chunk_id=material_id)
    chunk_evaluation.predict_phdos(chunk_evaluation.data,model,device=device)
    chunk_evaluation.cal_heatcap(chunk_evaluation.phdos,phfre.tolist(),T_lst,chunk_evaluation.structures)
    structures.append(chunk_evaluation.structures)
    if os.path.exists(h5_file):
        with h5py.File(h5_file, 'a') as hf:
            hf["material_id"].resize((hf["material_id"].shape[0]+np.array([material_id])[None,:].shape[0]),axis=0)
            hf["material_id"][-np.array([material_id])[None,:].shape[0]:] = np.array([material_id])[None,:]
            hf["phdos_max1"].resize((hf["phdos_max1"].shape[0]+np.array(chunk_evaluation.phdos).shape[0]),axis=0)
            hf["phdos_max1"][-np.array(chunk_evaluation.phdos).shape[0]:] = np.array(chunk_evaluation.phdos)
            hf["phdos_norm"].resize((hf["phdos_norm"].shape[0]+np.array(chunk_evaluation.phdos_norm).shape[0]),axis=0)
            hf["phdos_norm"][-np.array(chunk_evaluation.phdos_norm).shape[0]:] = np.array(chunk_evaluation.phdos_norm)
            hf["heat_cap_mol"].resize((hf["heat_cap_mol"].shape[0]+np.array(chunk_evaluation.C_v_mol).shape[0]),axis=0)
            hf["heat_cap_mol"][-np.array(chunk_evaluation.C_v_mol).shape[0]:] = np.array(chunk_evaluation.C_v_mol)
            hf["heat_cap_kg"].resize((hf["heat_cap_kg"].shape[0]+np.array(chunk_evaluation.C_v_kg).shape[0]),axis=0)
            hf["heat_cap_kg"][-np.array(chunk_evaluation.C_v_kg).shape[0]:] = np.array(chunk_evaluation.C_v_kg)
            
            print("{}   Calculating mp-{}          ".format(i, cif_id_suc[i]), end="\r", flush=True)
    else:
        with h5py.File(h5_file, 'w') as hf:
            hf.create_dataset("material_id", data=np.array([material_id])[None,:],
                              compression="gzip", compression_opts=9, chunks=True, maxshape=(None,None))
            hf.create_dataset("phdos_max1", data=np.array(chunk_evaluation.phdos),
                              compression="gzip", compression_opts=9, chunks=True, maxshape=(None,None))
            hf.create_dataset("phdos_norm", data=np.array(chunk_evaluation.phdos_norm),
                              compression="gzip", compression_opts=9, chunks=True, maxshape=(None,None))
            hf.create_dataset("heat_cap_mol", data=np.array(chunk_evaluation.C_v_mol),
                              compression="gzip", compression_opts=9, chunks=True, maxshape=(None,None))
            hf.create_dataset("heat_cap_kg", data=np.array(chunk_evaluation.C_v_kg),
                              compression="gzip", compression_opts=9, chunks=True, maxshape=(None,None))
            hf.create_dataset("phfre", data=phfre,
                              compression="gzip", compression_opts=9, chunks=True, maxshape=(None))
            print("Created new h5py data\n")


Created new h5py data

4347   Calculating mp-1205455          

In [20]:
import pickle
with open("phdos_maxsites13_2020Aug27.pickle", "wb") as output_file:
    pickle.dump(structures, output_file)

In [16]:
with h5py.File('predictions/'+h5_file, 'r') as hf:
    material_id_r = hf['material_id'][:]
    phdos_max1_r = hf['phdos_max1'][:]
    phdos_norm_r = hf['phdos_norm'][:]
    heat_cap_r = hf['heat_cap'][:]
    phfre_r = hf['phfre'][:]

In [23]:
# chunk_size = 50
# cif_name_chunks = [cif_name_suc[i:i+chunk_size] for i in range(0,len(cif_name_suc),chunk_size)]
# cif_id_chunks = [cif_id_suc[i:i+chunk_size] for i in range(0,len(cif_id_suc),chunk_size)]

# for chunk_id in range(253,len(cif_name_chunks)):
# #     print("Calculating chunk: {:3d}\n".format(chunk_id), end="\r", flush=True)
#     chunk_evaluation = ComprehensiveEvaluation(cif_name_chunks[chunk_id], data['model_kwargs'], cif_path='data/', chunk_id=chunk_id)
#     chunk_evaluation.predict_phdos(chunk_evaluation.data,model,device=device)
#     # T_lst = np.linspace(5,800,160,endpoint=True)
#     T_lst = [273.15, 293.15]
#     chunk_evaluation.cal_heatcap(chunk_evaluation.phdos,phfre.tolist(),T_lst,chunk_evaluation.structures)

#     with open('predictions/max50sites_chunk_{:03}.pkl'.format(chunk_id), 'wb') as f:
#         pickle.dump({'material_id': cif_id_chunks[chunk_id],
#                      'phdos_max1': chunk_evaluation.phdos,
#                      'phdos_norm': chunk_evaluation.phdos_norm,
#                      'T': T_lst,
#                      'heat_capacity': chunk_evaluation.C_v}, f)

In [None]:
# idx = np.arange(chunk_size)
# fig = plt.figure(figsize=(18, 10), constrained_layout=True)
# outer = gridspec.GridSpec(10, 5, wspace=0.25, hspace=2)
# for i in range(chunk_size):
#     inner = gridspec.GridSpecFromSubplotSpec(1, 2, wspace=0.5, hspace=0,
#                     subplot_spec=outer[i])

#     ax0 = plt.Subplot(fig, inner[0])
#     ax0.plot(phfre,chunk_evaluation.phdos[i],c='C0')
# #     ax0.set_xlabel('Frequency [1/cm]')
# #     ax0.set_ylabel('Phonon DOS [a.u.]')
#     ax0.set_title('mp-{}: '.format(cif_id_chunks[chunk_id][i])+str(chunk_evaluation.cif_strlist[i][1][5:]))
#     fig.add_subplot(ax0)

#     ax1 = plt.Subplot(fig, inner[1])
#     ax1.plot(T_lst,chunk_evaluation.C_v[i],c='C1')
# #     ax1.set_xlabel('T [K]')
# #     ax1.set_ylabel('Heat Capacity [J/(mol K)]')
#     ax1.set_title('Num Sites:{:3d}'.format(chunk_evaluation.structures[i].num_sites))
#     fig.add_subplot(ax1)
# fig.suptitle('CIF Chunk: {}'.format(chunk_id), horizontalalignment='center')
# # fig.show()
# fig.savefig('figs/phdos_Cv_{:03}.png'.format(chunk_id), dpi=300, bbox_inches='tight')
# plt.close(fig)