# Sample Notebook to save dataset for transformer training (Chignolin)

In [2]:
import os
import numpy as np
import mdtraj as md
import matplotlib.pyplot as plt
from cgp import SideChainLens


****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************



In [3]:
all_files = np.array(os.listdir("./uncoupled_gmm_dataset/"))

In [4]:
root_save_folder_mame = "./ChignolinGMMTransformerDataset/"

In [5]:
if not os.path.exists(root_save_folder_mame + "train_indices.npy"):
    assert not os.path.exists(root_save_folder_mame + "val_indices.npy")
    assert not os.path.exists(root_save_folder_mame + "test_indices.npy")

    rand_indices = np.random.permutation(len(all_files))
    train_indices = all_files[rand_indices[:int(0.8*len(all_files))]]
    val_indices = all_files[rand_indices[int(0.8*len(all_files)):int(0.9*len(all_files))]]
    test_indices = all_files[rand_indices[int(0.9*len(all_files)):]]


    np.save(root_save_folder_mame + "train_indices.npy", train_indices)
    np.save(root_save_folder_mame + "val_indices.npy", val_indices)
    np.save(root_save_folder_mame + "test_indices.npy", test_indices)
else:
    train_indices = np.load(root_save_folder_mame + "train_indices.npy")
    val_indices = np.load(root_save_folder_mame + "val_indices.npy")
    test_indices = np.load(root_save_folder_mame + "test_indices.npy")
print(train_indices.shape, val_indices.shape, test_indices.shape)

(40000,) (5000,) (5000,)


In [6]:
prop_temp = 300.0
dt = 0.001
num_steps = 5
cutoff_to_use_kt = -50


for (dataset_name, dataset_indices) in [("train", train_indices), ("val", val_indices), ("test", test_indices)]:

    root_relaxed_folder_name = f"./prop_temp_{prop_temp}_dt_{dt}_num_steps_{num_steps}/"
    save_folder_name = f"{root_save_folder_mame}prop_temp_{prop_temp}_dt_{dt}_num_steps_{num_steps}_cutoff_to_use_kt_{cutoff_to_use_kt}/{dataset_name}/"
    os.makedirs(save_folder_name, exist_ok=True)

    root_uncoupled_gmm_folder_name = "./uncoupled_gmm_dataset/"
    for file in dataset_indices:
        relaxed_folder_name = f"{root_relaxed_folder_name}{file}/"
        uncoupled_gmm_folder_name = f"{root_uncoupled_gmm_folder_name}{file}/"
        all_energies = []
        all_gen_components = []
        for batch_index in range(1):
            try:
                energies = np.load(
                    f"{relaxed_folder_name}all_relaxed_potential_energies_{batch_index}.npy")
            except FileNotFoundError:
                print(f"File not found for {file} batch {batch_index}")
                continue
            gen_components = np.load(
                f"{uncoupled_gmm_folder_name}uncoupled_gmm_gen_components_{batch_index}.npy", allow_pickle=True).item()
            gen_components[7] = np.ones_like(gen_components[1]) * 64
            concatenated_gen_components = []
            for res_num in range(12):
                gen_component = gen_components[res_num]
                concatenated_gen_components.append(gen_component[:, 0])
            concatenated_gen_components = np.stack(
                concatenated_gen_components, axis=-1)

            assert energies.shape[0] == 10000
            assert concatenated_gen_components.shape[0] == 10000
            all_energies.append(energies)
            all_gen_components.append(concatenated_gen_components)


        if len(all_energies) == 0:
            continue

        all_energies = np.concatenate(all_energies)
        all_gen_components = np.concatenate(all_gen_components)

        all_gen_components = all_gen_components[all_energies < cutoff_to_use_kt]
        all_energies = all_energies[all_energies < cutoff_to_use_kt]
        assert all_energies.shape[0] == all_gen_components.shape[0]

        traj = md.load(f"{uncoupled_gmm_folder_name}uncoupled_gmm_0.h5", frame=0)
        side_chain_lens = SideChainLens(protein_top=traj.topology)
        side_chain_info = side_chain_lens.get_data(traj)
        c_info = np.stack((side_chain_info["phi"]["dihedrals"],
                           side_chain_info["psi"]["dihedrals"]),
                          axis=-1)

        np.save(f"{save_folder_name}{file}_c_info.npy", c_info)
        np.save(f"{save_folder_name}{file}_energies.npy", all_energies)
        np.save(f"{save_folder_name}{file}_gen_components.npy",
                all_gen_components)



# Make Transformer Dataset

In [7]:
traj_folder_name = "./chignolin/fg_traj/"
protein_filename = f"{traj_folder_name}/chignolin_traj.h5"
protein_traj = md.load(protein_filename, frame=0)
protein_top = protein_traj.topology
all_residues =[residue.code for residue in protein_top.residues][1:-1]
residue_names, residue_indices = np.unique(all_residues, return_inverse=True)
residue_indices = residue_indices.reshape(-1, 10)
residue_indices

array([[6, 6, 0, 3, 1, 4, 2, 4, 5, 6]])

In [8]:
prop_temp = 300.0
dt = 0.001
num_steps = 5
cutoff_to_use_kt = -50
root_load_folder_name = f"./ChignolinGMMTransformerDataset/prop_temp_{prop_temp}_dt_{dt}_num_steps_{num_steps}_cutoff_to_use_kt_{cutoff_to_use_kt}/"
for (dataset_name, dataset_indices) in [("train", train_indices), ("val", val_indices), ("test", test_indices)]:

    load_folder_name = f"{root_load_folder_name}/{dataset_name}/"
    all_energies = []
    all_gen_components = []
    all_c_info = []
    for file in dataset_indices:
        try:
            c_info = np.load(f"{load_folder_name}{file}_c_info.npy")
            energies = np.load(f"{load_folder_name}{file}_energies.npy")
            gen_components = np.load(
                f"{load_folder_name}{file}_gen_components.npy")
        except FileNotFoundError:
            print(file)
            continue

        all_energies.append(energies)
        all_gen_components.append(gen_components)
        all_c_info.append(np.repeat(c_info, energies.shape[0], axis=0))

    all_energies = np.concatenate(all_energies)
    all_gen_components = np.concatenate(all_gen_components)
    all_c_info = np.concatenate(all_c_info)

    all_src_cont = np.stack((np.sin(all_c_info[:, :, 0]),
                             np.cos(all_c_info[:, :, 0]),
                             np.sin(all_c_info[:, :, 1]),
                             np.cos(all_c_info[:, :, 1])), axis=-1)
    all_target = np.concatenate((np.ones_like(all_gen_components[:, :1]) * 65,
                                 all_gen_components,
                                 np.ones_like(all_gen_components[:, :1]) * 66), axis=-1)

    np.save(f"{root_load_folder_name}{dataset_name}_all_energies.npy", all_energies)
    np.save(f"{root_load_folder_name}{dataset_name}_all_gen_components.npy",
            all_gen_components)
    np.save(f"{root_load_folder_name}{dataset_name}_all_c_info.npy", all_c_info)

    np.save(f"{root_load_folder_name}{dataset_name}_all_src_cont.npy", all_src_cont)
    np.save(f"{root_load_folder_name}{dataset_name}_all_src_cat.npy", np.repeat(residue_indices, all_src_cont.shape[0], axis=0))
    
    np.save(f"{root_load_folder_name}{dataset_name}_all_target.npy", all_target)
