## Analyse output of dataset

In [None]:
import os

import numpy as np

import plotly.express as px
import plotly.graph_objects as go

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T

from minimal_basis.dataset.reaction import ReactionDataset

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

from ase import units as ase_units
from ase.data import atomic_numbers, atomic_names

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 200

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import wandb
run = wandb.init()

In [None]:
model_name = "rudorff_lilienfeld_e2"
config_filename = os.path.join("config", f"{model_name}_model.yaml")
inputs = read_inputs_yaml(config_filename)

train_json_filename = inputs["train_json"]
validate_json_filename = inputs["validate_json"]
kwargs_dataset = inputs["dataset_options"]
kwargs_dataset["use_minimal_basis_node_features"] = inputs[
    "use_minimal_basis_node_features"
]

train_dataset = ReactionDataset( 
    root=get_train_data_path(model_name),
    filename=train_json_filename,
    basis_filename=inputs["basis_file"],
    **kwargs_dataset
)

validation_dataset = ReactionDataset(
    root=get_validation_data_path(model_name),
    filename=validate_json_filename,
    basis_filename=inputs["basis_file"],
    **kwargs_dataset
)

In [None]:
# Visualize the coefficient matrix of the initial and final state
# as well as the interpolated matrix
matrices_to_visualize = []
matrices_to_visualize.append(data[0].x.detach().numpy())
matrices_to_visualize.append(data[0].x_final_state.detach().numpy())
matrices_to_visualize.append(data[0].x_transition_state.detach().numpy())
matrices_to_visualize = np.abs(matrices_to_visualize)
# Set a maximum value for the coefficient matrix
fig = px.imshow(matrices_to_visualize, template="simple_white", color_continuous_scale="RdBu",
    labels=dict(x="Basis Functions", y="Atoms", color="Coefficient"), animation_frame=0,
    range_color=[0, 1 ])
x_tick_label = 5 *["s"] + 12*["p"] + 3*["d" , "d" , "d" , "d" , "d"]
fig.update_xaxes(ticktext=x_tick_label, tickvals=np.arange(35))
tickvals = data.species.view(-1).detach().numpy().flatten()
y_tickvals_species = [atomic_names[int(tickval)] for tickval in tickvals]
fig.update_yaxes(ticktext=y_tickvals_species, tickvals=np.arange(len(tickvals)))
fig.show()