# Calculate moduli from elastic stiffness tensors

**Instruction**  
The aim of this notebook is to obtain moduli calculated from elastic stiffness tensors.  
Execution of the following code requires `stiffness_tensors_rev.csv`.

**Contents**
1. Read and split dataset for each method
1. Visualize scatter plot

## 1. Read and split dataset for each method

In [None]:
import pandas as pd
import numpy as np
from utility import *
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error
plt.rcParams["font.family"] = 'Arial'
plt.rcParams["font.size"] = 14

In [None]:
df = pd.read_csv('../dataset/stiffness_tensors_rev.csv')
print(df.shape)
df.head()

In [None]:
moduli = ['E_V', 'E_R', 'E_H', 'E_RH', 'K_V', 'K_R', 'K_H', 'K_RH', 'G_V', 'G_R', 'G_H', 'G_RH', 'n_H', 'A_L']

for item in moduli:
    df[item] = ''
for i in range(df.shape[0]):
    res = stiffnesstensor2modulus(np.array(df.iloc[i,3:39]).reshape(6,6))
    for j, item in enumerate(moduli):
        df.at[i, item] = res[j]
df = df.dropna()
df = df.reset_index(drop=True)

In [None]:
df_exp = df[df['Method'].str.contains('Exp')]
df_hf = df[df['Method'].str.contains('S-HF-3c')]
df_dft = df[df['Method'].str.contains('DFT')]
df_nnp = df[df['Method'].str.contains('NNP')]

## 2. Visualize scatter plot

In [None]:
def plotdata(df_exp, df_calc, prop):
    y_exp = []
    y_calc = []
    names = []
    for name in list(df_calc['Compound']):
        if name in list(df_exp['Compound']):
            y_calc.append(df_calc[df_calc['Compound']==name][prop].values[0])
            y_exp.append(df_exp[df_exp['Compound']==name][prop].values.mean())
            names.append(name)
        
    return y_exp, y_calc, names

In [None]:
y_exp, y_calc, names = plotdata(df_exp, df_nnp, 'E_RH')

In [None]:
mean_absolute_error(y_exp, y_calc)

In [None]:
def visuzalize(df_exp, df_calc, text=False, figname=None, mode='normal'):
    fig = plt.figure(figsize=(12,12))
    for i, item in enumerate(moduli):
        ax = fig.add_subplot(4, 4, i+1)
        y_exp, y_calc, names = plotdata(df_exp, df_calc, item)
        if 'NNP' in figname:
            color = 'skyblue'
        elif 'DFT' in figname:
            color = 'green'
        elif 'HF' in figname:
            color = 'orange'
        if mode == 'normal':
            ax.scatter(y_exp, y_calc, c=color, ec='k', s=80)
            ax.plot([min(y_exp), max(y_exp)], [min(y_exp), max(y_exp)], c='k', linestyle='dashed')
            # ax.plot([min(y_calc), max(y_calc)], [min(y_calc), max(y_calc)], c='k', linestyle='dashed')
            if 'E' in item or 'K' in item or 'G' in item:
                if 'RH' in item:
                    item = item[0]+'_{RH}'
                ax.set(title=f'${item}$', xlabel=f'Exp. ${item}$ (GPa)', ylabel=f'Calc. ${item}$ (GPa)')
            elif 'n_H' in item:
                ax.set(title=f'$\u03BD$', xlabel=f'Exp. $\u03BD$', ylabel=f'Calc. $\u03BD$')
            else:
                ax.set(title=f'$A$', xlabel=f'Exp. $A$', ylabel=f'Calc. $A$')

            text = f'MAE: {mean_absolute_error(y_exp, y_calc):.2f}'
            ax.text(0.1, 0.95, text, transform=ax.transAxes, horizontalalignment='left', verticalalignment='top')
            if text is True:
                for i, name in enumerate(names):
                    ax.text(y_exp[i], y_calc[i], name)
        elif mode == 'error':
            error = np.array(y_calc) - np.array(y_exp)
            ax.scatter(y_exp, error, c=color, ec='k', s=80)
            ax.plot([min(y_exp), max(y_exp)], [0, 0], c='k', linestyle='dashed')
            if 'E' in item or 'K' in item or 'G' in item:
                ax.set(title=item, xlabel=f'Exp. ${item}$ (GPa)', ylabel='Error (GPa)')
            else:
                ax.set(title=item, xlabel=f'Exp. ${item}$', ylabel='Error (GPa)')
    fig.tight_layout()
    if figname is not None:
        fig.savefig(figname, dpi=300)
    fig.show()

In [None]:
# NNP
visuzalize(df_exp, df_nnp, figname='SAVE_FIG_NAME', mode='normal')
# visuzalize(df_exp, df_nnp, figname='SAVE_FIG_NAME', mode='error')

# HF
visuzalize(df_exp, df_hf, figname='SAVE_FIG_NAME', mode='normal')
# visuzalize(df_exp, df_hf, figname='SAVE_FIG_NAME', mode='error')

# DFT
visuzalize(df_exp, df_dft, figname='SAVE_FIG_NAME', mode='normal')
# visuzalize(df_exp, df_dft, figname='SAVE_FIG_NAME', mode='error')

In [None]:
df.to_csv('./modulus.csv')

In [None]:
for col in df_exp.iloc[:,39:].columns:
    # mean model
    # mae = mean_absolute_error(np.ones(df_exp.shape[0])*df_exp[col].mean(), df_exp[col])
    
    # PFP
    mae = mean_absolute_error(df_nnp[col], df_exp[col])
    
    print(col, mae)

In [None]:
fig = plt.figure(figsize=(6,4.5))
ax = fig.add_subplot(111)
y_exp, y_calc, names = plotdata(df_exp, df_hf, 'E_RH')
ax.scatter(y_exp, y_calc, c='orange', ec='k', s=60, label='S-HF-3c')
y_exp, y_calc, names = plotdata(df_exp, df_nnp, 'E_RH')
ax.scatter(y_exp, y_calc, c='skyblue', ec='k', s=60, label='PFP')
y_exp, y_calc, names = plotdata(df_exp, df_dft, 'E_RH')
ax.scatter(y_exp, y_calc, c='green', ec='k', s=60, label='DFT')
ax.plot([0, 45], [0, 45], linestyle='dashed', c='k')
ax.legend()
ax.set(xlabel='Exp. $E_{RH}$ (GPa)', ylabel='Pred. $E_{RH}$ (GPa)',
       xlim=(0,45), ylim=(0,100))
plt.tight_layout()
plt.savefig('SAVE_FIG_NAME.png', dpi=300)

In [None]:
fig = plt.figure(figsize=(6,3))
ax = fig.add_subplot(111)

y_exp, y_calc, names = plotdata(df_exp, df_hf, 'E_RH')
error = np.array(y_calc)-np.array(y_exp)
ax.scatter(y_exp, error, c='orange', ec='k', s=60, label='S-HF-3c')

y_exp, y_calc, names = plotdata(df_exp, df_nnp, 'E_RH')
error = np.array(y_calc)-np.array(y_exp)
ax.scatter(y_exp, error, c='skyblue', ec='k', s=60, label='PFP')

y_exp, y_calc, names = plotdata(df_exp, df_dft, 'E_RH')
error = np.array(y_calc)-np.array(y_exp)
ax.scatter(y_exp, error, c='green', ec='k', s=60, label='DFT')

ax.plot([0, 45], [0, 0], linestyle='dashed', c='k')
# ax.legend()
ax.set(xlabel='Exp. $E_{RH}$ (GPa)', ylabel='Error (GPa)', xlim=(0,45))
plt.tight_layout()
plt.savefig('SAVE_FIG_NAME.png', dpi=300)