In [None]:

import numpy as np
from transition1x import Dataloader
from sklearn.preprocessing import OneHotEncoder
from ase.data import chemical_symbols

from pymatgen.core import Molecule
from pymatgen.analysis.molecule_matcher import KabschMatcher

import torch
from torch_geometric.data import Data

# -------------------------
# Small helpers
# -------------------------
def _rmsd(A, B):
    A = np.asarray(A, dtype=np.float64)
    B = np.asarray(B, dtype=np.float64)
    return float(np.sqrt(((A - B) ** 2).sum(axis=1).mean()))

def pad5(Zoh: np.ndarray) -> np.ndarray:
    out = np.zeros((Zoh.shape[0], 5), dtype=np.float32)
    out[:, :Zoh.shape[1]] = Zoh.astype(np.float32)
    return out

def zs_to_symbols(zs: np.ndarray):
    # chemical_symbols[Z] gives symbol; index 0 is 'X'
    return [chemical_symbols[int(z)] for z in zs]

# -------------------------
# Kabsch via pymatgen (with optional reflection)
# -------------------------
def kabsch_align_pmg(
    A_xyz: np.ndarray,
    B_xyz: np.ndarray,
    symbols: list[str],
    allow_reflection: bool = True,
):
    """
    Align B -> A using pymatgen's KabschMatcher (proper rotation only),
    optionally also try a mirrored B and take the better RMSD.

    Returns:
      aligned_B  (np.ndarray, shape [N,3])  -- coords of B aligned to A
      info       (dict)                      -- diagnostics
    """
    A = np.asarray(A_xyz, dtype=np.float64)
    B = np.asarray(B_xyz, dtype=np.float64)

    raw_rmsd = _rmsd(A, B)

    # center both at their centroids (you center R elsewhere; this is just local)
    A0 = A - A.mean(0, keepdims=True)
    B0 = B - B.mean(0, keepdims=True)

    molA = Molecule(symbols, A0)
    molB = Molecule(symbols, B0)

    # proper rotation (det=+1)
    km = KabschMatcher(molA)
    # NOTE: KabschMatcher.fit returns (aligned_mol, rmsd)
    aligned_mol_rot, rmsd_rot = km.fit(molB)
    B_rot = np.array(aligned_mol_rot.cart_coords, dtype=np.float64)

    # optional: reflective case (flip chirality once, e.g., z -> -z, then rotate)
    rmsd_refl = np.inf
    B_best = B_rot
    choice = "rot"
    if allow_reflection:
        B_ref = B0.copy()
        B_ref[:, 2] *= -1.0  # reflect across XY plane (sufficient to flip chirality)
        molB_ref = Molecule(symbols, B_ref)
        aligned_mol_ref, rmsd_ref = KabschMatcher(molA).fit(molB_ref)
        rmsd_refl = float(rmsd_ref)
        if rmsd_refl < rmsd_rot:
            B_best = np.array(aligned_mol_ref.cart_coords, dtype=np.float64)
            choice = "refl"

    # We keep A centered at 0; add A's centroid back if you ever want absolute placement
    info = {
        "choice": choice,
        "rmsd_raw": raw_rmsd,
        "rmsd_rot": float(rmsd_rot),
        "rmsd_refl": float(rmsd_refl),
        "rmsd_best": float(min(rmsd_rot, rmsd_refl)),
    }
    return B_best.astype(np.float64), info

stage="test"
# -------------------------
# Data loading
# -------------------------
dataloader = Dataloader("data/transition1x.h5", datasplit=stage, only_final=True)

# Be robust to unseen atomic numbers
atom_encoder = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
atom_encoder.fit(np.array([1, 6, 7, 8]).reshape(-1, 1))  # extend if needed

idx = 0
dataset = []
f_misalign = open(f"mis-align-{stage}.dat", "w")
f_misalign.write("# rxn_id    d_rmsd_RT    d_rmsd_RP    dcentroid_RT    dcentroid_RP    Displacement_R\n")

ALLOW_REFLECTION = True  # toggle here if you want strictly proper rotations

for molecule in dataloader:
    # --- atomic numbers & positions ---
    Z_r = np.asarray(molecule["reactant"]["atomic_numbers"], dtype=np.int32)
    Z_t = np.asarray(molecule["transition_state"]["atomic_numbers"], dtype=np.int32)
    Z_p = np.asarray(molecule["product"]["atomic_numbers"], dtype=np.int32)

    pos_r_raw = np.asarray(molecule["reactant"]["positions"], dtype=np.float64)
    pos_t_raw = np.asarray(molecule["transition_state"]["positions"], dtype=np.float64)
    pos_p_raw = np.asarray(molecule["product"]["positions"], dtype=np.float64)

    # center the reactant at origin (your convention)
    pos_r_aligned = pos_r_raw - pos_r_raw.mean(0, keepdims=True)

    # sanity checks
    if not (len(Z_r) == len(Z_t) == len(Z_p)):
        raise ValueError("R/TS/P must have the same atom count")

    if not (np.array_equal(Z_r, Z_t) and np.array_equal(Z_r, Z_p)):
        raise ValueError(
            "Atom ordering differs between R/TS/P. "
            "Run a permutation matcher (e.g., BruteForce/Hungarian/Genetic) first."
        )

    symbols = zs_to_symbols(Z_r)

    # --- align TS and P to Reactant using KabschMatcher (with optional reflection) ---
    pos_ts_aligned, info_ts = kabsch_align_pmg(pos_r_aligned, pos_t_raw, symbols, allow_reflection=ALLOW_REFLECTION)
    pos_p_aligned,  info_p  = kabsch_align_pmg(pos_r_aligned, pos_p_raw, symbols, allow_reflection=ALLOW_REFLECTION)

    # --- features (one-hot, padded to 5 like your original code) ---
    z_r = atom_encoder.transform(Z_r.reshape(-1, 1))
    z_t = atom_encoder.transform(Z_t.reshape(-1, 1))
    z_p = atom_encoder.transform(Z_p.reshape(-1, 1))
    padded_z_r = pad5(z_r)
    padded_z_t = pad5(z_t)
    padded_z_p = pad5(z_p)

    # --- torch Data object ---
    data = Data(
        rxn=molecule['rxn'],

        E_transition_state=torch.tensor(
            molecule["transition_state"]["wB97x_6-31G(d).atomization_energy"], dtype=torch.float32
        ),
        E_reactant=torch.tensor(
            molecule["reactant"]["wB97x_6-31G(d).atomization_energy"], dtype=torch.float32
        ),
        E_product=torch.tensor(
            molecule["product"]["wB97x_6-31G(d).atomization_energy"], dtype=torch.float32
        ),

        pos_transition_state=torch.tensor(pos_ts_aligned, dtype=torch.float32),
        formula_transition_state=molecule["transition_state"]["formula"],
        z_transition_state=torch.tensor(padded_z_t, dtype=torch.float32),

        pos_reactant=torch.tensor(pos_r_aligned, dtype=torch.float32),
        formula_reactant=molecule["reactant"]["formula"],
        z_reactant=torch.tensor(padded_z_r, dtype=torch.float32),

        pos_product=torch.tensor(pos_p_aligned, dtype=torch.float32),
        formula_product=molecule["product"]["formula"],
        z_product=torch.tensor(padded_z_p, dtype=torch.float32),
    )

    dataset.append(data)

    # --- optional: write multi-frame XYZ (R, TS_aligned, P_aligned)
    # from ase import Atoms, io
    # io.write(f"{molecule['rxn']}-{idx}.xyz", Atoms(symbols=symbols, positions=pos_r_aligned), format="xyz")
    # io.write(f"{molecule['rxn']}-{idx}.xyz", Atoms(symbols=symbols, positions=pos_ts_aligned), format="xyz", append=True)
    # io.write(f"{molecule['rxn']}-{idx}.xyz", Atoms(symbols=symbols, positions=pos_p_aligned),  format="xyz", append=True)

    # checks + misalignment log (ΔRMSD = raw - best; positive means alignment helped)
    assert np.all(np.abs(pos_r_aligned.mean(0)) < 1e-6), "Reactant not centered as expected"
    dcent_rt = float(np.linalg.norm(pos_r_aligned.mean(0) - pos_ts_aligned.mean(0)))
    dcent_rp = float(np.linalg.norm(pos_r_aligned.mean(0) - pos_p_aligned.mean(0)))
    f_misalign.write(
        f"{molecule['rxn'][3:]}    "
        f"{info_ts['rmsd_raw'] - info_ts['rmsd_best']}    "
        f"{info_p['rmsd_raw']  - info_p['rmsd_best']}    "
        f"{dcent_rt}    {dcent_rp}    {np.linalg.norm(pos_r_aligned.mean(0)-pos_r_raw.mean(0))}\n"
    )
    idx += 1

f_misalign.close()


In [None]:
torch.save(dataset, "data/test.pt")

In [None]:
drmsd = np.loadtxt(f"mis-align-{stage}.dat", skiprows=1)
import matplotlib.pyplot as plt
_ = plt.hist(drmsd[:, 1], bins=50, alpha=0.5, label="R-TS misalign")
_ = plt.hist(drmsd[:, 2], bins=50, alpha=0.5, label="R-P misalign")
plt.xlabel("RMSD (Angstrom)")
plt.ylabel("Count")
plt.legend()
plt.title("Distribution of change of the RMSD\nbetween Reactant and TS/Product\ndue to Kabsch alignment")
plt.savefig("misalign_histogram.png")

In [None]:

_ = plt.hist(drmsd[:, 3], bins=50, alpha=0.5, label="centroid distance of alighed R-TS")
_ = plt.hist(drmsd[:, 4], bins=50, alpha=0.5, label="centroid distance of alighed R-P")
plt.xlabel("RMSD (Angstrom)")
plt.ylabel("Count")
plt.legend()
plt.title("Distribution of change of the RMSD\nbetween Reactant and TS/Product\ndue to Kabsch alignment")
plt.savefig("misalign_histogram.png")

In [None]:
_ = plt.hist(drmsd[:, 5], bins=50, alpha=0.5, label="Displacement_R")
plt.xlabel("RMSD (Angstrom)")
plt.ylabel("Count")
plt.legend()
plt.title("Distribution of change of the RMSD\nbetween Reactant and TS/Product\ndue to Kabsch alignment")
plt.savefig("misalign_histogram.png")

In [4]:
import torch
from mdgen.dataset import EquivariantTransformerDataset_Transition1x
data_dir = "data/Transition1x"
stage="val"
dataset = EquivariantTransformerDataset_Transition1x(data_dirname=data_dir, sim_condition=False, tps_condition=True, num_species=5, stage=stage)

tps_masked_dataset = []
for i in range(len(dataset)):
    tps_masked_dataset.append(dataset[i])

torch.save(tps_masked_dataset, f"{data_dir}/tps_masked_{stage}.pt")

In [None]:
E_reactant = [data.E_reactant for data in dataset]
E_product = [data.E_product for data in dataset]
E_transition_state = [data.E_transition_state for data in dataset]

In [None]:
import matplotlib.pyplot as plt
_ = plt.hist(E_reactant)
_ = plt.hist(E_product)
_ = plt.hist(E_transition_state)

In [None]:

_ = plt.hist(np.array(E_transition_state) - np.array(E_reactant))