Skip to content

Commit

Permalink
Merge pull request #74 from NDoering99/main
Browse files Browse the repository at this point in the history
rework representative frame calculation with jit compiler (added dependency numba)
  • Loading branch information
talagayev committed Apr 2, 2024
2 parents 39f6eb9 + 02fcb1d commit 7b17ee3
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 77 deletions.
2 changes: 1 addition & 1 deletion openmmdl/_version.py
@@ -1 +1 @@
__version__ = "1.0.0+723.g961795f.dirty"
__version__ = "1.0.0+724.g9044002.dirty"
88 changes: 58 additions & 30 deletions openmmdl/openmmdl_analysis/binding_mode_processing.py
@@ -1,9 +1,12 @@
import os
import itertools
import pandas as pd
from MDAnalysis.analysis import rms
import numpy as np
from MDAnalysis.analysis import rms, diffusionmap
from MDAnalysis.analysis.distances import dist
from tqdm import tqdm
from pathlib import Path
from numba import jit


def gather_interactions(df, ligand_rings, peptide=None):
Expand Down Expand Up @@ -637,8 +640,39 @@ def update_values(df, new, unique_data):
values_to_update = new.loc[frame_value, list(unique_data.values())]
df.loc[idx, list(unique_data.values())] = values_to_update


def calculate_representative_frame(traj, bmode_frame_list, lig):
@jit
def calc_rmsd_2frames(ref, frame):
"""
RMSD calculation between a reference and a frame.
"""
dist = np.zeros(len(frame))
for atom in range(len(frame)):
dist[atom] = (
(ref[atom][0] - frame[atom][0]) ** 2
+ (ref[atom][1] - frame[atom][1]) ** 2
+ (ref[atom][2] - frame[atom][2]) ** 2
)

return np.sqrt(dist.mean())


def calculate_distance_matrix(pdb_md, selection):
distances = np.zeros((len(pdb_md.trajectory), len(pdb_md.trajectory)))
# calculate distance matrix
for i in tqdm(range(len(pdb_md.trajectory))):
pdb_md.trajectory[i]
frame_i = pdb_md.select_atoms(selection).positions
# distances[i] = md.rmsd(traj_aligned, traj_aligned, frame=i)
for j in range(i + 1, len(pdb_md.trajectory)):
pdb_md.trajectory[j]
frame_j = pdb_md.select_atoms(selection).positions
rmsd = calc_rmsd_2frames(frame_i, frame_j)
distances[i][j] = rmsd
distances[j][i] = rmsd
return distances


def calculate_representative_frame(bmode_frames, DM):
"""Calculates the most representative frame for a bindingmode. This is based uppon the averagwe RMSD of a frame to all other frames in the binding mode.
Args:
Expand All @@ -649,30 +683,24 @@ def calculate_representative_frame(traj, bmode_frame_list, lig):
Returns:
int: Number of the most representative frame.
"""
representative_frame = -1
min_mean_rmsd = 1000.0
for reference_frame in bmode_frame_list:
rmsd_values = []
traj.trajectory[reference_frame]
reference_frame_state = traj.select_atoms(
f"protein or nucleic or resname {lig}"
).positions
for frame in bmode_frame_list:
if frame != reference_frame:
traj.trajectory[frame]
calculation_frame_state = traj.select_atoms(
f"protein or nucleic or resname {lig}"
).positions
calc_rmsd = rms.rmsd(
reference_frame_state,
calculation_frame_state,
center=True,
superposition=True,
)
rmsd_values.append(calc_rmsd)
mean_rmsd = sum(rmsd_values) / len(rmsd_values)
if mean_rmsd < min_mean_rmsd:
min_mean_rmsd = mean_rmsd
representative_frame = reference_frame

return representative_frame
frames = bmode_frames
mean_rmsd_per_frame = {}
# first loop : first frame
for frame_i in frames:
mean_rmsd_per_frame[frame_i] = 0
# we will add the rmsd between theses 2 frames and then calcul the
# mean
for frame_j in frames:
# We don't want to calcul the same frame.
if not frame_j == frame_i:
# we add to the corresponding value in the list of all rmsd
# the RMSD betwween frame_i and frame_j
mean_rmsd_per_frame[frame_i] += DM[frame_i - 1, frame_j - 1]
# mean calculation
mean_rmsd_per_frame[frame_i] /= len(frames)

# Representative frame = frame with lower RMSD between all other
# frame of the cluster
repre = min(mean_rmsd_per_frame, key=mean_rmsd_per_frame.get)

return repre
Expand Up @@ -311,7 +311,6 @@ def binding_site_markov_network(
verticalalignment="center",
)


# Add the legend to the plot
plt.legend(handles=legend_handles, loc="upper right", fontsize=48)

Expand Down
103 changes: 69 additions & 34 deletions openmmdl/openmmdl_analysis/openmmdlanalysis.py
Expand Up @@ -52,6 +52,7 @@
df_iteration_numbering,
update_values,
calculate_representative_frame,
calculate_distance_matrix,
)
from openmmdl.openmmdl_analysis.markov_state_figure_generation import (
min_transition_calculation,
Expand Down Expand Up @@ -193,39 +194,54 @@ def main():
)
parser.add_argument(
"-rep",
dest = "representative_frame",
help = "Calculate the representative frame for each binding mode. Defaults to False",
default = False,
dest="representative_frame",
help="Calculate the representative frame for each binding mode. Defaults to False",
default=False,
)

parser.add_argument(
"--watereps",
dest="water_eps",
help="Set the Eps for clustering, this defines how big clusters can be spatially in Angstrom",
default=1.0,
)


pdb_md = None
input_formats = [".pdb", ".dcd", ".sdf", ".csv", ".tpr", ".xtc", "trr",]
input_formats = [
".pdb",
".dcd",
".sdf",
".csv",
".tpr",
".xtc",
"trr",
]
args = parser.parse_args()
if input_formats[0] not in args.topology and input_formats[4] not in args.topology:
print("Topology is missing, try the absolute path")
if input_formats[1] not in args.trajectory and input_formats[5] not in args.trajectory and input_formats[6] not in args.trajectory :
if (
input_formats[1] not in args.trajectory
and input_formats[5] not in args.trajectory
and input_formats[6] not in args.trajectory
):
print("Trajectory is missing, try the absolute path")

# set variables for analysis and preprocess input files
topology = args.topology
trajectory = args.trajectory
# enable gromacs support and write topology and trajectory files
if ".tpr" in args.topology and (".xtc" in args.trajectory or ".trr" in args.trajectory):
if ".tpr" in args.topology and (
".xtc" in args.trajectory or ".trr" in args.trajectory
):
print("\033[1mGromacs format detected. Writing compatible file formats.\033[0m")
u = mda.Universe(args.topology, args.trajectory)
with mda.Writer("trajectory.dcd", n_atoms=u.atoms.n_atoms) as W:
first_frame_saved = False
for ts in u.trajectory:
if not first_frame_saved:
with mda.Writer("topology.pdb", n_atoms=u.atoms.n_atoms) as pdb_writer:
with mda.Writer(
"topology.pdb", n_atoms=u.atoms.n_atoms
) as pdb_writer:
pdb_writer.write(u.atoms)
first_frame_saved = True
W.write(u.atoms)
Expand All @@ -234,7 +250,7 @@ def main():
trajectory = "trajectory.dcd"
water_eps = float(args.water_eps)
stable_water_analysis = bool(args.stable_water_analysis)

# The following is the current water analysis if no ligand is present.
if not args.ligand_sdf and args.peptide == None and stable_water_analysis:
print("All analyses will be run which can be done without a ligand present")
Expand Down Expand Up @@ -264,7 +280,7 @@ def main():
special_ligand = args.special_ligand
reference = args.reference
peptide = args.peptide

generate_representative_frame = args.representative_frame

if reference != None:
Expand All @@ -279,8 +295,7 @@ def main():

if not pdb_md:
pdb_md = mda.Universe(topology, trajectory)



# Writing out the complex of the protein and ligand with water around 10A of the ligand
complex = pdb_md.select_atoms(
f"protein or nucleic or resname {ligand} or (resname HOH and around 10 resname {ligand}) or resname {special_ligand}"
Expand Down Expand Up @@ -308,7 +323,9 @@ def main():
lig_rd_ring = lig_rd.GetRingInfo()
except AttributeError:
print("\033[1mCould not get the ring information.\033[0m")
print("\033[1mTry to remove lone pairs prior to running an analysis!\033[0m")
print(
"\033[1mTry to remove lone pairs prior to running an analysis!\033[0m"
)
exit()

# getting the index of the first atom of the ligand from the complex pdb
Expand Down Expand Up @@ -502,9 +519,9 @@ def main():

# Check if the fingerprint has been encountered before
if fingerprint in treshold_fingerprint_dict:
grouped_frames_treshold.at[index, "Binding_fingerprint_treshold"] = (
treshold_fingerprint_dict[fingerprint]
)
grouped_frames_treshold.at[
index, "Binding_fingerprint_treshold"
] = treshold_fingerprint_dict[fingerprint]
else:
# Assign a new label if the fingerprint is new
label = f"Binding_Mode_{label_counter}"
Expand Down Expand Up @@ -621,7 +638,9 @@ def main():
# Generate a dictionary for negative ionizables
pi_dict = generate_interaction_dict("pi", highlighted_pi)
# Generate a dictionary for pication
pication_dict = generate_interaction_dict("pication", highlighted_pication)
pication_dict = generate_interaction_dict(
"pication", highlighted_pication
)
# Generate a dictionary for metal interactions
metal_dict = generate_interaction_dict("metal", highlighted_metal)

Expand Down Expand Up @@ -672,15 +691,19 @@ def main():
f.write(svg)

# Convert the svg to an png
cairosvg.svg2png(url=f"{binding_mode}.svg", write_to=f"{binding_mode}.png")
cairosvg.svg2png(
url=f"{binding_mode}.svg", write_to=f"{binding_mode}.png"
)

# Generate the interactions legend and combine it with the ligand png
merged_image_paths = create_and_merge_images(
binding_mode, occurrence_percent, split_data, merged_image_paths
)

# Create Figure with all Binding modes
arranged_figure_generation(merged_image_paths, "all_binding_modes_arranged.png")
arranged_figure_generation(
merged_image_paths, "all_binding_modes_arranged.png"
)
generate_ligand_image(
ligand, "complex.pdb", "lig_no_h.pdb", "lig.smi", "ligand_numbering.svg"
)
Expand All @@ -706,24 +729,32 @@ def main():
"Percentage Occurrence": [],
}
if generate_representative_frame:
DM = calculate_distance_matrix(
pdb_md,
f"protein or nucleic or resname {ligand} or resname {special_ligand}",
)
modes_to_process = top_10_binding_modes.index
for mode in tqdm(modes_to_process):
result_dict["Binding Mode"].append(mode)
first_frame = grouped_frames_treshold.loc[
grouped_frames_treshold["Binding_fingerprint_treshold"].str.contains(mode),
grouped_frames_treshold["Binding_fingerprint_treshold"].str.contains(
mode
),
"FRAME",
].iloc[0]
all_frames = grouped_frames_treshold.loc[
grouped_frames_treshold["Binding_fingerprint_treshold"].str.contains(mode),
grouped_frames_treshold["Binding_fingerprint_treshold"].str.contains(
mode
),
"FRAME",
].tolist()
percent_occurrence = (top_10_binding_modes[mode] / total_binding_modes) * 100
percent_occurrence = (
top_10_binding_modes[mode] / total_binding_modes
) * 100
result_dict["First Frame"].append(first_frame)
result_dict["All Frames"].append(all_frames)
result_dict["Percentage Occurrence"].append(percent_occurrence)
representative_frame = calculate_representative_frame(
pdb_md, all_frames, ligand
)
representative_frame = calculate_representative_frame(all_frames, DM)
result_dict["Representative Frame"].append(representative_frame)
top_10_binding_modes_df = pd.DataFrame(result_dict)
top_10_binding_modes_df.to_csv("top_10_binding_modes.csv")
Expand Down Expand Up @@ -761,19 +792,25 @@ def main():
# Extract the string representation of the tuple
tuple_string = row2["LIGCOO"]
# Split the string into individual values using a comma as the delimiter
ligcoo_values = tuple_string.strip("()").split(",")
ligcoo_values = tuple_string.strip("()").split(
","
)
# Convert the string values to float
ligcoo_values = [
float(value.strip()) for value in ligcoo_values
float(value.strip())
for value in ligcoo_values
]

# Extract the string representation of the tuple for PROTCOO
tuple_string = row2["PROTCOO"]
# Split the string into individual values using a comma as the delimiter
protcoo_values = tuple_string.strip("()").split(",")
protcoo_values = tuple_string.strip("()").split(
","
)
# Convert the string values to float
protcoo_values = [
float(value.strip()) for value in protcoo_values
float(value.strip())
for value in protcoo_values
]

bindingmode_dict[column]["LIGCOO"].append(
Expand Down Expand Up @@ -818,9 +855,7 @@ def main():
}

for interaction_type, interaction_data in interaction_types.items():
plot_barcodes_grouped(
interaction_data, df_all, interaction_type
)
plot_barcodes_grouped(interaction_data, df_all, interaction_type)

plot_waterbridge_piechart(df_all, waterbridge_barcodes, waterbridge_interactions)
print("\033[1mBarcodes generated\033[0m")
Expand Down

0 comments on commit 7b17ee3

Please sign in to comment.