# Sample Notebook to use generated components to generate internal coordinates (Chignolin)

In [4]:
import os
import numpy as np
import mdtraj as md
from cgp import SideChainLens, GaussianMixtureModel
import torch
import torch.nn as nn


****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************



In [5]:
device = torch.device("cpu")

In [6]:
root_protein_folder_name = "./datasets/chignolin/"


traj_folder_name = f"{root_protein_folder_name}/fg_traj/"
all_subsampled_indices = np.load(traj_folder_name
                                 + "all_subsampled_indices.npy")
protein_filename = traj_folder_name + "chignolin_traj.h5"
protein_traj = md.load(protein_filename, frame=0)
protein_top = protein_traj.topology




network_folder_name = f"{root_protein_folder_name}/gmm/"


def get_gmm_model(in_dim, all_coordinates, num_components, protein_name):

    num_datapoints = 500000
    n_dim = in_dim
    means = torch.randn(num_components, n_dim).to(device)
    all_offset = torch.randn(1, n_dim).to(device).cpu().numpy()
    covs = torch.eye(n_dim).unsqueeze(0).repeat(num_components,
                                                1, 1).to(device)


    weights = torch.ones(num_components, device=device) / num_components

    tag = (str(protein_name) + "_protein_name_"
           + "_".join(all_coordinates) + "_coordinate_"
           + str(num_components) + "_num_components_"
           + str(num_datapoints) + "_num_datapoints")

    folder_name = network_folder_name + tag + "/"

    if "_".join(all_coordinates) == "dihedrals":
       sample_more_info = True
    else:
       sample_more_info = False


    gmm_model = GaussianMixtureModel(means=means, covs=covs, weights=weights,
                                     all_offset=all_offset,
                                     folder_name=folder_name, 
                                     sample_more_info=sample_more_info,
                                     device=device)
    loaded_epoch = gmm_model.load(epoch_num=None)
    sk_info = np.load(folder_name + "sk_info.npy", allow_pickle=True)

    return gmm_model




# Load GMMs

In [7]:
all_tpp_names = np.loadtxt(
    f"{root_protein_folder_name}/tpp/sidechain_info/dihedral_name_n_components.txt", dtype=str)


side_chain_lens = SideChainLens(protein_top=protein_top)
side_chain_info = side_chain_lens.get_data(protein_traj)


all_chi_info_by_residue_num = side_chain_info["chi"]

ic_models_by_res_num = {}
cg_c_info_by_res_num = {}
z_matrix_by_res_num = {}
all_num_components_by_res_num = {}
dihedral_models_by_res_num = {}
bond_length_angle_models_by_res_num = {}

full_z_matrix = []
for (res_num, chi_info) in all_chi_info_by_residue_num.items():
    print(res_num)
    # if chi_info is None:
    if chi_info is None:
        continue

    if res_num == 0:
        assert protein_top.residue(res_num).name == "ACE"
        protein_name = "ACE"

    elif res_num == protein_top.n_residues - 1:
        assert protein_top.residue(res_num).name == "NME"
        protein_name = "NME"
    elif res_num == 1:
        protein_name = f"{protein_top.residue(res_num).name}_{protein_top.residue(res_num + 1).name}_{protein_top.residue(res_num + 2).name}_FIRST"
    elif res_num == protein_top.n_residues - 2:
        protein_name = f"{protein_top.residue(res_num - 2).name}_{protein_top.residue(res_num - 1).name}_{protein_top.residue(res_num).name}_LAST"
    else:
        protein_name = f"{protein_top.residue(res_num - 1).name}_{protein_top.residue(res_num).name}_{protein_top.residue(res_num + 1).name}"

    in_dim = np.load(
        f"{root_protein_folder_name}/tpp/sidechain_info/{protein_name}/dihedrals.npy").shape[1]

    n_components_dihedrals = all_tpp_names[all_tpp_names[:, 0]
                                           == protein_name, 1].astype(int)[0]

    bond_length_angle_model = get_gmm_model(in_dim * 2, ["distances", "angles"],
                                            1, protein_name)
    dihedral_model = get_gmm_model(in_dim, ["dihedrals"],
                                   n_components_dihedrals, protein_name)
    
    bond_length_angle_models_by_res_num[res_num] = bond_length_angle_model
    dihedral_models_by_res_num[res_num] = dihedral_model
    full_z_matrix.append(chi_info["dihedral_indices"])

full_z_matrix = np.concatenate(full_z_matrix)

0
1
2
3
4
5
6
7
8
9
10
11


In [8]:
dim_model = 256
dropout_p = 0.1
kt_cutoff = -50
beta_target = 1.0
epoch_num = 14
dataset_tag = f"prop_temp_300.0_dt_0.001_num_steps_5_cutoff_to_use_kt_{kt_cutoff}"
save_folder_name = f"{dataset_tag}/test/"
all_backbones = os.listdir(save_folder_name)
all_backbones = np.unique([x.split("_")[0] for x in all_backbones])

In [9]:
for backbone in all_backbones:
    pred_components = torch.load(f"{save_folder_name}{backbone}_pred_components.pt", map_location=device)
    all_dihedral_samples = []
    all_bond_length_samples = []
    all_angle_samples = []
    for (res_num, dihedral_model) in dihedral_models_by_res_num.items():
        dihedral_samples = dihedral_model.sample_from_selected_components(pred_components[:, res_num])
        bond_length_angle_model = bond_length_angle_models_by_res_num[res_num]
        bond_length_angle_samples, _ = bond_length_angle_model.sample(pred_components.shape[0])
        bond_length_samples = bond_length_angle_samples[:, :bond_length_angle_samples.shape[-1] // 2]
        angle_samples = bond_length_angle_samples[:, bond_length_angle_samples.shape[-1] // 2:]

        all_dihedral_samples.append(dihedral_samples)
        all_bond_length_samples.append(bond_length_samples)
        all_angle_samples.append(angle_samples)


    all_dihedral_samples = torch.cat(all_dihedral_samples, dim=-1)
    all_bond_length_samples = torch.cat(all_bond_length_samples, dim=-1)
    all_angle_samples = torch.cat(all_angle_samples, dim=-1)
    all_ic_samples = torch.cat([all_bond_length_samples, all_angle_samples, all_dihedral_samples], dim=-1)

    torch.save(all_dihedral_samples, f"{save_folder_name}{backbone}_all_dihedral_samples.pt")
    torch.save(all_bond_length_samples, f"{save_folder_name}{backbone}_all_bond_length_samples.pt")
    torch.save(all_angle_samples, f"{save_folder_name}{backbone}_all_angle_samples.pt")
    torch.save(all_ic_samples, f"{save_folder_name}{backbone}_all_ic_samples.pt")


In [17]:
all_dihedral_samples.shape

torch.Size([96, 100])