In [1]:
import os
import numpy as np
from e3nn import o3
from minimal_basis.dataset.dataset_hamiltonian import HamiltonianDataset
from minimal_basis.model.model_hamiltonian import (
    EquivariantConv,
    SimpleHamiltonianModel,
)
import warnings
import plotly.express as px
import torch
from torch_geometric.loader import DataLoader

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
inputs = read_inputs_yaml(os.path.join("input_files", "hamiltonian_model.yaml"))

train_json_filename = inputs["train_json"]
validate_json_filename = inputs["validate_json"]

train_dataset = HamiltonianDataset(
    root=get_train_data_path(),
    filename=train_json_filename,
    basis_file=inputs["basis_file"],
)
validation_dataset = HamiltonianDataset(
    root=get_validation_data_path(),
    filename=validate_json_filename,
    basis_file=inputs["basis_file"],
)

INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with data.
INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information from ./input_files/6-31G_star.json
INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with data.
INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information from ./input_files/6-31G_star.json


In [3]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False)

In [15]:
def get_all_ts_structures(dataset):
    for input_data in dataset.input_data:
        structures = input_data['structures']
        states = input_data['state']
        ts_structure = structures[states.index('transition_state')]
        yield ts_structure.cart_coords
train_ts_structures = get_all_ts_structures(train_dataset)
validation_ts_structures = get_all_ts_structures(validation_dataset)

train_ts_structures = list(train_ts_structures)
validation_ts_structures = list(validation_ts_structures)

all_mae_norms = []

for idx, data in enumerate(train_loader):

    interpolated_ts_coords = data.pos_interpolated_TS.detach().numpy()
    real_ts_coords = train_ts_structures[idx]

    difference_ts_coords = interpolated_ts_coords - real_ts_coords

    norm_difference_ts_coords = np.linalg.norm(difference_ts_coords, axis=1)
    # Mean absolute error
    mae = np.mean(norm_difference_ts_coords)
    all_mae_norms.append(mae)

    # Plot the real and interpolated TS structures
    # with two different colors on the same plot
    fig = px.scatter_3d(
        x=np.concatenate((real_ts_coords[:, 0], interpolated_ts_coords[:, 0])),
        y=np.concatenate((real_ts_coords[:, 1], interpolated_ts_coords[:, 1])),
        z=np.concatenate((real_ts_coords[:, 2], interpolated_ts_coords[:, 2])),
        color=np.concatenate((np.zeros(len(real_ts_coords)), np.ones(len(interpolated_ts_coords)))),
    )
    # Set the title of the plot as the mean absolute error
    fig.update_layout(title=f"MAE: {mae:.3f} Å")
    fig.write_html(f"plots/hamiltonian_model/interpolated_ts_{idx}_mae_{mae:.3f}.html")

# Plot a histogram of the MAE
fig = px.histogram(x=all_mae_norms, nbins=20)
fig.update_layout(title="Histogram of MAE")
fig.show()
