In [None]:
from argparse import Namespace
args = Namespace(
    sim_ckpt="workdir/alchem_loss_regress-path_linear/epoch=9-step=3350.ckpt",
    data_dir="data/alchem_CrCoNi_data",
    suffix="",
    num_frames=1,
    num_rollouts=100,
    out_dir="./test/alchem_loss_regress-path_linear-forwardpred_1teps"
    )

In [None]:
import os, torch, tqdm, time
import numpy as np
from mdgen.equivariant_wrapper import EquivariantMDGenWrapper

In [None]:
os.makedirs(args.out_dir, exist_ok=True)

In [None]:
from mdgen.dataset import EquivariantTransformerDataset_CrCoNi

dataset = EquivariantTransformerDataset_CrCoNi(traj_dirname=args.data_dir, cutoff=2.5, num_frames=args.num_frames, stage="train")


In [None]:
ckpt = torch.load(args.sim_ckpt, weights_only=False)
model = EquivariantMDGenWrapper(**ckpt["hyper_parameters"])
model.load_state_dict(ckpt["state_dict"])
model.eval().to('cuda')

In [None]:
@torch.no_grad()
def rollout(model, batch):
    expanded_batch = batch
    s, _ = model.inference(expanded_batch)
    new_batch = {**batch}
    new_batch['species'] = s
    return s, new_batch

In [None]:

map_to_chemical_symbol = {
    0: "Cr",
    1: 'Co',
    2: "Ni"
}

In [None]:
from ase import Atoms
from ase.geometry.geometry import get_distances

# all_rollout_positions = []
all_rollout_atoms = []
all_rollout_atoms_ref = []
start = time.time()
for i_rollout in range(10):
# for i_rollout in range(1):
    idx = np.random.randint(0, len(dataset), 1)[0]
    # idx = 0
    item = dataset.__getitem__(idx, random_starting_point=False)
    batch = next(iter(torch.utils.data.DataLoader([item])))

    for key in ['species', 'x', 'cell', 'num_atoms', 'mask', 'v_mask']:
        batch[key] = batch[key].to('cuda')


    pred_s, _ = rollout(model, batch)
    labels = torch.argmax(pred_s, dim=-1).squeeze(0)
    symbols = [[map_to_chemical_symbol[int(i_elem.to('cpu'))] for i_elem in labels[i_conf]] for i_conf in range(len(labels))]

    print("idx = ", idx, "rollout", i_rollout, pred_s.shape)
    all_atoms = []
    all_atoms_ref = []
    for t in range(len(pred_s[0])):
        formula = "".join(symbols[t])

        atoms = Atoms(formula, positions=batch["x"][0][0].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        # atoms.set_chemical_symbols(symbols[t])
        all_atoms.append(atoms)
        atoms_ref = Atoms(formula, positions=batch["x"][0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        all_atoms_ref.append(atoms_ref)
    all_rollout_atoms.append(all_atoms)
    all_rollout_atoms_ref.append(all_atoms_ref)


In [None]:


for i_rollout in range(10):
    print("rollout", i_rollout, pred_pos.shape)
    all_atoms = all_rollout_atoms[i_rollout]
    all_atoms_ref = all_rollout_atoms_ref[i_rollout]
    for t in range(len(pred_pos[0])):
        print("t=",t)
        atoms = all_atoms[t]
        atoms_ref = all_atoms_ref[t]
        for i in range(atoms.positions.shape[0]):
            err = get_distances(atoms_ref.positions[i], atoms.positions[i], cell=atoms.cell, pbc=True)[0][0][0]
            if err.max()>0.5:
                print(atoms.positions[i], atoms_ref.positions[i], err.max(), err.max()>0.5)
        


In [None]:
# Generate trajectory
# idx = np.random.randint(0, len(dataset), 1)[0]
idx = 0
item = dataset.__getitem__(idx, random_starting_point=False)
batch = next(iter(torch.utils.data.DataLoader([item])))

# all_rollout_positions = []
traj_rollout_atoms = []
# traj_rollout_atoms_ref = []
start = time.time()
for i_rollout in range(args.num_rollouts):
# for i_rollout in range(1):
    for key in ['species', 'x', 'cell', 'num_atoms', 'mask', 'v_mask']:
        batch[key] = batch[key].to('cuda')

    labels = torch.argmax(batch["species"], dim=3).squeeze(0)
    symbols = [[map_to_chemical_symbol[int(i_elem.to('cpu'))] for i_elem in labels[i_conf]] for i_conf in range(len(labels))]

    pred_pos, next_batch = rollout(model, batch)
    print("idx = ", idx, "rollout", i_rollout, pred_pos.shape)
    all_atoms = []
    all_atoms_ref = []
    all_out_pos = []
    for t in range(len(pred_pos[0])):
        formula = "".join(symbols[t])

        # if i_rollout == 0:
        #     for i in range(pred_pos.shape[2]):
        err = get_distances(batch["x_next"][0][t].cpu().numpy(), (pred_pos[0][t].cpu().numpy()), cell=batch['cell'][0][0].cpu().numpy(), pbc=True)[1]

        # out_pos = torch.stack([pred_pos[0][t][i] if err[i][i] > 1 else batch["x"][0][t][i] for i in range(len(pred_pos[0][t]))])
        out_pos = pred_pos[0][t]

        atoms = Atoms(formula, positions=out_pos.cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        
        # atoms.set_chemical_symbols(symbols[t])
        all_atoms.append(atoms)
        # atoms_ref = Atoms(formula, positions=batch["x_next"][0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        # all_atoms_ref.append(atoms_ref)
        all_out_pos.append(out_pos)
    # next_batch["x"] = out_pos.unsqueeze(0).unsqueeze(0)
    traj_rollout_atoms.append(all_atoms)
    # all_rollout_atoms_ref.append(all_atoms_ref)
    batch = next_batch

In [None]:
import shutil
from ase.io import write

for i in range(10):
    dirname = os.path.join(args.out_dir, f"rollout_{i}")
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    filename = os.path.join(dirname, "gentraj_fromstart.xyz")
    filename_ref = os.path.join(dirname, "reftraj_fromstart.xyz")
    if os.path.exists(filename):
        shutil.move(filename, os.path.join(dirname, "bck.0.gentraj.xyz"))
        shutil.move(filename_ref, os.path.join(dirname, "bck.0.reftraj.xyz"))
    # os.remove(filename)
    for atoms in all_rollout_atoms[i]:
        write(filename, atoms, append=True)
    for ref_atoms in all_rollout_atoms_ref[i]:
        write(filename_ref, ref_atoms, append=True)


filename = os.path.join(args.out_dir, "gentraj_fromstart.xyz")
if os.path.exists(filename):
    shutil.move(filename, os.path.join(args.out_dir, "bck.0.gentraj.xyz"))
for atoms in traj_rollout_atoms:
    write(filename, atoms, append=True)
