<a href="https://colab.research.google.com/github/vvithurshan/Graph-Neural-Network-Model-generator-for-Protein-Complex/blob/main/grinn_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
import mdtraj as md
import os
import MDAnalysis as mda
import networkx as nx
import numpy as np
import torch
from Bio.SeqUtils import seq1
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict
import pickle
from sklearn.model_selection import train_test_split

In [None]:
df = pd.read_csv("Actual_12A21_VRC01.csv")

In [None]:
protbert_dict = defaultdict(lambda: None)

def ProtBert(aa):
    with torch.no_grad():
        if protbert_dict[aa] is not None:
            return protbert_dict[aa]
        tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert")
        model = AutoModel.from_pretrained("Rostlab/prot_bert")
        token = tokenizer.encode(aa, add_special_tokens=False)
        token_tensor = torch.tensor([token]) 
        embeddings = model(token_tensor).last_hidden_state
        protbert_dict[aa] = embeddings
    return embeddings

In [None]:
Graph_lst = []
IC50_lst = []
chain_encoding = {'A':0,'B':1,'C':2}
chain_encoding_for_energy = {'A':'H','B':'L','C':'G'}

for index, row in df.iterrows():
    name = row['Concat']
    ic50 = row['IC50_median']

    pdb_location = f"CHARMM-GRINN/{name}/{name}.pdb"
    if not os.path.isfile(pdb_location):
        continue

    pickle_loc = f"CHARMM-GRINN/{name}/grinn_output/energies.pickle"

    if not os.path.isfile(pickle_loc):
        continue

    pdb = md.load(pdb_location)
    top = pdb.topology
    u = mda.Universe(pdb_location)
    chain_ids = [i.segid for i in u.segments]

    G = nx.Graph()
    seq_lst = []
    for chain in top.chains:
        for residue in chain.residues:
            chain_id = chain_ids[chain.index]
            residue_name = residue.name
            residue_id = residue.resSeq
            residue = u.select_atoms(f"resid {residue_id}")
            ca_atom = residue.select_atoms("name CA")
            ca_atom_position = ca_atom.positions[0]
            node = f"{chain_id}_{residue_name}_{residue_id}"
            G.add_node(node)
            seq = f"{chain_encoding_for_energy[chain_id]}{residue_name}{residue_id}"
            seq_lst.append(seq)

            ## add feature
            AA = seq1(residue_name)
            G.nodes[node]['Encoding'] =ProtBert(AA)
            G.nodes[node]['Chain_id'] = chain_encoding[chain_id]
            G.nodes[node]['CA_coords'] = ca_atom_position

    ## pickle
    energies = pickle.load(open(pickle_loc,'rb'))
    # energies['EASN34-EGLN64']['Total']

    for pair_1 in range(len(seq_lst) - 1):
        for pair_2 in range(1, len(seq_lst)):
            pair = f"{seq_lst[pair_1]}-{seq_lst[pair_2]}"
            try:
                pairwise_energy_lst = energies[pair]['Total']
            except:
                continue
            else:
                pairwise_energy = np.mean(pairwise_energy_lst)
                if pairwise_energy <= -2:
                    G.add_edge(pair_1, pair_2, weight = pairwise_energy)

    Graph_lst.append(G)
    IC50_lst.append(ic50)

X_train, X_test, y_train, y_test = train_test_split(Graph_lst, IC50_lst, test_size=0.2)            

In [None]:
Graph_lst[0].nodes(data=True)

In [None]:
print(Graph_lst[0].edges(data=True))
