# Study the parsing of matrix features with the Data object

In [1]:
import numpy as np
import plotly.express as px
from instance_mongodb import instance_mongodb_sei

from pymatgen.core.structure import Molecule
from pymatgen.analysis.graphs import MoleculeGraph
from pymatgen.analysis.local_env import OpenBabelNN

import torch

from minimal_basis.data.data_hamiltonian import MatrixSplitAtoms

from e3nn import o3

from utils import rotate_three_dimensions

In [2]:
# Visualise computed Hamiltonian
db = instance_mongodb_sei(project="mlts")
collection = db.rotated_waters_dataset

hamiltonians = []
angles = []
structures = []
overlap_matrices = []
molecule_graphs = []

for doc in collection.find({}).limit(10):
    hamiltonians.append(doc["fock_matrices"][0])
    overlap_matrices.append(doc["overlap_matrices"][0])
    structure = Molecule.from_dict(doc["structures"][0])
    structures.append(structure)
    molecule_graph = MoleculeGraph.with_local_env_strategy(structure, OpenBabelNN())
    molecule_graphs.append(molecule_graph)
    angles.append(doc["angles"])

hamiltonians = np.array(hamiltonians)
overlap_matrices = np.array(overlap_matrices)
angles = np.array(angles)

basis_info_atom = {'O': 5, 'H': 1}
node_features = []
edge_features = []
for idx, molecule_graph in enumerate(molecule_graphs):
    split_instance = MatrixSplitAtoms(
        molecule_graph=molecule_graph,
        matrix=hamiltonians[idx],
        basis_info_atom=basis_info_atom,
    )
    node_features.append(split_instance.node_features)
    edge_features.append(split_instance.edge_features)

In [3]:
fig = px.imshow(
    hamiltonians[:, 1, ...], animation_frame=0, labels=dict(x="Basis", y="Basis", color="Value"),
)
fig.update_xaxes(
    ticktext=[ "O1s", "O2s", "O2p", "O2p", "O2p", "H1s", "H1s",],
    tickvals=np.arange(11),
)
fig.update_yaxes(
    ticktext=[ "O1s", "O2s", "O2p", "O2p", "O2p", "H1s", "H1s",],
    tickvals=np.arange(11),
)
fig.update_layout(title_text='DFT computed Hamiltonian for a rotated water molecule', title_x=0.5)
fig.show()

In [4]:
# Plot one of the node and edge features
fig = px.imshow(
    node_features[0][0], animation_frame=0, labels=dict(x="Basis", y="Basis", color="Value"),
    title="Node features for the first atom in the molecule"
)
fig.show()

In [21]:
# Plot one of the edge features
edge_features = np.array(edge_features)
rotated_edge_features = edge_features[:,0,0]
fig = px.imshow(
    rotated_edge_features, animation_frame=0, labels=dict(x="Basis", y="Basis", color="Value"),
    title="Edge features for the first edge", range_color=[-1, 1]
)
# Set only integer ticks on the x and y axes
fig.update_xaxes(tickmode='linear')
fig.update_yaxes(tickmode='linear')
fig.show()

(10, 5, 1)


In [39]:
irreps_edge_feature = o3.Irreps("2x0e+1x1o")

D_matrices = []
residual_rotated_edge_features = []

for idx, angle in enumerate(angles):
    alpha, beta, gamma = angle

    rotation_matrix = rotate_three_dimensions(alpha, beta, gamma)
    rotation_matrix = torch.tensor(rotation_matrix)
    if idx == 0:
        rotation_matrix_0 = rotation_matrix

    # Reference the rotation matrix to the first one
    rotation_matrix = rotation_matrix @ rotation_matrix_0.T

    D_matrix = irreps_edge_feature.D_from_matrix(rotation_matrix)
    D_matrix = D_matrix.detach().numpy()

    reference_rotated_edge_features = rotated_edge_features[0]
    # reference_rotated_edge_features = reference_rotated_edge_features.T

    dmat_rotated_edge_features = D_matrix @ reference_rotated_edge_features
    
    # dmat_rotated_edge_features =  reference_rotated_edge_features @ D_matrix.T
    # dmat_rotated_edge_features = D_matrix @ dmat_rotated_edge_features.T

    residual_rotated_edge_features.append(dmat_rotated_edge_features - rotated_edge_features[idx])

residual_rotated_edge_features = np.array(residual_rotated_edge_features)
fig = px.imshow(
    residual_rotated_edge_features, animation_frame=0, labels=dict(x="Basis", y="Basis", color="Value"),
    title="Residual edge features for the first edge", range_color=[-1, 1],
)
fig.update_xaxes(tickmode='linear')
fig.update_yaxes(tickmode='linear')
fig.show()


In [70]:
edge_feature_1, edge_feature_2 = edge_features[0, :, 0]
irreps_in1 = o3.Irreps("2x0e+1x1o")
irreps_in2 = o3.Irreps("2x0e+1x1o")
fcpt = o3.FullyConnectedTensorProduct(irreps_in1=irreps_in1, irreps_in2=irreps_in2, irreps_out="1x0e+1x1o+1x2e")
irreps_output = fcpt.irreps_out
print(f"Output irreps: {irreps_output}")

tensor_edge_feature_1 = torch.tensor(edge_feature_1, dtype=torch.float32).view(-1)
tensor_edge_feature_2 = torch.tensor(edge_feature_2, dtype=torch.float32).view(-1)
print(tensor_edge_feature_1.shape)
print(tensor_edge_feature_2.shape)

rot = o3.rand_matrix()
D_in = irreps_in1.D_from_matrix(rot)
D_out = irreps_output.D_from_matrix(rot)
output_rotated = fcpt(tensor_edge_feature_1@D_in.T, tensor_edge_feature_2@D_in.T)
output = fcpt(tensor_edge_feature_1, tensor_edge_feature_2) @ D_out.T
torch.allclose(output_rotated, output, rtol=1e-4, atol=1e-4)


Output irreps: 1x0e+1x1o+1x2e
torch.Size([5])
torch.Size([5])



The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.



True