In [2]:
import os
import numpy as np
from e3nn import o3
from minimal_basis.dataset.dataset_hamiltonian import HamiltonianDataset
from minimal_basis.model.model_hamiltonian import (
    EquivariantConv,
    SimpleHamiltonianModel,
)
import warnings
import plotly.express as px
import torch
from torch_geometric.loader import DataLoader

from utils import (
    get_test_data_path,
    get_validation_data_path,
    get_train_data_path,
    read_inputs_yaml,
)
from ase import units as ase_units

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import wandb
wandb_run = wandb.init(project="hamiltonian", entity="sudarshanvj")

[34m[1mwandb[0m: Currently logged in as: [33msudarshanvj[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
inputs = read_inputs_yaml(os.path.join("input_files", "hamiltonian_model.yaml"))

train_json_filename = inputs["debug_train_json"]
validate_json_filename = inputs["debug_validate_json"]

train_dataset = HamiltonianDataset(
    root=get_train_data_path(),
    filename=train_json_filename,
    basis_file=inputs["basis_file"],
)
validation_dataset = HamiltonianDataset(
    root=get_validation_data_path(),
    filename=validate_json_filename,
    basis_file=inputs["basis_file"],
)

INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with data.
INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information from ./input_files/6-31G_star.json
INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with data.
INFO:minimal_basis.dataset.dataset_hamiltonian:Successfully loaded json file with basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information.
INFO:minimal_basis.dataset.dataset_hamiltonian:Parsing basis information from ./input_files/6-31G_star.json


In [4]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False)

In [9]:
# Get the model
artifact = wandb_run.use_artifact('sudarshanvj/hamiltonian/hamiltonian_model:v5', type='model')
artifact_dir = artifact.download()
model = torch.load(os.path.join(artifact_dir, "hamiltonian_model.pt"), map_location=torch.device('cpu') )
model.eval()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


SimpleHamiltonianModel(
  (conv): EquivariantConv(
    (tp): FullyConnectedTensorProduct(14x0e x 14x0e -> 20x0e | 3920 paths | 3920 weights)
    (fc): FullyConnectedNet[10, 64, 3920]
  )
  (conv_global): EquivariantConv(
    (tp): FullyConnectedTensorProduct(14x0e x 14x0e -> 20x0e | 3920 paths | 3920 weights)
    (fc): FullyConnectedNet[10, 64, 3920]
  )
)

In [11]:
all_mae_norms = []
output_comparison = [] 

for idx, data in enumerate(train_loader):

    interpolated_ts_coords = data.pos_interpolated_TS.detach().numpy()
    real_ts_coords = data.pos_real_TS
    real_ts_coords = np.array(real_ts_coords[0])

    difference_ts_coords = interpolated_ts_coords - real_ts_coords

    norm_difference_ts_coords = np.linalg.norm(difference_ts_coords, axis=1)
    # Mean absolute error
    mae = np.mean(norm_difference_ts_coords)
    all_mae_norms.append(mae)

    # Plot the real and interpolated TS structures
    # with two different colors on the same plot
    fig = px.scatter_3d(
        x=np.concatenate((real_ts_coords[:, 0], interpolated_ts_coords[:, 0])),
        y=np.concatenate((real_ts_coords[:, 1], interpolated_ts_coords[:, 1])),
        z=np.concatenate((real_ts_coords[:, 2], interpolated_ts_coords[:, 2])),
        color=np.concatenate((np.zeros(len(real_ts_coords)), np.ones(len(interpolated_ts_coords)))),
    )
    # Set the title of the plot as the mean absolute error
    fig.update_layout(title=f"MAE: {mae:.3f} Å")
    fig.write_html(f"plots/hamiltonian_model/interpolated_ts_{idx}_mae_{mae:.3f}.html")

    real_y = data.y
    predicted_y = model(data)
    output_comparison.append([real_y.detach().numpy(), predicted_y.detach().numpy()])

# Plot a histogram of the MAE
fig = px.histogram(x=all_mae_norms, nbins=20)
fig.update_layout(title="Histogram of MAE")
fig.show()

output_comparison = np.array(output_comparison).T
output_comparison = np.squeeze(output_comparison)
# Multiply the energies with the Hartree to eV conversion factor
output_comparison[0] *= ase_units.Hartree
output_comparison[1] *= ase_units.Hartree


AssertionError: Incorrect last dimension for x

In [None]:
# Make a parity plot of the real and predicted energies
fig = px.scatter(
    x=output_comparison[0].flatten(),
    y=output_comparison[1].flatten(),
)
# Draw a parity x=y line
fig.add_shape(
    type="line",
    x0= min(output_comparison[0].min(), output_comparison[1].min()),
    y0= min(output_comparison[0].min(), output_comparison[1].min()),
    x1 = max(output_comparison[0].max(), output_comparison[1].max()),
    y1 = max(output_comparison[0].max(), output_comparison[1].max()),
    line=dict(
        color="Red",
        width=4,
        dash="dashdot",
    )
)
# Make the aspect ratio of the plot equa
fig.update_xaxes(scaleanchor="y", scaleratio=1)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_layout(title="Parity plot of real and predicted energies")
fig.update_xaxes(title_text="Real energies (eV)")
fig.update_yaxes(title_text="Predicted energies (eV)")

fig.show()