In [None]:
from argparse import Namespace
args = Namespace(
    # sim_ckpt="workdir/mixedtrain_simcond_12actspacemask-rcut3.5-path_linearOT/epoch=18099-step=4341640.ckpt",
    sim_ckpt="workdir/rcut3.5_energy_encodedim1_perturbeddata/epoch=6029-step=2020050.ckpt",
    data_dir="data/CrCoNi_data/dataset-perturbed",
    suffix="",
    num_rollouts=1,
    # out_dir="./test/mixedtrain_simcond_12actspacemask-rcut3.5-path_linearOT",
    out_dir="./test/rcut3.5_energy_encodedim1_perturbeddata/encoded_dataset-perturbed",
    # num_frames=1,
    # random_starting_point=True,
    # localmask=True,
    # sim_condition=True,
    num_frames=20,
    random_starting_point=False,
    localmask=False,
    sim_condition=False,
    )
device = "cuda"

In [None]:
import os, torch, tqdm, time
import numpy as np
from mdgen.equivariant_wrapper import EquivariantMDGenWrapper

In [None]:
os.makedirs(args.out_dir, exist_ok=True)

In [None]:
from mdgen.dataset import EquivariantTransformerDataset_CrCoNi

dataset = EquivariantTransformerDataset_CrCoNi(traj_dirname=args.data_dir, cutoff=3.5, num_frames=args.num_frames, random_starting_point=args.random_starting_point, localmask=args.localmask, sim_condition=args.sim_condition, stage="val")


In [None]:
print(dataset[0])

In [None]:
ckpt = torch.load(args.sim_ckpt, weights_only=False)
hparams = ckpt["hyper_parameters"]
model = EquivariantMDGenWrapper(**hparams)
print(model.model)
model.load_state_dict(ckpt["state_dict"], strict=False)
model.eval().to(device)

In [None]:
print(ckpt["hyper_parameters"])
print(len(dataset))

In [None]:
embed_dim = ckpt["hyper_parameters"]['args'].embed_dim

In [None]:
print(ckpt["hyper_parameters"]['args'].num_heads)

In [None]:
batch_size = 1
val_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
)
sample_batch = next(iter(val_loader))


In [None]:
print(sample_batch.keys())

In [None]:
print(dataset[499]["x"].shape)

In [None]:
raise RuntimeError

## Test trained encoder

In [None]:
print(len(dataset))

In [None]:
batch_size = 1
# val_loader = torch.utils.data.DataLoader(
#     dataset,
#     batch_size=batch_size,
#     num_workers=0,
#     shuffle=True,
# )
# sample_batch = next(iter(val_loader))

for idx in range(len(dataset)):
# for idx in range(1):
    sample_batch = dataset[idx]
    sample_batch_clean = dataset[idx]
    for k in sample_batch.keys():
        if k != "name" and k != "idx":
            sample_batch[k] = sample_batch[k].unsqueeze(0).to('cuda')
            sample_batch_clean[k] = sample_batch_clean[k].unsqueeze(0).to('cuda')

    sample_batch["x"] += 0.5 * torch.randn_like(sample_batch["x"])

    idx_dataset = sample_batch["idx"]
    model.stage = "inference"
    prep_clean = model.prep_batch(sample_batch_clean)
    t = torch.ones(batch_size).to(model.device)
    encoded_h, encoded_v = model.model.forward_processor(prep_clean["latents"], t, **prep_clean["model_kwargs"] )
    B, T, N, _ = prep_clean["latents"].shape
    assert encoded_h.shape[0] == B*T*N
    print(idx, idx_dataset, B,T,N, encoded_h.reshape(B,T,-1,embed_dim).shape, sample_batch['x'].shape)
    torch.save(encoded_h.reshape(B, T,-1,embed_dim), os.path.join(args.out_dir, f"encoded_h-{idx_dataset}.pt"))
    torch.save(encoded_v.reshape(B, T,-1,embed_dim,3), os.path.join(args.out_dir, f"encoded_v-{idx_dataset}.pt"))

In [None]:

prep = model.prep_batch(sample_batch)
t = torch.ones(batch_size).to(model.device)
encoded_h_perturbed, encoded_v_perturbed = model.model.forward_processor(prep["latents"], t, **prep["model_kwargs"] )


In [None]:
energy = model.model(prep_clean["latents"], t, **prep_clean["model_kwargs"] )
print(energy.shape)
print(energy.sum(dim=2).shape, prep_clean["E"].shape)
print(energy.sum(dim=2), prep_clean["E"])

In [None]:
print(sample_batch.keys())

In [None]:
energy_perturbed = model.model(prep["latents"], t, **prep["model_kwargs"] )
print(energy_perturbed.sum(dim=2), prep["E"], sample_batch["e_mace"])

In [None]:
from sklearn.decomposition import PCA

In [None]:
# pca_2d = PCA(n_components=1)
# data_2d = pca_2d.fit_transform(encoded_h.detach().cpu().numpy())
data_2d = encoded_h.detach().cpu().numpy()
# data_2d_perturbed = pca_2d.transform(encoded_h_perturbed.detach().cpu().numpy())
data_2d_perturbed = encoded_h_perturbed.detach().cpu().numpy()
import matplotlib.pyplot as plt
plt.figure(figsize=(4, 3.5))
# plt.scatter(data_2d[:, 0], data_2d[:, 1], c=prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy(), cmap='viridis', s=1)
# plt.scatter(data_2d[:, 0], data_2d[:, 1], s=1, label="Crystalline")
# plt.scatter(data_2d_perturbed[:, 0], data_2d_perturbed[:, 1], s=1, label="Perturbed")
plt.scatter(data_2d[:, 0], data_2d_perturbed[:, 0], c=prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy(), cmap='viridis', s=1)
cbar =plt.colorbar()
cbar.set_label('species')
plt.plot(plt.xlim(), plt.ylim(), c="k", ls="--")
plt.legend()
plt.title("PCA of encoded_h")
# plt.xlabel("PC1")
# plt.ylabel("PC2")

plt.xlabel("Crystalline")
plt.ylabel("Perturbed")

In [None]:

plt.figure(figsize=(5, 3.5))
# plt.scatter(encoded_h[:, 0].detach().cpu().numpy(), encoded_h[:, 1].detach().cpu().numpy(), c=prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy(), cmap='viridis', s=1)
plt.scatter(encoded_h[:, 0].detach().cpu().numpy(), encoded_h[:, 1].detach().cpu().numpy(), s=1, label="Crystalline")
plt.scatter(encoded_h_perturbed[:, 0].detach().cpu().numpy(), encoded_h_perturbed[:, 1].detach().cpu().numpy(), s=1, label="Perturbed")
cbar =plt.colorbar()
cbar.set_label('species')
plt.legend()
plt.title("Components of encoded_h")
plt.xlabel("Component 1")
plt.ylabel("Component 2")

In [None]:
'''
from neighborhood import get_neighborhood
edge_index, shifts, _, cell = get_neighborhood(prep["latents"].squeeze(0).squeeze(0).cpu().numpy(), cutoff=3.5, pbc=[True, True, True], cell=prep["model_kwargs"]["cell"].squeeze(0).squeeze(0).cpu().numpy())
atomic_numbers = prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy()
neigh_atomic_numbers = atomic_numbers[edge_index[1]]
center_atomic_numbers = atomic_numbers[edge_index[0]]
print("center_atomic_numbers=", center_atomic_numbers)
print("neigh_atomic_numbers=", neigh_atomic_numbers)
print(center_atomic_numbers.shape, neigh_atomic_numbers.shape)
'''

In [None]:
def warren_cowley_sro(center_type, neigh_type, center_pool, neigh_pool):
    loc_center = np.where(center_pool == center_type)
    center = center_pool[loc_center]
    neigh = neigh_pool[loc_center]
    P_cluster = sum(neigh == neigh_type)/len(neigh)
    x_neigh_type = sum(center_pool == neigh_type)/len(center_pool)
    alpha = 1-P_cluster/x_neigh_type
    return alpha, P_cluster, x_neigh_type

In [None]:
def warren_cowley_sro_per_atom(idx_atom, _target_type, edge_pool, neigh_type_pool):
    loc_center = np.where(edge_index[0] == idx_atom)
    _neigh_type = neigh_atomic_numbers[loc_center]
    _P_cluster = sum(_neigh_type == _target_type)/len(_neigh_type)
    _alpha = 1-_P_cluster
    return _alpha

In [None]:
'''
alpha_conf = []
for i in range(prep["species"].shape[-2]):
    alpha_i = warren_cowley_sro_per_atom(i, prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy()[i], edge_index, neigh_atomic_numbers)
    alpha_conf.append(alpha_i)
alpha_conf = np.array(alpha_conf)
print("alpha_conf=", alpha_conf)
'''

In [None]:
pca_2d_x = PCA(n_components=2)
data_2d_x = pca_2d_x.fit_transform(encoded_v[...,0].detach().cpu().numpy())
data_2d_x_perturbed = pca_2d_x.transform(encoded_v_perturbed[...,0].detach().cpu().numpy())
plt.figure(figsize=(4, 3.5))
# plt.scatter(data_2d_x[:,0], data_2d_x[:,1], s=5, c=prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy(), cmap='viridis')
plt.scatter(data_2d_x[:,0], data_2d_x[:,1], s=5, label="Crystalline")
plt.scatter(data_2d_x_perturbed[:,0], data_2d_x_perturbed[:,1], s=5, label="Perturbed")
cbar =plt.colorbar()
cbar.set_label('species')
plt.legend()
plt.title("PCA of encoded_x")
plt.xlabel("PC1")
plt.ylabel("PC2")

In [None]:
plt.figure(figsize=(3.5, 3.5))
# plt.scatter(data_2d_x[:,0], data_2d_x[:,1], s=5, c=prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy(), cmap='viridis')
plt.scatter(encoded_v[...,0,0].detach().cpu().numpy(), encoded_v[...,1,0].detach().cpu().numpy(), s=5, label="Crystalline")
plt.scatter(encoded_v_perturbed[...,0,0].detach().cpu().numpy(), encoded_v_perturbed[...,1,0].detach().cpu().numpy(), s=5, label="Perturbed")
# cbar =plt.colorbar()
# cbar.set_label('species')
plt.title("Components of encoded_x")
plt.legend()
plt.xlabel("Component 1")
plt.ylabel("Component 2")

In [None]:
plt.figure(figsize=(3.5, 3.5))
# plt.scatter(data_2d_x[:,0], data_2d_x[:,1], s=5, c=prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy(), cmap='viridis')
plt.scatter(encoded_v[...,0,0].detach().cpu().numpy(), encoded_v_perturbed[...,0,0].detach().cpu().numpy(), s=5)
plt.scatter(encoded_v[...,1,0].detach().cpu().numpy(), encoded_v_perturbed[...,1,0].detach().cpu().numpy(), s=5)
plt.scatter(encoded_v[...,2,0].detach().cpu().numpy(), encoded_v_perturbed[...,2,0].detach().cpu().numpy(), s=5)
plt.scatter(encoded_v[...,3,0].detach().cpu().numpy(), encoded_v_perturbed[...,3,0].detach().cpu().numpy(), s=5)
# cbar =plt.colorbar()
# cbar.set_label('species')
plt.plot(plt.xlim(), plt.ylim(), 'k--', lw=1)
plt.title("Encoded_x")
plt.xlabel("Encoded_x of crystalline structure")
plt.ylabel("Encoded_x of perturbed structure")

In [None]:
print(prep["latents"].shape, data_2d_x.shape)
print(prep["latents"].squeeze(0).shape)

In [None]:
pca_2d_x = PCA(n_components=2)
data_2d_x = pca_2d_x.fit_transform(encoded_v[...,0].detach().cpu().numpy())
plt.figure(figsize=(4, 3.5))
plt.scatter(data_2d_x[:,0], data_2d_x[:,1], s=5, c=prep["latents"].squeeze(0).reshape(-1,3).cpu().numpy()[:,0], cmap='viridis')
cbar =plt.colorbar()
cbar.set_label('x')
plt.title("PCA of encoded_x")
plt.xlabel("PC1")
plt.ylabel("PC2")

In [None]:
pca_2d_y = PCA(n_components=2)
data_2d_y = pca_2d_y.fit_transform(encoded_v[...,1].detach().cpu().numpy())
plt.figure(figsize=(4, 3.5))
plt.scatter(data_2d_y[:,0], data_2d_y[:,1], s=5, c=prep["species"].squeeze(0).squeeze(0).argmax(dim=-1).cpu().numpy(), cmap='viridis')
cbar =plt.colorbar()
cbar.set_label('y')
plt.title("PCA of encoded_y")
plt.xlabel("PC1")
plt.ylabel("PC2")

In [None]:
pca_2d_y = PCA(n_components=2)
data_2d_y = pca_2d_y.fit_transform(encoded_v[...,1].detach().cpu().numpy())
plt.figure(figsize=(4, 3.5))
plt.scatter(data_2d_y[:,0], data_2d_y[:,1], s=5, c=prep["latents"].squeeze(0).reshape(-1,3).cpu().numpy()[:,1], cmap='viridis')
cbar =plt.colorbar()
cbar.set_label('y')
plt.title("PCA of encoded_y")
plt.xlabel("PC1")
plt.ylabel("PC2")

In [None]:
raise RuntimeError

## Test generative model

In [None]:
for key in ['species', 'x', 'cell', 'num_atoms', 'mask', 'v_mask', 'species_next', 'x_next', "TKS_mask", "TKS_v_mask"]:
    sample_batch[key] = sample_batch[key].to(device)

pred_pos = model.inference(sample_batch)
# prep = model.prep_batch(sample_batch)


In [None]:
@torch.no_grad()
def rollout(model, batch):
    expanded_batch = batch
    positions, _ = model.inference(expanded_batch)
    # mask_act_space = (batch["mask"] != 0)
    # positions = positions*mask_act_space
    new_batch = {**batch}
    new_batch['x'] = positions
    return positions, new_batch

In [None]:

map_to_chemical_symbol = {
    0: "Cr",
    1: 'Co',
    2: "Ni"
}

In [None]:
from ase import Atoms
from ase.geometry.geometry import get_distances

# all_rollout_positions = []
all_rollout_atoms = []
all_rollout_atoms_ref = []
start = time.time()
for i_rollout in range(10):
# for i_rollout in range(1):
    idx = np.random.randint(0, len(dataset), 1)[0]
    # idx = 0
    item = dataset.__getitem__(idx)
    batch = next(iter(torch.utils.data.DataLoader([item])))

    for key in ['species', 'x', 'cell', 'num_atoms', 'mask', 'v_mask', 'species_next', 'x_next', 'TKS_mask', 'TKS_v_mask']:
        batch[key] = batch[key].to(device)

    labels = torch.argmax(batch["species"], dim=3).squeeze(0)
    symbols = [[map_to_chemical_symbol[int(i_elem.to('cpu'))] for i_elem in labels[i_conf]] for i_conf in range(len(labels))]

    pred_pos, _ = rollout(model, batch)
    print("idx = ", idx, "rollout", i_rollout, pred_pos.shape)

    all_atoms = []
    all_atoms_ref = []
    for t in range(len(pred_pos[0])):
        formula = "".join(symbols[t])

        # print("t=",t)
        # for i in range(pred_pos.shape[2]):
        #     err = get_distances(batch["x_next"][0][t][i].cpu().numpy(), (pred_pos[0][t].cpu().numpy()[i]), cell=batch['cell'][0][0].cpu().numpy(), pbc=True)[1][0][0]
        #     if err>0.1:
        #         print(pred_pos[0][t].cpu().numpy()[i], batch["x_next"][0][t][i].cpu().numpy(), err, err>0.1)
        atoms = Atoms(formula, positions=pred_pos[0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        # atoms.set_chemical_symbols(symbols[t])
        all_atoms.append(atoms)
        atoms_ref = Atoms(formula, positions=batch["x_next"][0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        all_atoms_ref.append(atoms_ref)
    all_rollout_atoms.append(all_atoms)
    all_rollout_atoms_ref.append(all_atoms_ref)


In [None]:

'''
for i_rollout in range(10):
    print("rollout", i_rollout, pred_pos.shape)
    all_atoms = all_rollout_atoms[i_rollout]
    all_atoms_ref = all_rollout_atoms_ref[i_rollout]
    for t in range(len(pred_pos[0])):
        print("t=",t)
        atoms = all_atoms[t]
        atoms_ref = all_atoms_ref[t]
        for i in range(atoms.positions.shape[0]):
            err = get_distances(atoms_ref.positions[i], atoms.positions[i], cell=atoms.cell, pbc=True)[1][0][0]

            if err>0.1:
                print(atoms.positions[i], atoms_ref.positions[i], err, err>0.1)
        
'''

In [None]:
import shutil
from ase.io import write

for i in range(10):
    dirname = os.path.join(args.out_dir, f"rollout_{i}")
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    filename = os.path.join(dirname, "gentraj_.xyz")
    filename_ref = os.path.join(dirname, "reftraj_.xyz")
    if os.path.exists(filename):
        shutil.move(filename, os.path.join(dirname, "bck.0.gentraj.xyz"))
        shutil.move(filename_ref, os.path.join(dirname, "bck.0.reftraj.xyz"))

    for atoms in all_rollout_atoms[i]:
        write(filename, atoms, append=True)
    for ref_atoms in all_rollout_atoms_ref[i]:
        print(i, filename_ref)
        write(filename_ref, ref_atoms, append=True)

In [None]:
raise RuntimeError

In [None]:
# Generate trajectory
# idx = np.random.randint(0, len(dataset), 1)[0]
idx = 0
item = dataset.__getitem__(idx, random_starting_point=False)
batch = next(iter(torch.utils.data.DataLoader([item])))

# all_rollout_positions = []
traj_rollout_atoms = []
# traj_rollout_atoms_ref = []
start = time.time()
for i_rollout in range(args.num_rollouts):
# for i_rollout in range(1):
    for key in ['species', 'x', 'cell', 'num_atoms', 'mask', 'v_mask']:
        batch[key] = batch[key].to('cuda')

    labels = torch.argmax(batch["species"], dim=3).squeeze(0)
    symbols = [[map_to_chemical_symbol[int(i_elem.to('cpu'))] for i_elem in labels[i_conf]] for i_conf in range(len(labels))]

    pred_pos, next_batch = rollout(model, batch)
    print("idx = ", idx, "rollout", i_rollout, pred_pos.shape)
    all_atoms = []
    all_atoms_ref = []
    all_out_pos = []
    for t in range(len(pred_pos[0])):
        formula = "".join(symbols[t])

        # if i_rollout == 0:
        #     for i in range(pred_pos.shape[2]):
        err = get_distances(batch["x_next"][0][t].cpu().numpy(), (pred_pos[0][t].cpu().numpy()), cell=batch['cell'][0][0].cpu().numpy(), pbc=True)[1]

        # out_pos = torch.stack([pred_pos[0][t][i] if err[i][i] > 1 else batch["x"][0][t][i] for i in range(len(pred_pos[0][t]))])
        out_pos = pred_pos[0][t]

        atoms = Atoms(formula, positions=out_pos.cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        
        # atoms.set_chemical_symbols(symbols[t])
        all_atoms.append(atoms)
        # atoms_ref = Atoms(formula, positions=batch["x_next"][0][t].cpu().numpy(), cell=batch['cell'][0][0].cpu().numpy(), pbc=[1,1,1])
        # all_atoms_ref.append(atoms_ref)
        all_out_pos.append(out_pos)
    # next_batch["x"] = out_pos.unsqueeze(0).unsqueeze(0)
    traj_rollout_atoms.append(all_atoms)
    # all_rollout_atoms_ref.append(all_atoms_ref)
    batch = next_batch

In [None]:



filename = os.path.join(args.out_dir, "gentraj_fromstart.xyz")
if os.path.exists(filename):
    shutil.move(filename, os.path.join(args.out_dir, "bck.0.gentraj.xyz"))
for atoms in traj_rollout_atoms:
    write(filename, atoms, append=True)
