In [None]:
import os

import numpy as np

import plotly.express as px
import plotly.graph_objects as go

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from minimal_basis.dataset.dataset_reaction import ReactionDataset
from minimal_basis.transforms.absolute import Absolute

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

from ase import units as ase_units
from ase.data import atomic_numbers, atomic_names

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 200

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import wandb
run = wandb.init()

In [None]:
inputs = read_inputs_yaml(os.path.join("config", "interp_sn2_model.yaml"))

train_json_filename = inputs["train_json"]
validate_json_filename = inputs["validate_json"]
kwargs_dataset = inputs["dataset_options"]
kwargs_dataset["use_minimal_basis_node_features"] = inputs[
    "use_minimal_basis_node_features"
]

train_dataset = ReactionDataset( 
    root=get_train_data_path(),
    filename=train_json_filename,
    basis_filename=inputs["basis_file"],
    **kwargs_dataset
)

validation_dataset = ReactionDataset(
    root=get_validation_data_path(),
    filename=validate_json_filename,
    basis_filename=inputs["basis_file"],
    **kwargs_dataset
)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=True)

In [None]:
all_mae_norms = []
all_mae_norms_linear_interp = []

for idx, data in enumerate(train_loader):

    interpolated_ts_coords = data.pos_interpolated_transition_state.detach().numpy()
    real_ts_coords = data.pos_transition_state.detach().numpy()
    difference_ts_coords = interpolated_ts_coords - real_ts_coords
    linear_interp_coords = ( data.pos + data.pos_final_state ) / 2
    norm_difference_ts_coords = np.linalg.norm(difference_ts_coords, axis=1)
    difference_linear_interp_coords = linear_interp_coords - real_ts_coords
    norm_difference_linear_interp_coords = np.linalg.norm(difference_linear_interp_coords, axis=1)

    # Mean absolute error
    mae = np.mean(norm_difference_ts_coords)
    all_mae_norms.append(mae)

    mae_linear_interp = np.mean(norm_difference_linear_interp_coords)
    all_mae_norms_linear_interp.append(mae_linear_interp)

    # Plot the real and interpolated TS structures
    # with two different colors on the same plot
    # fig = px.scatter_3d(
    #     x=np.concatenate((real_ts_coords[:, 0], interpolated_ts_coords[:, 0])),
    #     y=np.concatenate((real_ts_coords[:, 1], interpolated_ts_coords[:, 1])),
    #     z=np.concatenate((real_ts_coords[:, 2], interpolated_ts_coords[:, 2])),
    #     color=np.concatenate((np.zeros(len(real_ts_coords)), np.ones(len(interpolated_ts_coords)))),
    # )

    # # Set the title of the plot as the mean absolute error
    # fig.update_layout(title=f"MAE of structure prediction: {mae:.3f} Å")
    # fig.write_html(f"plots/hamiltonian_model/interpolated_ts_{idx}_mae_{mae:.3f}.html")

# Plot a histogram of the MAE
fig = px.histogram(x=all_mae_norms, nbins=20, template="simple_white")
# On the same plot, plot the histogram of the MAE for the linear interpolation
# fig.add_trace(
#     go.Histogram(
#         x=all_mae_norms_linear_interp,
#         nbinsx=20,
#         name="Linear interpolation",
#         opacity=0.75,
#     )
# )
fig.update_layout(title="MAE structure prediction (Å)")
fig.update_xaxes(title_text="MAE (Å)")
fig.update_yaxes(title_text="Frequency")
# fig.write_html("plots/hamiltonian_model/interpolated_ts_mae_histogram.html")

# Reduce the aspect ratio
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
)

fig.show()

In [None]:
# Visualize the coefficient matrix of the initial and final state
# as well as the interpolated matrix
matrices_to_visualize = []
matrices_to_visualize.append(data[0].x.detach().numpy())
matrices_to_visualize.append(data[0].x_final_state.detach().numpy())
matrices_to_visualize.append(data[0].x_transition_state.detach().numpy())
matrices_to_visualize = np.abs(matrices_to_visualize)
# Set a maximum value for the coefficient matrix
fig = px.imshow(matrices_to_visualize, template="simple_white", color_continuous_scale="RdBu",
    labels=dict(x="Basis Functions", y="Atoms", color="Coefficient"), animation_frame=0,
    range_color=[0, 1 ])
x_tick_label = 5 *["s"] + 12*["p"] + 3*["s" , "d" , "d" , "d" , "d" , "d"]
fig.update_xaxes(ticktext=x_tick_label, tickvals=np.arange(35))
tickvals = data.species.view(-1).detach().numpy().flatten()
y_tickvals_species = [atomic_names[int(tickval)] for tickval in tickvals]
fig.update_yaxes(ticktext=y_tickvals_species, tickvals=np.arange(len(tickvals)))
fig.show()