# Analysis of Reaction Model
This notebook performs the following analysis for the best performing model for a given configuration
- Structure interpolator for transition state parity plot
- Energy comparison parity plot
- Coefficient matrix parity plot
- Molecular orbital visualization

In [None]:
import os

from pathlib import Path

from pprint import pprint

import numpy as np
import numpy.typing as npt

import plotly.express as px
import plotly.graph_objects as go

import pandas as pd

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from minimal_basis.dataset.reaction import ReactionDataset
from minimal_basis.postprocessing.transformations import (
    OrthoCoeffMatrixToGridQuantities,
    NodeFeaturesToOrthoCoeffMatrix,
    DatapointStoredVectorToOrthogonlizationMatrix,
)

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

from ase import units as ase_units
from ase.data import atomic_numbers, atomic_names, atomic_masses, vdw_radii
from ase.data.colors import jmol_colors


import matplotlib.pyplot as plt

plt.rcParams["figure.dpi"] = 200

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)


from model_functions import construct_model_name
from analysis_utils import get_instance_grid, add_grid_to_fig
import wandb

wandb_api = wandb.Api()

## Inputs

In [None]:
__input_folder__ = Path("input")
__config_folder__ = Path("config")

dataset_name = "rudorff_lilienfeld_sn2_dataset"
basis_set_type = "full"
basis_set = "6-31g*"
grid_points_per_axis = 25
debug = False 

model_config = __config_folder__ / "rudorff_lilienfeld_model.yaml"

In [None]:
model_name = construct_model_name(
    dataset_name=dataset_name,
    debug=debug,
)
print(f"Model name: {model_name}")
runs = wandb_api.runs(f"sudarshanvj/{model_name}")
df = pd.DataFrame()
for run in runs:
    if run.config.get("basis_set_type") == basis_set_type and run.config.get("basis_set") == basis_set:
        data_to_store = {}
        data_to_store.update(run.config)
        train_loss = run.summary.get("train_loss", None)
        val_loss = run.summary.get("val_loss", None)
        data_to_store.update({"train_loss": train_loss, "val_loss": val_loss})
        data_to_store.update({"wandb_model_name": run.name})
        df = pd.concat([df, pd.DataFrame(data_to_store, index=[0])], ignore_index=True)

In [None]:
basis_set_name = basis_set.replace("*", "star")
basis_set_name = basis_set_name.replace("+", "plus")
basis_set_name = basis_set_name.replace("(", "")
basis_set_name = basis_set_name.replace(")", "")
basis_set_name = basis_set_name.replace(",", "")
basis_set_name = basis_set_name.replace(" ", "_")
basis_set_name = basis_set_name.lower()

inputs = read_inputs_yaml(model_config)
input_foldername = (__input_folder__ / dataset_name / basis_set_type / basis_set_name)
dataset_options = inputs["dataset_options"][f"{basis_set_type}_basis"]

In [None]:
_debug_string = "_debug" if debug else ""
train_json_filename = input_foldername / f"train{_debug_string}.json"
validate_json_filename = input_foldername / f"validate{_debug_string}.json"

train_dataset = ReactionDataset(
    root=get_train_data_path(model_name + "_" + basis_set_type + "_" + basis_set_name),
    filename=train_json_filename,
    **dataset_options,
)
validate_dataset = ReactionDataset(
    root=get_train_data_path(model_name + "_" + basis_set_type + "_" + basis_set_name),
    filename=validate_json_filename,
    **dataset_options,
)
test_dataset = ReactionDataset(
    root=get_train_data_path(model_name + "_" + basis_set_type + "_" + basis_set_name),
    filename=validate_json_filename,
    **dataset_options,
)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = DataLoader(validate_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

## Performance of the structure interpolator

In [None]:
all_mae_norms = []

for loader in [train_loader, validation_loader, test_loader]:
    for idx, data in enumerate(loader):

        interpolated_ts_coords = data.pos_interpolated_transition_state.detach().numpy()
        real_ts_coords = data.pos_transition_state.detach().numpy()
        difference_ts_coords = interpolated_ts_coords - real_ts_coords
        norm_difference_ts_coords = np.linalg.norm(difference_ts_coords, axis=1)

        mae = np.mean(norm_difference_ts_coords)
        all_mae_norms.append(mae)

fig = px.histogram(x=all_mae_norms, nbins=20, template="simple_white")

fig.update_layout(title="MAE structure prediction (Å)")
fig.update_xaxes(title_text="MAE (Å)")
fig.update_yaxes(title_text="Frequency")

fig.update_layout(
    autosize=False,
    width=600,
    height=500,
)

fig.show()

## Performance of the `coeff_matrix` model 

In [None]:
best_coeff_matrix_model = df[df["prediction_mode"]=="coeff_matrix"].sort_values(by="val_loss").iloc[0]
best_coeff_matrix_run = [run for run in runs if run.name == best_coeff_matrix_model["wandb_model_name"]][0]
best_coeff_matrix_artifacts = best_coeff_matrix_run.logged_artifacts()
best_coeff_matrix_model = [artifact for artifact in best_coeff_matrix_artifacts if artifact.type == "model"][0] 
best_coeff_matrix_model.download()

In [None]:
coeff_matrix_model = torch.load(best_coeff_matrix_model.file(), map_location=torch.device('cpu'))
coeff_matrix_model.eval()
fig = go.Figure()
for idx, data in enumerate(validation_loader):
    predicted = coeff_matrix_model(data)
    predicted = predicted.detach().numpy()

    expected = data.x_transition_state
    expected = expected.detach().numpy()

    # Plot the parity plot
    fig.add_trace(
        go.Scatter(
            x=np.abs(expected.flatten()),
            y=np.abs(predicted.flatten()),
            mode="markers",
            marker=dict(
                color="LightSkyBlue",
                size=10,
                line=dict(
                    color="MediumPurple",
                    width=2,
                ),
            ),
        )
    )
# Change the template to a simple white background
fig.update_layout(template="simple_white")
fig.add_trace(
    go.Scatter(
        x=[0, 1],
        y=[0, 1],
        mode="lines",
        line=dict(color="MediumPurple", width=2, dash="dash"),
        name="Diagonal",
    )
)
fig.update_layout(showlegend=False)
# Add axis labels
fig.update_xaxes(title_text="Expected Absolute Coefficient Matrix")
fig.update_yaxes(title_text="Predicted Absolute Coefficient Matrix")
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
)

In [None]:
for idx, data in enumerate(validation_loader):
    predicted = coeff_matrix_model(data)
    predicted = predicted.detach().numpy()

    expected = data.x_transition_state
    expected = expected.detach().numpy()

    break


sum_squares_ouput = np.sum(predicted**2)
print(f"Sum of squares of predicted: {sum_squares_ouput:.3f}")

expected = data.x_transition_state.detach().numpy()
sum_squares_expected = np.sum(expected**2)
print(f"Sum of squares of expected: {sum_squares_expected:.3f}")

difference_negative = predicted - expected
difference_positive = predicted + expected
if np.sum(np.abs(difference_negative)) < np.sum(np.abs(difference_positive)):
    difference = difference_negative
else:
    difference = difference_positive
    predicted *= -1
difference = np.abs(difference)
sum_differences = np.sum(np.abs(difference))
sumsq_differences = np.sum(difference**2)
print(f"Sum of differences: {sum_differences:.3f}")
print(f"Sum of squares of differences: {sumsq_differences:.3f}")
print(f"Max differences", np.max(np.abs(difference)))

fig, axs = plt.subplots(
    3, 1, figsize=(12, 8), sharey=True, sharex=True, facecolor="w"
)
cax = axs[0].imshow(predicted, cmap="cividis")
fig.colorbar(cax, ax=axs[0])
axs[0].set_title(r"$\mathbf{C}_{\mathrm{model}}$")
cax = axs[1].imshow(expected, cmap="cividis")
fig.colorbar(cax, ax=axs[1])
cax.set_clim(axs[0].get_images()[0].get_clim())
axs[1].set_title(r"$\mathbf{C}_{\mathrm{DFT}}$")
cax = axs[2].imshow(difference, cmap="Blues")
fig.colorbar(cax, ax=axs[2])
axs[2].set_title(
    r"$\min \left( \left | \mathbf{C}_{\mathrm{model}} \pm \mathbf{C}_{\mathrm{DFT}} \right | \right)$"
)

tickvals = data.species.view(-1).detach().numpy().flatten()
tickvals_species = [atomic_names[int(tickval)] for tickval in tickvals]
axs[0].set_yticks(np.arange(len(tickvals)))
axs[0].set_yticklabels(tickvals_species)
axs[1].set_yticks(np.arange(len(tickvals)))
axs[1].set_yticklabels(tickvals_species)
axs[2].set_yticks(np.arange(len(tickvals)))
axs[2].set_yticklabels(tickvals_species)
max_s_functions = 1 * train_dataset.max_s_functions
max_p_functions = 3 * train_dataset.max_p_functions 
max_d_functions = 5 * train_dataset.max_d_functions
max_f_functions = 7 * train_dataset.max_f_functions
max_g_functions = 9 * train_dataset.max_g_functions
axs[0].set_xticks(np.arange(max_s_functions + max_p_functions + max_d_functions + max_f_functions + max_g_functions))
axs[0].set_xticklabels(max_s_functions * ["s"] + max_p_functions * ["p"] + max_d_functions * ["d"] + max_f_functions * ["f"] + max_g_functions * ["g"])

fig.tight_layout()
fig.set_dpi(75)
plt.show()


# Interpretation of `coeff_matrix` results

In [None]:
predicted_ortho_coeff_matrix_to_grid_quantities = get_instance_grid(
    data, predicted, grid_points_per_axis=grid_points_per_axis, buffer_grid=1.2
)
expected_ortho_coeff_matrix_to_grid_quantities = get_instance_grid(
    data, expected, grid_points_per_axis=grid_points_per_axis, buffer_grid=1.2
)

predicted_molecular_orbital = (
    predicted_ortho_coeff_matrix_to_grid_quantities.get_molecular_orbital()
)
expected_molecular_orbital = (
    expected_ortho_coeff_matrix_to_grid_quantities.get_molecular_orbital()
)
grid = expected_ortho_coeff_matrix_to_grid_quantities.get_grid()

fig_predicted = go.Figure()
add_grid_to_fig(
    grid,
    expected_molecular_orbital,
    fig_predicted,
    isomin=np.min(expected_molecular_orbital),
    isomax=np.max(expected_molecular_orbital),
    cmap="RdBu",
    surface_count=5,
    species=data.species,
    positions=data.pos_transition_state,
)
fig_expected = go.Figure()
add_grid_to_fig(
    grid,
    predicted_molecular_orbital,
    fig_expected,
    isomin=np.min(expected_molecular_orbital),
    isomax=np.max(expected_molecular_orbital),
    cmap="RdBu",
    surface_count=5,
    species=data.species,
    positions=data.pos_transition_state,
)

In [None]:
fig_expected.update_layout(
    title_text="Expected molecular orbital",
)
fig_expected.update_layout(
    autosize=False,
    width=600,
    height=500,
)


In [None]:
fig_predicted.update_layout(
    title_text="Predicted molecular orbital",
)

fig_predicted.update_layout(
    autosize=False,
    width=600,
    height=500,
)

In [None]:
_predicted_molecular_orbital = np.reshape(
    predicted_molecular_orbital,
    (
        grid_points_per_axis,
        grid_points_per_axis,
        grid_points_per_axis,
    ),
)
_expected_molecular_orbital = np.reshape(
    expected_molecular_orbital,
    (
        grid_points_per_axis,
        grid_points_per_axis,
        grid_points_per_axis,
    ),
)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=grid[:, 2],
        y=np.mean(_predicted_molecular_orbital, axis=(0, 1)),
        mode="lines",
        name="Predicted",
    )
)
fig.add_trace(
    go.Scatter(
        x=grid[:, 2],
        y=np.mean(_expected_molecular_orbital, axis=(0, 1)),
        mode="lines",
        name="Expected",
    )
)
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
    title="Average molecular orbital",
    title_x=0.5,
    xaxis_title="z (Å)",
    yaxis_title="xy-averaged molecular orbital",
    legend_title="",
    template="simple_white",
)


## Performance of the `relative_energy` model 

In [None]:
best_relative_energy_model = df[df["prediction_mode"]=="relative_energy"].sort_values(by="val_loss").iloc[0]
best_relative_energy_run = [run for run in runs if run.name == best_relative_energy_model["wandb_model_name"]][0]
best_relative_energy_artifacts = best_relative_energy_run.logged_artifacts()
best_relative_energy_model = [artifact for artifact in best_relative_energy_artifacts if artifact.type == "model"][0] 
best_relative_energy_model.download()
relative_energy_model = torch.load(best_relative_energy_model.file(), map_location=torch.device('cpu'))

In [None]:
df = pd.DataFrame(columns=["output", "expected", "loader"])

loaders = {"train": train_loader, "validation": validation_loader, "test": test_loader}

for loader_name, loader in loaders.items():

    for data in loader:

        output = relative_energy_model(data)
        output = output.mean(dim=1)

        output = output.detach().numpy()
        expected = data.total_energy_transition_state - data.total_energy
        expected = expected.detach().numpy()

        output = output * ase_units.Ha
        expected = expected * ase_units.Ha

        data_to_store = {
            "output": output,
            "expected": expected,
            "loader": loader_name,
        }

        df = pd.concat([df, pd.DataFrame(data_to_store)], ignore_index=True)
    
# Make a parity plot of the predicted vs. expected total energy barrier.
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=df[df["loader"] == "train"]["expected"],
        y=df[df["loader"] == "train"]["output"],
        mode="markers",
        name="Train",
        marker=dict(color="blue"),
    )
)
fig.add_trace(
    go.Scatter(
        x=df[df["loader"] == "validation"]["expected"],
        y=df[df["loader"] == "validation"]["output"],
        mode="markers",
        name="Validation",
        marker=dict(color="orange"),
    )
)
fig.add_trace(
    go.Scatter(
        x=df[df["loader"] == "test"]["expected"],
        y=df[df["loader"] == "test"]["output"],
        mode="markers",
        name="Test",
        marker=dict(color="green"),
    )
)
# Draw a dashed line at y=x.
fig.add_trace(
    go.Scatter(
        x=[np.min(df["expected"]), np.max(df["expected"])],
        y=[np.min(df["expected"]), np.max(df["expected"])],
        mode="lines",
        name="y=x",
        line=dict(color="black", dash="dash"),
    )
)
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
    title="Activation Energy Parity Plot",
    title_x=0.5,
    xaxis_title="Expected (eV)",
    yaxis_title="Predicted (eV)",
    legend_title="",
    template="simple_white",
)
fig.show()
        