# Sample Notebook to use generated components to generate side chains from sampled internal coordinates (Chignolin)

In [1]:
import torch

In [2]:
from openmm import *
from openmm.app import *
from openmm.unit import *

In [4]:
import os
import numpy as np
import mdtraj as md
import torch
import matplotlib.pyplot as plt
from cgp import IdentityModel, SideChainLens, Factory, ProteinImplicit


****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************



In [5]:
device = torch.device("cpu")

In [7]:

root_protein_folder_name = "./datasets/chignolin/"
traj_folder_name = f"{root_protein_folder_name}/fg_traj/"
protein_filename = traj_folder_name + "chignolin_traj.h5"
all_folded_indices = np.load(traj_folder_name + "all_folded_indices.npy")
all_unfolded_indices = np.load(traj_folder_name + "all_unfolded_indices.npy")
all_misfolded_indices = np.load(traj_folder_name + "all_misfolded_indices.npy")


traj = md.load(protein_filename)

protein_top = traj.topology

side_chain_lens = SideChainLens(protein_top=protein_top)
side_chain_info = side_chain_lens.get_data(traj)
cg_indices = side_chain_lens.atom_indices_coarser

all_chi_info_by_residue_num = side_chain_info["chi"]
bond_length_angle_models_by_res_num = {}
cg_c_info_by_res_num = {}
z_matrix_by_res_num = {}
full_z_matrix = []
for (res_num, chi_info) in all_chi_info_by_residue_num.items():
    if chi_info is None:
        continue
    full_z_matrix.append(chi_info["dihedral_indices"])

full_z_matrix = np.concatenate(full_z_matrix)



In [9]:
dim_model = 256
dropout_p = 0.1
kt_cutoff = -50
beta_target = 1.0
epoch_num = 6
dataset_tag = f"prop_temp_300.0_dt_0.001_num_steps_5_cutoff_to_use_kt_{kt_cutoff}"
save_folder_name = f"{dataset_tag}/test/"
all_backbones = os.listdir(save_folder_name)
all_backbones = np.unique([x.split("_")[0] for x in all_backbones]).astype(int)


In [None]:
for backbone in all_backbones:
    print(backbone)
    all_ic_samples = torch.load(f"{save_folder_name}{backbone}_all_ic_samples.pt", map_location=device)
    ic_model = IdentityModel(all_samples=all_ic_samples)

    traj_i = traj[backbone]
    traj_i = traj_i[[0] * all_ic_samples.shape[0]]
    cg_traj = traj_i.atom_slice(cg_indices)
    cg_traj_pos = torch.tensor(cg_traj.xyz, device=device)
    n_samples = cg_traj_pos.shape[0]
    side_chain_factory = Factory(lens=side_chain_lens)
    all_reconstructed_positions, total_log_prob = side_chain_factory.reconstruct_from_cg_traj(cg_traj_pos=cg_traj_pos,
                                                                                            ic_model=ic_model, cg_c_info=None,
                                                                                            z_matrix=full_z_matrix, grad_enabled=False, 
                                                                                            device=device)
    reconstructed_traj = md.Trajectory(all_reconstructed_positions.detach().cpu().numpy(),
                                    topology=protein_top)
    reconstructed_traj.save(f"{save_folder_name}{backbone}_chignolin_traj_reconstructed.h5")
