In [13]:
from pathlib import Path
import prody
import numpy as np
import torch
import dgl
import pickle
from dgl.dataloading import GraphDataLoader

from typing import Iterable, Union, List, Dict

class Unparsable(Exception):
    pass

prody.confProDy(verbosity='none')

#processing code adapted from Ian Dunn: https://github.com/Dunni3/keypoint-diffusion/blob/main/data_processing/pdbbind_processing.py

#hard-code in element map for now (find all elements across pockets and hard code it)

In [83]:
#function to parse the receptors
def parse_pocket(pocket_path): #reads in pdb file of a receptor(binding pocket) into a prody AtomGroup
    
    receptor = prody.parsePDB(str(pocket_path))

    if receptor is None: #errors in reading in a pocket
        raise Unparsable
    
    return receptor

#function to return atom positions, features
def get_pocket_atoms(rec_atoms, element_map):

    #position, features and indices for all pocket atoms
    rec_atom_positions = rec_atoms.getCoords()
    rec_atom_features, other_atoms_mask = receptor_featurizer(element_map=element_map, rec_atoms=rec_atoms)

    #convert positions and features to tensors
    rec_atom_positions = torch.tensor(rec_atom_positions).float()
    rec_atom_features = torch.tensor(rec_atom_features).float()

    # remove "other" atoms from the receptor
    rec_atom_positions = rec_atom_positions[~other_atoms_mask]
    rec_atom_features = rec_atom_features[~other_atoms_mask]

    return rec_atom_positions, rec_atom_features


#function to featurize the receptor atoms
def receptor_featurizer(element_map, rec_atoms, protein_atom_elements = None):

    if rec_atoms is None and protein_atom_elements is None:
        raise ValueError
    
    if protein_atom_elements is None:
        protein_atom_elements: np.ndarray = rec_atoms.getElements()

    #one-hot encode atom elements
    onehot_elements = one_hot_encode(protein_atom_elements, element_map)

    #mask "other" atoms
    other_atoms_mask = torch.tensor(onehot_elements[:, -1] == 1).bool()

    #remove "other" category from onehot_elements, assuming they are last in the one-hot encoding
    protein_atom_features = onehot_elements[:, :-1]

    return protein_atom_features, other_atoms_mask


#function to one-hot encode all atoms of the receptor
def one_hot_encode(atom_elements: Iterable, element_map: Dict[str, int]):

    def element_to_idx(element_str, element_map=element_map):
        try:
            return element_map[element_str]
        except KeyError:
            return element_map['other']

    element_idxs = np.fromiter((element_to_idx(element) for element in atom_elements), int)
    onehot_elements = np.zeros((element_idxs.size, len(element_map.values())))
    onehot_elements[np.arange(element_idxs.size), element_idxs] = 1

    return onehot_elements

#function to build a graph from receptor atoms using dgl
def build_pocket_graph(atom_positions: torch.Tensor, atom_features: torch.Tensor, k: int, edge_algorithm: str):
    #add functionality for radius graphs too

    g = dgl.knn_graph(atom_positions, k=k, algorithm=edge_algorithm, dist='euclidean', exclude_self=True)
    g.ndata['x_0'] = atom_positions
    g.ndata['h_0'] = atom_features
    
    return g

In [84]:
#graph class to read in pocket and compute pocket representation

class GraphPocket:
#callable class to read a pocket and output the graph
    
    def __init__(self):

        #hard code element map and k for graph
        self.rec_elements = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'other':4}
        self.threshold_k = 3
        self.algorithm = 'bruteforce-blas'

    def __call__(self, pocket_path):

        pocket = parse_pocket(pocket_path)
        positions, features = get_pocket_atoms(pocket, self.rec_elements)

        graph = build_pocket_graph(positions, features, self.threshold_k, self.algorithm)

        return graph
        

In [85]:
pocket_to_graph = GraphPocket()

graph = pocket_to_graph(pocket_path='../../TOUGH-M1/data/11asA/11asA_pocket.pdb')

print(graph)

torch.Size([357, 4])
torch.Size([357, 4])
Graph(num_nodes=357, num_edges=1071,
      ndata_schemes={'x_0': Scheme(shape=(3,), dtype=torch.float32), 'h_0': Scheme(shape=(4,), dtype=torch.float32)}
      edata_schemes={})


In [28]:
#build a class for a tuple dataset from the positive and negative pairs

class GraphTupleDataset(dgl.data.DGLDataset):

    def __init__(self, name, pocket_list, pos_list, neg_list):
        self.pocket_list = pocket_list

        self.pos_list = list(filter(lambda p: p[0] in self.pocket_list and p[1] in self.pocket_list, pos_list))
        self.neg_list = list(filter(lambda p: p[0] in self.pocket_list and p[1] in self.pocket_list, neg_list))

        self.graphs = []
        self.labels = []

        #filter pos and neg based on pocket_list

        self.pocket_index_map = {}

        super().__init__(name=name)

    def __getitem__(self, index):

        pair, label = self.labels[index]
        pocket1 = self.graphs[self.pocket_index_map[pair[0]]]
        pocket2 = self.graphs[self.pocket_index_map[pair[1]]]

        return (pocket1, pocket2), label

    def __len__(self):

        return len(self.labels)
    
    def process(self):

        pocket_to_graph = GraphPocket()

        for i, pocket in enumerate(self.pocket_list):
            graph = pocket_to_graph(pocket_path=f'../../TOUGH-M1/data/{pocket}/{pocket}_pocket.pdb')
            self.graphs.append(graph)
            self.pocket_index_map[pocket] = i

        for pos_pair in self.pos_list:
            self.labels.append((pos_pair, 1))
        for neg_pair in self.neg_list:
            self.labels.append((neg_pair, 0))



In [55]:
##function to read pocket from graph, call the graphpocket class and generate a graph
from sklearn.model_selection import KFold, GroupShuffleSplit

#function for dataloading tuples of the pockets from pocket lists - used to get dataloader from a dataset class

def create_dataset(pos_path, neg_path, fold_nr, type, n_folds=5, seed=42):
    
    #load in the list of pocket and corresponding sequence clusters
    with open('../cluster_map.pkl', 'rb') as file:
        pocket_seq = pickle.load(file)

    pocket_list = [pdb[0] + pdb[1] for pdb in list(pocket_seq.keys())]

    pockets = list(pocket_seq.keys())
    clusters = list(pocket_seq.values())

    if type == 'seq':
        split = GroupShuffleSplit(n_splits=n_folds, test_size=1.0/n_folds, random_state=seed)
        folds = list(split.split(pockets, groups=clusters))
        train_index, test_index = folds[fold_nr] #fold number?
        pocket_train, pocket_test = [pocket_list[i] for i in train_index], [pocket_list[i] for i in test_index]
    
    if type == 'random':
        split = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
        folds = list(split.split(pocket_list))
        train_index, test_index = folds[fold_nr]
        pocket_train, pocket_test = [pocket_list[i] for i in train_index], [pocket_list[i] for i in test_index]
    
    print(pocket_test)


    with open(pos_path) as f:
        pos_pairs = [line.split()[:2] for line in f.readlines()]
    with open(neg_path) as f:
        neg_pairs = [line.split()[:2] for line in f.readlines()]

    train_dataset = GraphTupleDataset(name='train', pocket_list=pocket_train, pos_list=pos_pairs, neg_list=neg_pairs)

    test_dataset = GraphTupleDataset(name='test', pocket_list=pocket_test, pos_list=pos_pairs, neg_list=neg_pairs)

    return train_dataset, test_dataset

def get_dataloader(dataset, batch_size, num_workers, **kwargs):

    dataloader = GraphDataLoader(dataset, batch_size=batch_size, drop_last=False, num_workers=num_workers)

    return dataloader

t, tt = create_dataset('../data/TOUGH-M1_positive.list', '../data/TOUGH-M1_negative.list', 0, 'seq',5,42)


['1wnbC', '1dmyB', '3a5zG', '4iqgC', '4nz3A', '4jr0B', '3ixlA', '1jjvA', '3r77B', '2zofB', '4gqtA', '4m0rB', '4ckiA', '1a4eD', '2p0wA', '1jfpA', '3r7wB', '2vytA', '3qb8A', '4wn9B', '2bbsB', '1lmpA', '2bpbB', '3grsA', '3tazB', '1gwjA', '1wdnA', '3depA', '2v8mD', '4mmpA', '1us5A', '256bB', '1b4kA', '2z0kA', '1zzdA', '4ffbB', '3uedA', '4zyjA', '1woxB', '4lgdD', '1issB', '4ur8D', '3aq0D', '1sg4B', '2l6xA', '1mokC', '1sqcA', '2uvfA', '2ot9A', '1og2A', '4lhwB', '2pgzA', '1h3hA', '1r17A', '1kyfA', '3hkvA', '3ah1B', '1gx0A', '4rk7A', '4nhdD', '3okxB', '3vzbA', '3tnoA', '3vlzA', '4ww4A', '4lzcA', '4fbaB', '3vvdA', '2uzjA', '3hk9E', '3emmA', '451cA', '2f5xA', '4zbgA', '1i3zA', '4qwmA', '1yobA', '4pfrA', '2ns8D', '3gryA', '1ecmA', '3t9bA', '4nz6B', '2dr1A', '3tv1B', '1kqoC', '1yy5B', '3uoaB', '3rpnB', '1x0xA', '4ir7A', '1c0pA', '2e7uA', '3anxA', '3qomA', '4djtB', '4dxeB', '1kdmA', '3gg2C', '4tkgA', '2ifrA', '1siqA', '1xskA', '3whbA', '3il4C', '4lqsA', '4evrA', '2v72A', '2peyB', '3myuA', '3fuwA', 

In [57]:
print(tt)

Dataset("test", num_graphs=45175, save_path=/Users/vratinsrivastava/.dgl/test)


In [58]:
train_dataloader = get_dataloader(tt, batch_size=2, num_workers=1, shuffle=True, pin_memory=True)

print(train_dataloader.dataset)

Dataset("test", num_graphs=45175, save_path=/Users/vratinsrivastava/.dgl/test)


In [None]:
#test dataset

with open('../data/TOUGH-M1_positive.list') as f:
    pos_pairs = [line.split()[:2] for line in f.readlines()]
with open('../data/TOUGH-M1_positive.list') as f:
    neg_pairs = [line.split()[:2] for line in f.readlines()]
with open('../data/TOUGH-M1_pocket.list') as f:
    pocket_list = [line.split()[0] for line in f.readlines()]

print(pocket_list[:3])
print(pos_pairs[:3])

pocket_test = pocket_list[:3]
pos_test = [['11asA', '13pkD'],['11asA', '154lA']]
neg_test = [['13pkD', '154lA']]

dataset = GraphTupleDataset(name='mock_dataset', pocket_list=pocket_test, pos_list=pos_test, neg_list=neg_test)

sample, label = dataset[0]

print(f'Sample shape: {sample[0].ndata}, {sample[1].ndata}, Label: {label}')

# Print dataset length
print(f'Dataset length: {len(dataset)}')

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

nx_g = graph.to_networkx().to_undirected()

# Step 2: Extract Node Features (here, using 'x_0' for coloring)
# Normalize features for coloring or choose a specific feature
node_colors = graph.ndata['x_0'][:, 0].numpy()  # Use the first feature for coloring
node_colors = (node_colors - node_colors.min()) / (node_colors.max() - node_colors.min())

# Step 3: Visualize the Graph
plt.figure(figsize=(8, 8))
# Position nodes using the Spring layout
pos = nx.spring_layout(nx_g, seed=42)
# Draw the nodes (you can map node_colors to any colormap you like)
for edge in nx_g.edges():
    points = np.array([pos[edge[0]], pos[edge[1]]])
    plt.plot(points[:, 0], points[:, 1], 'k-', lw=0.5)

# Draw nodes
for node in nx_g.nodes():
    plt.scatter(pos[node][0], pos[node][1], c='b', s=50)

plt.show()

In [None]:
from Bio.PDB import PDBParser

parser = PDBParser()

structure = parser.get_structure(id = '1lasA', file = '../../TOUGH-M1/data/11asA/11asA_pocket.pdb')

atom_coords = []
for model in structure:
    for chain in model:
        for residue in chain:
            for atom in residue:
                atom_coords.append(atom.get_coord())  # Get 3D coordinates of each atom

In [None]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Plot atom positions
x, y, z = zip(*atom_coords)
ax.scatter(x, y, z)

ax.scatter(x, y, z, color='r')  # 'r' for nodes, change as needed

# Optionally, draw edges
# For demonstration, let's assume we have a simple edge list
edges = [(0, 1), (1, 2), (2, 3)]  # Example edges, replace with your graph's edges
for edge in edges:
    xs, ys, zs = zip(*[atom_coords[edge[0]], atom_coords[edge[1]]])
    ax.plot(xs, ys, zs, color='b')

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

In [None]:
src, dst = graph.edges()

# Convert the tensors to lists for easy handling
src_list = src.tolist()
dst_list = dst.tolist()

# Combine the source and destination lists into a list of tuples representing edges
edge_list = list(zip(src_list, dst_list))

node_coords = np.array(atom_coords)  # Example, replace with actual node positions

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Plotting graph nodes (overlain on the protein structure)
ax.scatter(node_coords[:, 0], node_coords[:, 1], node_coords[:, 2], color='r', s=50)  # Red nodes

# Plotting edges
for edge in edge_list:
    start_pos, end_pos = node_coords[edge[0]], node_coords[edge[1]]
    ax.plot([start_pos[0], end_pos[0]], [start_pos[1], end_pos[1]], [start_pos[2], end_pos[2]], color='g')  # Green edges

# You might want to adjust plot limits and other aesthetics
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

In [None]:
import torch
import dgl 
from dgl.dataloading import GraphDataLoader

import pickle

In [None]:
#shift the sequence cluster code here (provided you have the information)