In [1]:
import os

import numpy as np

from e3nn import o3
from e3nn.nn.models.gate_points_2102 import Convolution as ConvolutionGatePoints
from e3nn.nn.models.gate_points_2102 import Network as NetworkGatePoints
from e3nn.math import soft_one_hot_linspace

import plotly.express as px

import torch
from torch_geometric.loader import DataLoader

from minimal_basis.dataset.dataset_reaction import ReactionDataset

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)

from ase import units as ase_units

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [2]:
inputs = read_inputs_yaml(os.path.join("input_files", "reaction_model.yaml"))

train_json_filename = inputs["debug_train_json"]
validate_json_filename = inputs["debug_validate_json"]

train_dataset = ReactionDataset( 
    root=get_train_data_path(),
    filename=train_json_filename,
    basis_filename=inputs["basis_file"],
)

validation_dataset = ReactionDataset(
    root=get_validation_data_path(),
    filename=validate_json_filename,
    basis_filename=inputs["basis_file"],
)

INFO:minimal_basis.dataset.dataset_reaction:Successfully loaded json file with data.
INFO:minimal_basis.dataset.dataset_reaction:Successfully loaded json file with basis information.
INFO:minimal_basis.dataset.dataset_reaction:Parsing basis information.
INFO:minimal_basis.dataset.dataset_reaction:Successfully loaded json file with data.
INFO:minimal_basis.dataset.dataset_reaction:Successfully loaded json file with basis information.
INFO:minimal_basis.dataset.dataset_reaction:Parsing basis information.


In [3]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False)

In [4]:
all_mae_norms = []

for idx, data in enumerate(train_loader):

    interpolated_ts_coords = data.pos_interpolated_transition_state.detach().numpy()
    real_ts_coords = data.pos_transition_state.detach().numpy()
    difference_ts_coords = interpolated_ts_coords - real_ts_coords
    norm_difference_ts_coords = np.linalg.norm(difference_ts_coords, axis=1)

    # Mean absolute error
    mae = np.mean(norm_difference_ts_coords)
    all_mae_norms.append(mae)

    # Plot the real and interpolated TS structures
    # with two different colors on the same plot
    fig = px.scatter_3d(
        x=np.concatenate((real_ts_coords[:, 0], interpolated_ts_coords[:, 0])),
        y=np.concatenate((real_ts_coords[:, 1], interpolated_ts_coords[:, 1])),
        z=np.concatenate((real_ts_coords[:, 2], interpolated_ts_coords[:, 2])),
        color=np.concatenate((np.zeros(len(real_ts_coords)), np.ones(len(interpolated_ts_coords)))),
    )

    # Set the title of the plot as the mean absolute error
    fig.update_layout(title=f"MAE of structure prediction: {mae:.3f} Å")
    fig.write_html(f"plots/hamiltonian_model/interpolated_ts_{idx}_mae_{mae:.3f}.html")

# Plot a histogram of the MAE
fig = px.histogram(x=all_mae_norms, nbins=20, template="simple_white")
fig.update_layout(title="MAE structure prediction (Å)")
fig.update_xaxes(title_text="MAE (Å)")
fig.update_yaxes(title_text="Frequency")
fig.write_html("plots/hamiltonian_model/interpolated_ts_mae_histogram.html")
fig.show()

In [5]:
node_input_irrep = o3.Irreps("1x0e+1x1o")
node_attr_irrep = o3.Irreps("1x0e")
irreps_sh = o3.Irreps("1x0e+1x1o+1x2e")
irreps_hidden = o3.Irreps("3x0e+4x1o")
irreps_out = o3.Irreps("1x0e+1x1o")
max_radius = 4
num_basis = 4
radial_layers = 4
radial_neurons = 5
num_neighbors = 3
typical_number_of_nodes = 20 

convolution = ConvolutionGatePoints(
    irreps_in=node_input_irrep,
    irreps_node_attr=node_attr_irrep,
    irreps_edge_attr=o3.Irreps(f"{num_basis}x0e"),
    irreps_out=irreps_out,
    number_of_edge_features=1,
    radial_layers=radial_layers,
    radial_neurons=radial_neurons,
    num_neighbors=num_neighbors,
)

network = NetworkGatePoints(
    irreps_in=node_input_irrep,
    irreps_hidden=irreps_hidden,
    irreps_out=irreps_out,
    irreps_node_attr=node_attr_irrep,
    irreps_edge_attr=o3.Irreps(f"{num_basis}x0e"),
    layers=radial_layers,
    max_radius=max_radius,
    number_of_basis=num_basis,
    radial_layers=radial_layers,
    radial_neurons=radial_neurons,
    num_neighbors=num_neighbors,
    num_nodes=typical_number_of_nodes,
    reduce_output=False,
)

for idx, data in enumerate(train_loader):

    row, col = data.edge_index
    edge_vec = data.pos[row] - data.pos[col]

    norm_edge_vec = edge_vec.norm(dim=1)
    norm_edge_vec = norm_edge_vec.view(-1, 1)

    sh = o3.spherical_harmonics(
        irreps_sh, edge_vec, normalize=True, normalization="component"
    )

    edge_length_embedding = soft_one_hot_linspace(
        edge_vec.norm(dim=1),
        start=0.0,
        end=max_radius,
        number=num_basis,
        basis="smooth_finite",
        cutoff=True,
    )

    node_input = data.x
    node_attr = data.species_initial_state
    edge_src = data.edge_index[0]
    edge_dst = data.edge_index[1]
    edge_attr = edge_length_embedding
    edge_features = norm_edge_vec

    print(f"Shape of node_input: {node_input.shape}")
    print(f"Shape of node_attr: {node_attr.shape}")
    print(f"Shape of edge_src: {edge_src.shape}")
    print(f"Shape of edge_dst: {edge_dst.shape}")
    print(f"Shape of edge_attr: {edge_attr.shape}")
    print(f"Shape of edge_features: {edge_features.shape}")

    output = convolution(
        node_attr=node_attr,
        node_input=node_input,
        edge_src=edge_src,
        edge_dst=edge_dst,
        edge_attr=edge_attr,
        edge_features=norm_edge_vec,
    )

    print(f"Shape of output: {output.shape}")
    
    difference = output - node_attr

    print(f"Shape of difference: {difference.shape}")
    fig = px.imshow(difference.detach().numpy())
    tickvals = node_attr.view(-1).detach().numpy().flatten()
    fig.update_yaxes(tickvals=np.arange(len(tickvals)), ticktext=tickvals)
    fig.show()

    output_network = network(
        {
            "pos": data.pos,
            "x": data.x,
            "z": data.species_initial_state,
            "batch": data.batch,
        }
    )

    print(f"Shape of output_network: {output_network.shape}")
    difference_network = output_network - node_attr
    fig = px.imshow(difference_network.detach().numpy())
    tickvals = node_attr.view(-1).detach().numpy().flatten()
    fig.update_yaxes(tickvals=np.arange(len(tickvals)), ticktext=tickvals)
    fig.show()

    break

Shape of node_input: torch.Size([13, 4])
Shape of node_attr: torch.Size([13, 1])
Shape of edge_src: torch.Size([11])
Shape of edge_dst: torch.Size([11])
Shape of edge_attr: torch.Size([11, 4])
Shape of edge_features: torch.Size([11, 1])
Shape of output: torch.Size([13, 4])
Shape of difference: torch.Size([13, 4])


Shape of output_network: torch.Size([13, 4])
