In [None]:
from collections import Counter

from sklearn.decomposition import PCA
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.constants import k, Avogadro

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import mdtraj as md

from EBC import EBC

In [None]:
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('font', size=48)

COLORMAP = np.array(['#C40E63','#001180','#F5B709','#0EA6C2','#DB611A','#C70202','#2EBD8E','#007828'])

In [None]:
def plot_graph_and_clusters(ebc, tau, n=8, overlap_cutoff=7, save_pdb_frames=False, save_name_cluster=None, save_name_graph=None):
    flow = np.linalg.matrix_power(ebc.diffusion_matrix, int(tau))
    target_set = [tuple(sorted(x)) for x in np.argsort(-flow, axis=1)[:, :n]]
    counter = Counter(target_set)
    nodes, edges, node_dict = [], [], {}
    for idk, key in enumerate(target_set):
        if counter[key] > 0 and key not in nodes:
            nodes.append(key)
        if key not in node_dict:
            node_dict[key] = []
        node_dict[key].append(idk)
    g = nx.Graph()
    g.add_nodes_from(nodes)
    for ida, node_a in enumerate(nodes):
        for idb, node_b in enumerate(nodes[ida+1:]):
            overlap_value = len(set(node_a) & set(node_b)) 
            if overlap_value >= overlap_cutoff: 
                edges.append((node_a, node_b))
    g.add_edges_from(edges)
    g = g.subgraph(sorted(nx.connected_components(g), key=len, reverse=True)[0])
    node_color = np.array([np.mean(ebc.pi[list(x)]) for x in g.nodes]) * 1e3

    plt.figure(0, figsize=(24, 20), dpi=120)
    ax = plt.gca()
    pos = nx.kamada_kawai_layout(g)
    nx.draw_networkx_edges(g, pos=pos, ax=ax, width=1, alpha=1, arrows=False, edge_color='black')
    pc = nx.draw_networkx_nodes(g, pos=pos, ax=ax, node_color=node_color, cmap=plt.cm.inferno)
    cbar = plt.colorbar(pc)
    cbar.set_label('Population', labelpad=30) 
    cbar.outline.set_visible(False)
    plt.box(False)
    ax.grid(False)
    plt.axis('off')
    if save_name_graph is not None:
        plt.savefig(save_name_graph, bbox_inches='tight')
    plt.show()

    node_color_ids, states_list, state_id_dict = [], [],  {}
    for key in g.nodes:
        cluster_id = Counter(ebc.proto_labels[list(key)]).most_common()[0][0]
        node_color_ids.append(cluster_id)
        if cluster_id not in state_id_dict:
            state_id_dict[cluster_id] = []
        state_id_dict[cluster_id].extend(key)
        states_list.append(node_dict[key])
    states_list = np.array(states_list)
    node_color = COLORMAP[np.array(node_color_ids)]
    labeldict = {}
    for cid, key in zip(node_color_ids, g.nodes):
        labeldict[key] = cid

    plt.figure(1, figsize=(20, 20), dpi=120)
    ax = plt.gca()
    nx.draw_networkx_edges(g, pos=pos, ax=ax, width=1, alpha=1, arrows=False, edge_color='black') 
    nx.draw_networkx_nodes(g, pos=pos, ax=ax, node_color=node_color)
    plt.box(False)
    ax.grid(False)
    plt.axis('off')
    if save_name_cluster is not None:
        plt.savefig(save_name_cluster, bbox_inches='tight')
    plt.show()    
    
    plt.figure(2, figsize=(20, 20), dpi=120)
    ax = plt.gca()
    nx.draw_networkx_edges(g, pos=pos, ax=ax, width=1, alpha=1, arrows=False, edge_color='black') 
    nx.draw_networkx_nodes(g, pos=pos, ax=ax, node_color=node_color)
    nx.draw_networkx_labels(g, pos, labeldict)
    plt.show()    
    for cluster_id in state_id_dict:
        state_ids = np.hstack(states_list[np.where(node_color_ids == cluster_id)])
        cluster_pop = ebc.pi[state_ids]
        max_pop_id = np.argmax(cluster_pop)
        max_pop = cluster_pop[max_pop_id]
        state_id = state_ids[max_pop_id]
        traj_id = ebc.select(state_id)
        if save_pdb_frames:
            trajectory[traj_id].save_pdb(f'frame_{cluster_id}_{tau}.pdb')
        print(rmsds[traj_id], cluster_id, np.log(max_pop), traj_id)

Use this cell if you have access to the trajectory file:

In [None]:
STRIDE = 1
potentials = np.load('data/TRP_CAGE_energies.npy', allow_pickle=True)[::STRIDE]
trajectory = md.load('data/TRP_CAGE_trajectory.dcd', top='data/TRP_CAGE.pdb', stride=STRIDE) # not included due to size
ref = md.load('data/TRP_CAGE.pdb')

heavy_indices = [ida for ida, a in enumerate(ref.topology.atoms) if a.element.symbol != 'H']
backbone_indices = ref.topology.select('backbone')
n_atoms = trajectory.n_atoms
trajectory = trajectory.atom_slice(backbone_indices)
ref = ref.atom_slice(backbone_indices)
trajectory = trajectory.superpose(ref)
rmsds = md.rmsd(trajectory, ref)

coords = trajectory.xyz
coords = np.reshape(coords, [coords.shape[0], -1])

pca = PCA(3)
pca_coords = pca.fit_transform(coords)
np.save('data/TRP_CAGE_pca_coords.npy', pca_coords)
ref.save_pdb('ref.pdb')

Otherwise, a low-dim version of the trajectory is available here

In [None]:
ref = md.load('data/TRP_CAGE.pdb')
n_atoms = ref.n_atoms

coords = np.load('data/TRP_CAGE_pca_coords.npy')
potentials = np.load('data/TRP_CAGE_energies.npy')

In [None]:
potentials *= 1000
potentials -= np.amin(potentials)
potentials /= n_atoms
potentials /= k * Avogadro

In [None]:
#ebc = EBC(temperature=100, proto_radius=0.5, pca_components=3, n_clusters=7, knn=16)
#ebc.fit(coords, potentials)
ebc = EBC(temperature=100, proto_radius=0.5, pca_components=3, n_clusters=7, knn=16)
ebc.fit(coords, potentials)

In [None]:
for tau in [75, 50, 60, 75, 100, 120, 125]:
    plot_graph_and_clusters(ebc, tau=tau, save_pdb_frames=True, save_name_cluster=f'TAU{tau}.pdf', save_name_graph=f'TAU{tau}_graph.pdf')

In [None]:
for tau in [60, 120]:
    plot_graph_and_clusters(ebc, tau=tau, save_pdb_frames=False)