In [None]:
from argparse import Namespace
args = Namespace(
    sim_ckpt="workdir/TPS_Transition1x-rcut12-path_linearOT/PotentialEnergy/epoch=198-step=1894990.ckpt",
    data_dir="data/RGD1/",
    suffix="",
    num_rollouts=2,
    # out_dir="./test/Transition1x-rcut12-path_linearOT/",
    out_dir="./test/Transition1x-rcut12-path_linearOT/RGD1_linked_reactions/cyclobutane_cyclization/",
    num_frames=1,
    localmask=False,
    tps_condition=True,
    sim_condition=False
    )
device = "cuda"

In [None]:
import os, torch, tqdm, time
import numpy as np
from mdgen.equivariant_wrapper import EquivariantMDGenWrapper

In [None]:
os.makedirs(args.out_dir, exist_ok=True)

In [None]:

from mdgen.dataset import EquivariantTransformerDataset_Transition1x
dataset = EquivariantTransformerDataset_Transition1x(data_dirname=args.data_dir, sim_condition=args.sim_condition, tps_condition=args.tps_condition, num_species=5, stage="cyclobutane_cyclization")


In [None]:
print(len(dataset))


In [None]:
ckpt = torch.load(args.sim_ckpt, weights_only=False)
hparams = ckpt["hyper_parameters"]
hparams['args'].guided = False
# hparams['args'].sampling_method = 'euler'
# hparams['args'].guidance_pref = 2
hparams['args'].inference_steps = 50
model = EquivariantMDGenWrapper(**hparams)
print(model.model)
model.load_state_dict(ckpt["state_dict"], strict=False)
model.eval().to(device)

In [None]:
print(ckpt["hyper_parameters"])
print(len(dataset))

In [None]:
embed_dim = ckpt["hyper_parameters"]['args'].embed_dim

In [None]:
print(ckpt["hyper_parameters"]['args'].num_heads)

In [None]:
batch_size = 1
val_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
)
sample_batch = next(iter(val_loader))


In [None]:
print(sample_batch.keys())

In [None]:
print(dataset[499]["x"].shape)

## Test generative model

In [None]:
for key in ['species', 'x', 'cell', 'num_atoms', 'mask', 'v_mask', 'species_next', 'x_next', "TKS_mask", "TKS_v_mask"]:
    try:
        sample_batch[key] = sample_batch[key].to(device)
    except:
        print(f"{key} not found")


pred_pos = model.inference(sample_batch)

'''
model.stage = "inference"
prep = model.prep_batch(sample_batch)
B,T,L,_ = prep["latents"].shape
t = torch.ones((B,), device=prep["latents"].device)
print(model.potential_model(prep['latents'], t, **prep['model_kwargs']).sum(dim=2).squeeze(-1)[:,1])
'''


In [None]:
@torch.no_grad()
def rollout(model, batch):
    expanded_batch = batch
    
    positions, _ = model.inference(expanded_batch)

    # mask_act_space = (batch["mask"] != 0)
    # positions = positions*mask_act_space
    new_batch = {**batch}
    new_batch['x'] = positions
    return positions, new_batch


map_to_chemical_symbol = {
    0: "H",
    1: 'C',
    2: "N",
    3: "O"

}

In [None]:
print(len(dataset))

In [None]:
# idx_rollouts = np.random.choice(len(dataset), size=1, replace=False)
idx_rollouts = np.arange(len(dataset))
print(idx_rollouts)

In [None]:
from ase import Atoms
from ase.geometry.geometry import get_distances
import shutil, os
from ase.io import write

all_rollout_atoms_ref_0 = []
all_rollout_atoms = []
all_rollout_atoms_ref = []
start = time.time()

# idx_rollouts = np.arange(643, len(dataset))
for i_rollout, idx in enumerate(idx_rollouts):
# idx = idx_rollouts[0]
# for i_rollout in range(args.num_rollouts):
    item = dataset.__getitem__(idx)
    batch = next(iter(torch.utils.data.DataLoader([item])))

    for key in ['species', 'x', 'cell', 'num_atoms', 'mask', 'v_mask', 'species_next', 'x_next', "TKS_mask", "TKS_v_mask"]:
        try:
            batch[key] = batch[key].to(device)
        except:
            print(f"{key} not found")

    labels = torch.argmax(batch["species"], dim=3).squeeze(0)
    symbols = [[map_to_chemical_symbol[int(i_elem.to('cpu'))] for i_elem in labels[i_conf]] for i_conf in range(len(labels))]

    pred_pos, _ = rollout(model, batch)
    # print("idx = ", idx, "rollout", i_rollout, pred_pos.shape)

    all_atoms = []
    all_atoms_ref = []
    all_atoms_ref_0 = []
    for t in range(len(pred_pos[0])):
        print("rollout", i_rollout, "idx = ", idx, "t", t)
        formula = "".join(symbols[t])

        # print("t=",t)
        # for i in range(pred_pos.shape[2]):
        #     err = get_distances(batch["x_next"][0][t][i].cpu().numpy(), (pred_pos[0][t].cpu().numpy()[i]), cell=batch['cell'][0][0].cpu().numpy(), pbc=True)[1][0][0]
        #     if err>0.1:
        #         print(pred_pos[0][t].cpu().numpy()[i], batch["x_next"][0][t][i].cpu().numpy(), err, err>0.1)
        atoms = Atoms(formula, positions=pred_pos[0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        # atoms.set_chemical_symbols(symbols[t])
        all_atoms.append(atoms)
        if args.sim_condition:
            atoms_ref_0 = Atoms(formula, positions=batch["x"][0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
            atoms_ref = Atoms(formula, positions=batch["x_next"][0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        else:
            atoms_ref = Atoms(formula, positions=batch["x"][0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        all_atoms_ref.append(atoms_ref)
        if args.sim_condition:
            all_atoms_ref_0.append(atoms_ref_0)
        if args.tps_condition:
            if t == 1:
                err = pred_pos[0][t]-batch["x"][0][t]
                print(torch.abs(err).max(), torch.abs(err).min(), torch.abs(err).mean(), )
                assert not torch.allclose(pred_pos[0][t], batch["x"][0][t])
                assert not np.allclose(pred_pos[0][t].cpu().numpy(), batch["x"][0][t].cpu().numpy())
            else:
                assert torch.allclose(pred_pos[0][t], batch["x"][0][t])
    # all_rollout_atoms.append(all_atoms)
    # all_rollout_atoms_ref.append(all_atoms_ref)
    # if args.sim_condition:
    #     all_rollout_atoms_ref_0.append(all_atoms_ref_0)
    out_dir = args.out_dir
    dirname = os.path.join(out_dir, f"rollout_{i_rollout}")
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    with open(os.path.join(dirname, "README.md"), "w") as fp:
        fp.write("Data index from Transition1x: %d"%idx)
    filename = os.path.join(dirname, "gentraj_1.xyz")
    filename_ref = os.path.join(dirname, "reftraj_1.xyz")
    print(filename_ref)
    if os.path.exists(filename):
    #     shutil.move(filename_0, os.path.join(dirname, "bck.0.gentraj_0.xyz"))
        os.remove(filename)
    #     shutil.move(filename_ref_0, os.path.join(dirname, "bck.0.reftraj_0.xyz"))
        os.remove(filename_ref)
    assert not np.allclose(all_atoms[1].positions, all_atoms_ref[1].positions)
    for atoms in all_atoms:
        atoms.set_cell(np.eye(3,3)*25)
        write(filename, atoms, append=True)
    for ref_atoms in all_atoms_ref:
        ref_atoms.set_cell(np.eye(3,3)*25)
        write(filename_ref, ref_atoms, append=True)


In [None]:
17.4/287

In [None]:

'''
for i_rollout in range(10):
    print("rollout", i_rollout, pred_pos.shape)
    all_atoms = all_rollout_atoms[i_rollout]
    all_atoms_ref = all_rollout_atoms_ref[i_rollout]
    for t in range(len(pred_pos[0])):
        print("t=",t)
        atoms = all_atoms[t]
        atoms_ref = all_atoms_ref[t]
        for i in range(atoms.positions.shape[0]):
            err = get_distances(atoms_ref.positions[i], atoms.positions[i], cell=atoms.cell, pbc=True)[1][0][0]

            if err>0.1:
                print(atoms.positions[i], atoms_ref.positions[i], err, err>0.1)
        
'''