# Apply this code to sample TS from model.

In [1]:
!export CUDA_VISIBLE_DEVICES=0
!echo $CUDA_VISIBLE_DEVICES




In [2]:
# --- Importing and defining some functions ----
import os
import torch
import py3Dmol
import numpy as np

from typing import Optional
from torch import tensor
from e3nn import o3
from torch_scatter import scatter_mean

from oa_reactdiff.model import LEFTNet

default_float = torch.float64
torch.set_default_dtype(default_float)  # Use double precision for more accurate testing


def remove_mean_batch(
    x: tensor, 
    indices: Optional[tensor] = None
) -> tensor:
    """Remove the mean from each batch in x

    Args:
        x (tensor): input tensor.
        indices (Optional[tensor], optional): batch indices. Defaults to None.

    Returns:
        tensor: output tensor with batch mean as 0.
    """
    if indices == None:
         return x - torch.mean(x, dim=0)
    mean = scatter_mean(x, indices, dim=0)
    x = x - mean[indices]
    return x


def draw_in_3dmol(mol: str, fmt: str = "xyz") -> py3Dmol.view:
    """Draw the molecule

    Args:
        mol (str): str content of molecule.
        fmt (str, optional): format. Defaults to "xyz".

    Returns:
        py3Dmol.view: output viewer
    """
    viewer = py3Dmol.view(1024, 576)
    viewer.addModel(mol, fmt)
    viewer.setStyle({'stick': {}, "sphere": {"radius": 0.36}})
    viewer.zoomTo()
    return viewer


def assemble_xyz(z: list, pos: tensor) -> str:
    """Assembling atomic numbers and positions into xyz format

    Args:
        z (list): chemical elements
        pos (tensor): 3D coordinates

    Returns:
        str: xyz string
    """
    natoms =len(z)
    xyz = f"{natoms}\n\n"
    for _z, _pos in zip(z, pos.numpy()):
        xyz += f"{_z}\t" + "\t".join([str(x) for x in _pos]) + "\n"
    return xyz

### Building a LEFTNet model

A simple test is performed to verify SE(3) symmetry. The model here is for testing, so we only need to build a very small model.

Note: [LEFTNet](https://arxiv.org/abs/2304.04757) is a new SOTA-level SE(3) graph neural network. Although we use LEFTNet here, the properties it exhibits are model-independent (other SE(3) models, such as [EGNN](https://arxiv.org/pdf/2102.09844.pdf), will give the same results)

TL: EGNN is not SE$(3)$  equivariant?

In [3]:
num_layers = 2
hidden_channels = 8
in_hidden_channels = 4
num_radial = 4

model =  LEFTNet(
    num_layers=num_layers,
    hidden_channels=hidden_channels,
    in_hidden_channels=in_hidden_channels,
    num_radial=num_radial,
    object_aware=False,
)

sum(p.numel() for p in model.parameters() if p.requires_grad)



7882



### Create an "Object-Aware" LEFTNet

In [4]:
# --- Importing necessary function ---
from torch.utils.data import DataLoader

from oa_reactdiff.trainer.pl_trainer import DDPMModule


from oa_reactdiff.dataset import ProcessedTS1x, ProcessedSCAN
from oa_reactdiff.diffusion._schedule import DiffSchedule, PredefinedNoiseSchedule

from oa_reactdiff.diffusion._normalizer import FEATURE_MAPPING
from oa_reactdiff.analyze.rmsd import batch_rmsd

from oa_reactdiff.utils.sampling_tools import assemble_sample_inputs, write_tmp_xyz

In [5]:
!pwd

/misc/home/guest50/OAReactDiff



### Import the pre-trained model and redefine the schedule.

In [6]:
# TL fix: {
from oa_reactdiff.trainer.pl_trainer import DDPMModule
# } fix. Why didn' this carry over from the previous cell import statement?

device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")
print(device) # TL
print(device.index)
print(device.type)


tspath = os.path.abspath(os.path.join(os.getcwd(), "oa_reactdiff","trainer"))
print(tspath)
# zenodo_pretrained_ckpt
ddpm_trainer = DDPMModule.load_from_checkpoint(
    checkpoint_path="./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/leftnet-SCAN-6-w_selfoops-lr2.5e-4-rcmconly_passerini-03516f3022c5/ddpm-epoch=1899-val-totloss=736.23.ckpt",    
    #checkpoint_path="./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/leftnet-SCAN-6-w_selfoops-lr2.5e-4-rcmconly_passerini-03516f3022c5/ddpm-epoch=1799-val-totloss=787.66.ckpt", --
    #
    #checkpoint_path="./pretrained-ts1x-diff.ckpt", # original
    #checkpoint_path=f"{tspath}/checkpoint/OAReactDiff/leftnet-0-f1ff7dc18fa3/ddpm-epoch=1999-val-totloss=509.31.ckpt", # Our recapitulation
    map_location=device,
)
ddpm_trainer = ddpm_trainer.to(device)

cuda
None
cuda
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer




In [7]:
noise_schedule: str = "polynomial_2"
timesteps: int = 150
precision: float = 1e-5

gamma_module = PredefinedNoiseSchedule(
            noise_schedule=noise_schedule,
            timesteps=timesteps,
            precision=precision,
        )
schedule = DiffSchedule(
    gamma_module=gamma_module,
    norm_values=ddpm_trainer.ddpm.norm_values
)
ddpm_trainer.ddpm.schedule = schedule
ddpm_trainer.ddpm.T = timesteps
ddpm_trainer = ddpm_trainer.to(device)

In [8]:
def prep_ddpm_trainer(ckpt_path: str, device=device):
    ddpm_trainer = DDPMModule.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        map_location=device,
    )
    ddpm_trainer = ddpm_trainer.to(device)

    noise_schedule: str = "polynomial_2"
    timesteps: int = 150
    precision: float = 1e-5
    
    gamma_module = PredefinedNoiseSchedule(
                noise_schedule=noise_schedule,
                timesteps=timesteps,
                precision=precision,
            )
    schedule = DiffSchedule(
        gamma_module=gamma_module,
        norm_values=ddpm_trainer.ddpm.norm_values
    )
    ddpm_trainer.ddpm.schedule = schedule
    ddpm_trainer.ddpm.T = timesteps
    ddpm_trainer = ddpm_trainer.to(device)
    return ddpm_trainer


### Prepare dataset and data loader and select a reaction involving multiple molecules

In [9]:
import pickle

npz_path = "./oa_reactdiff/data/SCAN-6/train.pkl" # w  self-loops after Angstrom correction

train_pkl = pickle.load(open(npz_path, "rb"))

In [10]:
valid_pkl = pickle.load(open("./oa_reactdiff/data/SCAN-6/valid_addprop.pkl", "rb"))

In [11]:
train_pkl["use_ind"][:10]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [12]:
valid_pkl["use_ind"][:10]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [13]:
# 01/7/2025:
# Something doesn't make sense
# 957  cp /scr/trond/SCAN/SCAN_train_wo_selfloops-rcmconly_passerini_fix.pkl train_addprop.pkl
# 958  cp train_addprop.pkl train.pkl 
# 959  cp /scr/trond/SCAN/SCAN-6_valid_wo_selfloops-rcmconly_passerini_fix.pkl valid_addprop.pkl

train_pkl2 = pickle.load(open("/scr/trond/SCAN/SCAN_train_w_selfloops-rcmconly_passerini_fix.pkl", "rb"))
valid_pkl2 = pickle.load(open("/scr/trond/SCAN/SCAN-6_valid_w_selfloops-rcmconly_passerini_fix.pkl", "rb"))

In [14]:
#assert sorted(list(set(train_pkl2["use_ind"] + valid_pkl2["use_ind"]))) == list(range(len(train_pkl2["transition_state"]["rxn"])))

In [15]:
train_pkl2["use_ind"] == valid_pkl2["use_ind"]

True

In [16]:
train_pkl["reactant"]["rxn"][177]

'ALD1-TS177'

In [17]:
train_pkl["reactant"]['num_atoms'][177]

17

In [18]:
selected_passerini_rcmconly_pts = [13605, 17605, 1206, 7394, 15965, 5435, 17328]
selected_strecker_rcmconly_pts = [3978, 5528, 7950]
selected_wl2_pts = [6029, 784, 11193]

In [19]:
from parse import parse
from pprint import pprint

In [20]:
set(selected_passerini_rcmconly_pts+selected_strecker_rcmconly_pts+selected_wl2_pts)

{784,
 1206,
 3978,
 5435,
 5528,
 6029,
 7394,
 7950,
 11193,
 13605,
 15965,
 17328,
 17605}

In [21]:
prefixes = {'ALD1',
 'EN1',
 'HDF1',
 'WL1',
 'f260_DFG1',
 'rcmconly_passerini',
 'rcmconly_strecker'}

select_prefixes = {'WL1', 'rcmconly_passerini', 'rcmconly_strecker'}
rxn_pattern = "{prefix}-TS{id_num}"

for i, rxn in enumerate(train_pkl["transition_state"]["rxn"]):
    parsed = parse(rxn_pattern, rxn)
    prefix = parsed["prefix"]
    id_num = parsed["id_num"]
    #prefixes.add(prefix)
    if prefix not in select_prefixes: 
        continue
    if int(id_num) in set(selected_passerini_rcmconly_pts+selected_strecker_rcmconly_pts+selected_wl2_pts):
        print(rxn)
        print(i)
        print(i in train_pkl["use_ind"])
        print("")

WL1-TS784
55285
True

WL1-TS1206
55707
True

WL1-TS3978
58479
True

WL1-TS5435
59936
True

WL1-TS5528
60029
True

WL1-TS6029
60530
False

WL1-TS7394
61895
True

WL1-TS7950
62451
True

WL1-TS11193
65694
True

rcmconly_passerini-TS784
73180
True

rcmconly_passerini-TS1206
73602
True

rcmconly_passerini-TS3978
76374
True

rcmconly_passerini-TS5435
77831
True

rcmconly_passerini-TS5528
77924
True

rcmconly_passerini-TS6029
78425
False

rcmconly_passerini-TS7394
79790
True

rcmconly_passerini-TS7950
80346
True

rcmconly_passerini-TS11193
83589
True

rcmconly_passerini-TS13605
86001
True

rcmconly_passerini-TS15965
88361
True

rcmconly_passerini-TS17328
89724
True

rcmconly_passerini-TS17605
90001
True

rcmconly_strecker-TS784
91382
True

rcmconly_strecker-TS1206
91804
True

rcmconly_strecker-TS3978
94576
True

rcmconly_strecker-TS5435
96033
True

rcmconly_strecker-TS5528
96126
True

rcmconly_strecker-TS6029
96627
True

rcmconly_strecker-TS7394
97992
True

rcmconly_strecker-TS7950
98548
True

In [22]:
select_rxns = []

selection = {'WL1': selected_wl2_pts, 'rcmconly_strecker': selected_strecker_rcmconly_pts, 'rcmconly_passerini': selected_passerini_rcmconly_pts}

for prefix, id_list in selection.items():
    for id_num in id_list:
        select_rxns.append(f"{prefix}-TS{id_num}")

print(select_rxns)

['WL1-TS6029', 'WL1-TS784', 'WL1-TS11193', 'rcmconly_strecker-TS3978', 'rcmconly_strecker-TS5528', 'rcmconly_strecker-TS7950', 'rcmconly_passerini-TS13605', 'rcmconly_passerini-TS17605', 'rcmconly_passerini-TS1206', 'rcmconly_passerini-TS7394', 'rcmconly_passerini-TS15965', 'rcmconly_passerini-TS5435', 'rcmconly_passerini-TS17328']


In [23]:
# assert sorted(list(set(train_pkl["use_ind"] + valid_pkl["use_ind"]))) == list(range(len(train_pkl["transition_state"]["rxn"])))

In [24]:
len(train_pkl["transition_state"]["rxn"])

108754

In [25]:
len(valid_pkl["transition_state"]["rxn"])

108754

In [26]:
train_pkl["use_ind"] == valid_pkl["use_ind"]

True

In [27]:
selected_train_indices = []
selected_valid_indices = []
selected_rxns_dict = {}
for i, rxn in enumerate(train_pkl["transition_state"]["rxn"]):
    if rxn in select_rxns:
        print(rxn)
        print(i)
        print(i in train_pkl["use_ind"])
        if i in train_pkl["use_ind"]:
            selected_train_indices.append(train_pkl["use_ind"].index(i))
        elif i in valid_pkl["use_ind"]:
            selected_valid_indices.append(valid_pkl["use_ind"].index(i))
        selected_rxns_dict[i] = rxn
        print("")

WL1-TS784
55285
True

WL1-TS6029
60530
False

WL1-TS11193
65694
True

rcmconly_passerini-TS1206
73602
True

rcmconly_passerini-TS5435
77831
True

rcmconly_passerini-TS7394
79790
True

rcmconly_passerini-TS13605
86001
True

rcmconly_passerini-TS15965
88361
True

rcmconly_passerini-TS17328
89724
True

rcmconly_passerini-TS17605
90001
True

rcmconly_strecker-TS3978
94576
True

rcmconly_strecker-TS5528
96126
True

rcmconly_strecker-TS7950
98548
True



In [28]:
dataset = ProcessedSCAN(
    npz_path=npz_path,
    center=True,
    pad_fragments=0,
    device=device,
    zero_charge=False,
    remove_h=False,
    single_frag_only=False,
    swapping_react_prod=False,
    use_by_ind=True,
)
loader = DataLoader(
    dataset, 
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=dataset.collate_fn
)
itl = iter(loader)
idx = -1 # TL: why?

len(dataset)

97878

In [29]:
num_indices_to_select = 20
SEED = 747
np.random.default_rng(seed=SEED)
np.random.seed(SEED)
random_indices = np.random.choice(len(dataset), size=num_indices_to_select, replace=False)
random_indices = sorted(random_indices)
random_indices

[6854,
 12271,
 13864,
 23946,
 24032,
 25222,
 25256,
 29889,
 37106,
 37380,
 40165,
 51364,
 57459,
 67770,
 74604,
 83451,
 86121,
 90515,
 96706,
 97623]

In [30]:
train_indices = [train_pkl["use_ind"][x] for x in random_indices]
train_indices

[7656,
 13675,
 15472,
 26686,
 26778,
 28110,
 28149,
 33283,
 41300,
 41604,
 44688,
 57086,
 63812,
 75301,
 82867,
 92683,
 95662,
 100586,
 107463,
 108480]

In [31]:
# TL visualize each state:
atomic_num2sym = {
    1: 'H',    2: 'He',   3: 'Li',   4: 'Be',   5: 'B',    6: 'C',    7: 'N',    8: 'O',    9: 'F',    10: 'Ne',
    11: 'Na',  12: 'Mg',  13: 'Al',  14: 'Si',  15: 'P',   16: 'S',   17: 'Cl',  18: 'Ar',  19: 'K',   20: 'Ca',
    21: 'Sc',  22: 'Ti',  23: 'V',   24: 'Cr',  25: 'Mn',  26: 'Fe',  27: 'Co',  28: 'Ni',  29: 'Cu',  30: 'Zn',
    31: 'Ga',  32: 'Ge',  33: 'As',  34: 'Se',  35: 'Br',  36: 'Kr',  37: 'Rb',  38: 'Sr',  39: 'Y',   40: 'Zr',
    41: 'Nb',  42: 'Mo',  43: 'Tc',  44: 'Ru',  45: 'Rh',  46: 'Pd',  47: 'Ag',  48: 'Cd',  49: 'In',  50: 'Sn',
    51: 'Sb',  52: 'Te',  53: 'I',   54: 'Xe',  55: 'Cs',  56: 'Ba',  57: 'La',  58: 'Ce',  59: 'Pr',  60: 'Nd',
    61: 'Pm',  62: 'Sm',  63: 'Eu',  64: 'Gd', 65: 'Tb',  66: 'Dy',  67: 'Ho',  68: 'Er',  69: 'Tm',  70: 'Yb',
    71: 'Lu',  72: 'Hf',  73: 'Ta',  74: 'W',   75: 'Re',  76: 'Os',  77: 'Ir',  78: 'Pt',  79: 'Au',  80: 'Hg',
    81: 'Tl',  82: 'Pb',  83: 'Bi',  84: 'Po',  85: 'At',  86: 'Rn',  87: 'Fr',  88: 'Ra',  89: 'Ac',  90: 'Th',
    91: 'Pa',  92: 'U',   93: 'Np',  94: 'Pu',  95: 'Am',  96: 'Cm',  97: 'Bk',  98: 'Cf',  99: 'Es', 100: 'Fm',
    101: 'Md', 102: 'No', 103: 'Lr', 104: 'Rf', 105: 'Db', 106: 'Sg', 107: 'Bh', 108: 'Hs', 109: 'Mt', 110: 'Ds',
    111: 'Rg', 112: 'Cn', 113: 'Nh', 114: 'Fl', 115: 'Mc', 116: 'Lv', 117: 'Ts', 118: 'Og'
}

In [32]:
def xyz_block_from_node_features(xh: torch.tensor, comment: str="", c2a: dict=atomic_num2sym) -> str:
    num_atoms = xh.shape[0]
    xyz_lines = [str(num_atoms), comment]
    for row in xh:
        position = row[:3].cpu().numpy()
        z = c2a[row[-1].long().item()]
        xyz_lines.append(f"{z}\t" + "\t".join([str(x) for x in position]))
    return "\n".join(xyz_lines)

In [33]:
#!mkdir my_results
!mkdir results/sample_selected_TS-20250703-SCAN-6_7-w_wo_selfloops-lr2_5_lr5_0e-4-ValFix3

mkdir: cannot create directory 'results/sample_selected_TS-20250703-SCAN-6_7-w_wo_selfloops-lr2_5_lr5_0e-4-ValFix3': File exists


In [34]:
!ls -haltr results

total 44K
drwxr-xr-x  2 guest50 users 4.0K Jun 11 17:49 sample_20_TS-20250611-Tr1x-from_scratch_pretrained_ckpt
drwxr-xr-x  2 guest50 users 4.0K Jun 11 18:05 sample_20_TS-20250611-Tr1x-zenodo_pretrained_ckpt
drwxr-xr-x  2 guest50 users 4.0K Jun 11 18:57 sample_20_TS-20250611-SCAN-from_scratch_pretrained_ckpt
drwxr-xr-x  2 guest50 users 4.0K Jun 17 18:13 sample_20_TS-20250616-SCAN-w_selfloops-from_scratch_pretrained_ckpt
drwxr-xr-x  2 guest50 users 4.0K Jun 17 18:35 sample_20_TS-20250616-SCAN-wo_selfloops-from_scratch_pretrained_ckpt
drwxr-xr-x  2 guest50 users 4.0K Jun 25 15:44 sample_20_TS-20250625-SCAN-4-w_selfloops-lr2_5e-4-from_scratch_pretrained_ckpt
drwxr-xr-x  2 guest50 users   10 Jun 26 14:54 sample_selected_TS-20250626-SCAN-4-various_models-from_scratch_pretrained_ckpt
drwxr-xr-x  2 guest50 users   10 Jul  1 14:22 sample_selected_TS-20250701-SCAN-6_7-ValFix2-various_models-from_scratch_pretrained_ckpt
drwxr-xr-x  2 guest50 users 4.0K Jul  1 16:52 sample_selected_TS-20250701-SC

In [35]:
output_dir = os.path.abspath("results/sample_selected_TS-20250703-SCAN-6_7-w_wo_selfloops-lr2_5_lr5_0e-4-ValFix3")
#output_dir = None

In [36]:
!ls -haltr oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/

total 80K
drwxr-xr-x  4 guest50 users   61 Jun  4 21:37 ..
drwxr-xr-x  2 guest50 users   32 Jun  4 21:37 leftnet-SCAN-0-1b5f166db70f
drwxr-xr-x  2 guest50 users   32 Jun  4 21:41 leftnet-SCAN-0-f07862fcb99f
drwxr-xr-x  2 guest50 users   32 Jun  4 23:12 leftnet-SCAN-0-f851475eba02
drwxr-xr-x  2 guest50 users   32 Jun  4 23:20 leftnet-SCAN-0-c045e295fe38
drwxr-xr-x  2 guest50 users   32 Jun  4 23:27 leftnet-SCAN-0-f3fa37995f70
drwxr-xr-x  2 guest50 users   32 Jun  4 23:48 leftnet-SCAN-0-174f699cc8c1
drwxr-xr-x  2 guest50 users   32 Jun  4 23:51 leftnet-SCAN-0-2c3eab8f65b3
drwxr-xr-x  2 guest50 users   32 Jun  5 00:10 leftnet-SCAN-0-51a26ab46b18
drwxr-xr-x  2 guest50 users   32 Jun  5 14:49 leftnet-SCAN-0-0b401e663721
drwxr-xr-x  2 guest50 users   32 Jun  7 16:07 leftnet-SCAN-0-37a0a175a522
drwxr-xr-x  2 guest50 users   32 Jun  7 16:30 leftnet-SCAN-0-5579cb2cd21c
drwxr-xr-x  2 guest50 users  287 Jun  7 16:42 leftnet-SCAN-0-e1b24c127367
drwxr-xr-x  2 guest50 users 4.0K Jun  9 17:18 leftnet

In [37]:
!ls -haltr ./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/6-*

./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/6-w_selfoops-lr5e-4-rcmconly_passerini-SCAN-leftnetee6a8c83b353:
total 1.6G
-rw-r--r--  1 guest50 users  29K Jun 28 00:42 leftnet.py
-rw-r--r--  1 guest50 users 163M Jun 28 18:59 ddpm-epoch=774-val-totloss=704.41.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 04:35 ddpm-epoch=1180-val-totloss=708.74.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 06:52 ddpm-epoch=1278-val-totloss=711.23.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 10:17 ddpm-epoch=1421-val-totloss=707.90.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 12:36 ddpm-epoch=1519-val-totloss=702.43.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 18:19 ddpm-epoch=1762-val-totloss=706.70.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 18:33 ddpm-epoch=1771-val-totloss=710.16.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 18:39 ddpm-epoch=1777-val-totloss=710.10.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 21:02 ddpm-epoch=1878-val-totloss=710.04.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 22:

In [38]:
!ls -haltr ./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/7-*

./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/7-wo_selfoops-lr2.5e-4-rcmconly_passerini-SCAN-leftnet17383d5f6686:
total 1.6G
-rw-r--r--  1 guest50 users  29K Jun 28 00:44 leftnet.py
-rw-r--r--  1 guest50 users 163M Jun 28 19:31 ddpm-epoch=774-val-totloss=704.94.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 11:11 ddpm-epoch=1421-val-totloss=701.98.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 13:34 ddpm-epoch=1519-val-totloss=707.98.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 15:34 ddpm-epoch=1602-val-totloss=710.39.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 17:50 ddpm-epoch=1697-val-totloss=687.96.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 17:52 ddpm-epoch=1698-val-totloss=703.13.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 20:44 ddpm-epoch=1817-val-totloss=708.96.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 21:53 ddpm-epoch=1864-val-totloss=711.26.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 23:19 ddpm-epoch=1923-val-totloss=704.87.ckpt
-rw-r--r--  1 guest50 users 163M Jun 29 

In [39]:
!ls -haltr ./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8w*
#8wo-lr5e-4-ValFix3-SCAN-leftnet3f50bcab0e5f
#8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8
#8w-lr5e-4-ValFix3-SCAN-leftnetefadd8bd15e7
#8w-lr2.5e-4-ValFix3-SCAN-leftnet3ccd31ef9ecd

./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8:
total 1.6G
-rw-r--r--  1 guest50 users  29K Jul  1 15:33 leftnet.py
-rw-r--r--  1 guest50 users 163M Jul  2 10:33 ddpm-epoch=774-val-totloss=764.85.ckpt
drwxr-xr-x 36 guest50 users 4.0K Jul  2 19:03 ..
-rw-r--r--  1 guest50 users 163M Jul  2 20:11 ddpm-epoch=1168-val-totloss=755.50.ckpt
-rw-r--r--  1 guest50 users 163M Jul  2 20:32 ddpm-epoch=1180-val-totloss=771.08.ckpt
-rw-r--r--  1 guest50 users 163M Jul  2 22:56 ddpm-epoch=1278-val-totloss=764.75.ckpt
-rw-r--r--  1 guest50 users 163M Jul  3 02:30 ddpm-epoch=1421-val-totloss=756.71.ckpt
-rw-r--r--  1 guest50 users 163M Jul  3 04:55 ddpm-epoch=1519-val-totloss=768.78.ckpt
-rw-r--r--  1 guest50 users 163M Jul  3 08:51 ddpm-epoch=1679-val-totloss=762.24.ckpt
-rw-r--r--  1 guest50 users 163M Jul  3 09:15 ddpm-epoch=1697-val-totloss=747.86.ckpt
-rw-r--r--  1 guest50 users 163M Jul  3 10:49 ddpm-epoch=1760-val-totloss=767.25.ckpt
-rw-r--r--  

In [40]:
#checkpoints_to_try = ["./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/leftnet-SCAN-6-w_selfoops-lr2.5e-4-rcmconly_passerini-03516f3022c5/ddpm-epoch=1899-val-totloss=736.23.ckpt"]

In [41]:
from collections import defaultdict
per_checkpoint_ts_rmsds = defaultdict(list)

In [42]:
from parse import parse

In [43]:
import os
basedir = os.path.abspath("oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN")
run_names = ["8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8", 
             "8wo-lr5e-4-ValFix3-SCAN-leftnet3f50bcab0e5f", 
             "8w-lr2.5e-4-ValFix3-SCAN-leftnet3ccd31ef9ecd", 
             "8w-lr5e-4-ValFix3-SCAN-leftnetefadd8bd15e7"]
selected_checkpoint_paths = {}
for run_name in run_names:
    print(run_name)
    runpath = os.path.join(basedir, run_name)
    parsed = parse("8{self_loops}-{lr}e-4-{etc}", run_name)
    self_loops = parsed["self_loops"]+"_self_loops"
    lr = parsed["lr"]+"e-4"
 
    # Let's select only the latest saved checkpoint and those among the top-k older checkpoints that have better val-totloss.
    checkpoints = []
    for file in os.listdir(runpath):
         if file.endswith(".ckpt"):
             ckpt_path = os.path.join(runpath, file)
             mod_time = os.path.getmtime(ckpt_path)
             checkpoints.append(tuple([file, mod_time]))
    checkpoints.sort(key=lambda x: x[1], reverse=True) # Most recent first.
    
    selected_checkpoints = []
    seen_val_totloss = set()
    for ckpt, _ in checkpoints:
        parsed2 = parse("ddpm-epoch={epoch}-val-totloss={val-totloss}.ckpt", ckpt)
        val_totloss = float(parsed2["val-totloss"])
        epoch = int(parsed2["epoch"])
        if all(val_totloss < value for value in seen_val_totloss):
            seen_val_totloss.add(val_totloss)
            selected_checkpoints.append(ckpt)
            ckpt_path = os.path.join(runpath, ckpt)
            selected_checkpoint_paths[ckpt_path] = {"self_loops":self_loops, "lr":lr, "epoch":epoch, "val-totloss":val_totloss}
    print(list(reversed(selected_checkpoints)))
    print(len(selected_checkpoints))
    print("")

8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8
['ddpm-epoch=1697-val-totloss=747.86.ckpt', 'ddpm-epoch=1814-val-totloss=760.52.ckpt']
2

8wo-lr5e-4-ValFix3-SCAN-leftnet3f50bcab0e5f
['ddpm-epoch=1293-val-totloss=753.83.ckpt', 'ddpm-epoch=1380-val-totloss=760.11.ckpt', 'ddpm-epoch=1749-val-totloss=762.73.ckpt', 'ddpm-epoch=1802-val-totloss=768.32.ckpt', 'ddpm-epoch=1895-val-totloss=771.19.ckpt']
5

8w-lr2.5e-4-ValFix3-SCAN-leftnet3ccd31ef9ecd
['ddpm-epoch=1872-val-totloss=644.14.ckpt', 'ddpm-epoch=1951-val-totloss=649.85.ckpt', 'ddpm-epoch=1963-val-totloss=655.64.ckpt']
3

8w-lr5e-4-ValFix3-SCAN-leftnetefadd8bd15e7
['ddpm-epoch=1697-val-totloss=655.96.ckpt', 'ddpm-epoch=1698-val-totloss=663.60.ckpt', 'ddpm-epoch=1923-val-totloss=663.67.ckpt']
3



In [44]:
selected_representation_triples = {}
rs_rmsdsx = []
ts_rmsdsx = []
ps_rmsdsx = []

selected_indices = selected_train_indices

for i in range(len(dataset)): 
    representations, res = next(itl)
    if i in selected_indices: #random_indices:
        xyz_blocks = []
        print(i)
        train_idx = train_pkl["use_ind"][i]
        print(train_idx)
        rxn_id = train_pkl["reactant"]["rxn"][train_idx]
        print(rxn_id)
        n_samples = representations[0]["size"].size(0)
        fragments_nodes = [
            repre["size"] for repre in representations
        ]
        conditions = torch.tensor([[0] for _ in range(n_samples)], device=device)
        # skipping permutation of indices in reactant state
        xh_fixed = [
            torch.cat(
                [repre[feature_type] for feature_type in FEATURE_MAPPING],
                dim=1,
            )
            for repre in representations
        ]
        #print(xh_fixed[2].shape[0])
        #print(train_pkl["reactant"]['num_atoms'][train_idx])
        assert xh_fixed[2].shape[0] == train_pkl["reactant"]['num_atoms'][train_idx]
        #ground_truth_ts = xh_fixed[1]

        if output_dir is not None:            
            # Now wrap up XYZs into an output file with informative comments.
            file_name = f"{rxn_id}.xyz"
            # Reactant state, two versions
            rs_ref_xyz = xyz_block_from_node_features(xh_fixed[0], comment=f"True/calculated reference reactant state.")
            xyz_blocks.append(rs_ref_xyz)
            #rs_rec_xyz = xyz_block_from_node_features(out_samples[0][0], comment=f"Reconstructed reactant state. RMSD: {str(round(rs_rmsds[0],6))} Å.")
            #xyz_blocks.append(rs_rec_xyz)
            
            # Transition state, two versions
            ts_ref_xyz = xyz_block_from_node_features(xh_fixed[1], comment=f"True/calculated reference transition state.")
            xyz_blocks.append(ts_ref_xyz)
            #ts_gen_xyz = xyz_block_from_node_features(out_samples[0][1], comment=f"Generated/inpainted transition state. RMSD: {str(round(ts_rmsds[0],6))} Å.")
            #xyz_blocks.append(ts_gen_xyz)
    
            # Product state, two versions
            ps_ref_xyz = xyz_block_from_node_features(xh_fixed[2], comment=f"True/calculated reference product state.")
            xyz_blocks.append(ps_ref_xyz)
            #ps_rec_xyz = xyz_block_from_node_feattotures(out_samples[0][2], comment=f"Reconstructed product state. RMSD: {str(round(ps_rmsds[0],6))} Å.")
            #xyz_blocks.append(ps_rec_xyz)

        for checkpoint_path, ckpt_vals in selected_checkpoint_paths.items():
            checkpoint_name = f"{os.path.dirname(checkpoint_path)}/{os.path.basename(checkpoint_path)}" 
            ddpm_trainer = prep_ddpm_trainer(checkpoint_path, device=device)
            out_samples, out_masks = ddpm_trainer.ddpm.inpaint(
                n_samples=n_samples,
                fragments_nodes=fragments_nodes,
                conditions=conditions,
                return_frames=1,
                resamplings=5,
                jump_length=5,
                timesteps=None,
                xh_fixed=xh_fixed,
                frag_fixed=[0, 2],
            )
            
            # # reactant state (ts_..):
            # rs_rmsds = batch_rmsd(
            #     fragments_nodes, 
            #     out_samples[0],
            #     xh_fixed,
            #     idx=0,
            # )
            # print(rs_rmsds)
            # rs_rmsdsx.append(rs_rmsds[0])
    
            # transition state (ts_..):
            ts_rmsds = batch_rmsd(
                fragments_nodes, 
                out_samples[0],
                xh_fixed,
                idx=1,
            )
            print(f"{checkpoint_name}: {ts_rmsds[0]}")
            per_checkpoint_ts_rmsds[checkpoint_name].append(ts_rmsds[0])

            if output_dir is not None:
                ts_gen_xyz = xyz_block_from_node_features(out_samples[0][1], comment=f"Generated/inpainted transition state. RMSD: {str(round(ts_rmsds[0],6))} Å. By model {checkpoint_name}.")
                xyz_blocks.append(ts_gen_xyz)
        
            # # product state (ps_..):
            # ps_rmsds = batch_rmsd(
            #     fragments_nodes, 
            #     out_samples[0],
            #     xh_fixed,
            #     idx=2,
            # )
            # print(ps_rmsds)
            # ps_rmsdsx.append(ps_rmsds[0])
            #assert len(rmsds) == 1
        print("")

        if output_dir is not None:
            with open(os.path.join(output_dir, file_name), "w") as f_out:
                f_out.write("\n".join(xyz_blocks))
            
        if i == selected_indices[-1]: # no need to keep iterating.
            break

49742
55285
WL1-TS784
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8/ddpm-epoch=1814-val-totloss=760.52.ckpt: 0.9123466581722263
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8/ddpm-epoch=1697-val-totloss=747.86.ckpt: 0.926495622053253
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr5e-4-ValFix3-SCAN-leftnet3f50bcab0e5f/ddpm-epoch=1895-val-totloss=771.19.ckpt: 0.9877012990558813
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr5e-4-ValFix3-SCAN-leftnet3f50bcab0e5f/ddpm-epoch=1802-val-totloss=768.32.ckpt: 0.6436422659688241
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr5e-4-ValFix3-SCAN-leftnet3f50bcab0e5f/ddpm-epoch=1749-val-totloss=762.73.ckpt: 0.7008328307908044
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/che

In [45]:
1-3/13

0.7692307692307692

In [46]:
#select_rxns
pprint(dict(sorted(selected_rxns_dict.items())))

{55285: 'WL1-TS784',
 60530: 'WL1-TS6029',
 65694: 'WL1-TS11193',
 73602: 'rcmconly_passerini-TS1206',
 77831: 'rcmconly_passerini-TS5435',
 79790: 'rcmconly_passerini-TS7394',
 86001: 'rcmconly_passerini-TS13605',
 88361: 'rcmconly_passerini-TS15965',
 89724: 'rcmconly_passerini-TS17328',
 90001: 'rcmconly_passerini-TS17605',
 94576: 'rcmconly_strecker-TS3978',
 96126: 'rcmconly_strecker-TS5528',
 98548: 'rcmconly_strecker-TS7950'}


In [47]:
# Report/analysis of RMSDs seen:

#rs_mean = round(np.mean(rs_rmsdsx), 6)
#rs_std = round(np.std(rs_rmsdsx), 6)
#print(f"Reactant state reconstruction RMSD was mean ± std.dev.: \t{rs_mean} \t± {rs_std} Å.")
ts_mean = round(np.mean(ts_rmsds), 6)
ts_std = round(np.std(ts_rmsds), 6)
print(f"Transition state inpainting RMSD was mean ± std.dev.: \t\t{ts_mean} \t± {ts_std} Å.")
#ps_mean = round(np.mean(ps_rmsdsx), 6)
#ps_std = round(np.std(ps_rmsdsx), 6)
#print(f"Product state reconstruction RMSD was mean ± std.dev.: \t\t{ps_mean} \t± {ps_std} Å.")

Transition state inpainting RMSD was mean ± std.dev.: 		0.356407 	± 0.0 Å.


Tried checkpoints, fraction of RMSD==1.0, and mean RMSD:

`checkpoint_path="./oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/leftnet-SCAN-6-w_selfoops-lr2.5e-4-rcmconly_passerini-03516f3022c5/...`

`ddpm-epoch=999-val-totloss=770.24.ckpt`: ...    

`ddpm-epoch=1799-val-totloss=787.66.ckpt`: 12/20, 0.819974 	± 0.273995 Å.


In [49]:
for checkpoint_name, ts_rmsds in per_checkpoint_ts_rmsds.items():
    print(checkpoint_name)
    ts_mean = round(np.mean(ts_rmsds), 6)
    ts_std = round(np.std(ts_rmsds), 6)
    failures = len([x for x in ts_rmsds if x == 1.0])
    print(checkpoint_name)
    print(f"    Transition state inpainting RMSD was mean ± std.dev.: \t\t{ts_mean} \t± {ts_std} Å.")
    print(f"    Number of 1.0 RMSD indicating complete failure: {failures} out of {len(ts_rmsds)}.")
    print("")

/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8/ddpm-epoch=1814-val-totloss=760.52.ckpt
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8/ddpm-epoch=1814-val-totloss=760.52.ckpt
    Transition state inpainting RMSD was mean ± std.dev.: 		0.652573 	± 0.330465 Å.
    Number of 1.0 RMSD indicating complete failure: 4 out of 12.

/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8/ddpm-epoch=1697-val-totloss=747.86.ckpt
/misc/home/guest50/OAReactDiff/oa_reactdiff/trainer/checkpoint/OAReactDiff-SCAN/8wo-lr2.5e-4-ValFix3-SCAN-leftnetae2ad45e2dd8/ddpm-epoch=1697-val-totloss=747.86.ckpt
    Transition state inpainting RMSD was mean ± std.dev.: 		0.711723 	± 0.259058 Å.
    Number of 1.0 RMSD indicating complete failure: 3 out of 12.

/misc/home/guest50/OAReactDiff/oa_re

In [52]:
output_dir

'/misc/home/guest50/OAReactDiff/results/sample_selected_TS-20250703-SCAN-6_7-w_wo_selfloops-lr2_5_lr5_0e-4-ValFix3'