In [7]:
import random

from pathlib import Path

import numpy as np

import plotly.express as px

from pymatgen.core.structure import Molecule
from pymatgen.analysis.local_env import OpenBabelNN
from pymatgen.analysis.graphs import MoleculeGraph

import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.dpi']= 150

from minimal_basis.data.data_reaction import CoefficientMatrix, ModifiedCoefficientMatrix

from instance_mongodb import instance_mongodb_sei

from monty.serialization import loadfn, dumpfn


basis_set = loadfn( Path("input_files") / "sto-3g.json")

In [28]:
def get_quantities(cursor):
    """Get quantities that are needed for plotting."""
    for document in cursor:

        alpha_eigenvalues = document["calcs_reversed"][0]["alpha_eigenvalues"]
        final_energy = document["output"]["final_energy"]
        alpha_coeff_matrix = document["calcs_reversed"][0]["alpha_coeff_matrix"]
        alpha_coeff_matrix = np.array(alpha_coeff_matrix)

        alpha_eigenvalues = np.sort(alpha_eigenvalues)
        positive_alpha_eigenvalues = alpha_eigenvalues[alpha_eigenvalues > 0]
        negative_alpha_eigenvalues = alpha_eigenvalues[alpha_eigenvalues < 0]

        molecule = document["input"]["initial_molecule"]
        molecule = Molecule.from_dict(molecule)
        molecule_graph = MoleculeGraph.with_local_env_strategy(molecule, OpenBabelNN())

        yield {
            "final_energy": final_energy,
            "alpha_eigenvalues": alpha_eigenvalues,
            "alpha_coeff_matrix": alpha_coeff_matrix,
            "molecule_graph": molecule_graph,
        } 

db = instance_mongodb_sei(project="mlts")

collections_data = db.minimal_basis

# Create a new collection to store the data in
collections_data_new = db.minimal_basis_interpolated_sn2
groupname = "sn2_interpolated_from_transition_states"

# Find all unique reaction labels
reaction_labels = collections_data.distinct("tags.label")

df = pd.DataFrame()

atom_idx = 0

for reaction_label in reaction_labels:

    cursor = (
        collections_data.find(
            {
                "tags.label": reaction_label,
                "tags.group": groupname,
            }
        )
        .sort("tags.scaling", 1)
    )
    data_to_plot = list(get_quantities(cursor))

    selected_eigenval = data_to_plot[0]["alpha_eigenvalues"]
    selected_eigenval = selected_eigenval[selected_eigenval < 0]
    selected_eigenval = np.sort(selected_eigenval)
    selected_eigenval = selected_eigenval[-1]
    selected_eigenval_index = np.where(
        data_to_plot[0]["alpha_eigenvalues"] == selected_eigenval
    )[0][0]

    selected_coeff_matrix = []

    for data in data_to_plot:
        coeff_matrix = ModifiedCoefficientMatrix(
            molecule_graph=data["molecule_graph"],
            basis_info_raw=basis_set,
            coefficient_matrix=data["alpha_coeff_matrix"],
            store_idx_only=selected_eigenval_index,
            set_to_absolute=True,
        )
        minimal_basis_representation = coeff_matrix.get_minimal_basis_representation_atom(atom_idx)
        selected_coeff_matrix.append(minimal_basis_representation)
    
    selected_coeff_matrix = np.array(selected_coeff_matrix)
    
    fig = px.imshow(selected_coeff_matrix, color_continuous_scale="RdBu_r", animation_frame=0, 
                    range_color=[0, 0.1],)
    fig.show()

    break

