# Analyse Reaction Model

In [None]:
import os

import numpy as np

import plotly.express as px
import plotly.graph_objects as go

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from minimal_basis.dataset.reaction import ReactionDataset

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

from ase import units as ase_units
from ase.data import atomic_numbers, atomic_names

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 200

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import wandb
run = wandb.init()

## Inputs

In [None]:
# Run this notebook with this `model_name`
model_name = "rudorff_lilienfeld_sn2_model"

In [None]:
# Choose the artifacts that will present the model in this code block
artifact_coeff = run.use_artifact(f'sudarshanvj/{model_name}/{model_name}:v7', type='model')
artifact_dir_coeff = artifact_coeff.download()
artifact_barrier = run.use_artifact(f'sudarshanvj/{model_name}/{model_name}:v4', type='model')
artifact_dir_barrier = artifact_barrier.download()

## Setup

In [None]:
config_filename = os.path.join("config", f"{model_name}.yaml")
inputs = read_inputs_yaml(config_filename)

train_json_filename = inputs["train_json"]
validate_json_filename = inputs["validate_json"]
test_json_filename = inputs["test_json"]
kwargs_dataset = inputs["dataset_options"]
kwargs_dataset["use_minimal_basis_node_features"] = inputs[
    "use_minimal_basis_node_features"
]

train_dataset = ReactionDataset( 
    root=get_train_data_path(model_name),
    filename=train_json_filename,
    basis_filename=inputs["basis_file"],
    **kwargs_dataset
)

validation_dataset = ReactionDataset(
    root=get_validation_data_path(model_name),
    filename=validate_json_filename,
    basis_filename=inputs["basis_file"],
    **kwargs_dataset
)

test_dataset = ReactionDataset(
    root=get_test_data_path(model_name),
    filename=test_json_filename,
    basis_filename=inputs["basis_file"],
    **kwargs_dataset
)    

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

## Performance of the structure interpolator

In [None]:
all_mae_norms = []

for idx, data in enumerate(validation_loader):

    interpolated_ts_coords = data.pos_interpolated_transition_state.detach().numpy()
    real_ts_coords = data.pos_transition_state.detach().numpy()
    difference_ts_coords = interpolated_ts_coords - real_ts_coords
    norm_difference_ts_coords = np.linalg.norm(difference_ts_coords, axis=1)

    mae = np.mean(norm_difference_ts_coords)
    all_mae_norms.append(mae)

fig = px.histogram(x=all_mae_norms, nbins=20, template="simple_white")

fig.update_layout(title="MAE structure prediction (Å)")
fig.update_xaxes(title_text="MAE (Å)")
fig.update_yaxes(title_text="Frequency")

fig.update_layout(
    autosize=False,
    width=600,
    height=500,
)

fig.show()

## Performance of the Coefficient Model

In [None]:
model = torch.load(os.path.join(artifact_dir_coeff, f"{model_name}.pt"), map_location=torch.device('cpu'))

for idx, data in enumerate(train_loader):
    output = model(data)
    output = output.detach().numpy()
    output = np.abs(output)
    sum_squares_ouput = np.sum(output**2)
    print(f"Sum of squares of output: {sum_squares_ouput:.3f}")

    expected = data.x_transition_state.detach().numpy()
    expected = np.abs(expected)
    sum_squares_expected = np.sum(expected**2)
    print(f"Sum of squares of expected: {sum_squares_expected:.3f}")

    difference = output - expected
    difference = np.abs(difference)
    sum_differences = np.sum(np.abs(difference))
    sumsq_differences = np.sum(difference**2)
    print(f"Sum of differences: {sum_differences:.3f}")
    print(f"Sum of squares of differences: {sumsq_differences:.3f}")
    print(f"Max differences", np.max(np.abs(difference)))

    fig, axs = plt.subplots(3, 1, figsize=(12,8),  sharey=True, sharex=True, facecolor='w')
    cax = axs[0].imshow(output, cmap="cividis")
    fig.colorbar(cax, ax=axs[0])
    axs[0].set_title(r"$\mathbf{C}_{\mathrm{model}}$")
    cax = axs[1].imshow(expected, cmap="cividis")
    fig.colorbar(cax, ax=axs[1])
    cax.set_clim(axs[0].get_images()[0].get_clim())
    axs[1].set_title(r"$\mathbf{C}_{\mathrm{DFT}}$")
    cax = axs[2].imshow(difference, cmap="Blues")
    fig.colorbar(cax, ax=axs[2])
    axs[2].set_title(r"$\left | \mathbf{C}_{\mathrm{model}} - \mathbf{C}_{\mathrm{DFT}} \right | $")

    tickvals = data.species.view(-1).detach().numpy().flatten()
    tickvals_species = [atomic_names[int(tickval)] for tickval in tickvals]
    axs[0].set_yticks(np.arange(len(tickvals)))
    axs[0].set_yticklabels(tickvals_species)
    axs[1].set_yticks(np.arange(len(tickvals)))
    axs[1].set_yticklabels(tickvals_species)
    axs[2].set_yticks(np.arange(len(tickvals)))
    axs[2].set_yticklabels(tickvals_species)
    axs[0].set_xticks(np.arange(17))
    axs[0].set_xticklabels(5 *["s"] + 12*["p"])

    fig.tight_layout()
    fig.set_dpi(100)
    plt.show()

    break

## Performance of the Barrier Model

In [None]:
model = torch.load(os.path.join(artifact_dir_barrier, f"{model_name}.pt"), map_location=torch.device('cpu'))

outputs = []
expecteds = []

for idx, data in enumerate(validation_loader):
    output = model(data)
    output = output.mean(dim=1)
    
    output = output.detach().numpy()
    expected = data.total_energy_transition_state - data.total_energy
    expected = expected.detach().numpy()

    outputs.append(output[0])
    expecteds.append(expected[0])


# Make a parity plot of the output vs. expected
outputs = np.array(outputs).flatten()
expecteds = np.array(expecteds).flatten()
outputs *= ase_units.Ha
expecteds *= ase_units.Ha

# Remove entries from output and expected that less than 0
# idx_to_remove = np.where(expecteds < 0)[0]
# outputs = np.delete(outputs, idx_to_remove)
# expecteds = np.delete(expecteds, idx_to_remove)


print(f"Number of points: {len(outputs)}")
mae = np.mean(np.abs(outputs - expecteds))
print(f"Mean absolute error: {mae:.3f} eV")

fig = px.scatter(x=expecteds, y=outputs, template="simple_white")
fig.update_xaxes(title_text="DFT Computed Barrier (eV)")
fig.update_yaxes(title_text="Model Output Barrier (eV)")
fig.add_shape(
    type="line",
    x0=outputs.min(),
    y0=outputs.min(),
    x1=outputs.max(),
    y1=outputs.max(),
    line=dict(
        color="Red",
        width=4,
        dash="dashdot",
    )
)
fig.update_layout(
    autosize=False,
    width=600,
    height=500,
)


fig.show()