In [12]:
import py3Dmol as p3d
import numpy as np
import pickle
from ase.build.rotate import minimize_rotation_and_translation as transf
from ase.io import read
import torch

In [65]:
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start**0.5,
                beta_end**0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas

def make_xyz(symbols, pos_array):
    msg = f"{len(pos_array)}\n"
    symbol_dict = {1:"H", 6:"C", 7:"N", 8:"O"}
    for s,xyz in zip(symbols, pos_array):
        if not isinstance(s, str):
            s = symbol_dict[s]
        msg += f"\n{s}"
        for v in xyz:
            msg += f" {v}"
    return msg

In [7]:
betas = get_beta_schedule(beta_schedule="sigmoid", beta_start=1e-07, beta_end=2e-03, num_diffusion_timesteps=5000)
alphas = (1.0-betas).cumprod(axis=0)
alphas = np.flip(alphas)

In [8]:
with open("test_save.pkl","rb") as f: data = pickle.load(f)

In [67]:
symbol = data[0]["atom_type"].numpy()
traj_pos = data[0]["pos_gen"].numpy()[:,:len(symbol),:]
ts_pos = data[0]["pos"].numpy()

In [68]:
xyz_ts = make_xyz(symbol, traj_pos[-1])
xyz_gen = make_xyz(symbol, ts_pos)

In [69]:
traj = traj_pos * np.sqrt(alphas[-len(traj_pos):].reshape(-1,1,1))
xyz_traj = ""
i = 0
every = 1
start_from = 0
for pos in traj[start_from::every]:
    xyz_traj += make_xyz(symbol, pos) + "\n"

In [70]:
view = p3d.view(viewergrid=(1,2), width=800, height=400, linked=True)
view.addModel(xyz_ts, "xyz", viewer=(0,0))
view.addModel(xyz_gen, "xyz", viewer=(0,1))
view.setStyle({'stick':{}})
view.show()


In [72]:
# Animate optimization
viewer = p3d.view(width=1000, height=500)
viewer.addModelsAsFrames(xyz_traj)
viewer.animate({"loop": "forward"})
viewer.setStyle({"stick": {}})
viewer.rotate(-60, "x")
viewer.rotate(40, "y")
viewer.show()

In [63]:
traj_data = np.load("traj_1_372_62_1.npz")
traj_pos = traj_data["pos_gen"]
symbol = traj_data["atom_type"]
r_pos = traj_data["r_pos"]
p_pos = traj_data["p_pos"]
ts_pos = traj_data["pos"]
traj_pos = traj_pos[:,:len(symbol),:]
traj = traj_pos * np.sqrt(alphas).reshape(-1,1,1)
xyz_traj = ""
i = 0
every = 20
start_from = 0
for pos in traj[start_from:]:
    i += 1
    if i % every == 0:
        xyz_traj += make_xyz(symbol, pos)+"\n"

xyz_gen = make_xyz(symbol, traj[-1])
xyz_r = make_xyz(symbol, r_pos)
xyz_p = make_xyz(symbol, p_pos)
xyz_ts = make_xyz(symbol, ts_pos)
with open("tmp.xyz","w") as f:
    f.write(xyz_ts)
gt = read("tmp.xyz",)
with open("tmp.xyz","w") as f:
    f.write(xyz_gen)
gen = read("tmp.xyz")
transf(gt, gen)
xyz_gen = make_xyz(symbol, gen.get_positions())

In [43]:
view = p3d.view(viewergrid=(1,4), width=800, height=400, linked=True)
view.addModel(xyz_r, "xyz", viewer=(0,0))
view.addModel(xyz_ts, "xyz", viewer=(0,1))
view.addModel(xyz_gen, "xyz", viewer=(0,2))
view.addModel(xyz_p, "xyz", viewer=(0,3))
view.setStyle({'stick':{}})
view.show()


In [44]:
# Animate optimization

viewer = p3d.view(width=1000, height=500)
viewer.addModelsAsFrames(xyz_traj)
viewer.animate({"loop": "forward"})
viewer.setStyle({"stick": {}})
viewer.rotate(-60, "x")
viewer.rotate(40, "y")
viewer.show()
