# Reaction Model 

In [None]:
import os

from pathlib import Path

import numpy as np
import numpy.typing as npt

import plotly.express as px
import plotly.graph_objects as go

import pandas as pd

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from coeffnet.dataset.reaction import ReactionDataset
from coeffnet.postprocessing.transformations import (
    OrthoCoeffMatrixToGridQuantities,
    NodeFeaturesToOrthoCoeffMatrix,
    DatapointStoredVectorToOrthogonlizationMatrix,
)

from cli_functions import create_timestamp

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

from ase import units as ase_units
from ase.data import atomic_numbers, atomic_names, atomic_masses, vdw_radii
from ase.data.colors import jmol_colors


import matplotlib.pyplot as plt

plt.rcParams["figure.dpi"] = 200

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

import wandb
wandb_api = wandb.Api()

from model_functions import construct_model_name

## Inputs

In [None]:
# --- Inputs to be filled out by the user --- #
__input_folder__ = Path("input")
__config_folder__ = Path("config")
dataset_name = "rudorff_lilienfeld_sn2_dataset"
model_config = __config_folder__ / "rudorff_lilienfeld_model.yaml"
basis_set_type = "full"
basis_set_name = "def2-svp"
grid_points_per_axis = 25

In [None]:
model_name = construct_model_name(
    dataset_name=dataset_name,
    debug=False,
)
print(f"Model name: {model_name}")

timestamp = create_timestamp()

inputs = read_inputs_yaml(model_config)
input_foldername = (__input_folder__ / dataset_name / basis_set_type / basis_set_name)
dataset_options = inputs["dataset_options"]

In [None]:
coeff_run = wandb_api.run(f"sudarshanvj/{model_name}_archive/valiant-breeze-178")
barrier_run = wandb_api.run(f"sudarshanvj/{model_name}_archive/smooth-music-182")

## Loading data 

In [None]:
input_foldername = Path("input") / dataset_name / basis_set_type / basis_set_name
train_json_filename = input_foldername / "train.json"
validate_json_filename = input_foldername / "validate.json"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = T.ToDevice(device)

train_dataset = ReactionDataset(
    root=get_train_data_path(model_name + "_" + timestamp),
    filename=train_json_filename,
    transform=transform,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)
validate_dataset = ReactionDataset(
    root=get_validation_data_path(model_name + "_" + timestamp),
    filename=validate_json_filename,
    transform=transform,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)
test_dataset = ReactionDataset(
    root=get_test_data_path(model_name + "_" + timestamp),
    filename=validate_json_filename,
    transform=transform,
    **inputs["dataset_options"][f"{basis_set_type}_basis"],
)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = DataLoader(validate_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

## Performance of the transition-state structure interpolator

In [None]:
all_mae_norms = []

for loader in [train_loader, validation_loader, test_loader]:
    for idx, data in enumerate(loader):

        interpolated_ts_coords = data.pos_interpolated_transition_state.cpu().detach().numpy()
        real_ts_coords = data.pos_transition_state.cpu().detach().numpy()
        difference_ts_coords = interpolated_ts_coords - real_ts_coords
        norm_difference_ts_coords = np.linalg.norm(difference_ts_coords, axis=1)

        mae = np.mean(norm_difference_ts_coords)
        all_mae_norms.append(mae)

fig = px.histogram(x=all_mae_norms, nbins=20, template="simple_white")

fig.update_layout(title="MAE structure prediction (Å)")
fig.update_xaxes(title_text="MAE (Å)")
fig.update_yaxes(title_text="Frequency")

fig.update_layout(
    autosize=False,
    width=600,
    height=500,
)

fig.show()

## Performance of the model on the coefficient-matrix prediction task

In [None]:
def get_instance_grid(
    data,
    node_features,
    basis_name: str = "6-31g*",
    charge: int = -1,
    grid_points_per_axis: int = 10,
    buffer_grid: int = 5,
    uses_cartesian_orbitals: bool = True,
):

    datapoint_to_orthogonalization_matrix = (
        DatapointStoredVectorToOrthogonlizationMatrix(
            data.orthogonalization_matrix_transition_state
        )
    )
    datapoint_to_orthogonalization_matrix()

    orthogonalization_matrix_transition_state = (
        datapoint_to_orthogonalization_matrix.get_orthogonalization_matrix()
    )

    nodefeatures_to_orthocoeffmatrix = NodeFeaturesToOrthoCoeffMatrix(
        node_features=node_features,
        mask=data.basis_mask,
    )
    nodefeatures_to_orthocoeffmatrix()

    ortho_coeff_matrix = nodefeatures_to_orthocoeffmatrix.get_ortho_coeff_matrix()
    ortho_coeff_matrix_to_grid_quantities = OrthoCoeffMatrixToGridQuantities(
        ortho_coeff_matrix=ortho_coeff_matrix,
        orthogonalization_matrix=orthogonalization_matrix_transition_state,
        positions=data.pos_transition_state,
        species=data.species,
        basis_name=basis_name,
        indices_to_keep=data.indices_to_keep,
        charge=charge,
        uses_carterian_orbitals=uses_cartesian_orbitals,
        buffer_grid=buffer_grid,
        grid_points=grid_points_per_axis,
    )
    ortho_coeff_matrix_to_grid_quantities()

    return ortho_coeff_matrix_to_grid_quantities


def add_grid_to_fig(
    grid: npt.ArrayLike,
    molecular_orbital,
    fig: go.Figure,
    isomin: float,
    isomax: float,
    cmap: str = "cividis",
    surface_count: int = 3,
):

    fig_go = go.Figure(
        data=go.Isosurface(
            x=grid[:, 0],
            y=grid[:, 1],
            z=grid[:, 2],
            value=molecular_orbital.flatten(),
            isomin=isomin,
            isomax=isomax,
            surface_count=surface_count,
            caps=dict(x_show=False, y_show=False, z_show=False),
            opacity=0.5,
            colorbar=dict(
                title="Molecular orbital",
                titleside="right",
                titlefont=dict(size=18),
                tickfont=dict(size=14),
            ),
            colorscale=cmap,
        )
    )
    for figdata in fig_go.data:
        fig.add_trace(figdata)

    colors_of_atom = [
        jmol_colors[int(species)]
        for species in data.species.view(-1).detach().numpy().flatten()
    ]
    colors_of_atom = [
        f"rgb({color[0]*255}, {color[1]*255}, {color[2]*255})"
        for color in colors_of_atom
    ]

    fig.add_trace(
        go.Scatter3d(
            x=data.pos_transition_state[:, 0],
            y=data.pos_transition_state[:, 1],
            z=data.pos_transition_state[:, 2],
            mode="markers",
            marker=dict(
                color=colors_of_atom,
            ),
        )
    )

    fig.update_layout(
        scene=dict(
            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
            xaxis_title="",
            yaxis_title="",
            zaxis_title="",
        ),
        template="plotly_dark",
    )

In [None]:
# Make a parity plot of the coefficients
model = torch.load(artifact_coeff.file())

fig = go.Figure()

for loader in [train_loader, validation_loader, test_loader]:
    for idx, data in enumerate(validation_loader):
        predicted = model(data)
        predicted = predicted.cpu().detach().numpy()
        expected = data.x_transition_state.cpu().detach().numpy()

        # Make a parity plot, if the it is the training set plot it itin bluw with alpha=0.2
        if loader == train_loader:
            fig.add_trace(
                go.Scatter(
                    x=expected.flatten(),
                    y=predicted.flatten(),
                    mode="markers",
                    marker=dict(color="blue", opacity=0.2),
                )
            )
        elif loader == validation_loader:
            fig.add_trace(
                go.Scatter(
                    x=expected.flatten(),
                    y=predicted.flatten(),
                    mode="markers",
                    marker=dict(color="orange"),
                )
            )
        else:
            fig.add_trace(
                go.Scatter(
                    x=expected.flatten(),
                    y=predicted.flatten(),
                    mode="markers",
                    marker=dict(color="green"),
                )
            )

fig.show()

In [None]:
for idx, data in enumerate(validation_loader):
    predicted = model(data)
    predicted = predicted.cpu().detach().numpy()

    expected = data.x_transition_state
    expected = expected.cpu().detach().numpy()

    break


sum_squares_ouput = np.sum(predicted**2)
print(f"Sum of squares of predicted: {sum_squares_ouput:.3f}")

sum_squares_expected = np.sum(expected**2)
print(f"Sum of squares of expected: {sum_squares_expected:.3f}")

difference_negative = predicted - expected
difference_positive = predicted + expected
if np.sum(np.abs(difference_negative)) < np.sum(np.abs(difference_positive)):
    difference = difference_negative
else:
    difference = difference_positive
#     predicted *= -1
difference = np.abs(difference)
sum_differences = np.sum(np.abs(difference))
sumsq_differences = np.sum(difference**2)
print(f"Sum of differences: {sum_differences:.3f}")
print(f"Sum of squares of differences: {sumsq_differences:.3f}")
print(f"Max differences", np.max(np.abs(difference)))

fig, axs = plt.subplots(
    3, 1, figsize=(12, 8), sharey=True, sharex=True, facecolor="w"
)
cax = axs[0].imshow(predicted, cmap="cividis")
fig.colorbar(cax, ax=axs[0])
axs[0].set_title(r"$\mathbf{C}_{\mathrm{model}}$")
cax = axs[1].imshow(expected, cmap="cividis")
fig.colorbar(cax, ax=axs[1])
cax.set_clim(axs[0].get_images()[0].get_clim())
axs[1].set_title(r"$\mathbf{C}_{\mathrm{DFT}}$")
cax = axs[2].imshow(difference, cmap="Blues")
fig.colorbar(cax, ax=axs[2])
axs[2].set_title(
    r"$\min \left( \left | \mathbf{C}_{\mathrm{model}} \pm \mathbf{C}_{\mathrm{DFT}} \right | \right)$"
)

tickvals = data.species.view(-1).cpu().detach().numpy().flatten()
tickvals_species = [atomic_names[int(tickval)] for tickval in tickvals]
axs[0].set_yticks(np.arange(len(tickvals)))
axs[0].set_yticklabels(tickvals_species)
axs[1].set_yticks(np.arange(len(tickvals)))
axs[1].set_yticklabels(tickvals_species)
axs[2].set_yticks(np.arange(len(tickvals)))
axs[2].set_yticklabels(tickvals_species)
max_s_functions = 1 * train_dataset.max_s_functions 
max_p_functions = 3 * train_dataset.max_p_functions 
max_d_functions = 5 * train_dataset.max_d_functions 
axs[0].set_xticks(np.arange(max_s_functions + max_p_functions + max_d_functions))
axs[0].set_xticklabels(max_s_functions * ["s"] + max_p_functions * ["p"] + max_d_functions * ["d"])

fig.tight_layout()
fig.set_dpi(150)
plt.show()


## Postprocessing the molecular orbitals from the coefficient matrix

In [None]:
predicted_ortho_coeff_matrix_to_grid_quantities = get_instance_grid(
    data, predicted, grid_points_per_axis=grid_points_per_axis, buffer_grid=1.2
)
expected_ortho_coeff_matrix_to_grid_quantities = get_instance_grid(
    data, expected, grid_points_per_axis=grid_points_per_axis, buffer_grid=1.2
)

predicted_molecular_orbital = (
    predicted_ortho_coeff_matrix_to_grid_quantities.get_molecular_orbital()
)
expected_molecular_orbital = (
    expected_ortho_coeff_matrix_to_grid_quantities.get_molecular_orbital()
)
grid = expected_ortho_coeff_matrix_to_grid_quantities.get_grid()

fig_predicted = go.Figure()
add_grid_to_fig(
    grid,
    expected_molecular_orbital,
    fig_predicted,
    isomin=np.min(expected_molecular_orbital),
    isomax=np.max(expected_molecular_orbital),
    cmap="RdBu",
    surface_count=5,
)
fig_expected = go.Figure()
add_grid_to_fig(
    grid,
    predicted_molecular_orbital,
    fig_expected,
    isomin=np.min(expected_molecular_orbital),
    isomax=np.max(expected_molecular_orbital),
    cmap="RdBu",
    surface_count=5,
)

In [None]:
# Confirm that the volume integrated sum of squares of all of the molecular
# orbitals is similar for both the predicted and expected molecular orbitals. 
x_max_grid = np.max(grid[:, 0])
x_min_grid = np.min(grid[:, 0])
y_max_grid = np.max(grid[:, 1])
y_min_grid = np.min(grid[:, 1])
z_max_grid = np.max(grid[:, 2])
z_min_grid = np.min(grid[:, 2])
volume_grid = (
    (x_max_grid - x_min_grid)
    * (y_max_grid - y_min_grid)
    * (z_max_grid - z_min_grid)
)
print(f"Volume of grid: {volume_grid:.3f}")
Nx, Ny, Nz = grid_points_per_axis, grid_points_per_axis, grid_points_per_axis
sumsq_expected =  (expected_molecular_orbital**2) * volume_grid / (Nx * Ny * Nz)
sumsq_expected = np.sum(sumsq_expected)
sumsq_predicted = (predicted_molecular_orbital**2) * volume_grid / (Nx * Ny * Nz)
sumsq_predicted = np.sum(sumsq_predicted)
print(f"Sum of squares of predicted: {sumsq_predicted:.3f}")
print(f"Sum of squares of expected: {sumsq_expected:.3f}")


In [None]:
fig_expected.update_layout(
    title_text="Expected molecular orbital",
)
fig_expected.update_layout(
    autosize=False,
    width=600,
    height=500,
)


In [None]:
fig_predicted.update_layout(
    title_text="Predicted molecular orbital",
)

fig_predicted.update_layout(
    autosize=False,
    width=600,
    height=500,
)

In [None]:
_predicted_molecular_orbital = np.reshape(
    predicted_molecular_orbital,
    (
        grid_points_per_axis,
        grid_points_per_axis,
        grid_points_per_axis,
    ),
)
_expected_molecular_orbital = np.reshape(
    expected_molecular_orbital,
    (
        grid_points_per_axis,
        grid_points_per_axis,
        grid_points_per_axis,
    ),
)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=grid[:, 2],
        y=np.mean(_predicted_molecular_orbital, axis=(0, 1)),
        mode="lines",
        name="Predicted",
    )
)
fig.add_trace(
    go.Scatter(
        x=grid[:, 2],
        y=np.mean(_expected_molecular_orbital, axis=(0, 1)),
        mode="lines",
        name="Expected",
    )
)
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
    title="Average molecular orbital",
    title_x=0.5,
    xaxis_title="z (Å)",
    yaxis_title="xy-averaged molecular orbital",
    legend_title="",
    template="simple_white",
)


## Performance of the barrier Model

In [None]:
model = torch.load(
    os.path.join(artifact_dir_barrier, f"{model_name}.pt"),
    map_location=torch.device("cpu"),
)

df = pd.DataFrame(columns=["output", "expected", "loader"])

loaders = {"train": train_loader, "validation": validation_loader, "test": test_loader}

for loader_name, loader in loaders.items():

    for data in loader:

        output = model(data)
        output = output.mean(dim=1)

        output = output.detach().numpy()
        expected = data.total_energy_transition_state - data.total_energy
        expected = expected.detach().numpy()

        output = output * ase_units.Ha
        expected = expected * ase_units.Ha

        data_to_store = {
            "output": output,
            "expected": expected,
            "loader": loader_name,
        }

        df = pd.concat([df, pd.DataFrame(data_to_store)], ignore_index=True)
    
# Make a parity plot of the predicted vs. expected total energy barrier.
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=df[df["loader"] == "train"]["expected"],
        y=df[df["loader"] == "train"]["output"],
        mode="markers",
        name="Train",
        marker=dict(color="blue"),
    )
)
fig.add_trace(
    go.Scatter(
        x=df[df["loader"] == "validation"]["expected"],
        y=df[df["loader"] == "validation"]["output"],
        mode="markers",
        name="Validation",
        marker=dict(color="orange"),
    )
)
fig.add_trace(
    go.Scatter(
        x=df[df["loader"] == "test"]["expected"],
        y=df[df["loader"] == "test"]["output"],
        mode="markers",
        name="Test",
        marker=dict(color="green"),
    )
)
# Draw a dashed line at y=x.
fig.add_trace(
    go.Scatter(
        x=[np.min(df["expected"]), np.max(df["expected"])],
        y=[np.min(df["expected"]), np.max(df["expected"])],
        mode="lines",
        name="y=x",
        line=dict(color="black", dash="dash"),
    )
)
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
    title="Activation Energy Parity Plot",
    title_x=0.5,
    xaxis_title="Expected (eV)",
    yaxis_title="Predicted (eV)",
    legend_title="",
    template="simple_white",
)
fig.show()
        