# Apply this code to sample TS from model.

In [1]:
# --- Importing and defining some functions ----
import os
import torch
import py3Dmol
import numpy as np

from typing import Optional
from torch import tensor
from e3nn import o3
from torch_scatter import scatter_mean

from oa_reactdiff.model import LEFTNet

default_float = torch.float64
torch.set_default_dtype(default_float)  # Use double precision for more accurate testing


def remove_mean_batch(
    x: tensor, 
    indices: Optional[tensor] = None
) -> tensor:
    """Remove the mean from each batch in x

    Args:
        x (tensor): input tensor.
        indices (Optional[tensor], optional): batch indices. Defaults to None.

    Returns:
        tensor: output tensor with batch mean as 0.
    """
    if indices == None:
         return x - torch.mean(x, dim=0)
    mean = scatter_mean(x, indices, dim=0)
    x = x - mean[indices]
    return x


def draw_in_3dmol(mol: str, fmt: str = "xyz") -> py3Dmol.view:
    """Draw the molecule

    Args:
        mol (str): str content of molecule.
        fmt (str, optional): format. Defaults to "xyz".

    Returns:
        py3Dmol.view: output viewer
    """
    viewer = py3Dmol.view(1024, 576)
    viewer.addModel(mol, fmt)
    viewer.setStyle({'stick': {}, "sphere": {"radius": 0.36}})
    viewer.zoomTo()
    return viewer


def assemble_xyz(z: list, pos: tensor) -> str:
    """Assembling atomic numbers and positions into xyz format

    Args:
        z (list): chemical elements
        pos (tensor): 3D coordinates

    Returns:
        str: xyz string
    """
    natoms =len(z)
    xyz = f"{natoms}\n\n"
    for _z, _pos in zip(z, pos.numpy()):
        xyz += f"{_z}\t" + "\t".join([str(x) for x in _pos]) + "\n"
    return xyz

### Building a LEFTNet model

A simple test is performed to verify SE(3) symmetry. The model here is for testing, so we only need to build a very small model.

Note: [LEFTNet](https://arxiv.org/abs/2304.04757) is a new SOTA-level SE(3) graph neural network. Although we use LEFTNet here, the properties it exhibits are model-independent (other SE(3) models, such as [EGNN](https://arxiv.org/pdf/2102.09844.pdf), will give the same results)

TL: EGNN is not SE$(3)$  equivariant?

In [2]:
num_layers = 2
hidden_channels = 8
in_hidden_channels = 4
num_radial = 4

model =  LEFTNet(
    num_layers=num_layers,
    hidden_channels=hidden_channels,
    in_hidden_channels=in_hidden_channels,
    num_radial=num_radial,
    object_aware=False,
)

sum(p.numel() for p in model.parameters() if p.requires_grad)



7882



### Create an "Object-Aware" LEFTNet

In [3]:
# --- Importing necessary function ---
from torch.utils.data import DataLoader

from oa_reactdiff.trainer.pl_trainer import DDPMModule


from oa_reactdiff.dataset import ProcessedTS1x, ProcessedSCAN
from oa_reactdiff.diffusion._schedule import DiffSchedule, PredefinedNoiseSchedule

from oa_reactdiff.diffusion._normalizer import FEATURE_MAPPING
from oa_reactdiff.analyze.rmsd import batch_rmsd

from oa_reactdiff.utils.sampling_tools import assemble_sample_inputs, write_tmp_xyz

In [4]:
!pwd

/misc/home/guest50/OAReactDiff



### Import the pre-trained model and redefine the schedule.

In [5]:
# TL fix: {
from oa_reactdiff.trainer.pl_trainer import DDPMModule
# } fix. Why didn' this carry over from the previous cell import statement?

device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")
print(device) # TL

tspath = os.path.abspath(os.path.join(os.getcwd(), "oa_reactdiff","trainer"))
print(tspath)
# zenodo_pretrained_ckpt
ddpm_trainer = DDPMModule.load_from_checkpoint(
    checkpoint_path="./trained_models/pretrained-ts1x-diff-zenodo.ckpt",
    #checkpoint_path=os.path.abspath("/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff/leftnet-0-84724de349da/ddpm-epoch=1964-val-totloss=510.42.ckpt"),
    #checkpoint_path=os.path.abspath("/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff/leftnet-0-84724de349da/ddpm-epoch=1999-val-totloss=523.77.ckpt"),
    #checkpoint_path="./pretrained-ts1x-diff.ckpt", # original
    #checkpoint_path=f"{tspath}/checkpoint/OAReactDiff/leftnet-0-f1ff7dc18fa3/ddpm-epoch=1999-val-totloss=509.31.ckpt", # Our recapitulation
    map_location=device,
)
ddpm_trainer = ddpm_trainer.to(device)

cuda
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer




In [6]:
noise_schedule: str = "polynomial_2"
timesteps: int = 150
precision: float = 1e-5

gamma_module = PredefinedNoiseSchedule(
            noise_schedule=noise_schedule,
            timesteps=timesteps,
            precision=precision,
        )
schedule = DiffSchedule(
    gamma_module=gamma_module,
    norm_values=ddpm_trainer.ddpm.norm_values
)
ddpm_trainer.ddpm.schedule = schedule
ddpm_trainer.ddpm.T = timesteps
ddpm_trainer = ddpm_trainer.to(device)


### Prepare dataset and data loader and select a reaction involving multiple molecules

In [7]:
import pickle

npz_path = "./oa_reactdiff/data/transition1x/train.pkl"

train_pkl = pickle.load(open(npz_path, "rb"))

In [8]:
train_pkl["use_ind"][:10]

[0, 1, 2, 3, 4, 5, 6, 8, 9, 10]

In [9]:
train_pkl["reactant"]["rxn"][177]

'rxn0306'

In [10]:
train_pkl["reactant"]['num_atoms'][177]

10

In [11]:
train_pkl["reactant"]['num_atoms'][177]

10

In [11]:
dataset = ProcessedTS1x(
    npz_path=npz_path,
    center=True,
    pad_fragments=0,
    device=device,
    zero_charge=False,
    remove_h=False,
    single_frag_only=False,
    swapping_react_prod=False,
    use_by_ind=True,
)
loader = DataLoader(
    dataset, 
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=dataset.collate_fn
)
itl = iter(loader)
idx = -1 # TL: why?

len(dataset)

9000

In [12]:
num_indices_to_select = 20
SEED = 747
np.random.default_rng(seed=SEED)
np.random.seed(SEED)
random_indices = np.random.choice(len(dataset), size=num_indices_to_select, replace=False)
random_indices = sorted(random_indices)
random_indices

[177,
 501,
 625,
 902,
 2616,
 3130,
 3267,
 3294,
 4171,
 4258,
 4392,
 4424,
 4983,
 5334,
 6474,
 6941,
 7484,
 8388,
 8703,
 8894]

In [13]:
train_indices = [train_pkl["use_ind"][x] for x in random_indices]
train_indices

[202,
 560,
 700,
 1009,
 2924,
 3505,
 3660,
 3691,
 4684,
 4779,
 4931,
 4967,
 5612,
 5999,
 7281,
 7802,
 8399,
 9402,
 9748,
 9952]

In [14]:
# TL visualize each state:
atomic_num2sym = {
    1: 'H',    2: 'He',   3: 'Li',   4: 'Be',   5: 'B',    6: 'C',    7: 'N',    8: 'O',    9: 'F',    10: 'Ne',
    11: 'Na',  12: 'Mg',  13: 'Al',  14: 'Si',  15: 'P',   16: 'S',   17: 'Cl',  18: 'Ar',  19: 'K',   20: 'Ca',
    21: 'Sc',  22: 'Ti',  23: 'V',   24: 'Cr',  25: 'Mn',  26: 'Fe',  27: 'Co',  28: 'Ni',  29: 'Cu',  30: 'Zn',
    31: 'Ga',  32: 'Ge',  33: 'As',  34: 'Se',  35: 'Br',  36: 'Kr',  37: 'Rb',  38: 'Sr',  39: 'Y',   40: 'Zr',
    41: 'Nb',  42: 'Mo',  43: 'Tc',  44: 'Ru',  45: 'Rh',  46: 'Pd',  47: 'Ag',  48: 'Cd',  49: 'In',  50: 'Sn',
    51: 'Sb',  52: 'Te',  53: 'I',   54: 'Xe',  55: 'Cs',  56: 'Ba',  57: 'La',  58: 'Ce',  59: 'Pr',  60: 'Nd',
    61: 'Pm',  62: 'Sm',  63: 'Eu',  64: 'Gd', 65: 'Tb',  66: 'Dy',  67: 'Ho',  68: 'Er',  69: 'Tm',  70: 'Yb',
    71: 'Lu',  72: 'Hf',  73: 'Ta',  74: 'W',   75: 'Re',  76: 'Os',  77: 'Ir',  78: 'Pt',  79: 'Au',  80: 'Hg',
    81: 'Tl',  82: 'Pb',  83: 'Bi',  84: 'Po',  85: 'At',  86: 'Rn',  87: 'Fr',  88: 'Ra',  89: 'Ac',  90: 'Th',
    91: 'Pa',  92: 'U',   93: 'Np',  94: 'Pu',  95: 'Am',  96: 'Cm',  97: 'Bk',  98: 'Cf',  99: 'Es', 100: 'Fm',
    101: 'Md', 102: 'No', 103: 'Lr', 104: 'Rf', 105: 'Db', 106: 'Sg', 107: 'Bh', 108: 'Hs', 109: 'Mt', 110: 'Ds',
    111: 'Rg', 112: 'Cn', 113: 'Nh', 114: 'Fl', 115: 'Mc', 116: 'Lv', 117: 'Ts', 118: 'Og'
}

In [15]:
def xyz_block_from_node_features(xh: torch.tensor, comment: str="", c2a: dict=atomic_num2sym) -> str:
    num_atoms = xh.shape[0]
    xyz_lines = [str(num_atoms), comment]
    for row in xh:
        position = row[:3].cpu().numpy()
        z = c2a[row[-1].long().item()]
        xyz_lines.append(f"{z}\t" + "\t".join([str(x) for x in position]))
    return "\n".join(xyz_lines)

In [16]:
#output_dir = os.path.abspath("results/sample_20_TS-20250611-Tr1x-zenodo_pretrained_ckpt")
output_dir = None

In [17]:
selected_representation_triples = {}
rs_rmsdsx = []
ts_rmsdsx = []
ps_rmsdsx = []

for i in range(len(dataset)): 
    representations, res = next(itl)
    if i in random_indices:
        xyz_blocks = []
        print(i)
        train_idx = train_pkl["use_ind"][i]
        print(train_idx)
        rxn_id = train_pkl["reactant"]["rxn"][train_idx]
        n_samples = representations[0]["size"].size(0)
        fragments_nodes = [
            repre["size"] for repre in representations
        ]
        conditions = torch.tensor([[0] for _ in range(n_samples)], device=device)
        # skipping permutation of indices in reactant state
        xh_fixed = [
            torch.cat(
                [repre[feature_type] for feature_type in FEATURE_MAPPING],
                dim=1,
            )
            for repre in representations
        ]
        print(xh_fixed[2].shape[0])
        print(train_pkl["reactant"]['num_atoms'][train_idx])
        #ground_truth_ts = xh_fixed[1]
        out_samples, out_masks = ddpm_trainer.ddpm.inpaint(
            n_samples=n_samples,
            fragments_nodes=fragments_nodes,
            conditions=conditions,
            return_frames=1,
            resamplings=5,
            jump_length=5,
            timesteps=None,
            xh_fixed=xh_fixed,
            frag_fixed=[0, 2],
        )
        
        # Confirm unchanged reactant state: torch.allclose causing segmentation fault?
        #reactant_diff = torch.allclose(out_samples[0][0] - xh_fixed[0]) #= torch.max(torch.abs(out_samples[0][0] - xh_fixed[0]))
        #print(reactant_diff)
        #assert out_samples[0][0] == xh_fixed[0]
        # Confirm unchanged product state:
        #assert out_samples[0][2] == xh_fixed[2]
        # reactant state (rs_..):
        rs_rmsds = batch_rmsd(
            fragments_nodes, 
            out_samples[0],
            xh_fixed,
            idx=0,
        )
        print(rs_rmsds)
        rs_rmsdsx.append(rs_rmsds[0])

        # transition state (ts_..):
        ts_rmsds = batch_rmsd(
            fragments_nodes, 
            out_samples[0],
            xh_fixed,
            idx=1,
        )
        print(ts_rmsds)
        ts_rmsdsx.append(ts_rmsds[0])
        
        # product state (ps_..):
        ps_rmsds = batch_rmsd(
            fragments_nodes, 
            out_samples[0],
            xh_fixed,
            idx=2,
        )
        print(ps_rmsds)
        ps_rmsdsx.append(ps_rmsds[0])
        #assert len(rmsds) == 1
        print("")

        if output_dir == None:
            continue # skip preparing strings for file output

        # Now wrap up XYZs into an output file with informative comments.
        file_name = f"{rxn_id}.xyz"
        # Reactant state, two versions
        rs_ref_xyz = xyz_block_from_node_features(xh_fixed[0], comment=f"True/calculated reference reactant state.")
        xyz_blocks.append(rs_ref_xyz)
        rs_rec_xyz = xyz_block_from_node_features(out_samples[0][0], comment=f"Reconstructed reactant state. RMSD: {str(round(rs_rmsds[0],6))} Å.")
        xyz_blocks.append(rs_rec_xyz)
        
        # Transition state, two versions
        ts_ref_xyz = xyz_block_from_node_features(xh_fixed[1], comment=f"True/calculated reference transition state.")
        xyz_blocks.append(ts_ref_xyz)
        ts_gen_xyz = xyz_block_from_node_features(out_samples[0][1], comment=f"Generated/inpainted transition state. RMSD: {str(round(ts_rmsds[0],6))} Å.")
        xyz_blocks.append(ts_gen_xyz)

        # Product state, two versions
        ps_ref_xyz = xyz_block_from_node_features(xh_fixed[2], comment=f"True/calculated reference product state.")
        xyz_blocks.append(ps_ref_xyz)
        ps_rec_xyz = xyz_block_from_node_features(out_samples[0][2], comment=f"Reconstructed product state. RMSD: {str(round(ps_rmsds[0],6))} Å.")
        xyz_blocks.append(ps_rec_xyz)

        with open(os.path.join(output_dir, file_name), "w") as f_out:
            f_out.write("\n\n".join(xyz_blocks))
        
        if i == random_indices[-1]: # no need to keep iterating.
            break

177
202
10
10
[0.0027844603802332233]
[0.010383678689219172]
[0.0034456645464990944]

501
560
10
10
[0.004800592802929688]
[0.008966076543586072]
[0.0038860981323944267]

625
700
10
10
[0.0033001688482348045]
[0.013241237793895636]
[0.0038895617165288043]

902
1009
12
12
[0.004022391827383418]
[0.00587251775167008]
[0.004040244613062573]

2616
2924
13
13
[0.0038898312913396857]
[0.014055473134347036]
[0.00386503858588595]

3130
3505
13
13
[0.0038547169104909823]
[0.15032817487660272]
[0.0034979079422697563]

3267
3660
13
13
[0.002978078083873235]
[0.015897363142569863]
[0.004239871766747407]

3294
3691
13
13
[0.0034875060368150367]
[0.019910638847494697]
[0.004123994470618216]

4171
4684
14
14
[0.003089296499216545]
[0.25038845089827844]
[0.003843110593978436]

4258
4779
14
14
[0.0037580238508104578]
[0.02379710488750487]
[0.004163554335264125]

4392
4931
15
15
[0.0034576414326273233]
[0.41491431382881055]
[0.0038971820264111605]

4424
4967
15
15
[0.002939046068775061]
[0.0167431400775

In [18]:
# Report/analysis of RMSDs seen:

rs_mean = round(np.mean(rs_rmsdsx), 6)
rs_std = round(np.std(rs_rmsdsx), 6)
print(f"Reactant state reconstruction RMSD was mean ± std.dev.: \t{rs_mean} \t± {rs_std} Å.")
ts_mean = round(np.mean(ts_rmsdsx), 6)
ts_std = round(np.std(ts_rmsdsx), 6)
print(f"Transition state inpainting RMSD was mean ± std.dev.: \t\t{ts_mean} \t± {ts_std} Å.")
ps_mean = round(np.mean(ps_rmsdsx), 6)
ps_std = round(np.std(ps_rmsdsx), 6)
print(f"Product state reconstruction RMSD was mean ± std.dev.: \t\t{ps_mean} \t± {ps_std} Å.")

Reactant state reconstruction RMSD was mean ± std.dev.: 	0.003581 	± 0.000502 Å.
Transition state inpainting RMSD was mean ± std.dev.: 		0.15342 	± 0.214209 Å.
Product state reconstruction RMSD was mean ± std.dev.: 		0.003756 	± 0.000391 Å.


# Improve the visualization of multiple states together:

In [19]:
from glob import glob
import plotly.express as px

from oa_reactdiff.analyze.rmsd import xyz2pmg, pymatgen_rmsd

from pymatgen.core import Molecule
from collections import OrderedDict


def draw_reaction(react_path: str, idx: int = 0, prefix: str = "gen") -> py3Dmol.view:
    """Draw the {reactants, transition states, products} of the reaction.

    Args:
        react_path (str): path to the reaction.
        idx (int, optional): index for the generated reaction. Defaults to 0.
        prefix (str, optional): prefix for distinguishing true sample and generated structure.
            Defaults to "gen".

    Returns:
        py3Dmol.view: _description_
    """
    with open(f"{react_path}/{prefix}_{idx}_react.xyz", "r") as fo:
        natoms = int(fo.readline()) * 3
    mol = f"{natoms}\n\n"
    for ii, t in enumerate(["react", "ts", "prod"]):
        pmatg_mol = xyz2pmg(f"{react_path}/{prefix}_{idx}_{t}.xyz")
        pmatg_mol_prime = Molecule(
            species=pmatg_mol.atomic_numbers,
            coords=pmatg_mol.cart_coords + 8 * ii,
        )
        mol += "\n".join(pmatg_mol_prime.to(fmt="xyz").split("\n")[2:]) + "\n"
    viewer = py3Dmol.view(1024, 576)
    viewer.addModel(mol, "xyz")
    viewer.setStyle({'stick': {}, "sphere": {"radius": 0.3}})
    viewer.zoomTo()
    return viewer

def draw_reaction_mod(react_path: str, idx: int = 0, prefix: str = "gen") -> py3Dmol.view:
    """Draw the {reactants, transition states, products} of the reaction.

    Args:
        react_path (str): path to the reaction.
        idx (int, optional): index for the generated reaction. Defaults to 0.
        prefix (str, optional): prefix for distinguishing true sample and generated structure.
            Defaults to "gen".

    Returns:
        py3Dmol.view: _description_
    """
    with open(f"{react_path}/{prefix}_{idx}_react.xyz", "r") as fo:
        natoms = int(fo.readline()) * 3
    mol = f"{natoms}\n\n"
    for ii, t in enumerate(["react", "ts", "prod"]):
        pmatg_mol = xyz2pmg(f"{react_path}/{prefix}_{idx}_{t}.xyz")
        pmatg_mol_prime = Molecule(
            species=pmatg_mol.atomic_numbers,
            coords=pmatg_mol.cart_coords + 8 * ii,
        )
        mol += "\n".join(pmatg_mol_prime.to(fmt="xyz").split("\n")[2:]) + "\n"
    viewer = py3Dmol.view(1024, 576)
    viewer.addModel(mol, "xyz")
    viewer.setStyle({'stick': {}, "sphere": {"radius": 0.3}})
    viewer.zoomTo()
    return viewer

In [21]:

ts_ref_xyz = xyz_block_from_node_features(xh_fixed[1], comment=f"True/calculated reference transition state.")
ts_ref_xyz

'19\nTrue/calculated reference transition state.\nC\t1.0337306\t-1.2022932\t-1.2106149\nC\t-0.38546628\t-0.66329134\t-1.0804629\nC\t-1.1488522\t-0.7932403\t0.21562387\nC\t-0.79307747\t-0.21712416\t1.4323623\nC\t0.42886135\t0.57680243\t1.5876906\nC\t0.29236686\t1.6744617\t-0.4098736\nC\t-0.60238266\t0.8148652\t-1.0047852\nH\t1.031492\t-2.293493\t-1.1177677\nH\t1.4485413\t-0.9401047\t-2.190542\nH\t1.7149061\t-0.81102854\t-0.45051065\nH\t-0.994589\t-1.1071033\t-1.8740281\nH\t-2.1521523\t-1.2065308\t0.1466279\nH\t-1.5953531\t0.007382358\t2.134896\nH\t0.090429194\t-0.6706312\t2.2016292\nH\t0.4975539\t1.259573\t2.4311655\nH\t1.3690892\t0.19764867\t1.2128975\nH\t1.3610688\t1.5040045\t-0.50576895\nH\t0.0047180653\t2.7122762\t-0.25444165\nH\t-1.6008847\t1.1578263\t-1.2640975'