In [1]:
from __future__ import annotations
from pymatgen.core import Structure

from zeopp_mace_wf import MofDiscovery
from glob import glob
import json
import subprocess
from atomate2.ase.jobs import AseRelaxMaker, StoreTrajectoryOption, AseResult
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core import Structure, Molecule
from typing import Union, List, Dict
from jobflow import run_locally, Response
from dataclasses import dataclass, field
from jobflow import Maker, job
import json

from mace_mof_opt import MACEMofStaticMaker
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio

from pymatviz import ptable_heatmap, ptable_heatmap_plotly
from pymatviz.enums import Key

from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from concurrent.futures import ProcessPoolExecutor, TimeoutError, as_completed
from scipy.stats import gaussian_kde
from pymatgen.core import Element


In [None]:
with open('/path/to/qmof_database/qmof.json') as f:
    qmof = json.load(f)
import pandas as pd
qmof_props = pd.json_normalize(qmof).set_index('qmof_id')

In [3]:
with open('qmof_mace_mof_results.json', 'r') as f:
    dict_ = json.load(f)
with open('filtered_qmof_structures_mace.json', 'r') as f:
    qmof_struct_completed = json.load(f)

In [4]:
e0s =  {1: -1.11734008, 2: 0.00096759, 3: -0.29754725, 4: -0.01781697, 5: -0.26885011, 6: -1.26173507,
        7: -3.12438806, 8: -1.54838784, 9: -0.51882044, 10: -0.01241601, 11: -0.22883163, 12: -0.00951015, 13: -0.21630193, 
       14: -0.8263903, 15: -1.88816619, 16: -0.89160769, 17: -0.25828273, 18: -0.04925973, 19: -0.22697913, 20: -0.0927795, 
       21: -2.11396364, 22: -2.50054871, 23: -3.70477179, 24: -5.60261985, 25: -5.32541181, 26: -3.52004933, 27: -1.93555024, 
       28: -0.9351969, 29: -0.60025846, 30: -0.1651332, 31: -0.32990651, 32: -0.77971828, 33: -1.68367812, 34: -0.76941032, 35: -0.22213843,
       36: -0.0335879, 37: -0.1881724, 38: -0.06826294, 39: -2.17084228, 40: -2.28579303, 41: -3.13429753, 42: -4.60211419,
       43: -3.45201492, 44: -2.38073513, 45: -1.46855515, 46: -1.4773126, 47: -0.33954585, 48: -0.16843877, 49: -0.35470981,
       50: -0.83642657, 51: -1.41101987, 52: -0.65740879, 53: -0.18964571, 54: -0.00857582, 55: -0.13771876, 56: -0.03457659,
       57: -0.45580806, 58: -1.3309175, 59: -0.29671824, 60: -0.30391193, 61: -0.30898427, 62: -0.25470891, 63: -8.38001538,
       64: -10.38896525, 65: -0.3059505, 66: -0.30676216, 67: -0.30874667, 68: -0.31610927, 69: -0.25190039, 70: -0.06431414,
       71: -0.31997586, 72: -3.52770927, 73: -3.54492209, 74: -4.65658356, 75: -4.70108713, 76: -2.88257209, 77: -1.46779304,
       78: -0.50269936, 79: -0.28801193, 80: -0.12454674, 81: -0.31737194, 82: -0.77644932, 83: -1.32627283, 89: -0.26827152,
       90: -0.90817426, 91: -2.47653193, 92: -4.90438537, 93: -7.63378961, 94: -10.77237713}

e0s = {Element.from_Z(atomic_num).symbol: (energy, energy) for atomic_num, energy in e0s.items()}

In [5]:
models = ["macemp0b_small", "macemp0b_medium", "macemp0b_large", "pbe_energy"] 
dict_diff_models = {"diff_macemp0b_small_macemp0b_large": ["macemp0b_small", "macemp0b_large"]} 
custom_map_title_name = {
    "macemp0b_small": "MACE-MP0b small",
    "macemp0b_medium": "MACE-MP0b medium",
    "macemp0b_large": "MACE-MP0b large",
    "diff_macemp0b_small_macemp0b_large": "MACE-MP0b small vs large"
}

In [None]:
from ptable_heatmap_mace import QMOFAnalyzer
analyzer = QMOFAnalyzer(qmof_struct_completed, dict_, e0s, models=models)
analyzer.calculate_errors()
analyzer.compute_mae()

Periodic Table HeatMap

In [None]:
analyzer.generate_heatmap_data(dict_diff_models=dict_diff_models)
analyzer.plot_heatmap(dict_diff_models=dict_diff_models, map_title_name=custom_map_title_name)

Scatter plot Model and PBE

In [None]:
analyzer.scatter_plot(map_title_name=custom_map_title_name)

Display MOF topology regarding MAE of Energy per atom

In [None]:
df = analyzer.display_dict_atomic()
df2 = df.set_index('qmof_id')

merged_df = df2.join(qmof_props, on="qmof_id")
merged_df = merged_df[
    (merged_df['model'] == "macemp0b_large_atomic") &
    (~merged_df['info.mofid.topology'].isin(["ERROR,UNKNOWN", "MISMATCH"]))
]

plt.figure(figsize=(12, 6))
merged_df.groupby('info.mofid.topology')['mae_energy_per_atom'].mean().sort_values()[::5].plot(kind='bar', color='skyblue')
plt.ylabel('Energy MAE per Atom (eV)')
plt.xlabel('MOF Topology')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

Compute Oxidation states of the highest error MOF energies

In [None]:
outliers, inliers = analyzer.filter_outliers(qmof_props)

In [None]:
from pymatgen.analysis.bond_valence import BVAnalyzer
from collections import Counter
for model in models:
    if model == "pbe_energy":
        continue
    model_atomic = model + "_atomic"
    outliers_model = outliers[outliers["model"] == model_atomic]

    bva = BVAnalyzer()
    metal_oxidations = []

    for index, row in tqdm(outliers_model.iterrows(), total=outliers_model.shape[0], desc="Processing"):
        qmof_id_st = row.name 
        try:
            qmof_st = next((item for item in qmof_struct_completed if item['qmof_id'] == qmof_id_st), None)
            if qmof_st is None:
                print(f"Structure not found for {qmof_id_st}")
                continue

            extract_st = qmof_st["structure"]
            struct = Structure.from_dict(extract_st)

            oxi_list = bva.get_valences(struct)
            struct.add_oxidation_state_by_site(oxi_list)
            metal_info = [
                (site.species_string, oxi_list[i])
                for i, site in enumerate(struct) if site.specie.is_metal
            ]
            metal_oxidations.append((qmof_id_st, metal_info))

        except Exception as exc:
            continue

    unique_metal_oxidations = []
    for qmof_id, metal_info in metal_oxidations:
        unique_info = list(set(metal_info))
        unique_metal_oxidations.append((qmof_id, unique_info))

    all_unique_metal_info = [info for _, metal_info in unique_metal_oxidations for info in metal_info]

    counter = Counter(all_unique_metal_info)

    species_oxidation_pairs_sorted = sorted(
        counter.keys(), key=lambda x: (x[0], x[1])
    )
    counts_sorted = [counter[pair] for pair in species_oxidation_pairs_sorted]

    xtick_labels = [
        f"{sp.rstrip('0123456789+-')} ({ox}{'+' if ox > 0 else '-'})"
        for sp, ox in species_oxidation_pairs_sorted 
    ]

    plt.figure(figsize=(12, 6))
    plt.bar(range(len(counts_sorted)), counts_sorted, tick_label=xtick_labels)
    plt.xticks(rotation=90)
    plt.xlabel("Metal Species and Oxidation State")
    plt.ylabel("Count")
    plt.title(f"Histogram of Metal Species and Oxidation States in MOF Structures {model}")
    plt.xticks(rotation=45, ha='right')
    plt.show()