In [None]:
import os
import numpy as np
from e3nn import o3
from minimal_basis.dataset.dataset_hamiltonian import HamiltonianDataset
from minimal_basis.model.model_hamiltonian import (
    EquivariantConv,
    SimpleHamiltonianModel,
)
import warnings
import plotly.express as px
import torch
from torch_geometric.loader import DataLoader
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

# Features from the Hamiltonian for chemical reactions

Looking at a single reaction which has been rotated 20 different times. Ideally our features must encode these atomic interactions while doing away with the rotation dependence.

In [None]:
dataset = HamiltonianDataset(
    root="./output",
    filename="input_files/train_data_rotated_sn2.json",
    basis_file="input_files/sto-3g.json",
)

In [None]:
fock_matrices = []
overlap_matrices = []
for input_data in dataset.input_data:
    idx_init_state = input_data['state'].index('initial_state')
    fock_matrices.append(input_data['fock_matrices'][idx_init_state][0])
    overlap_matrices.append(input_data['overlap_matrices'][idx_init_state][0])
fock_matrices = np.array(fock_matrices)
overlap_matrices = np.array(overlap_matrices)
# Visualise the fock and overlap matrices as heatmap animations

fig = px.imshow(
    fock_matrices,
    animation_frame=0,
    color_continuous_scale='RdBu',
    range_color=(-1, 1),
    title='Fock matrix',
    labels={'x': 'Basis function', 'y': 'Basis function', 'animation_frame': 'Snapshot'},
)
fig.show()


In [None]:
fig = px.imshow(
    overlap_matrices,
    animation_frame=0,
    color_continuous_scale='RdBu',
    range_color=(-1, 1),
    title='Overlap matrix',
    labels={'x': 'Basis function', 'y': 'Basis function', 'animation_frame': 'Snapshot'},
)
fig.show()

In [None]:
loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
def rotate_three_dimensions(alpha, beta, gamma):
    """Rotate the molecule by arbitrary angles alpha
    beta and gamma."""
    cos = np.cos
    sin = np.sin

    r_matrix = [
        [
            cos(alpha) * cos(beta),
            cos(alpha) * sin(beta) * sin(gamma) - sin(alpha) * cos(gamma),
            cos(alpha) * sin(beta) * cos(gamma) + sin(alpha) * sin(gamma),
        ],
        [
            sin(alpha) * cos(beta),
            sin(alpha) * sin(beta) * sin(gamma) + cos(alpha) * cos(gamma),
            sin(alpha) * sin(beta) * cos(gamma) - cos(alpha) * sin(gamma),
        ],
        [-sin(beta), cos(beta) * sin(gamma), cos(beta) * cos(gamma)],
    ]

    r_matrix = np.array(r_matrix)

    return r_matrix


In [None]:
node_features = []
global_features = []
angles = []
irreps_fock_matrices = []
minimal_fock_matrices = []

for data in loader:
    node_feature = data.x
    node_feature = node_feature.detach().numpy()
    node_features.append(node_feature)

    dim_global_attr = data.irreps_minimal_basis[0].dim
    irreps_fock_matrices.append(data.irreps_minimal_basis[0])

    global_attr = data.global_attr
    global_attr = global_attr.detach().numpy()
    global_features.append(global_attr)

    minimal_fock_matrix = data.minimal_fock_matrix['initial_state'] 
    minimal_fock_matrix = np.array(minimal_fock_matrix)
    minimal_fock_matrices.append(minimal_fock_matrix)

    angles.append(data.angles[0])

node_features = np.array(node_features)
global_features = np.array(global_features)
global_features = global_features.reshape(-1, 1, global_features.shape[-1])
angles = np.array(angles)
minimal_fock_matrices = np.array(minimal_fock_matrices)

In [None]:
D_matrices = []

for idx, angle in enumerate(angles):
    alpha, beta, gamma = angle

    irreps_fock = irreps_fock_matrices[idx]

    rotation_matrix = rotate_three_dimensions(alpha, beta, gamma)
    rotation_matrix = torch.tensor(rotation_matrix)

    if idx == 0:
        rotation_matrix_0 = rotation_matrix

    # Reference the rotation matrix to the first one
    rotation_matrix = rotation_matrix @ rotation_matrix_0.T

    D_matrix = irreps_fock.D_from_matrix(rotation_matrix)
    D_matrices.append(D_matrix)

D_matrices = torch.stack(D_matrices)

# Convert to numpy array
D_matrices = D_matrices.detach().numpy()

In [None]:
minimal_fock_matrices_rotated = np.zeros_like(minimal_fock_matrices)

for i in range(len(angles)):
    minimal_fock_matrices_rotated[i, ...] =  D_matrices[i].T @ minimal_fock_matrices[0, ..., :, :] @ D_matrices[i].T

minimal_fock_matrices_rotated_diff = minimal_fock_matrices - minimal_fock_matrices_rotated 

fig = px.imshow(
    minimal_fock_matrices_rotated_diff[:, 0, 0, ...],
    animation_frame=0,
    labels=dict(x="Basis", y="Basis", color="Value"),
    range_color=[-1, 1],
    title="Difference between the computed and rotated Fock matrix"
)
fig.show()

- **Node features**: The diagonal elements of the sub-diagonalised Hamiltonian matrix
- **Edge features**: Bond lengths of the _z_-matrix interpolated transition state structures
- **Global features**: Eigenvalues of the minimal basis representation of the Hamiltonian matrix

In [None]:
fig = px.imshow(
    node_features,
    animation_frame=0,
    color_continuous_scale='RdBu',
    title='Node features',
    labels={'x': 'Dimension of Irreducible Representation', 'y': 'Node feature', 'animation_frame': 'Snapshot'},
)
fig.show()

In [None]:
# Plot the global attribute for the first spin
fig = px.imshow(
    global_features,
    animation_frame=0,
    color_continuous_scale='RdBu',
    title='Global Features',
    range_color=(-1, 1),
    labels={'x': '(Minimal) Basis function', 'animation_frame': 'Snapshot'},
)
fig.update_yaxes(showticklabels=False)
fig.show()

## Simple model

For each node and global attribute, generate the following model:

- Take the tensor product between the features of the reactants and products
- Parameterise the weights of these tensor products by the bond lengths of the _interpolated_ transition state structure

$$
f_{\mathrm{output}} = \frac{1}{\mathrm{norm}} f_{i} \otimes h\left (||x_{ij}|| \right ) f_j
$$

In [None]:
irreps_in = o3.Irreps(f"{dataset[0].irreps_node_features.dim}x0e")
irreps_out = o3.Irreps("20x0e") 

rot = o3.rand_matrix()
D_in = irreps_in.D_from_matrix(rot)
D_out = irreps_out.D_from_matrix(rot)
D_out = D_out.detach().numpy()

conv = EquivariantConv(
    irreps_in=irreps_in,
    irreps_out=irreps_out,
    hidden_layers=64,
    num_basis=10,
    max_radius=4.0,
)
equivart_output_model = []
equivart_rotated_output_model = []

for idx, data in enumerate(loader):
    output = conv(
        data.x,
        data.x_final_state,
        data.edge_index_interpolated_TS,
        data.pos_interpolated_TS,
    )
    equivart_output_model.append(output.detach().numpy())
    rotated_output = conv(
        data.x @ D_in.T,
        data.x_final_state @ D_in.T,
        data.edge_index_interpolated_TS,
        data.pos_interpolated_TS @ rot.T,
    )
    equivart_rotated_output_model.append(rotated_output.detach().numpy())

equivart_output_model = np.array(equivart_output_model)
equivart_rotated_output_model = np.array(equivart_rotated_output_model)

# Plot the output of the equivariant model
fig = px.imshow(
    equivart_output_model,
    animation_frame=0,
    color_continuous_scale='RdBu',
    title='Output of the equivariant model',
    labels={'x': 'Dimension of Irreducible Representation', 'y': 'Node feature', 'animation_frame': 'Snapshot'},
)
fig.show()


In [None]:
# Perform the rotation of the outputs
rotated_equivart_output_model = equivart_output_model[:, ...] @ D_out.T 

# Subtract the rotated output from the original output
equivart_output_model_diff = equivart_output_model - rotated_equivart_output_model

# Plot the difference between the original and rotated output
fig = px.imshow(
    equivart_output_model_diff,
    animation_frame=0,
    color_continuous_scale='RdBu',
    title='Difference between the original and rotated output of the equivariant model',
    labels={'x': 'Dimension of Irreducible Representation', 'y': 'Node feature', 'animation_frame': 'Snapshot'},
)
fig.show()

In [None]:
# Create the SimpleHamiltonianModel and see the output

model = SimpleHamiltonianModel(
    irreps_in=irreps_in,
    irreps_intermediate=irreps_out,
    hidden_layers=64,
    num_basis=10,
    max_radius=4.0,
)

for data in loader:
    output = model(data)
    print(output)