In [1]:
import numpy as np
import torch
import mdtraj as md
import matplotlib.pyplot as plt
from cgp import SideChainLens, res_code_to_res_info
from openmm import *
from openmm.app import *
from openmm.unit import *
import os


****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************



In [2]:
def get_bond_info(prmtop_file):
    prmtop = AmberPrmtopFile(f"{prmtop_file}.prmtop")
    system = prmtop.createSystem()
    bond_length_force = system.getForce(0)
    assert bond_length_force.getName() == "HarmonicBondForce"
    bond_angle_force = system.getForce(1)
    assert bond_angle_force.getName() == "HarmonicAngleForce"

    all_bond_length_info = {}
    for i in range(bond_length_force.getNumBonds()):
        bond_atoms = bond_length_force.getBondParameters(i)[0:2]
        bond_atoms = tuple(sorted(bond_atoms))
        all_bond_length_info[bond_atoms] = [bond_length_force.getBondParameters(i)[2]._value, 
                                            bond_length_force.getBondParameters(i)[3]._value]


    all_bond_angle_info = {}
    for i in range(bond_angle_force.getNumAngles()):
        bond_angle_atoms = bond_angle_force.getAngleParameters(i)[0:3]
        bond_angle_atoms = tuple(sorted(bond_angle_atoms))
        all_bond_angle_info[bond_angle_atoms] = [bond_angle_force.getAngleParameters(i)[3]._value, 
                                                bond_angle_force.getAngleParameters(i)[4]._value]
    return all_bond_length_info, all_bond_angle_info

In [3]:
all_residues = np.loadtxt("all_residues.txt", dtype=str)

In [4]:
first_tpp = "_".join(all_residues[1:4])
last_tpp = "_".join(all_residues[-4:-1])

In [5]:
tpp_folder_name = "./starter_files/"
all_tpp_files = [file for file in os.listdir('concatenated_no_solvent') 
                 if file.endswith('.h5')]
all_dihedral_names = []
kbT = 0.0083144621 * 300  # units of kJ/mol

for file in all_tpp_files:
    dihedral_gaussian_info = {}
    tpp_traj = md.load('concatenated_no_solvent/' + file)
    tpp_top = tpp_traj.topology
    side_chain_lens = SideChainLens(protein_top=tpp_top)
    side_chain_info = side_chain_lens.get_data(tpp_traj)
    c_info = np.concatenate((side_chain_info["phi"]["dihedrals"],
                             side_chain_info["psi"]["dihedrals"]),
                            axis=1)
    chi_info = side_chain_info["chi"][2]

    prmtop_file = tpp_folder_name + file[:-3]
    all_bond_length_info, all_bond_angle_info = get_bond_info(prmtop_file=prmtop_file)
    if chi_info is not None:
        save_folder_name = 'sidechain_info/' + file[:-3]
        os.makedirs(save_folder_name, exist_ok=True)
        np.save(save_folder_name + '/c_info.npy', c_info)
        np.save(save_folder_name + '/dihedrals.npy', chi_info["dihedrals"])
        np.save(save_folder_name + '/distances.npy', chi_info["distances"])
        np.save(save_folder_name + '/angles.npy', chi_info["angles"])


        all_mean_distances = []
        all_std_distances = []
        for distance_indices in chi_info["dihedral_indices"][:, :2]:
            distance_indices = tuple(sorted(distance_indices))
            all_mean_distances.append(all_bond_length_info[distance_indices][0])
            all_std_distances.append(np.sqrt(kbT / all_bond_length_info[distance_indices][1]))
        all_mean_distances = np.array(all_mean_distances)
        all_std_distances = np.array(all_std_distances)
        np.save(save_folder_name + '/all_mean_distances.npy', all_mean_distances)
        np.save(save_folder_name + '/all_std_distances.npy', all_std_distances)

        all_mean_angles = []
        all_std_angles = []
        for angle_indices in chi_info["dihedral_indices"][:, :3]:
            angle_indices = tuple(sorted(angle_indices))
            all_mean_angles.append(all_bond_angle_info[angle_indices][0])
            all_std_angles.append(np.sqrt(kbT / all_bond_angle_info[angle_indices][1]))
        all_mean_angles = np.array(all_mean_angles)
        np.save(save_folder_name + '/all_mean_angles.npy', all_mean_angles)
        np.save(save_folder_name + '/dihedral_indices.npy', chi_info["dihedral_indices"])

        np.save(save_folder_name + '/all_std_angles.npy', all_std_angles)

        all_atom_names = list(tpp_top.atoms)
        np.save(save_folder_name + '/all_atom_names.npy', all_atom_names)

        all_dihedral_names.append(file[:-3])

    if file[:-3] == first_tpp:
        save_folder_name = 'sidechain_info/' + file[:-3]+"_FIRST"
        chi_info = side_chain_info["chi"][1]

        if chi_info is not None:
            os.makedirs(save_folder_name, exist_ok=True)
            np.save(save_folder_name + '/c_info.npy', c_info)
            np.save(save_folder_name + '/dihedrals.npy', chi_info["dihedrals"])
            np.save(save_folder_name + '/distances.npy', chi_info["distances"])
            np.save(save_folder_name + '/angles.npy', chi_info["angles"])
            all_dihedral_names.append(file[:-3]+"_FIRST")





            all_mean_distances = []
            all_std_distances = []
            for distance_indices in chi_info["dihedral_indices"][:, :2]:
                distance_indices = tuple(sorted(distance_indices))
                all_mean_distances.append(all_bond_length_info[distance_indices][0])
                all_std_distances.append(np.sqrt(kbT / all_bond_length_info[distance_indices][1]))
            all_mean_distances = np.array(all_mean_distances)
            all_std_distances = np.array(all_std_distances)
            np.save(save_folder_name + '/all_mean_distances.npy', all_mean_distances)
            np.save(save_folder_name + '/all_std_distances.npy', all_std_distances)

            all_mean_angles = []
            all_std_angles = []
            for angle_indices in chi_info["dihedral_indices"][:, :3]:
                angle_indices = tuple(sorted(angle_indices))
                all_mean_angles.append(all_bond_angle_info[angle_indices][0])
                all_std_angles.append(np.sqrt(kbT / all_bond_angle_info[angle_indices][1]))
            all_mean_angles = np.array(all_mean_angles)
            np.save(save_folder_name + '/all_mean_angles.npy', all_mean_angles)
            np.save(save_folder_name + '/all_std_angles.npy', all_std_angles)
            
            np.save(save_folder_name + '/dihedral_indices.npy', chi_info["dihedral_indices"])

            all_atom_names = list(tpp_top.atoms)
            np.save(save_folder_name + '/all_atom_names.npy', all_atom_names)



    elif file[:-3] == last_tpp:
        save_folder_name = 'sidechain_info/' + file[:-3]+"_LAST"
        chi_info = side_chain_info["chi"][3]

        if chi_info is not None:
            os.makedirs(save_folder_name, exist_ok=True)
            np.save(save_folder_name + '/c_info.npy', c_info)
            np.save(save_folder_name + '/dihedrals.npy', chi_info["dihedrals"])
            np.save(save_folder_name + '/distances.npy', chi_info["distances"])
            np.save(save_folder_name + '/angles.npy', chi_info["angles"])
            all_dihedral_names.append(file[:-3]+"_LAST")

            all_mean_distances = []
            all_std_distances = []
            for distance_indices in chi_info["dihedral_indices"][:, :2]:
                distance_indices = tuple(sorted(distance_indices))
                all_mean_distances.append(all_bond_length_info[distance_indices][0])
                all_std_distances.append(np.sqrt(kbT / all_bond_length_info[distance_indices][1]))
            all_mean_distances = np.array(all_mean_distances)
            all_std_distances = np.array(all_std_distances)
            np.save(save_folder_name + '/all_mean_distances.npy', all_mean_distances)
            np.save(save_folder_name + '/all_std_distances.npy', all_std_distances)

            all_mean_angles = []
            all_std_angles = []
            for angle_indices in chi_info["dihedral_indices"][:, :3]:
                angle_indices = tuple(sorted(angle_indices))
                all_mean_angles.append(all_bond_angle_info[angle_indices][0])
                all_std_angles.append(np.sqrt(kbT / all_bond_angle_info[angle_indices][1]))
            all_mean_angles = np.array(all_mean_angles)
            all_std_angles = np.array(all_std_angles)

            np.save(save_folder_name + '/all_mean_angles.npy', all_mean_angles)
            np.save(save_folder_name + '/all_std_angles.npy', all_std_angles)

            np.save(save_folder_name + '/dihedral_indices.npy', chi_info["dihedral_indices"])

            all_atom_names = list(tpp_top.atoms)
            np.save(save_folder_name + '/all_atom_names.npy', all_atom_names)




In [6]:
ace_chi_info = side_chain_info["chi"][0]
ace_c_info = np.zeros((ace_chi_info["dihedrals"].shape[0], 6))
save_folder_name = 'sidechain_info/ACE/'
os.makedirs(save_folder_name, exist_ok=True)
np.save(save_folder_name + '/dihedrals.npy', ace_chi_info["dihedrals"])
np.save(save_folder_name + '/distances.npy', ace_chi_info["distances"])
np.save(save_folder_name + '/angles.npy', ace_chi_info["angles"])
np.save(save_folder_name + '/c_info.npy', ace_c_info)

all_mean_distances = []
all_std_distances = []
for distance_indices in ace_chi_info["dihedral_indices"][:, :2]:
    distance_indices = tuple(sorted(distance_indices))
    all_mean_distances.append(all_bond_length_info[distance_indices][0])
    all_std_distances.append(np.sqrt(kbT / all_bond_length_info[distance_indices][1]))
all_mean_distances = np.array(all_mean_distances)
all_std_distances = np.array(all_std_distances)
np.save(save_folder_name + '/all_mean_distances.npy', all_mean_distances)
np.save(save_folder_name + '/all_std_distances.npy', all_std_distances)

all_mean_angles = []
all_std_angles = []
for angle_indices in ace_chi_info["dihedral_indices"][:, :3]:
    angle_indices = tuple(sorted(angle_indices))
    all_mean_angles.append(all_bond_angle_info[angle_indices][0])
    all_std_angles.append(np.sqrt(kbT / all_bond_angle_info[angle_indices][1]))
all_mean_angles = np.array(all_mean_angles)
all_std_angles = np.array(all_std_angles)
np.save(save_folder_name + '/all_mean_angles.npy', all_mean_angles)
np.save(save_folder_name + '/all_std_angles.npy', all_std_angles)

np.save(save_folder_name + '/dihedral_indices.npy', ace_chi_info["dihedral_indices"])

all_dihedral_names.append("ACE")

all_atom_names = list(tpp_top.atoms)
np.save(save_folder_name + '/all_atom_names.npy', all_atom_names)


nme_chi_info = side_chain_info["chi"][4]
nme_c_info = np.zeros((nme_chi_info["dihedrals"].shape[0], 6))
save_folder_name = 'sidechain_info/NME/'
os.makedirs(save_folder_name, exist_ok=True)
np.save(save_folder_name + '/dihedrals.npy', nme_chi_info["dihedrals"])
np.save(save_folder_name + '/distances.npy', nme_chi_info["distances"])
np.save(save_folder_name + '/angles.npy', nme_chi_info["angles"])
np.save(save_folder_name + '/c_info.npy', nme_c_info)


all_dihedral_names.append("NME")

all_mean_distances = []
all_std_distances = []
for distance_indices in nme_chi_info["dihedral_indices"][:, :2]:
    distance_indices = tuple(sorted(distance_indices))
    all_mean_distances.append(all_bond_length_info[distance_indices][0])
    all_std_distances.append(np.sqrt(kbT / all_bond_length_info[distance_indices][1]))
all_mean_distances = np.array(all_mean_distances)
all_std_distances = np.array(all_std_distances)
np.save(save_folder_name + '/all_mean_distances.npy', all_mean_distances)
np.save(save_folder_name + '/all_std_distances.npy', all_std_distances)

all_mean_angles = []
all_std_angles = []
for angle_indices in nme_chi_info["dihedral_indices"][:, :3]:
    angle_indices = tuple(sorted(angle_indices))
    all_mean_angles.append(all_bond_angle_info[angle_indices][0])
    all_std_angles.append(np.sqrt(kbT / all_bond_angle_info[angle_indices][1]))
all_mean_angles = np.array(all_mean_angles)
all_std_angles = np.array(all_std_angles)
np.save(save_folder_name + '/all_mean_angles.npy', all_mean_angles)
np.save(save_folder_name + '/all_std_angles.npy', all_std_angles)

np.save(save_folder_name + '/dihedral_indices.npy', nme_chi_info["dihedral_indices"])
all_atom_names = list(tpp_top.atoms)
np.save(save_folder_name + '/all_atom_names.npy', all_atom_names)




In [8]:
all_dihedral_names = np.loadtxt("sidechain_info/dihedral_names.txt", dtype=str)
all_dihedral_names

array(['TYR_TYR_ASP', 'TYR_TYR_ASP_FIRST', 'ASP_PRO_GLU', 'GLY_THR_TRP',
       'PRO_GLU_THR', 'GLU_THR_GLY', 'TYR_ASP_PRO', 'THR_TRP_TYR',
       'THR_TRP_TYR_LAST', 'ACE', 'NME'], dtype='<U17')

In [9]:
all_dihedral_names_n_components = []
for dihedral_names in all_dihedral_names:
    if dihedral_names.split("_")[-1] == "FIRST":
        residue_name = dihedral_names.split("_")[0]
        print(dihedral_names, residue_name)
    elif dihedral_names.split("_")[-1] == "LAST":

        residue_name = dihedral_names.split("_")[2]
        print(dihedral_names, residue_name)
    elif dihedral_names == "NME":
        residue_name = "NME"
    elif dihedral_names == "ACE":
        residue_name = "ACE"
    else:
        residue_name = dihedral_names.split("_")[1]
    offset = np.array(res_code_to_res_info[residue_name].n_components)
    all_dihedral_names_n_components.append([dihedral_names, str(offset)])

TYR_TYR_ASP_FIRST TYR
THR_TRP_TYR_LAST TYR


In [10]:
all_offsets_by_dihedral = {}
for dihedral_names in all_dihedral_names:
    if dihedral_names.split("_")[-1] == "FIRST":
        residue_name = dihedral_names.split("_")[0]
        print(dihedral_names, residue_name)
    elif dihedral_names.split("_")[-1] == "LAST":

        residue_name = dihedral_names.split("_")[2]
        print(dihedral_names, residue_name)
    elif dihedral_names == "NME":
        residue_name = "NME"
    elif dihedral_names == "ACE":
        residue_name = "ACE"
    else:
        residue_name = dihedral_names.split("_")[1]
    offset = np.array(res_code_to_res_info[residue_name].all_offset)
    all_offsets_by_dihedral[dihedral_names] = offset

TYR_TYR_ASP_FIRST TYR
THR_TRP_TYR_LAST TYR


In [11]:
np.savetxt("sidechain_info/dihedral_name_n_components.txt", all_dihedral_names_n_components, fmt="%s")

In [12]:
# for dihedral_names in all_dihedral_names:
#     dihedrals = np.load(f"sidechain_info/{dihedral_names}/dihedrals.npy")
#     offset = all_offsets_by_dihedral[dihedral_names]
#     np.save(f"sidechain_info/{dihedral_names}/dihedrals_offset.npy", offset)
#     print(dihedrals.shape, offset.shape)
#     dihedrals_offset = dihedrals + offset
#     to_subtract = (dihedrals_offset > np.pi) * 2 * np.pi
#     to_add = (dihedrals_offset < -np.pi) * 2 * np.pi
#     dihedrals_offset -= to_subtract
#     dihedrals_offset += to_add


#     in_dim = dihedrals.shape[1]
#     fig, ax = plt.subplots(2, in_dim, figsize=((in_dim) * 5, 10))
#     fig.suptitle(f"tpp_name: {dihedral_names}", fontsize=36)
#     for i in range(in_dim):
#         ax[0, i].hist(dihedrals[:, i], bins=np.arange(-np.pi, np.pi, 0.1), density=True, label="Data")
#         ax[1, i].hist(dihedrals_offset[:, i], bins=np.arange(-np.pi, np.pi, 0.1), density=True, label="Data")
#         ax[0, i].set_title(f"Input {i}", fontsize=24)
#     plt.show()
