In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from tqdm import tqdm
import pathlib
import math
import sklearn
import torch_optimizer as optim
from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau
from metrics import *
from torch_geometric.nn import GCNConv, GATConv, global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep,global_sort_pool
from torch_geometric.utils import dropout_adj
import networkx as nx

import biographs as bg
from Bio import SeqIO
from Bio.PDB.PDBParser import PDBParser

In [None]:

# Check if CUDA (GPU) is available and set the device accordingly
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


# Get the embeddings

In [None]:
from bio_embeddings.embed import SeqVecEmbedder

workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/WT_R963A_test'
# Dictionary for getting Residue symbols
ressymbl = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU':'E', 'PHE': 'F', 'GLY': 'G', 'HIE': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN':'Q', 'ARG':'R', 'SER': 'S','THR': 'T', 'VAL': 'V', 'TRP':'W', 'TYR': 'Y'}

# get structure from a pdb file
# Uses biopython
def get_structure(file):
    parser = PDBParser()
    structure = parser.get_structure(id, file) # return a Structure object.
    return structure

# Function to get sequence from pdb structure
# Uses structure made using biopython
# Those residues for which symbols are U / X are converted into A
def get_sequence(structure):
    sequence =""
    for model in structure:
      for chain in model:
        for residue in chain:
          if residue.get_resname() in ressymbl.keys():
              sequence = sequence+ ressymbl[residue.get_resname()]
    return sequence

# Define the protein sequence
# seq = 'MKVTKVGGISHKKYTSEGRLVKSESEENRTDERLSALLNMRLDMY' 
# seq = get_sequence(get_structure(os.path.join(workdir, "R963A/raw/frame_1480.pdb")))
seq = get_sequence(get_structure(os.path.join(workdir, "WT/raw/frame_1480.pdb")))
# Create an instance of the SeqVecEmbedder
embedder = SeqVecEmbedder()

# Obtain the embedding of the protein sequence
embedding = embedder.embed(seq)

# Convert the embedding to a PyTorch tensor on CPU
protein_embd = torch.tensor(embedding).sum(dim=0).cpu().numpy()
print(protein_embd.shape)

In [None]:
import MDAnalysis as mda
import glob

# workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/WT_R963A_test/R963A_cluster'
workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/WT_R963A_test/WT_cluster'

pdb_files = glob.glob(os.path.join(workdir,'*.pdb'))
print(pdb_files)
print(pdb_files[0].split('/')[-1])

for pdb in pdb_files:
    u = mda.Universe(pdb)
    name  = pdb.split('/')[-1]
    pdb_filename = os.path.join(workdir, f"raw/{name}")

    ag = u.select_atoms("protein")

    # Write the coordinates of the current frame to the PDB file
    with mda.Writer(pdb_filename, bonds=None, n_atoms=ag.atoms.n_atoms) as pdb:
        pdb.write(ag.atoms)

    print(f"Saved PDB file: {pdb_filename}")



# Get the protein graphs

In [None]:
from torch_geometric.data import Dataset, Dataset, download_url, Data,  Batch

# list of 20 proteins
pro_res_table = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

# Dictionary for getting Residue symbols
ressymbl = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU':'E', 'PHE': 'F', 'GLY': 'G', 'HIE': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN':'Q', 'ARG':'R', 'SER': 'S','THR': 'T', 'VAL': 'V', 'TRP':'W', 'TYR': 'Y'}


class ProteinDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(ProteinDataset, self).__init__(root, transform=None,
                 pre_transform=None)
        self.data = self.processed_paths
        #self.processed_dir = "../human_features/processed/"
        # self.data = torch.load(self.processed_paths)
        # print("Daata si {}". format(self.data))

    @property
    def raw_file_names(self):

#         return [filename for filename in os.scandir(self.root)] # A list of files in the raw directory which needs to be found in order to skip the download. (this file path is also raw_path)
        files = [filename for filename in os.scandir(self.root+"/raw")]
#         print(files)
#         print(self.raw_paths)
        return files


    
    @property
    def processed_file_names(self):
#         print(self.raw_paths)
        file_names = [file.split('/')[0]+"/"+file.split('/')[1]+"/"+file.split('/')[4] for file in self.raw_paths]
#         file = self.raw_paths[0]
#         print(file_names)
#         return [os.path.splitext(os.path.basename(file))[0]+'.pt' for file in self.raw_paths] # A list of files in the processed_dir which needs to be found in order to skip the processing. If *.pt not found, data will be processed
        return [os.path.splitext(os.path.basename(file))[0]+'.pt' for file in file_names]
    def download(self):
        # Download to `self.raw_dir`.
        pass

    def process(self):
        # Read data into huge `Data` list.
        self.data = self.processed_paths


        data_list =[]
        count = 0
        file_names = [file.split('/')[0]+"/"+file.split('/')[1]+"/"+file.split('/')[4] for file in self.raw_paths]
#         for file in tqdm(self.raw_paths): # tqdm is a progress bar library that can be used to visualize progress when iterating through a sequence (in this case, file paths).
        for file in tqdm(file_names): # tqdm is a progress bar library that can be used to visualize progress when iterating through a sequence (in this case, file paths).
           if(pathlib.Path(file).suffix ==".pdb"): # checks if the current file has a ".pdb" extension. This is done using pathlib.Path(file).suffix, which returns the file extension.

               try:
                struct = self._get_structure(file) # extract structural information from the PDB file (file)
               except:
                print('except', file)
                continue
               seq = self._get_sequence(struct) # extract the sequence information from the structural information (struct)

          # node features extracted
              #  node_feats = self._get_one_hot_symbftrs(seq) # extract node features (one-hot encoded symbol features) from the sequence (seq)
               node_feats = torch.tensor(protein_embd) # extract node features (one-hot encoded symbol features) from the sequence (seq)

          #edge-index extracted


               mat = self._get_adjacency(file) # extract the adjacency matrix (mat) from the PDB file (file

           # if sequence size > matrix dimensions
               if(mat.shape[0] < torch.Tensor.size(node_feats)[0]) :
                 #node_feats = torch.tensor(ftrs.item()[os.path.splitext(os.path.basename(file))[0]])
                 edge_index = self._get_edgeindex(file, mat)

                 print(f'Node features size :{torch.Tensor.size(node_feats)}')
                 print(f'mat size :{mat.shape}')
          # create data object

                 data = Data(x = node_feats, edge_index = edge_index )
                 count += 1
                 data_list.append(data)
                 torch.save(data, self.processed_dir + "/"+ os.path.splitext(os.path.basename(file))[0]+'.pt')


               elif mat.shape[0] == torch.Tensor.size(node_feats)[0] :
                 #node_feats = torch.tensor(ftrs.item()[os.path.splitext(os.path.basename(file))[0]])
                 edge_index = self._get_edgeindex(file, mat)


                 print(f'Node features size :{torch.Tensor.size(node_feats)}')
                 print(f'mat size :{mat.shape}')

          # create data object

                 data = Data(x = node_feats, edge_index = edge_index )
                 count += 1

                 data_list.append(data)
                 torch.save(data, self.processed_dir + "/"+ os.path.splitext(os.path.basename(file))[0]+'.pt')

        self.data_prot = data_list
        print(count)


        # data, slices = self.collate(data_list)
        # torch.save((data, slices), self.processed_paths[0])

    def __len__(self):
        return len(self.processed_file_names)


    # file stands for file path
    def __getitem__(self, idx):

        return self.data_prot[idx]

    # Biographs returns the network as a networkx.Graph object
    def _get_adjacency(self, file):
        edge_ind =[]
        molecule = bg.Pmolecule(file)
        network = molecule.network() # network.nodes[:10] can be used for output residue nodes; network = molecule.network(cutoff=8) to increase distance cut-off from 5A (default to 8A)
        mat = nx.adjacency_matrix(network) # calculates the adjacency matrix of a NetworkX graph ()
        m = mat.todense() # convert to dense matrix (probably internally using scipy library)
        return m


    # get adjacency matrix in coo format to pass in GCNN model
    def _get_edgeindex(self, file, adjacency_mat):
        edge_ind = []
        m = self._get_adjacency(file) #
        #check_symmetric(m, rtol=1e-05, atol=1e-08)

        a = np.nonzero(m > 0)[0] # find the indices of nonzero elements in the matrix m. [row]
        b = np.nonzero(m > 0)[1] # find the indices of nonzero elements in the matrix m. [column]
        edge_ind.append(a) # These arrays represent the edge indices of the nonzero elements in the adjacency matrix.
        edge_ind.append(b)
        return torch.tensor(np.array(edge_ind), dtype= torch.long) # creates a PyTorch tensor from the concatenated arrays in edge_ind, converting it to a long data type tensor


    # get structure from a pdb file
    # Uses biopython
    def _get_structure(self, file):
        parser = PDBParser()
        structure = parser.get_structure(id, file) # return a Structure object.
        return structure

    # Function to get sequence from pdb structure
    # Uses structure made using biopython
    # Those residues for which symbols are U / X are converted into A
    def _get_sequence(self, structure):
        sequence =""
        for model in structure:
          for chain in model:
            for residue in chain:
              if residue.get_resname() in ressymbl.keys():
                  sequence = sequence+ ressymbl[residue.get_resname()]
        return sequence


    # One hot encoding for symbols
    def _get_one_hot_symbftrs(self, sequence):
        one_hot_symb = np.zeros((len(sequence),len(pro_res_table)))
        row= 0
        for res in sequence:
          col = pro_res_table.index(res)
          one_hot_symb[row][col]=1
          row +=1
        return torch.tensor(one_hot_symb, dtype= torch.float)


    # Residue features calculated from pcp_dict
    def _get_res_ftrs(self, sequence):
        res_ftrs_out = []
        for res in sequence:
          res_ftrs_out.append(pcp_dict[res])
        res_ftrs_out= np.array(res_ftrs_out)
        #print(res_ftrs_out.shape)
        return torch.tensor(res_ftrs_out, dtype = torch.float)


    # total features after concatenating one_hot_symbftrs and res_ftrs
    def _get_node_ftrs(self, sequence):
        one_hot_symb = one_hot_symbftrs(sequence)
        res_ftrs_out = res_ftrs(sequence)
        return torch.tensor(np.hstack((one_hot_symb, res_ftrs_out)), dtype = torch.float)


In [None]:
# prot_graphs = ProteinDataset(root="R963A_cluster/")
prot_graphs = ProteinDataset(root="WT_cluster/")

1. Node feature shape: torch.Size([11530, 1024])
The node features for all nodes in the batch (10*1153), with each node having 1024 features. The first dimension (11530) represents the total number of nodes across all graphs in the batch.
2. Batch tensor: tensor([0, 0, 0,  ..., 9, 9, 9])
The batch tensor is a 1-dimensional tensor where each element indicates the graph index for the corresponding node.
3. Edge index shape: torch.Size([2, 145596])
The edge indices for all edges in the batch. The first dimension (2) is because edges are represented as pairs of node indices (source and target). The second dimension (145734) is the total number of edges in the batch.

# Define the model

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, num_features_pro, hidden_channels, num_classes):
        super(GCN, self).__init__()
#         torch.manual_seed(12345)
#         self.conv1 = GCNConv(num_features_pro, hidden_channels)
        self.conv1 = GCNConv(num_features_pro, num_features_pro)
#         self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv2 = GCNConv(num_features_pro, hidden_channels)
        self.fc1 = Linear(hidden_channels, 128)
        
        self.fc2 = Linear(128, 64)
        self.out = Linear(64, num_classes)
        self.relu = torch.nn.LeakyReLU()
        self.dropout = torch.nn.Dropout(0.2)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
#         x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # flatten
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        
        # add some dense layers
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.out(x)

        # 3. Apply a final classifier
#         x = F.dropout(x, p=0.5, training=self.training)
#         x = self.lin(x)
        
        return x




In [None]:
file_path = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/test/crRNA_cluster/processed/cluster_c3.pt'
data = torch.load(file_path)

In [None]:
print(data.x)
print(data.batch)
print(data.edge_index)

### Load the model state and run the predictor

In [None]:
model = GCN(num_features_pro=1024, hidden_channels=256, num_classes=2)
model.load_state_dict(torch.load('best_model.pt'))
print(model)

In [None]:
y_pred = []
model.eval()
with torch.no_grad():
    # for data in loader:  # Iterate in batches over the dataset.
    data = data.to(device)
    out = model(data.x, data.edge_index, data.batch) 
    
    # Use softmax to get probabilities
    prob = torch.softmax(out, dim=1)
    pred = prob.argmax(dim=1)  # Use the class with highest probability
    
    # y_true.extend(data.y.cpu().numpy())
    y_pred.extend(pred.cpu().numpy())
print(f"Predicted class of the given input is {y_pred}")
print(f"Output of the model: {out}")
print(f"Probability of the output obtained using siftmax: {prob}")

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer

# Define the explainer
explainer_pyg = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='graph',
        return_type='raw',
    ),
)

# Select a sample from the testset
# data = testset[0].to(device)

# Explain the prediction for this graph
explanation = explainer_pyg(data.x, data.edge_index, batch=data.batch)

# Print edge mask and node mask
print("Edge mask:", explanation.edge_mask)
print("Node mask:", explanation.node_mask)



In [None]:
print(explanation.node_mask.shape)
print(explanation.edge_mask.shape)
print(sorted(explanation.edge_mask, reverse=True)[:10])

In [None]:
# Function to explain a single graph
def explain_graph(data):
    explanation = explainer_pyg(data.x, data.edge_index, batch=data.batch)
    return explanation.edge_mask

# Collect explanations for multiple graphs
edge_masks = []
for data in testset:
    data = data.to(device)
    edge_mask = explain_graph(data)
    edge_masks.append(edge_mask.cpu().detach().numpy())

# Compute the average edge mask
average_edge_mask = np.mean(edge_masks, axis=0)

In [None]:
print(len(edge_masks[60]))
# print(edge_masks)
for i, count in enumerate(edge_masks):
    print(count.shape)
# combined_array = np.stack(edge_masks)

# # Sum over the first axis
# sum_array = np.sum(combined_array, axis=0)
# print(sum_array.shape)

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

def visualize_top_subgraphs(data, output, edge_mask, top_n=10):
    # Convert data to networkx graph
    G = to_networkx(data, to_undirected=True)
    
    # Set edge weights from edge mask
    for i, (u, v) in enumerate(G.edges()):
        G[u][v]['weight'] = edge_mask[i].item()
    
    # Get top-N edges based on weight
    top_edges = sorted(G.edges(data=True), key=lambda x: x[2]['weight'], reverse=True)[:top_n]
    top_edges = [(u, v) for u, v, w in top_edges]

    # Create a subgraph containing the top-N edges
    subgraph = G.edge_subgraph(top_edges)
    
    # Plot the subgraph
    pos = nx.spring_layout(subgraph)
#     pos = nx.shell_layout(subgraph)
    edge_weights = [subgraph[u][v]['weight'] for u, v in subgraph.edges()]
    nx.draw(subgraph, pos, with_labels=True, edge_color=edge_weights, edge_cmap=plt.cm.Reds, node_color='lightblue')
    
   # Add edge labels (weights) on top of the edges
#     edge_labels = {(u, v): f'{w:.4f}' for u, v, w in subgraph.edges(data='weight')}
#     nx.draw_networkx_edge_labels(subgraph, pos, edge_labels=edge_labels)
#     nx.draw_networkx_edge_labels(subgraph, pos, edge_labels={(u, v): f'{w:.2f}' for u, v, w in subgraph.edges(data='weight')})
    plt.savefig(output, format="PNG")
    plt.show()

# Example usage (assuming 'data' and 'edge_mask' are defined)
visualize_top_subgraphs(data, "test1_top30.png", explanation.edge_mask, top_n=30)
# visualize_top_subgraphs(data, average_edge_mask, top_n=20)


### Visualize explaination (top n edges) using Plotly

In [None]:
import networkx as nx
from torch_geometric.utils import to_networkx
import plotly.graph_objects as go

def visualize_top_subgraphs_plotly(data, edge_mask, top_n=10, save_path=None):
    # Convert data to networkx graph
    G = to_networkx(data, to_undirected=True)
    
    # Set edge weights from edge mask
    for i, (u, v) in enumerate(G.edges()):
        G[u][v]['weight'] = edge_mask[i].item()
    
    # Get top-N edges based on weight
    top_edges = sorted(G.edges(data=True), key=lambda x: x[2]['weight'], reverse=True)[:top_n]
    top_edges = [(u, v) for u, v, w in top_edges]

    # Create a subgraph containing the top-N edges
    subgraph = G.edge_subgraph(top_edges)

    edge_weights = [subgraph[u][v]['weight'] for u, v in subgraph.edges()]
    # edge_weights = [float(subgraph[u][v]['weight']) for u, v in subgraph.edges()]

    # Get positions for all nodes in subgraph
    pos = nx.spring_layout(subgraph)

    edge_x = []
    edge_y = []
    for edge in subgraph.edges(data=True):
        # x, y = [], []
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        # x.append(x0)
        # x.append(x1)
        # x.append(None)

        # y.append(y0)
        # y.append(y1)
        # y.append(None)
        
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
    # print(edge_x)
    
    # Create Plotly edge traces
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=5, color='#888'),
        hoverinfo='none',
        mode='lines'
    )

    node_x = []
    node_y = []
    node_text = []
    
    for node in subgraph.nodes():
            x, y = pos[node]
            node_x.append(x)
            node_y.append(y)
            node_text.append(str(node))
        
    # Create Plotly node traces
    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        text=node_text,
        mode='markers+text',
        # textposition='top center',
        textposition='middle center',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            size=30,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2)
    )

    node_adjacencies = []
    node_text = []
    for node, adjacencies in enumerate(subgraph.adjacency()):
        node_adjacencies.append(len(adjacencies[1]))
        node_text.append('# of connections: '+str(len(adjacencies[1])))
    
    node_trace.marker.color = node_adjacencies
    # node_trace.text = node_text



    # fig = go.Figure(data=[node_trace]+edge_trace[0:],
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        # title='<br>Network graph of top subgraph edges',
                        # paper_bgcolor='white',
                        plot_bgcolor='white',
                        titlefont_size=20,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        annotations=[dict(
                            text=f"Top {top_n} subgraphs",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002)],
                        xaxis=dict(showgrid=False, zeroline=False),
                        yaxis=dict(showgrid=False, zeroline=False),
                        width=1000,
                        height=800)
                    )

    # Save the plot if save_path is provided
    if save_path:
        fig.write_image(save_path)
    fig.show()

# Example usage (assuming 'data' and 'edge_mask' are defined)
visualize_top_subgraphs_plotly(data, explanation.edge_mask, top_n=40, save_path='top40_crRNA_cluster5.png')


In [None]:
# 4 660 6 667 5 32 7 676 664 663 8 2 21 20 22 19 11 3 9 25 147 16 363 361 12 17 18 10 146 151
# 30 48 63 31 96 35 33 

# Getting isomorphic graphs/sub graphs from the individual cluster graphs:

In [None]:
import networkx as nx
from networkx.algorithms import isomorphism
from networkx.algorithms.isomorphism import ISMAGS

def find_common_subgraphs(graphs):
    """
    Find common subgraphs between multiple NetworkX graphs.

    Parameters:
    graphs (list): List of NetworkX graphs.

    Returns:
    list: List of common subgraphs as NetworkX graphs.
    """
    common_subgraphs = []

    if not graphs:
        return common_subgraphs

    base_graph = graphs[0]

    for graph in graphs[1:]:
        gm = ISMAGS(base_graph, graph)
        for subgraph_nodes in gm.find_isomorphisms():
            subgraph = base_graph.subgraph(subgraph_nodes.keys())
            common_subgraphs.append(subgraph)
        # gm = isomorphism.GraphMatcher(base_graph, graph)
        # for subgraph_nodes in gm.subgraph_isomorphisms_iter():
        #     subgraph = base_graph.subgraph(subgraph_nodes.keys())
        #     common_subgraphs.append(subgraph)

    return common_subgraphs

# Example usage
# G1 = nx.fast_gnp_random_graph(10, 0.5, seed=1)
# G2 = nx.fast_gnp_random_graph(9, 0.5, seed=1)
# G3 = nx.fast_gnp_random_graph(8, 0.5, seed=1)

# common_subgraphs = find_common_subgraphs([G1, G2, G3])

common_subgraphs = find_common_subgraphs([c4_graph, c3_graph])

for i, subgraph in enumerate(common_subgraphs):
    print(f"Common Subgraph {i+1}:")
    print(subgraph.nodes())
    print(subgraph.edges())


In [None]:
print(c4_graph, c3_graph)
print(common_subgraphs)

In [None]:
import matplotlib.pyplot as plt

def visualize_subgraphs(subgraphs):
    for i, subgraph in enumerate(subgraphs):
        plt.figure(figsize=(5, 5))
        pos = nx.spring_layout(subgraph)
        nx.draw(subgraph, pos, with_labels=True, node_color='lightblue', edge_color='gray')
        plt.title(f"Common Subgraph {i+1}")
        plt.show()

# Visualize the common subgraphs
visualize_subgraphs(common_subgraphs)


# Run the prediction on all the cluster centers, compose all cluster graphs to one & also look for intrsections:

In [None]:

import plotly.graph_objs as go
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.utils import to_networkx

## Return a subgraph containing top N edges
def top_subgraphs(data, edge_mask, top_n=30):
    # Convert data to networkx graph
    G = to_networkx(data, to_undirected=True)
    
    # Set edge weights from edge mask
    for i, (u, v) in enumerate(G.edges()):
        G[u][v]['weight'] = edge_mask[i].item()
    
    # Get top-N edges based on weight
    top_edges = sorted(G.edges(data=True), key=lambda x: x[2]['weight'], reverse=True)[:top_n]
    top_edges = [(u, v) for u, v, w in top_edges]
    
    # Create a subgraph containing the top-N edges
    subgraph = G.edge_subgraph(top_edges)
    return subgraph

def get_graphs(pt_files, model_state, top_n=30):
    # model = GCN(num_features_pro=1024, hidden_channels=256, num_classes=2)
    model = GCN(num_features_pro=1024, hidden_channels=256, num_classes=3)
    model.load_state_dict(torch.load(model_state))

    graphs = []

    # Define the explainer
    explainer_pyg = Explainer(
        model=model,
        algorithm=GNNExplainer(epochs=200),
        explanation_type='model',
        node_mask_type='attributes',
        edge_mask_type='object',
        model_config=dict(
            mode='binary_classification',
            task_level='graph',
            return_type='raw',
        ),
    )
    for file in pt_files:
        data = torch.load(file)

        ## Evaluate output
        y_pred = []
        model.eval()
        with torch.no_grad():
            # for data in loader:  # Iterate in batches over the dataset.
            # data = data.to(device)
            out = model(data.x, data.edge_index, data.batch) 
            
            # Use softmax to get probabilities
            prob = torch.softmax(out, dim=1)
            pred = prob.argmax(dim=1)  # Use the class with highest probability
            
            # y_true.extend(data.y.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())
        print(f"Predicted class of the given input is {y_pred}")
        print(f"Output of the model: {out}")
        print(f"Probability of the output obtained using softmax: {prob}")

        # Explain the prediction for this graph
        explanation = explainer_pyg(data.x, data.edge_index, batch=data.batch)

        graph = top_subgraphs(data, explanation.edge_mask, top_n=top_n)
        graphs.append(graph)
        name = file.split('/')[-1].split('.')[0]
        visualize_graph_plotly(graph, save_path= os.path.join(workdir, f'top{top_n}_{name}.png'))

    return graphs

def visualize_graph_plotly(graph, node_dist = None, save_path=None):
    pos = nx.spring_layout(graph, k = node_dist)
    
    edge_x = []
    edge_y = []
    for edge in graph.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append([x0, x1, None])
        edge_y.append([y0, y1, None])
    
    edge_trace = go.Scatter(
        x=[x for edge in edge_x for x in edge],
        y=[y for edge in edge_y for y in edge],
        line=dict(width=3, color='gray'),
        hoverinfo='none',
        mode='lines')
    
    node_x = [pos[node][0] for node in graph.nodes()]
    node_y = [pos[node][1] for node in graph.nodes()]
    node_degree = [graph.degree[node] for node in graph.nodes()]
    
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=[str(node) for node in graph.nodes()],
        # textposition='top center',
        textposition='middle center',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            # colorscale='YlGnBu',
            colorscale='tealgrn',
            size=30,
            color = node_degree,
            # cmin = 0,
            # cmax = 6,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right',
            ),
            line_width=2))
    
    
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Intersection Graph',
                        plot_bgcolor='white',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        annotations=[dict(
                            text=" ",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002)],
                        xaxis=dict(showgrid=False, zeroline=False),
                        yaxis=dict(showgrid=False, zeroline=False),
                    width=1000,
                        height=1000)
                   )
    fig.update_layout(
        yaxis = dict(tickfont = dict(size=20)),
        xaxis = dict(tickfont = dict(size=20)))
    
    # Save the plot if save_path is provided
    if save_path:
        fig.write_image(save_path)
    fig.show()



In [None]:
# workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/WT_R963A_test/R963A_cluster/processed/'
workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/win0_vs_win16_win56/frame_1/processed/'
pt_files = glob.glob(os.path.join(workdir,'frame*.pt'))
model_state = 'best_model_GCNN.pt'

graphs_state0 = get_graphs(pt_files, model_state, top_n=60 )

In [None]:
# workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/WT_R963A_test/R963A_cluster/processed/'
workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/win0_vs_win16_win56/frame_16/processed/'
pt_files = glob.glob(os.path.join(workdir,'frame*.pt'))
model_state = 'best_model_GCNN.pt'

graphs_state1 = get_graphs(pt_files, model_state, top_n=60 )

In [None]:
# workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/WT_R963A_test/R963A_cluster/processed/'
workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/win0_vs_win16_win56/frame_56/processed/'
pt_files = glob.glob(os.path.join(workdir,'frame*.pt'))
model_state = 'best_model_GCNN.pt'

graphs_state2 = get_graphs(pt_files, model_state, top_n=60 )

In [None]:
# visualize_top_subgraphs_plotly(data, explanation.edge_mask, top_n=40, save_path='top40_crRNA_cluster5.png')
# for file in pt_files:
#     print(file)
#     data = torch.load(file)
#     print(file.split('/')[-1].split('.')[0])

### Finding intersection

In [None]:
# import networkx as nx
from collections import Counter
from mpl_toolkits.axes_grid1 import make_axes_locatable

workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/win0_vs_win16_win56/'

all_edges = []
for G in graphs_state0:
    all_edges.extend(G.edges)

# Step 2: Count occurrences of each edge using a Counter
edge_counts = Counter(all_edges)
# print(edge_counts)

# Step 3: Create a new graph for the intersection based on the criterion (minimum 3 graphs)
G_intersection = nx.Graph()

## Calculate frequency edges appeared in top-n edge list
x_ax, y_ax, z = [], [], []
for edge, count in edge_counts.items():
    # print(edge[0], edge[1], count)
    x_ax.append(edge[0])
    y_ax.append(edge[1])
    z.append(count/100)

## Plot
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

new_z = [i*100 for i in z]
# Create the scatter plot
fig, ax = plt.subplots(figsize=(10, 8))
# plt.figure(figsize=(8,6))
# sc = ax.scatter(x_ax, y_ax, c=z, s=new_z, linewidth=0.2, edgecolors='black', cmap='viridis_r', alpha=0.8)
sc = ax.scatter(x_ax, y_ax, c=z, s=new_z, linewidth=0.2, edgecolors='black', cmap='viridis_r', alpha=0.7)

# Add a color bar
# plt.colorbar(sc, label='Frequency')

# Labels and title
plt.xlim(0, 552)
plt.ylim(0, 552)

plt.xlabel('Dimer', fontsize = 24, fontweight = 'bold')
plt.ylabel('Dimer', fontsize = 24, fontweight = 'bold')

# Add horizontal and vertical lines
monomer_A = [55,65,95,138,146,197,275]  # Vertical line at x=5
monomer_B = [i+276 for i in monomer_A]  # Horizontal line at y=7

text_loc_A = [25, 60, 80, 120, 142, 170, 230]
text_loc_B = [i+276 for i in text_loc_A]

domain = ['N-ter', 'Walker-A', 'A-ISM linker', 'ISM', 'Walker-B', 'B-Cter linker', 'C-ter']
for j,i in enumerate(monomer_A):
    plt.axvline(x=i, color='r', linestyle='--', linewidth=0.5)
    plt.text(text_loc_A[j] , 555, f'{domain[j]}', color='r', rotation=45, fontsize=9)

    plt.axhline(y=i, color='r', linestyle='-.', linewidth=0.5)
    plt.text(555, text_loc_A[j], f'{domain[j]}', color='r',rotation=45, fontsize=9)

for j,i in enumerate(monomer_B):
    plt.axvline(x=i, color='b', linestyle='--', linewidth=0.5)
    plt.text(text_loc_B[j] , 555, f'{domain[j]}', color='b', rotation=45, fontsize=9)
    
    plt.axhline(y=i, color='b', linestyle='-.', linewidth=0.5)
    plt.text(555, text_loc_B[j], f'{domain[j]}', color='b', rotation=45, fontsize=9)

plt.axvline(x=275, color='black', linestyle='--', linewidth=1)
plt.text(275-10 , 555, f'MonomerA->B', color='black', rotation=45, fontsize=12)

plt.axhline(y=275, color='black', linestyle='-.', linewidth=1)
plt.text(555, 275-2, f'MonomerA->B', color='black',rotation=45, fontsize=12)

# Create a divider to adjust the position of the colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=1.2)  # You can adjust 'size' and 'pad' to move the colorbar

# Add the colorbar to the custom axes
cbar = fig.colorbar(sc, cax=cax)
cbar.set_label('Frequency')

# plt.show()
# plt.xticks([i for i in range(0, 10)], [5*i for i in range(1, 11)], fontsize = 22)
plt.xticks(fontsize = 22)
plt.yticks(fontsize = 22)
plt.legend(frameon = False, fontsize = 20)
plt.savefig(workdir + 'state_0_top60.png', dpi = 600, bbox_inches = 'tight')

## Filter out common top-ranked edges
# for edge, count in edge_counts.items():
    # if count >= 40:  # Minimum intersection criterion: 3 graphs
    #     G_intersection.add_edge(*edge)

# # Print the edges of the intersection graph
# print("Edges in the intersection graph (appearing in at least 50 graphs):")
# print(list(G_intersection.edges))

# # Visualize the graph (optional)
# import matplotlib.pyplot as plt
# # nx.draw(G_intersection, with_labels=True, node_color='lightblue', edge_color='red')
# # plt.show()
# visualize_graph_plotly(G_intersection, save_path='state0_top30edges_intersection_cutoff40.png')


For state 1:

In [None]:
# workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/WT_R963A_test/R963A_cluster/processed/'
workdir = '/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/frame_56/processed/'
pt_files = glob.glob(os.path.join(workdir,'frame_*.pt'))
model_state = 'best_model.pt'

graphs = get_graphs(pt_files, model_state, top_n=30 )

In [None]:
# import networkx as nx
from collections import Counter

all_edges = []
for G in graphs:
    all_edges.extend(G.edges)

# Step 2: Count occurrences of each edge using a Counter
edge_counts = Counter(all_edges)
# print(edge_counts)

# Step 3: Create a new graph for the intersection based on the criterion (minimum 3 graphs)
G_intersection = nx.Graph()

for edge, count in edge_counts.items():
    if count >= 5:  # Minimum intersection criterion: 3 graphs
        G_intersection.add_edge(*edge)

# Print the edges of the intersection graph
print("Edges in the intersection graph (appearing in at least 50 graphs):")
print(list(G_intersection.edges))

# Visualize the graph (optional)
import matplotlib.pyplot as plt
# nx.draw(G_intersection, with_labels=True, node_color='lightblue', edge_color='red')
# plt.show()
visualize_graph_plotly(G_intersection, save_path='/Users/souviksinha/Desktop/Palermo_Lab/LabWork/GNN_classifier/PPI_GNN/Chinmai_proj/frame_56/processed/state1_top30edges_intersection_cutoff5.png')


### Finding union

In [None]:
# C = nx.compose_all([c4_graph, c3_graph, c5_graph])
# print(list(C.nodes()))

# print(list(C.edges()))
# visualize_graph_plotly(C, node_dist=0.14, save_path='test_compose.png' )

In [None]:

C = nx.compose_all(graphs)
print(list(C.nodes()))

print(list(C.edges()))
# visualize_graph_plotly(C, node_dist=0.1, save_path='cluster0-10_top30edges_compose.png' )
visualize_graph_plotly(C, node_dist=0.6, save_path='frame1*_top30edges_compose.png' )

In [None]:
a = [0, 1, 2, 3, 4, 5, 8, 136, 10, 139, 9, 140, 143, 19, 147, 29, 32, 33, 36, 37, 1028, 133, 398, 400, 404, 406, 663, 408, 407, 665, 410, 1052, 1057, 417, 419, 421, 423, 180, 181, 183, 184, 75, 78, 79, 80, 81, 82, 83, 84, 85, 86, 340, 88, 87, 92, 93, 113, 115, 118, 119, 120, 123, 127, 64, 35, 39, 40, 42, 43, 309, 137, 396, 403, 550, 551, 426, 182, 185, 186, 187, 188, 189, 193, 72, 73, 74, 76, 77, 464, 336, 103, 104, 105, 114, 117, 68, 47, 71, 38]
sele = "resid " + ' '.join(map(str, [i+1 for i in a])) + ' and not name C N O and noh'
print(sele)