In [None]:
num_samples = 13

In [None]:
from pymatgen.core import Molecule
from pymatgen.analysis.molecule_matcher import BruteForceOrderMatcher, GeneticOrderMatcher, HungarianOrderMatcher, KabschMatcher
import numpy as np
from pymatgen.io.xyz import XYZ

def xh2pmg(species, xh):
    mol = Molecule(
        species=species,
        coords=xh[:, :3],
    )
    return mol


def xyz2pmg(xyzfile):
    xyz_converter = XYZ(mol=None)
    mol = xyz_converter.from_file(xyzfile).molecule
    return mol


def rmsd_core(mol1, mol2, threshold=0.5, same_order=False):
    _, count = np.unique(mol1.atomic_numbers, return_counts=True)
    if same_order:
        bfm = KabschMatcher(mol1)
        _, rmsd = bfm.fit(mol2)

        # Raw-centered RMSD (translation removed, no rotation)
        A = np.asarray(mol1.cart_coords, dtype=np.float64)
        B = np.asarray(mol2.cart_coords, dtype=np.float64)
        A0 = A - A.mean(0, keepdims=True)
        B0 = B - B.mean(0, keepdims=True)
        rmsd_raw_centered = float(np.sqrt(((A0 - B0) ** 2).sum(axis=1).mean()))
        if rmsd_raw_centered < rmsd:
            print(mol1.species, mol2.species)
            print(mol1.cart_coords, mol2.cart_coords)
            raise RuntimeError

        return rmsd
    total_permutations = 1
    for c in count:
        total_permutations *= np.math.factorial(c)  # type: ignore
    if total_permutations < 1e4:
        bfm = BruteForceOrderMatcher(mol1)
        _, rmsd = bfm.fit(mol2)
    else:
        bfm = GeneticOrderMatcher(mol1, threshold=threshold)
        pairs = bfm.fit(mol2)
        rmsd = threshold
        for pair in pairs:
            rmsd = min(rmsd, pair[-1])
        if not len(pairs):
            bfm = HungarianOrderMatcher(mol1)
            _, rmsd = bfm.fit(mol2)
    return rmsd


def pymatgen_rmsd(
    species, 
    mol1,
    mol2,
    ignore_chirality: bool = True,
    threshold: float = 0.5,
    same_order: bool = False,
):
    if isinstance(mol1, str):
        mol1 = xyz2pmg(species, mol1)
    if isinstance(mol2, str):
        mol2 = xyz2pmg(species, mol2)
    rmsd = rmsd_core(mol1, mol2, threshold, same_order=same_order)
    if ignore_chirality:
        coords = mol2.cart_coords
        coords[:, -1] = -coords[:, -1]
        mol2_reflect = Molecule(
            species=mol2.species,
            coords=coords,
        )
        rmsd_reflect = rmsd_core(
            mol1, mol2_reflect, threshold, same_order=same_order)
        rmsd = min(rmsd, rmsd_reflect)
    return rmsd


def batch_rmsd_sb(
    species,
    fragments_node,
    pred_xh,
    target_xh,
    threshold: float = 0.5,
    same_order: bool = False,
):

    rmsds = []
    end_ind = np.cumsum(fragments_node)
    start_ind = np.concatenate([np.int64(np.zeros(1)), end_ind[:-1]])
    for start, end in zip(start_ind, end_ind):
        mol1 = xh2pmg(species[start:end], pred_xh[start : end])
        mol2 = xh2pmg(species[start:end], target_xh[start : end])
        rmsd = pymatgen_rmsd(
            species[start:end], 
            mol1,
            mol2,
            ignore_chirality=True,
            threshold=threshold,
            same_order=same_order,
            
        )
        rmsds.append(min(rmsd, 1.0))
    return rmsds

In [None]:
import ase, ase.io


In [None]:
atoms = ase.io.read('rollout_0/gentraj_1.xyz', format='xyz', index=":")[1]
atoms_ref = ase.io.read('rollout_0/reftraj_1.xyz', format='xyz', index=":")[1]
fragments_idx = np.loadtxt("rollout_0/Fragment_idx.dat")[1]
fragments_node = np.unique(fragments_idx, return_counts=True)[1]
rmsd = batch_rmsd_sb(
    atoms.get_chemical_symbols(),
    fragments_node,
    atoms.get_positions(),
    atoms_ref.get_positions(),
)[0]

print(f"RMSD: {rmsd} Angstrom")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
mask = np.where(np.array(atoms_ref.get_positions()).ravel()- np.array(atoms.get_positions()).ravel() != 0)
print(mask)

In [None]:
plt.figure(figsize=(3.5, 3.5))
mask = np.where(np.array(atoms_ref.get_positions()).ravel()- np.array(atoms.get_positions()).ravel() != 0)
plt.scatter(np.array(atoms_ref.get_positions()).ravel(), np.array(atoms.get_positions()).ravel())
plt.plot(plt.xlim(), plt.xlim(), 'k--')
plt.xlabel('Reference positions')
plt.ylabel('Generated positions')


In [None]:
wrong_ref = np.loadtxt("/home/tuoping/odefed_mdgen/TS-GEN-transition_states/edge_e3/experiments/test-psi4/log.nonts")
all_positions = None
all_positions_ref = None
all_energy_atoms = []
all_energy_atoms_ref = []
all_atomic_numbers = None
mean_distances = []
mean_distances_read = []
for i_dir in range(0, num_samples):
    if i_dir in wrong_ref:
        continue
    dirname = f'rollout_{i_dir}'
    sample_positions = None
    sample_energy_atoms = []
    sample_energy_atoms_ref = []
    sample_atomic_numbers = None
    # sample_mean_distances = []
    sample_mean_distances_read = []
    sample_fragments_node = []
    sample_atoms = []

    atoms_ref = ase.io.read(f'{dirname}/reftraj_1.xyz', format='xyz', index=":")[1]
    atomic_numers = atoms_ref.get_atomic_numbers()
    atoms_ref.set_cell(np.eye(3,3)*25)
    # atoms_ref.calc = calculator
    # energy_atoms_ref = atoms_ref.get_potential_energy()/ len(atomic_numers)
    energy_atoms_ref = np.loadtxt(f"{dirname}/all_energy_atoms_ref.dat")[1]*len(atoms_ref.get_atomic_numbers())
    for i_trial in range(30):
        
        # sample_mean_distances_read.append(np.loadtxt(f'{dirname}/mean_distances_{i_trial}.dat'))
        atoms = ase.io.read(f'{dirname}/gentraj_{i_trial}.xyz', format='xyz', index=":")[1]
        atoms.set_cell(np.eye(3,3)*25)
        assert np.all(atoms_ref.get_atomic_numbers() == atoms.get_atomic_numbers())
        sample_atoms.append(atoms)
        # atoms.calc = calculator
        # energy_atoms = atoms.get_potential_energy()/ len(atoms.get_atomic_numbers())
        # try:
        energy_atoms = np.loadtxt(f"{dirname}/all_energy_atoms_{i_trial}.dat")[1]*len(atoms.get_atomic_numbers())
        # except:
        #     print(f"WARNNING:: Skipping {i_trial}")
        #     continue
        sample_energy_atoms.append(energy_atoms)

        fragments_idx = np.loadtxt(f"{dirname}/Fragment_idx.dat")[1]
        fragments_node = np.unique(fragments_idx, return_counts=True)[1]
        sample_fragments_node.append(fragments_node)
        # fragments_node = [len(atoms.get_atomic_numbers())]
        # rmsd = np.mean( batch_rmsd_sb(
        #     atoms.get_chemical_symbols(),
        #     fragments_node,
        #     atoms.get_positions(),
        #     atoms_ref.get_positions(),
        # ))
        # sample_mean_distances.append(rmsd)
        if sample_positions is None:
            sample_positions = list(atoms.get_positions())
        else:
            sample_positions += list(atoms.get_positions())

        del atoms
    idx_min = np.argmin(sample_energy_atoms)
    print(dirname, idx_min)
    all_energy_atoms.append(sample_energy_atoms[idx_min])
    all_energy_atoms_ref.append(energy_atoms_ref)
    rmsd = np.mean( batch_rmsd_sb(
            atoms_ref.get_chemical_symbols(),
            sample_fragments_node[idx_min],
            sample_atoms[idx_min].get_positions(),
            atoms_ref.get_positions(),
        ))
    mean_distances.append(rmsd)
    # mean_distances_read.append(sample_mean_distances_read[idx_min])
    if all_positions is None:
        all_positions = sample_positions[idx_min*len(atomic_numers):(idx_min+1)*len(atomic_numers)]
        all_atomic_numbers = list(atomic_numers)
    else:
        all_positions += (sample_positions[idx_min*len(atomic_numers):(idx_min+1)*len(atomic_numers)])
        all_atomic_numbers += list(atomic_numers)
    if all_positions_ref is None:
        all_positions_ref = list(atoms_ref.get_positions())
    else:
        all_positions_ref += list(atoms_ref.get_positions())
    del atoms_ref

all_positions = np.array(all_positions)
all_positions_ref = np.array(all_positions_ref)



In [None]:
np.save("mean_distances_median.npy", np.array(mean_distances))
# np.save("all_energy_atoms.npy", np.array(all_energy_atoms))
# np.save("all_energy_atoms_ref.npy", np.array(all_energy_atoms_ref))

In [None]:
'''
mean_distances = np.array(mean_distances)
mean_distances_read = np.array(mean_distances_read)
plt.figure(figsize=(3.5, 3.5))
plt.scatter(mean_distances.ravel(), mean_distances_read.ravel())
plt.plot(plt.xlim(), plt.xlim(), 'k--')
plt.xlabel('Reference positions')
plt.ylabel('Generated positions')
'''

In [None]:
plt.figure(figsize=(3.5, 3.5))
plt.scatter(all_energy_atoms_ref, all_energy_atoms)
plt.plot(plt.xlim(), plt.xlim(), 'k--')
plt.xlabel('Reference energies')
plt.ylabel('Generated energies')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Optional styling
# sns.set(style="whitegrid")

plt.figure(figsize=(6,4))
plt.hist(x=np.array(mean_distances).clip(0,1), bins=50)  # clip prevents density < 0 for RMSD
plt.xlabel('RMSD')
plt.ylabel('Density')
plt.title('RMSD Distribution')
# plt.xlim((0.035, 0.085))
plt.tight_layout()
plt.show()

In [None]:
print("MEDIAN ERROR: ", np.median(mean_distances))
print("MEAN ERROR: ", np.mean(mean_distances))
print("MAX ERROR: ", np.max(mean_distances))

In [None]:
eV_2_kcalmol = 23.0605
idx_remove_outliers = np.where(np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref))*eV_2_kcalmol < 50 )
print(len(idx_remove_outliers[0])/len(all_energy_atoms), " outliers removed")
print("Median Energy Difference: ", (np.median(np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref))[idx_remove_outliers]))*eV_2_kcalmol)
print("Mean Energy Difference: ", (np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref))[idx_remove_outliers]).mean()*eV_2_kcalmol)
print("Max Energy Difference: ", (np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref))[idx_remove_outliers]).max()*eV_2_kcalmol)

# print((np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)))

In [None]:
print(np.where(np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)) <= 1.58/eV_2_kcalmol)[0].shape[0])
print(np.array(all_energy_atoms).shape[0])
print(np.where(np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)) <= 1.58/eV_2_kcalmol)[0].shape[0]/np.array(all_energy_atoms).shape[0])

In [None]:
eV_2_kcalmol = 23.0605
plt.figure(figsize=(3.5, 3.5))
plt.scatter(np.array(mean_distances), np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref))*eV_2_kcalmol)
plt.axhline(0, c='k', ls='--')
# plt.ylim((np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)).min(), 0.1)
plt.xlim(0.01,1)
plt.axhline(1.58, ls="--", c="grey")
# plt.axhline(-1.58, ls="--", c="grey")
plt.xscale("log")
plt.yscale('log')
plt.xlabel('r.m.s.d $[\AA]$')
plt.ylabel('$\Delta E\ [kcal/mol]$')

In [None]:


plt.figure(figsize=(6,4))
plt.hist(x=(np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)).clip(0, 1))*eV_2_kcalmol, bins=100)  # clip prevents density < 0 for RMSD
plt.xlabel("$\Delta E_{TS}\ (kcal/mol)$")
plt.ylabel('Density')
# plt.xlim(0,10)
plt.tight_layout()
plt.axvline(1.58, ls='--', c='k')
plt.show()

In [None]:
print(len(np.where((np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)))*eV_2_kcalmol > 1.58)[0]))
print(np.where((np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)))*eV_2_kcalmol > 1.58)[0])
print((np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref))*eV_2_kcalmol)[np.where((np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)))*eV_2_kcalmol > 1.58)[0]])

In [None]:
# xstd-l1
a = [  1,   3,   4,   5,   8,   9,  12,  13,  24,  25,  27,  28,  39,  40,  43,  51,  62,  70,
  71,  73,  83,  88,  89,  93,  98,  99, 101, 102, 104, 105, 107, 114, 119, 120, 121, 123,
 125, 128, 130, 134, 137, 138, 142, 146, 153, 155, 156, 161, 164, 166, 167, 172, 175, 179,
 186, 189, 190, 193, 196, 197, 199, 201, 202, 204, 208, 210, 211, 215, 218, 221, 223,]


b = [  1,   3,   4,   7,  12,  13,  25,  27,  28,  31,  40,  43,  45,  50,  51,  58,  59,  60,
  62,  67,  70,  71,  75,  80,  83,  88,  89,  94, 101, 104, 105, 107, 114, 119, 122, 123,
 127, 128, 130, 133, 134, 135, 136, 137, 138, 142, 143, 145, 151, 155, 156, 157, 161, 162,
 163, 164, 166, 171, 172, 174, 175, 179, 180, 182, 186, 190, 191, 198, 199, 201, 202, 204,
 205, 208, 214, 215, 218, 219, 221, 223,]

a = np.where((np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref)))*eV_2_kcalmol > 1.58)[0]

wrong_ref = np.loadtxt("/home/tuoping/odefed_mdgen/TS-GEN-transition_states/edge_e3/experiments/test-psi4/log.nonts")
print(wrong_ref)

n_total = 0
n_wrongpred = 0
n_wrongref = 0
for x in a:
    n_total += 1
    if x not in wrong_ref:
        n_wrongpred += 1
    else:
        n_wrongref += 1

print(n_wrongref/n_total, n_wrongref)
print('Ratio of wrong predictions: ', n_wrongpred/n_total, n_wrongpred)

In [None]:
plt.figure(figsize=(6,4))
err_e_zRP_sorted = np.sort(np.abs(np.array(all_energy_atoms) - np.array(all_energy_atoms_ref))*eV_2_kcalmol)
err_e_zR_cP_cum = np.arange(1, len(err_e_zRP_sorted) + 1) / len(err_e_zRP_sorted)
plt.plot(err_e_zRP_sorted, err_e_zR_cP_cum)
plt.xlabel("$\Delta E_{TS}\ (kcal/mol)$")
plt.ylabel('Cumulative probability')
plt.xscale("log")
plt.ylim(0, 1)
plt.xlim(0.1,100)
plt.axvline(1.58, ls='--', c='k')
plt.tight_layout()
plt.show()

In [None]:
raise RuntimeError

In [None]:
mask_ts = np.where(np.array(mean_distances) > 0.05)[0]

import os
from ase.optimize import BFGS

if os.path.exists("relax.extxyz"): os.remove("relax.extxyz")
for i_dir in mask_ts:
    dirname = f'rollout_{i_dir}'
    traj_ref = ase.io.read(f'{dirname}/reftraj_1.xyz', format='xyz', index=":")
    for i_atoms in range(3):
        atoms_ref = traj_ref[i_atoms]
        atoms_ref.set_cell(np.eye(3,3)*25)
        center = np.mean(atoms_ref.positions, axis=0)
        atoms_ref.positions += 12.5 - center
        if i_atoms in [0,2]:
            atoms_ref.calc = calculator
            # print(energy_atoms)
            opt = BFGS(atoms, logfile='mace_ase.log')
            opt.run(fmax=0.01)
            energy_atoms = atoms_ref.get_potential_energy()
            atoms_ref.info['energy'] = energy_atoms
            atoms_ref.calc = None
            natoms = len(atoms_ref)
            traj_ref[i_atoms].set_array('energy', np.full(natoms, energy_atoms))
    traj_ref[1].info['energy'] = all_energy_atoms_ref[i_dir]
    natoms = len(traj_ref[1])
    traj_ref[1].set_array('energy', np.full(natoms, energy_atoms))
    print("Barrier = ", traj_ref[1].info['energy']-traj_ref[0].info['energy'])
    ase.io.write("relax.extxyz", traj_ref, append=True, format="extxyz")

In [None]:
from ase import Atoms
from ase.io import read
dirname="rollout_1"
traj_ref = ase.io.read(f'{dirname}/reftraj_1.xyz', format='xyz', index=":")
mol = traj_ref[0]      # ASE Atoms object
# suppose you know atoms 0–1–2 form the angle you want:
theta = mol.get_angle(1, 3, 4, mic=False)
print(f"Angle = {theta:.2f}°")

In [None]:
from ase import Atoms
from ase.io import read
dirname="rollout_2"
traj_ref = ase.io.read(f'{dirname}/reftraj_1.xyz', format='xyz', index=":")
mol = traj_ref[0]      # ASE Atoms object
# suppose you know atoms 0–1–2 form the angle you want:
theta = mol.get_angle(1, 3, 4, mic=False)
print(f"Angle = {theta:.2f}°")

In [None]:
from ase import Atoms
from ase.io import read
dirname="rollout_1"
traj = ase.io.read(f'{dirname}/gentraj_1.xyz', format='xyz', index=":")
mol = traj[0]      # ASE Atoms object
# suppose you know atoms 0–1–2 form the angle you want:
theta = mol.get_angle(1, 3, 4, mic=False)
print(f"Angle = {theta:.2f}°")