A routine to test OBIWAN on the COMP6v2 benchmark suite.

In [None]:
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import json

In [None]:
from data import datasetting
from architectures.net_utils import getModel

In [None]:
checkpoint_path = "results/models/obiwan_ani1Uani2_FH_VL_2.404"
model_name = "obiwan"

# COMP6v2 visualization

In [None]:
def cardinalityAndMaxSizeMeter(dataset):

    max_size = 0
    cardinality = 0

    for (coords, species), energy in iter(dataset):
        size = species.shape[0]
        if size > max_size:
            max_size = size
        cardinality += 1

    return cardinality, max_size

In [None]:
ds_path_list = glob.glob("data/datasets/comp6v2/*.h5")

ds_generator_list = [datasetting.COMP6v2Yielder(ds_path) for ds_path in ds_path_list]

In [None]:
ds_properties_list = [cardinalityAndMaxSizeMeter(ds) for ds in ds_generator_list]

In [None]:
ds_properties_dict = {
                        "-".join(ds_path.split("/")[-1].split(".")[0].split("-")[0:2]):
                            {"cardinality":cardinality, "max_size":max_size}
                        for ds_path, (cardinality, max_size) in zip(ds_path_list, ds_properties_list)
                        }

In [None]:
df = pd.DataFrame(ds_properties_dict)
df

In [None]:
max_molecule_size = max([ds_properties_dict[ds_name]['max_size'] for ds_name in ds_properties_dict.keys()])
max_molecule_size

# The tests

In [None]:
model = getModel(model_name=model_name, max_molecule_size=max_molecule_size)
model.loadWeights(checkpoint_path)

In [None]:
errors_dict = {}

for ds_path in ds_path_list:

    # kcal/mol RMSE
    rmse = datasetting.OneTestOnCOMP6v2(ds_path=ds_path, model=model, print_error=True)

    # Store the error
    ds_name = "-".join(ds_path.split("/")[-1].split(".")[0].split("-")[0:2])
    errors_dict[ds_name] = rmse

In [None]:
datasetting.addRowToCSV(df_path="results/comp6v2.csv", errors_dict=errors_dict, checkpoint_name=checkpoint_path.split("/")[-1])

# Visualise the results

In [None]:
benchmark_df = pd.read_csv("results/comp6v2.csv")

In [None]:
benchmark_df['mean'] = benchmark_df.drop(['model', 'ANI-BenchMD'], axis=1).mean(axis=1)

In [None]:
final_df = benchmark_df.sort_values(by=['mean']).head(2)
final_df

In [None]:
final_df = final_df.set_index('model').drop(['ANI-BenchMD'], axis=1)

In [None]:
anakin_mean = final_df.loc['ANI-2x[0]']['mean']
obiwan_mean = final_df.loc['obiroi_sepnorm_ani1Uani2_FH_VL_2.404']['mean']

anakin_legend = f'anakin (mean = {anakin_mean:.2f})'
obiwan_legend = f'obiwan (mean = {obiwan_mean:.2f})'

In [None]:
final_df = final_df.drop(['mean'], axis=1)

In [None]:
anakin_bars = final_df.loc['ANI-2x[0]'].to_numpy()
obiwan_bars = final_df.loc['obiroi_sepnorm_ani1Uani2_FH_VL_2.404'].to_numpy()

In [None]:
x_ticks = final_df.columns

In [None]:
plt.rc('font', size=20)

x_axis = np.arange(len(x_ticks))

plt.bar(x=x_axis-0.2, height=anakin_bars, width=0.4, label=anakin_legend, color='red', zorder=2)
plt.bar(x=x_axis+0.2, height=obiwan_bars, width=0.4, label=obiwan_legend, color='blue', zorder=2)

# display the error on top of the bars, rounded to 1 decimal places
for i, v in enumerate(anakin_bars):
    plt.text(i-0.2, v + 0.05, f'{v:.1f}', ha='center', va='bottom', fontsize=12)

for i, v in enumerate(obiwan_bars):
    plt.text(i+0.2, v + 0.05, f'{v:.1f}', ha='center', va='bottom', fontsize=12)

plt.xticks(x_axis, x_ticks, rotation=60, ha='right', rotation_mode='anchor')
plt.xlabel('COMP6v2 subset')
plt.ylabel('RMSE [kcal/mol]')

plt.title('Energy errors on the COMP6v2 test set')

plt.legend()

plt.grid(color='0.9', zorder=1)

fig = plt.gcf()
fig.set_size_inches(20, 8)

In [None]:
with open("results/logs/comp6_final_test/obiwan_ani.json", "r") as f:
    obiwan_on_anis = json.load(f)
obiwan_on_anis = obiwan_on_anis['obiwan_ani']

In [None]:
energy_rmse = obiwan_on_anis['energyRMSE']
val_energy_rmse = obiwan_on_anis['val_energyRMSE']

forces_rmse = obiwan_on_anis['forcesRMSE']
val_forces_rmse = obiwan_on_anis['val_forcesRMSE']

In [None]:
# reset all matplotlib parameters to their default values
plt.rcdefaults()

plt.rc('font', size=14)

plt.plot(energy_rmse, color='cyan', label='training')
plt.plot(val_energy_rmse, color='blue', label='validation')

plt.legend()
plt.ylim(0., 10.)
plt.xlabel('Epoch')
plt.ylabel('RMSE [kcal/mol]')
plt.title('OBIWAN learning curves for energy prediction on the complete ANAKIN dataset', fontsize=13)

plt.axhline(y=1, color='black', linestyle='-.')
plt.text(0, 1.1, 'chemical accuracy', color = 'black')

plt.yticks(range(11))
plt.grid(color='0.9')

fig = plt.gcf()
fig.set_size_inches(10, 8)

plt.savefig("results/ani1U2_energy_error.png", dpi=300)