In [2]:
import plotly.express as px
import numpy as np
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import plotly.express as px
from rdkit.Chem import AllChem
import pdb
from HEML.utils.xyz2mol import xyz2mol, xyz2AC_vdW, AC2mol, xyz2AC_huckel

atom_int_dict = {
    'H': 1,
    'C': 6,
    'N': 7,
    'O': 8,
    'F': 9,
    'P': 15,
    'S': 16,
    'Cl': 17,
    'Br': 35,
    'Fe': 26, 
    'FE': 26, 
    'I': 53
}

int_atom_dict = {
    1: 'H',
    6: 'C',
    7: 'N',
    8: 'O',
    9: 'F',
    15: 'P',
    16: 'S',
    17: 'Cl',
    35: 'Br',
    26: 'Fe',
    53: 'I'
}

atomic_size = {
    'H': 0.5,
    'C': 1.7,
    'N': 1.55,
    'O': 1.52,
    'F': 1.47,
    'P': 1.80,
    'S': 1.80,
    'Cl': 1.75,
    'Br': 1.85,
    'Fe': 1.80,
    'I': 1.98
}

atom_colors = {
    'H': 'white',
    'C': 'black',
    'N': 'blue',
    'O': 'red',
    'F': 'orange',
    'P': 'green',
    'S': 'yellow',
    'Cl': 'green',
    'Br': 'brown',
    'Fe': 'orange',
    'I': 'purple'
}


In [3]:
from HEML.utils.visualization import  shift_and_rotate
from HEML.utils.data import  get_nodes_and_edges_from_pdb

def plot_nodes_edge(file = "../../data/pdbs_processed/1a4e.pdb"): 
    
    G = nx.Graph()
    atom_list, bond_list, xyz_list = get_nodes_and_edges_from_pdb("../../data/pdbs_processed/1a4e.pdb", distance_filter= 8.0)
    
    NA_pos = [129.775,  39.761,  38.051]
    NB_pos = [130.581,  41.865,  36.409]
    NC_pos = [131.320,  43.348,  38.639]
    ND_pos = [130.469,  41.267,  40.273]
    Fe_pos = [130.581,  41.541,  38.350]
    center = np.mean([NA_pos, NB_pos, NC_pos, ND_pos], axis = 0)
    x_axis = np.array(NA_pos) - np.array(Fe_pos)
    x_axis = x_axis / np.linalg.norm(x_axis)
    y_axis = np.array(NB_pos) - np.array(Fe_pos)
    y_axis = y_axis / np.linalg.norm(y_axis)
    z_axis = np.cross(y_axis, x_axis)
    z_axis = z_axis / np.linalg.norm(z_axis)

    xyz_list = shift_and_rotate(
        xyz_list, 
        center = center, 
        x_axis = x_axis,
        y_axis = y_axis,
        z_axis = z_axis
    )


    for i in range(len(atom_list)):
        G.add_node(i, 
        xyz=xyz_list[i], 
        atom=atom_list[i]
        )
        
    for i in range(len(bond_list)):
        G.add_edge(
            bond_list[i][0], 
            bond_list[i][1]
            )


    edge_x, edge_y, edge_z = [], [], []
    node_x, node_y, node_z = [], [], []

    for edge in G.edges():
        x0, y0, z0  = G.nodes[edge[0]]['xyz']
        x1, y1, z1 = G.nodes[edge[1]]['xyz']
        edge_x+=[x0, x1, None]
        edge_y+=[y0, y1, None]
        edge_z+=[z0, z1, None]

    for node in G.nodes():
        x, y, z = G.nodes[node]['xyz']
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)

    scalar = 10
    color = [atom_colors[int_atom_dict[G.nodes[i]["atom"]]] for i in G.nodes]
    size = [scalar * atomic_size[int_atom_dict[G.nodes[i]["atom"]]] for i in G.nodes]

    trace_nodes = go.Scatter3d(x=node_x, 
                            y=node_y, 
                            z=node_z, 
                            mode="markers",
                            #hoverinfo='text',
                            #hover_name='title',
                            text = [int_atom_dict[i] for i in atom_list],
                            marker = dict(
                                    symbol='circle', 
                                    size=size,
                                    color=color,
                                    colorscale='Viridis',
                                    opacity= 0.8
                            ))
        
    trace_edges = go.Scatter3d(
        x=edge_x, 
        y=edge_y, 
        z=edge_z, 
        line=dict(width=1, color="#000000"), 
        hoverinfo='none', 
        mode='lines')


    return trace_edges, trace_nodes
    #fig.show()

trace_edges, trace_nodes = plot_nodes_edge()


In [4]:
from HEML.utils.data import *
from HEML.utils.attrib import *
from HEML.utils.model import *

def get_cones_viz_from_pca(vector_scale = 3, components = 10, data_file = "../../data/protein_data.csv", dir_fields = "../../data/cpet/"): 

    cones = []

    x, _ = pull_mats_w_label(dir_data = data_file, dir_fields = dir_fields)
    arr_min, arr_max,  = np.min(x), np.max(x)
    #x = (x - arr_min) / np.abs(arr_max - arr_min + 0.1)
    # getting sign of every element
    x_sign = np.sign(x)
    # getting absolute value of every element
    x_abs = np.abs(x)
    # applying log1p
    x_log1p = np.log1p(x_abs)
    # getting sign back
    x = np.multiply(x_log1p, x_sign)
    
    x_untransformed = x
    x_pca, pca_obj = pca(x, verbose = True, pca_comps = components, write = False) 
    shape_mat = x.shape


    for ind,pca_comp in enumerate(pca_obj.components_):
        comp_vect_field = pca_comp.reshape(shape_mat[1], shape_mat[2], shape_mat[3], shape_mat[4])

        x, y, z = np.meshgrid(
                        np.arange(-3, 3.3, 0.3),
                        np.arange(-3, 3.3, 0.3),
                        np.arange(-3, 3.3, 0.3)
                        )

        u_1, v_1, w_1 = split_and_filter(
            comp_vect_field, 
            cutoff=95, 
            std_mean=True, 
            min_max=False
            )
        
        cones.append(go.Cone(
            x=x.flatten(), 
            y=y.flatten(), 
            z=z.flatten(), 
            u=u_1,
            v=v_1, 
            w=w_1,
            sizeref=vector_scale,
            opacity=0.99))
        
    return cones 
        
vector_field_pca = get_cones_viz_from_pca(vector_scale = 5, components = 10)

21 71 95
[0.3168702  0.26424932 0.19303652 0.01938781 0.01341084 0.01035656
 0.00911447 0.00854392 0.00689237 0.00609141]


In [5]:
# Important Components from Last Run
#3       0.089 +/- 0.019
#7       0.062 +/- 0.016
#6       0.049 +/- 0.016
#8       0.031 +/- 0.012
#4       0.028 +/- 0.013
#1       0.023 +/- 0.009

In [6]:
fig = go.Figure(data=[trace_edges, trace_nodes, vector_field_pca[3]],
            layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_layout(yaxis_range=[-5,5], xaxis_range=[-5,5])
fig.show()
fig.write_html("test_important_boruta.html")

In [7]:
fig = go.Figure(data=[trace_edges, trace_nodes, vector_field_pca[7]],
            layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_layout(yaxis_range=[-5,5], xaxis_range=[-5,5])
fig.show()
fig.write_html("test_not_important_boruta.html")

In [8]:
fig = go.Figure(data=[trace_edges, trace_nodes, vector_field_pca[6]],
            layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_layout(yaxis_range=[-5,5], xaxis_range=[-5,5])
fig.show()

In [9]:
fig = go.Figure(data=[trace_edges, trace_nodes, vector_field_pca[8]],
            layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_layout(yaxis_range=[-5,5], xaxis_range=[-5,5])
fig.show()

In [10]:
fig = go.Figure(data=[trace_edges, trace_nodes, vector_field_pca[4]],
            layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_layout(yaxis_range=[-5,5], xaxis_range=[-5,5])
fig.show()

In [11]:
fig = go.Figure(data=[trace_edges, trace_nodes, vector_field_pca[1]],
            layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_layout(yaxis_range=[-5,5], xaxis_range=[-5,5])
fig.show()

In [12]:
fig = go.Figure(data=[trace_edges, trace_nodes, vector_field_pca[0]],
            layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
fig.update_layout(yaxis_range=[-5,5], xaxis_range=[-5,5])
fig.show()