In [None]:
import pickle
from oa_reactdiff.dataset import ProcessedTS1x
from pathlib import Path
import numpy as np
import py3Dmol
from oa_reactdiff.analyze.rmsd import xyz2pmg, pymatgen_rmsd
from pymatgen.core import Molecule
from typing import Tuple
from mace.calculators import mace_off
from ase import build
from ase.io import read
from tblite.ase import TBLite

In [None]:
def draw_reaction(react_path: str, idx: int = 0, prefix: str = "gen") -> py3Dmol.view:
    """Draw the {reactants, transition states, products} of the reaction.

    Args:
        react_path (str): path to the reaction.
        idx (int, optional): index for the generated reaction. Defaults to 0.
        prefix (str, optional): prefix for distinguishing true sample and generated structure.
            Defaults to "gen".

    Returns:
        py3Dmol.view: _description_
    """
    with open(f"{react_path}/{prefix}_{idx}_react.xyz", "r") as fo:
        natoms = int(fo.readline()) * 3
    mol = f"{natoms}\n\n"
    for ii, t in enumerate(["react", "ts", "prod"]):
        pmatg_mol = xyz2pmg(f"{react_path}/{prefix}_{idx}_{t}.xyz")
        coords = np.array(pmatg_mol.cart_coords)
        coords[:, 0] += ii * 10
        pmatg_mol_prime = Molecule(
            species=pmatg_mol.atomic_numbers,
            coords=coords,
        )
        mol += "\n".join(pmatg_mol_prime.to(fmt="xyz").split("\n")[2:]) + "\n"
    viewer = py3Dmol.view(2024, 1576)
    viewer.addModel(mol, "xyz")
    viewer.setStyle({'stick': {'radius': 0.20}, "sphere": {"radius": 0.35}})
    viewer.zoomTo()
    return viewer

def draw_true_reaction(molecules: Tuple) -> py3Dmol.view:

    natoms = molecules[0].num_sites * 3
    mol = f"{natoms}\n\n"
    for ii, pmatg_mol in enumerate(molecules):
        coords = np.array(pmatg_mol.cart_coords)
        coords[:, 0] += ii * 10
        pmatg_mol_prime = Molecule(
            species=pmatg_mol.atomic_numbers,
            coords=coords,
        )
        mol += "\n".join(pmatg_mol_prime.to(fmt="xyz").split("\n")[2:]) + "\n"
    viewer = py3Dmol.view(2024, 1576)
    viewer.addModel(mol, "xyz")
    viewer.setStyle({'stick': {'radius': 0.20}, "sphere": {"radius": 0.35}})
    viewer.zoomTo()
    return viewer

def data2pmg(data):
    mol0 = Molecule(
        species=data['charge_0'].squeeze().cpu().numpy(),
        coords=data['pos_0'].cpu().numpy(),
    )

    mol1 = Molecule(
        species=data['charge_1'].squeeze().cpu().numpy(),
        coords=data['pos_1'].cpu().numpy(),
    )

    mol2 = Molecule(
        species=data['charge_2'].squeeze().cpu().numpy(),
        coords=data['pos_2'].cpu().numpy(),
    )
    return mol0, mol1, mol2

In [None]:
val_config = dict(
    datadir="oa_reactdiff/data/transition1x/",
    remove_h=False,
    bz=1,
    num_workers=0,
    clip_grad=True,
    gradient_clip_val=None,
    ema=False,
    ema_decay=0.999,
    swapping_react_prod=False,
    append_frag=False,
    use_by_ind=True,
    reflection=False,
    single_frag_only=False,
    only_ts=False,
    lr_schedule_type=None,
    lr_schedule_config=dict(
        gamma=0.8,
        step_size=100,
    ),  # step
)

# dataset
dataraw =pickle.load(open('oa_reactdiff/data/transition1x/valid_addprop.pkl','rb'))
dataset = ProcessedTS1x(Path("oa_reactdiff/data/transition1x/valid_addprop.pkl"),
                        **val_config)

In [None]:
idx = 0
ori_idx = dataraw['use_ind'][idx]
smi = dataraw['product']['smi'][ori_idx][0]
print(smi)
data = dataset[idx]

In [None]:
print(dataraw['reactant']['wB97x_6-31G(d).energy'][5] - dataraw['transition_state']['wB97x_6-31G(d).energy'][5])
from ase import Atoms
mol1 = Atoms(numbers=dataraw['reactant']['charges'][5], positions=dataraw['reactant']['positions'][5])
mol2 = Atoms(numbers=dataraw['transition_state']['charges'][5], positions=dataraw['transition_state']['positions'][5])
calc = mace_off(model='medium', device='cpu')
mol1.set_calculator(TBLite())
mol2.set_calculator(TBLite())
print(mol1.get_potential_energy() - mol2.get_potential_energy())


In [None]:
atoms_r = read('./fix/iteration_2/gen_2_react.xyz')
atoms_p = read('./fix/iteration_2/gen_12_react.xyz')
calc = mace_off(model="medium", device='cpu')
atoms_r.set_calculator(calc)
atoms_p.set_calculator(calc)
print(atoms_r.get_potential_energy() - atoms_p.get_potential_energy())

In [None]:
# ground_truth
mol = data2pmg(data)
draw_true_reaction(mol)

In [None]:
draw_reaction("./fix/iteration_2", idx=idx, prefix="gen")