In [1]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

In [2]:
path = Path('../data')

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import griddata
from griddata.grid import Grid
import numba
from math import exp, sqrt, cos, sin

In [4]:
class GridPDB:
    def __init__(self, file):
        if file.endswith('pdb'):
            self.pdbfile = file
            self.parse_pdb()
        if file.endswith('mol2'):
            self.mol2file = file
            self.parse_mol2()
        
    def parse_mol2(self):
        self.atoms = []
        self.atomtypes = []
        self.coords = []
        flag = False
        for line in open(self.mol2file):
            if line.startswith("@<TRIPOS>ATOM"):
                flag = True
                continue
            if line.startswith("@<TRIPOS>BOND"):
                break
            if flag:
                name = line[8:16].strip()
                if name[0] == 'H': continue
                    
                x = line[16:26]
                y = line[26:36]
                z = line[36:46]
                self.atoms.append(name)
                self.atomtypes.append(name[0])
                self.coords.append(list(map(float, (x, y, z))))
                
        self.atoms = np.array(self.atoms)
        self.atomtypes = np.array(self.atomtypes)
        self.coords = np.array(self.coords, dtype=np.float32)
        self.center = np.average(self.coords, axis=0)
    
    def parse_pdb(self):
        self.atoms = []
        self.atomtypes = []
        self.coords = []
        for line in open(self.pdbfile):
            if line.startswith("ATOM"):
                name = line[11:17].strip()
                if name[0] == 'H': continue
                if name[0].isdigit(): continue
                    
                x = line[30:38]
                y = line[38:46]
                z = line[46:54]
                self.atoms.append(name)
                self.atomtypes.append(name[0])
                self.coords.append(list(map(float, (x, y, z))))
                
        self.atoms = np.array(self.atoms)
        self.atomtypes = np.array(self.atomtypes)
        self.coords = np.array(self.coords, dtype=np.float32)
        self.center = np.average(self.coords, axis=0)
    
    def compute_grid(self, size=20, spacing=1.0):
        nx, ny, nz = [int(size/spacing)+1 for _ in range(3)]
        xmin, ymin, zmin = [_-int(size/2) for _ in pdb.center]
        grid = np.zeros((nx, ny, nz), dtype=np.float32)
        self.ndelements = coords_to_grid_numba(self.coords, grid, nx, ny, nz, xmin, ymin, zmin, spacing)
    
    def save_grid(self, filename):
        g = Grid()
        g.n_elements = np.cumprod(self.elements.shape)
        g.center = list(self.center)
        g.shape = self.elements.shape
        g.spacing = (self.spacing, self.spacing, self.spacing)
        g.set_elements(self.ndelements.flatten())
        griddata.save(g, open(filename, 'w'), format='dx')

In [5]:
def coords_to_grid_np(coords, grid, nx, ny, nz, xmin, ymin, zmin, spacing, rvdw):
    assert grid.shape == (nx, ny, nz)
    ncoords = len(coords)
    X,Y,Z = np.mgrid[xmin:xmin+nx*spacing:spacing, 
                     ymin:ymin+ny*spacing:spacing,
                     zmin:zmin+nz*spacing:spacing]

    xyz = np.vstack((X.flatten(), Y.flatten(), Z.flatten())).T
    for i in range(ncoords):
        r = np.linalg.norm(xyz - (coords[i]), axis=1).reshape((nx, ny, nz))
        grid += 1 - np.exp(-(rvdw/r)**12)
    return grid

In [121]:
@numba.jit('f4[:,:,:](f4[:,:], f4[:,:,:], i8, i8, i8, f8, f8, f8, f8, f8)', nopython=True)
def coords_to_grid_numba(coords, grid, nx, ny, nz, xmin, ymin, zmin, spacing, rvdw):
    exps = 0.001
    rmax = 30
    expt = np.exp(-(rvdw/np.arange(0,rmax,exps))**12)
    nc = len(coords)
    for i in range(nx):
        ix = xmin + i*spacing
        for j in range(ny):
            iy = ymin + j*spacing
            for k in range(nz):
                iz = zmin + k*spacing
                for l in range(nc):
                    dx = ix - coords[l,0]
                    dy = iy - coords[l,1]
                    dz = iz - coords[l,2]
                    r = sqrt(dx*dx + dy*dy + dz*dz)
                    #grid[i,j,k] += 1 - exp(-(rvdw/r)**12)
                    if r > rmax: continue
                    grid[i,j,k] += 1 - expt[int(r/exps)]
    return grid

In [122]:
pdb = GridPDB('../../../2018/refined-set/10gs/10gs_pocket.pdb')
size = 20
rvdw = 1.4
spacing = 1.0
nx, ny, nz = [int(size/spacing)+1 for _ in range(3)]
xmin, ymin, zmin = [_-int(size/2) for _ in pdb.center]
grid = np.zeros((nx, ny, nz), dtype=np.float32)
%timeit coords_to_grid_numba(pdb.coords, grid, nx, ny, nz, xmin, ymin, zmin, spacing, rvdw)

29.3 ms ± 2.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [117]:
pdb = GridPDB('../../../2018/refined-set/10gs/10gs_pocket.pdb')
size = 20
spacing = 1.0
rvdw = 1.4
nx, ny, nz = [int(size/spacing)+1 for _ in range(3)]
xmin, ymin, zmin = [_-int(size/2) for _ in pdb.center]
grid = np.zeros((nx, ny, nz), dtype=np.float32)
%timeit coords_to_grid_np(pdb.coords, grid, nx, ny, nz, xmin, ymin, zmin, spacing, rvdw)

345 ms ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [123]:
grid = np.zeros((nx, ny, nz), dtype=np.float32)
grid1 = coords_to_grid_numba(pdb.coords, grid, nx, ny, nz, xmin, ymin, zmin, spacing, rvdw)
grid = np.zeros((nx, ny, nz), dtype=np.float32)
grid2 = coords_to_grid_np(pdb.coords, grid, nx, ny, nz, xmin, ymin, zmin, spacing, rvdw)

In [126]:
grid1

array([[[  8.47720439e-05,   1.05965046e-04,   4.41204465e-05, ...,
           5.44053762e-07,   1.26176860e-06,   7.78798858e-06],
        [  6.54129311e-04,   8.96209094e-04,   2.86681083e-04, ...,
           5.84005465e-06,   9.77592663e-06,   4.39566466e-05],
        [  1.57928886e-03,   2.78580957e-03,   1.80941517e-03, ...,
           1.61369229e-04,   2.89633434e-04,   2.19976049e-04],
        ..., 
        [  1.03132258e-07,   1.39413302e-07,   1.88481494e-07, ...,
           1.30732979e-05,   4.44853686e-06,   1.34741424e-06],
        [  3.96200221e-07,   4.81297945e-07,   6.02961109e-07, ...,
           3.73741500e-06,   1.52943824e-06,   5.03256842e-07],
        [  1.71884722e-06,   1.61948117e-06,   1.69106067e-06, ...,
           8.00277007e-07,   3.95969323e-07,   1.60187739e-07]],

       [[  1.67078874e-03,   2.37828190e-03,   4.83293523e-04, ...,
           6.20913909e-07,   1.41529472e-06,   9.44735802e-06],
        [  7.64495656e-02,   1.47052869e-01,   8.75425804e-0

In [125]:
np.sum(grid1 - grid2)

3.6914661

In [314]:
class PdbBindDataset(Dataset):
    def __init__(self, csvfile, rootdir, transform=None, filter_kd=False):
        self.df = pd.read_csv(csvfile)
        self.rootdir = rootdir
        self.transform = transform
        if filter_kd:
            self.df = self.df[self.df.afftype == 'Kd']
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        pdbfile = '{}/{}/{}_protein.pdb'.format(self.rootdir, row.code, row.code)
        pocketfile = '{}/{}/{}_pocket.pdb'.format(self.rootdir, row.code, row.code)
        ligandfile = '{}/{}/{}_ligand.mol2'.format(self.rootdir, row.code, row.code)
        sample = {
            'code': row.code,
            'pdbfile': pdbfile,
            'pocket': GridPDB(pocketfile),
            'ligand': GridPDB(ligandfile),
            'channels': [],
            'affinity': row.affinity
        }
        if self.transform:
            sample = self.transform(sample)
        return sample

In [335]:
class Channel:
    """Convert atomic coordinates into grid (channel)
    
    Args:
        atomtypes: list of atom types to convert into grid
        size: size of grid in angstrom
        spacing: grid spacing in angstrom
        rvdw: r_vdw parameter in grid
    """
    def __init__(self, atomtypes, size, spacing, rvdw):
        self.atomtypes = atomtypes
        self.size = size
        self.spacing = spacing
        self.rvdw = rvdw
    
    def __call__(self, sample):
        size = float(self.size)
        spacing = float(self.spacing)
        rvdw = float(self.rvdw)
        nx, ny, nz = [int(size/spacing)+1 for _ in range(3)]
        xmin, ymin, zmin = [_-size/2 for _ in sample['pocket'].center]
        idx = [_ in self.atomtypes for _ in sample['pocket'].atomtypes]
        grid = np.zeros((nx, ny, nz), dtype=np.float32)
        grid = coords_to_grid_numba(sample['pocket'].coords[idx], grid, 
                                    nx, ny, nz, xmin, ymin, zmin, spacing, rvdw)
        sample['channels'].append(grid)
        
        idx = [_ in self.atomtypes for _ in sample['ligand'].atomtypes]
        grid = np.zeros((nx, ny, nz), dtype=np.float32)
        grid = coords_to_grid_numba(sample['ligand'].coords[idx], grid, 
                                    nx, ny, nz, xmin, ymin, zmin, self.spacing, self.rvdw)
        sample['channels'].append(grid)
        return sample

class Rotate:
    """Rotate input structure
    
    Args:
        degree: maximum degree to rotate (+/-)
    """
    def __init__(self, degree):
        self.degree = degree
    
    def __call__(self, sample):
        theta = (np.random.random_sample(3,) - 0.5)*self.degree/180*np.pi
        rx = np.matrix((( 1,             0,              0),
                        ( 0, cos(theta[0]), -sin(theta[0])),
                        ( 0, sin(theta[0]),  cos(theta[0]))))
        ry = np.matrix((( cos(theta[1]), 0, sin(theta[1])),
                        (             0, 1,             0),
                        (-sin(theta[1]), 0, cos(theta[1]))))
        rz = np.matrix((( cos(theta[2]), -sin(theta[2]), 0),
                        ( sin(theta[2]),  cos(theta[2]), 0),
                        (             0,              0, 1)))
        r = rx * ry * rz
        sample['pocket'].coords = np.array(np.dot(r, (sample['pocket'].coords).T).T, dtype=np.float32)
        sample['ligand'].coords = np.array(np.dot(r, (sample['ligand'].coords).T).T, dtype=np.float32)
        return sample
    
class Center:
    """Center input structure"""
    def __call__(self, sample):
        com = sample['pocket'].center
        sample['pocket'].coords = sample['pocket'].coords - com
        sample['ligand'].coords = sample['ligand'].coords - com
        return sample
    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        grids = np.vstack([c[np.newaxis,:] for c in sample['channels']])
        return {
            'grids': torch.from_numpy(grids),
            'affinity': torch.from_numpy(np.array([sample['affinity']]))
        }

In [336]:
rotate = Rotate(90)
channel_c = Channel(['C'], 20, 1.0, 1.4)
channel_o = Channel(['O'], 20, 1.0, 1.4)
channel_n = Channel(['N'], 20, 1.0, 1.4)
composed = transforms.Compose([Center(),
                               rotate,
                               channel_c,
                               channel_o,
                               channel_n,
                               ToTensor()])

In [339]:
pdbbind_dataset = PdbBindDataset(csvfile=path/'refined_set.csv',
                                 rootdir='../../2018/refined-set/',
                                 filter_kd=True)
sample = pdbbind_dataset[0]

In [310]:
# test
print(sample['pdbfile'])
rotated = rotate(sample)['pocket'].coords
with open('test.pdb', 'w') as f:
    for c in rotated:
        f.write("%8.3f%8.3f%8.3f\n" % (c[0], c[1], c[2]))

../../2018/refined-set//2tpi/2tpi_protein.pdb


In [340]:
assert len(composed(sample)['grids'].shape) == 4

In [341]:
tfms = transforms.Compose([Center(),
                           rotate,
                           channel_c,
                           channel_o,
                           channel_n,
                           ToTensor()])
ds = PdbBindDataset(csvfile=path/'refined_set.csv',
                    rootdir='../../2018/refined-set/',
                    filter_kd=True,
                    transform=tfms)

In [342]:
dataloader = DataLoader(ds, batch_size=4,
                        shuffle=True, num_workers=0)