In [None]:
from torch_geometric.data import Data
import torch

In [None]:
import numpy as np
import torch
from transition1x import Dataloader
from sklearn.preprocessing import OneHotEncoder
from ase import Atoms
import ase.io
from ase.data import chemical_symbols

# -------------------------
# Kabsch with auto-reflection
# -------------------------
def _rmsd(A, B):
    A = np.asarray(A, dtype=np.float64)
    B = np.asarray(B, dtype=np.float64)
    return np.sqrt(((A - B) ** 2).sum(axis=1).mean())

def kabsch_align_auto_reflect(A, B):
    """
    Align B -> A using the best among:
      - proper rotation (det=+1),
      - reflective alignment (det=-1),
      - raw (no transform) as a safety fallback.
    Returns: (aligned_B, info_dict)
    """
    A = np.asarray(A, dtype=np.float64)
    B = np.asarray(B, dtype=np.float64)

    raw_rmsd = _rmsd(A, B)

    Ac = A - A.mean(0, keepdims=True)
    Bc = B - B.mean(0, keepdims=True)

    H = Bc.T @ Ac
    U, S, Vt = np.linalg.svd(H)

    # proper rotation (enforce det=+1)
    R_no = Vt.T @ U.T
    if np.linalg.det(R_no) < 0:
        Vt_fix = Vt.copy()
        Vt_fix[-1, :] *= -1
        R_no = Vt_fix.T @ U.T
    B_no = (Bc @ R_no) + A.mean(0, keepdims=True)
    rmsd_no = _rmsd(A, B_no)

    # reflective alignment (allow det=-1)
    R_rf = Vt.T @ U.T  # do NOT fix det
    B_rf = (Bc @ R_rf) + A.mean(0, keepdims=True)
    rmsd_rf = _rmsd(A, B_rf)

    # choose best
    candidates = [
        (rmsd_no, "proper",   B_no, R_no, float(np.linalg.det(R_no))),
        (rmsd_rf, "reflect",  B_rf, R_rf, float(np.linalg.det(R_rf))),
        (raw_rmsd, "raw",     B,    np.eye(3), 1.0),
    ]
    best = min(candidates, key=lambda x: x[0])

    info = {
        "choice": best[1],
        "rmsd_best": best[0],
        "rmsd_no_reflect": rmsd_no,
        "rmsd_reflect": rmsd_rf,
        "rmsd_raw": raw_rmsd,
        "det_R_chosen": best[4],
    }
    return best[2].astype(np.float64), info

# -------------------------
# Utilities
# -------------------------
def pad5(Zoh: np.ndarray) -> np.ndarray:
    out = np.zeros((Zoh.shape[0], 5), dtype=np.float32)
    out[:, :Zoh.shape[1]] = Zoh.astype(np.float32)
    return out

def zs_to_symbols(zs: np.ndarray):
    # chemical_symbols[Z] gives symbol; index 0 is 'X' placeholder
    return [chemical_symbols[int(z)] for z in zs]

# -------------------------
# Data loading
# -------------------------
dataloader = Dataloader("data/transition1x.h5", datasplit="test", only_final=True)

# Be robust to unseen atomic numbers
atom_encoder = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
atom_encoder.fit(np.array([1, 6, 7, 8]).reshape(-1, 1))  # extend if your set is larger

idx = 0
dataset = []
f_misalign = open(f"mis-align-test.dat", "w")
f_misalign.write("# rxn_id    d_rmsd_RT    d_rmsd_RP    dcentroid_RT    dcentroid_RP\n")
for molecule in dataloader:
    # --- atomic numbers & positions ---
    Z_r = np.asarray(molecule["reactant"]["atomic_numbers"], dtype=np.int32)
    Z_t = np.asarray(molecule["transition_state"]["atomic_numbers"], dtype=np.int32)
    Z_p = np.asarray(molecule["product"]["atomic_numbers"], dtype=np.int32)

    pos_r_raw = np.asarray(molecule["reactant"]["positions"], dtype=np.float64)
    pos_r_raw = pos_r_raw - pos_r_raw.mean(0)
    pos_t_raw = np.asarray(molecule["transition_state"]["positions"], dtype=np.float64)
    pos_p_raw = np.asarray(molecule["product"]["positions"], dtype=np.float64)

    # sanity checks
    if not (len(Z_r) == len(Z_t) == len(Z_p)):
        raise ValueError("R/TS/P must have the same atom count")

    if not (np.array_equal(Z_r, Z_t) and np.array_equal(Z_r, Z_p)):
        raise ValueError("Atom ordering differs between R/TS/P. Permutation matching required before alignment.")

    # --- align TS and P to Reactant (auto-reflection when helpful) ---
    pos_ts_aligned, info_ts = kabsch_align_auto_reflect(pos_r_raw, pos_t_raw)
    pos_p_aligned,  info_p  = kabsch_align_auto_reflect(pos_r_raw, pos_p_raw)
    pos_r_aligned = pos_r_raw.copy()

    # Optional: inspect decisions
    # print("TS align:", info_ts, "  P align:", info_p)

    # --- features (one-hot, padded to 5 like your original code) ---
    z_r = atom_encoder.transform(Z_r.reshape(-1, 1))
    z_t = atom_encoder.transform(Z_t.reshape(-1, 1))
    z_p = atom_encoder.transform(Z_p.reshape(-1, 1))
    padded_z_r = pad5(z_r)
    padded_z_t = pad5(z_t)
    padded_z_p = pad5(z_p)

    # --- torch Data object ---
    data = Data(
        rxn=molecule['rxn'],

        E_transition_state=torch.tensor(
            molecule["transition_state"]["wB97x_6-31G(d).atomization_energy"],
            dtype=torch.float32
        ),
        E_reactant=torch.tensor(
            molecule["reactant"]["wB97x_6-31G(d).atomization_energy"],
            dtype=torch.float32
        ),
        E_product=torch.tensor(
            molecule["product"]["wB97x_6-31G(d).atomization_energy"],
            dtype=torch.float32
        ),

        pos_transition_state=torch.tensor(pos_ts_aligned, dtype=torch.float32),
        formula_transition_state=molecule["transition_state"]["formula"],
        z_transition_state=torch.tensor(padded_z_t, dtype=torch.float32),

        pos_reactant=torch.tensor(pos_r_aligned, dtype=torch.float32),
        formula_reactant=molecule["reactant"]["formula"],
        z_reactant=torch.tensor(padded_z_r, dtype=torch.float32),

        pos_product=torch.tensor(pos_p_aligned, dtype=torch.float32),
        formula_product=molecule["product"]["formula"],
        z_product=torch.tensor(padded_z_p, dtype=torch.float32),
    )

    dataset.append(data)

    # --- write multi-frame XYZ: R, TS_aligned, P_aligned ---
    # Use per-atom symbols to match length (safer than a formula string)
    symbols = zs_to_symbols(Z_r)
    # fname = f"{molecule['rxn']}-{idx}.xyz"
    # ase.io.write(fname, Atoms(symbols=symbols, positions=pos_r_aligned), format="xyz")
    # ase.io.write(fname, Atoms(symbols=symbols, positions=pos_ts_aligned), format="xyz", append=True)
    # ase.io.write(fname, Atoms(symbols=symbols, positions=pos_p_aligned),  format="xyz", append=True)

    # fname = f"{molecule['rxn']}-{idx}-raw.xyz"
    # ase.io.write(fname, Atoms(symbols=symbols, positions=pos_r_raw), format="xyz")
    # ase.io.write(fname, Atoms(symbols=symbols, positions=pos_t_raw), format="xyz", append=True)
    # ase.io.write(fname, Atoms(symbols=symbols, positions=pos_p_raw),  format="xyz", append=True)

        # info = {
        #     "choice": best[1],
        #     "rmsd_best": best[0],
        #     "rmsd_no_reflect": rmsd_no,
        #     "rmsd_reflect": rmsd_rf,
        #     "rmsd_raw": raw_rmsd,
        #     "det_R_chosen": best[4],
        # }
    assert np.all(pos_r_aligned.mean(0) < 1e-6), "Reactant position not centered"
    f_misalign.write(f"{molecule['rxn'][3:]}    {-info_ts['rmsd_best']+info_ts['rmsd_raw']}    {-info_p['rmsd_best']+info_p['rmsd_raw']}    {np.linalg.norm(pos_r_aligned.mean(0)-pos_ts_aligned.mean(0))}    {np.linalg.norm(pos_r_aligned.mean(0)-pos_p_aligned.mean(0))}\n")
    idx += 1
f_misalign.close()

In [None]:
torch.save(dataset, "data/test.pt")

In [None]:

from mdgen.dataset import EquivariantTransformerDataset_Transition1x
data_dir = "data"
dataset = EquivariantTransformerDataset_Transition1x(data_dirname=data_dir, sim_condition=False, tps_condition=True, num_species=5, stage="test")

tps_masked_dataset = []
for i in range(len(dataset)):
    tps_masked_dataset.append(dataset[i])

torch.save(tps_masked_dataset, f"{data_dir}/tps_masked_test.pt")

In [None]:
E_reactant = [data.E_reactant for data in dataset]
E_product = [data.E_product for data in dataset]
E_transition_state = [data.E_transition_state for data in dataset]

In [None]:
import matplotlib.pyplot as plt
_ = plt.hist(E_reactant)
_ = plt.hist(E_product)
_ = plt.hist(E_transition_state)

In [None]:

_ = plt.hist(np.array(E_transition_state) - np.array(E_reactant))