## Analyze rotation of water with pyscf calculation

In [None]:
from collections import defaultdict

import numpy as np

import pandas as pd

import plotly.express as px

from pymatgen.core.structure import Molecule

import torch

from e3nn import o3

from pyscf import dft
from pyscf import gto

from minimal_basis.transforms.rotations import RotationMatrix

from instance_mongodb import instance_mongodb_sei

In [None]:
db = instance_mongodb_sei(project="mlts")
collection = db.rotated_water_initial_structures

data = defaultdict(list)

for doc in collection.find():
    pymatgen_molecule = doc["molecule"]
    pymatgen_molecule = Molecule.from_dict(pymatgen_molecule)
    coordinates = np.array(pymatgen_molecule.cart_coords)
    species = [str(site.specie) for site in pymatgen_molecule]
    mol = gto.Mole()

    atom_string = ''.join([f"{species[i]} {coordinates[i][0]} {coordinates[i][1]} {coordinates[i][2]}; " for i in range(len(species))])
    atom_string = atom_string[:-2]

    mol.atom = atom_string
    mol.basis = 'def2-svp'
    mol.build()

    rks_h2o = dft.RKS(mol)
    rks_h2o.xc = 'b3lyp'

    rks_h2o.kernel()

    data["molecule"].append(mol)
    data["rks"].append(rks_h2o)
    data["euler_angles"].append(doc["euler_angles"])


In [None]:
def labels_to_irreps(pyscf_labels):
    """Convert the pyscf labels to irreps."""
    pyscf_labels = [label.split()[0:3] for label in pyscf_labels]
    pyscf_labels = [
        d[:-1] + [d[-1][0]] + [d[-1][1]] + [d[-1][2:]] for d in pyscf_labels
    ]
    pyscf_labels = pd.DataFrame(
        pyscf_labels, columns=["atom_idx", "atom_name", "n", "l", "m"]
    )
    pyscf_labels["atom_idx"] = pyscf_labels["atom_idx"].astype(int)
    pyscf_labels["n"] = pyscf_labels["n"].astype(int)
    l = pyscf_labels.l.values

    idx = 0
    irreps = ""
    while idx < len(l):
        if l[idx] == "s":
            irreps += "+1x0e"
            idx += 1
        elif l[idx] == "p":
            irreps += "+1x1o"
            idx += 3
        elif l[idx] == "d":
            irreps += "+1x2e"
            idx += 5
    irreps = irreps[1:]
    return irreps 

def labels_to_ordering(pyscf_labels):
    """Convert the pyscf labels to ordering."""
    pyscf_labels = [label.split()[0:3] for label in pyscf_labels]
    pyscf_labels = [
        d[:-1] + [d[-1][0]] + [d[-1][1]] + [d[-1][2:]] for d in pyscf_labels
    ]
    pyscf_labels = pd.DataFrame(
        pyscf_labels, columns=["atom_idx", "atom_name", "n", "l", "m"]
    )
    pyscf_labels["atom_idx"] = pyscf_labels["atom_idx"].astype(int)
    pyscf_labels["n"] = pyscf_labels["n"].astype(int)
    l = pyscf_labels.l.values

    idx = 0
    indices = []
    while idx < len(l):
        if l[idx] == "s":
            indices.append(idx)
            idx += 1
        elif l[idx] == "p":
            indices.extend([idx+1, idx + 2, idx])
            idx += 3
        elif l[idx] == "d":
            indices.extend([idx + i for i in range(5)])
            idx += 5
    indices = np.array(indices)
    return indices


In [None]:
calculated_coeff_matrices = []
self_rotated_coeff_matrices = []

for idx, angle in enumerate(data["euler_angles"]): 

    rks = data['rks'][idx]
    mo_coeff = rks.mo_coeff

    molecule = data['molecule'][idx]
    irreps = labels_to_irreps(molecule.ao_labels())

    ordering = labels_to_ordering(molecule.ao_labels())
    mo_coeff = mo_coeff[ordering, :]

    labels = molecule.ao_labels()
    labels = np.array(labels)
    labels = labels[ordering]
    eigenvalues = rks.mo_energy

    rotation_matrix = RotationMatrix(angle_type="euler", angles=angle) 
    rotation_matrix = rotation_matrix()
    rotation_matrix = torch.tensor(rotation_matrix, dtype=torch.float64)

    permutation_matrix = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float64)
    rotation_matrix = permutation_matrix.T @ rotation_matrix @ permutation_matrix

    if idx == 0:
        original_coeff_matrix = mo_coeff
        rotation_matrix_0 = rotation_matrix

    rotation_matrix = rotation_matrix @ rotation_matrix_0.T    

    D_matrix = o3.Irreps(irreps).D_from_matrix(rotation_matrix)
    D_matrix = D_matrix.detach().numpy()

    _new_coeff_matrix = np.zeros_like(original_coeff_matrix)
    for i in range(original_coeff_matrix.shape[1]):
        _new_coeff_matrix[:, i] = original_coeff_matrix[:, i] @ D_matrix.T 

    calculated_coeff_matrices.append(mo_coeff)
    self_rotated_coeff_matrices.append(_new_coeff_matrix)

calculated_coeff_matrices = np.array(calculated_coeff_matrices)
self_rotated_coeff_matrices = np.array(self_rotated_coeff_matrices)
difference_coeff_matrices = np.abs(calculated_coeff_matrices) - np.abs(self_rotated_coeff_matrices)

In [None]:
fig = px.imshow(calculated_coeff_matrices, color_continuous_scale='RdBu_r', animation_frame=0, range_color=[-1, 1])
fig.update_layout(
    title="Calculated Coefficient Matrices",
    xaxis_title="MO Eigenvalues (Ha)",
    yaxis_title="AO Index",
)
fig.update_yaxes(
    ticktext=labels,
    tickvals=np.arange(len(molecule.ao_labels())),
)
fig.update_xaxes(
    ticktext=np.round(eigenvalues, 2),
    tickvals=np.arange(len(eigenvalues)),
)
fig.show()


In [None]:
# Plot the self rotated coefficient matrices
fig = px.imshow(self_rotated_coeff_matrices, color_continuous_scale='RdBu_r', animation_frame=0, range_color=[-1, 1])
fig.update_layout(
    title="Self Rotated Coefficient Matrices",
    xaxis_title="MO Eigenvalues (Ha)",
    yaxis_title="AO Index",
)
fig.update_yaxes(
    ticktext=labels,
    tickvals=np.arange(len(molecule.ao_labels())),
)
fig.update_xaxes(
    ticktext=np.round(eigenvalues, 2),
    tickvals=np.arange(len(eigenvalues)),
)
fig.show()

In [None]:
# Plot the difference
fig = px.imshow(difference_coeff_matrices, color_continuous_scale='RdBu_r', animation_frame=0, range_color=[-1, 1])
fig.update_layout(
    title="Difference Coefficient Matrices",
    xaxis_title="MO Eigenvalues (Ha)",
    yaxis_title="AO Index",
)
fig.update_yaxes(
    ticktext=labels,
    tickvals=np.arange(len(molecule.ao_labels())),
)
fig.update_xaxes(
    ticktext=np.round(eigenvalues, 2),
    tickvals=np.arange(len(eigenvalues)),
)
fig.show()
    