In [4]:
import torch
from torch.autograd import Variable

In [5]:
import numpy
import scipy.optimize

In [6]:
import requests
import traitlets

In [7]:
import tmol.extern.py3dmol as py3dmol

In [8]:
import tmol.pdb_parsing as pdb_parsing

def fetch_pdb(pdbid):
    return requests.get("https://files.rcsb.org/download/%s.pdb" % str.upper(pdbid)).text

In [9]:
class FixedNamedAtomSystem(traitlets.TraitType):
    def __init__(self, atoms):
        self.atoms = atoms
    
    @property
    def dtype(self):
        return numpy.dtype([(n, "f4", 3) for n in self.atoms])
    
    def from_pdb(self, pdb):
        atoms = pdb_parsing.parse_pdb(pdb)
        atoms = atoms[atoms.apply(lambda r: r["atomn"] in self.atoms, axis=1)]
        for t in ("model", "chain"):
            assert atoms[t].nunique() == 1
        resi = atoms["resi"].unique()
        assert numpy.all(resi == numpy.arange(resi[0], resi[-1] + 1))
        
        atoms = atoms.set_index(["resi", "atomn"])
        
        result = numpy.empty_like(resi, dtype=self.dtype)
        
        for i, ri in enumerate(resi):
            for a in result.dtype.names:
                result[i][a] = atoms.loc[ri, a][["x", "y", "z"]].values
                
        return result
    
    def to_pdb(self, value, b = None):
        atom_records = numpy.zeros((len(value), 4), dtype=pdb_parsing.atom_record_dtype)

        atom_records["resn"] = "CEN"
        atom_records["chain"] = "X"
        atom_records["resi"] = numpy.arange(len(value)).reshape((-1, 1))

        for i, n in enumerate(value.dtype.names):
            atom_records[:,i]["atomn"] = n
            atom_records[:,i]["x"] = value[n][:,0]
            atom_records[:,i]["y"] = value[n][:,1]
            atom_records[:,i]["z"] = value[n][:,2]
            
        atom_records = atom_records.ravel()
        atom_records["atomi"] = numpy.arange(len(atom_records))
        if b is not None:
            atom_records["b"] = b
        
        return pdb_parsing.to_pdb(atom_records.ravel())
        
    def validate(self, obj, value):
        if isinstance(value, str):
            value = self.from_pdb(value)
            
        value = numpy.array(value, dtype=self.dtype, copy=False)
        assert value.ndim == 1
        
        return value

In [21]:
class BBModel(traitlets.HasTraits):
    state = FixedNamedAtomSystem(("N", "C", "CA", "O"))
    
    @property
    def coords(self):
        return self.state.view(dtype="f4").reshape((-1, 3))
    def compute_connectivity( self ) :
        self.connectivity = Variable( torch.Tensor( numpy.ones( [ self.coords.shape[0], self.coords.shape[0] ] ) * 5 ), requires_grad=False  )
        inds = Variable(torch.Tensor(numpy.arange(self.coords.shape[0])), requires_grad=False)
        ind_a = inds.view((-1, 1))
        ind_b = inds.view((1, -1))

        #temp_ind = Variable(torch.Tensor(numpy.arange(10)), requires_grad=False)
        #temp_ind_a = temp_ind.view((-1,1))
        #temp_ind_b = temp_ind.view((1,-1))
        #print( temp_ind_a )
        #print( temp_ind_b )
        #print( ( temp_ind_a % 4 == 0) & (temp_ind_b - temp_ind_a < 4) & (temp_ind_b - temp_ind_a >= 0))
        #temp_conn = Variable( torch.Tensor(numpy.ones([10,10])*5 ), requires_grad=False)
        #print( temp_conn )
        #temp_conn = torch.where(
        #    (torch.fmod(temp_ind_a,4) == 0) & (temp_ind_b - temp_ind_a < 4) & (temp_ind_b - temp_ind_a >= 0),
        #    temp_ind_b - temp_ind_a,
        #    temp_conn
        #)
        #print( temp_conn )
        #print( "(temp_ind_b - temp_ind_a).size()", (temp_ind_b - temp_ind_a).size() )
        #print( ind_b.size())
        #print( ind_a.size() )
        #print( "(ind_b - ind_a).size() ", (ind_b - ind_a).size() )
        #print( ((torch.fmod(ind_a,4) == 0) & (ind_b - ind_a < 4) & (ind_b - ind_a >= 0)).size() )
        #print( "self.connectivity.size()", self.connectivity.size() )
        #print( "(ind_b - ind_a).size()", (ind_b - ind_a).size() )

        # TOTALLY HACKY APPROACH TO CONNECTIVITY INFORMATION FOR 4-ATOM SYSTEM
        # N -- distance to atoms i+1, i+2, i+3
        self.connectivity = torch.where(
            (ind_a % 4 == 0) & (ind_b - ind_a < 4) & (ind_b - ind_a >= 0),
            (ind_b - ind_a),
            self.connectivity
        )

        # N -- distance to atoms i+4, i+5
        self.connectivity = torch.where(
            (ind_a % 4 == 0) & (ind_b - ind_a < 6) & (ind_b - ind_a >= 4),
            ind_b - ind_a - 1,
            self.connectivity
        )

        # CA -- distance to atoms i+1, i+2
        self.connectivity = torch.where(
            (ind_a % 4 == 1) & (ind_b - ind_a < 3) & (ind_b - ind_a >= 0),
            ind_b - ind_a,
            self.connectivity
        )

        # CA -- distance to atoms i+3, i+4, i+5
        self.connectivity = torch.where(
            (ind_a % 4 == 1) & (ind_b - ind_a < 6) & (ind_b - ind_a >= 3),
            ind_b - ind_a - 1,
            self.connectivity
        )

        # C -- distance to atom i+1
        self.connectivity = torch.where(
            (ind_a % 4 == 2) & (ind_b - ind_a < 2) & (ind_b - ind_a >= 0),
            ind_b - ind_a,
            self.connectivity
        )

        # C -- distance to atoms i+2, i+3, i+4, i+5
        self.connectivity = torch.where(
            (ind_a % 4 == 2) & (ind_b - ind_a < 6) & (ind_b - ind_a >= 2),
            ind_b - ind_a - 1,
            self.connectivity
        )

        # O -- distance to atoms i+1, i+2, i+3, i+4
        self.connectivity = torch.where(
            (ind_a % 4 == 2) & (ind_b - ind_a < 5) & (ind_b - ind_a >= 1),
            ind_b - ind_a + 1,
            self.connectivity
        )
        #print( self.connectivity[:10,:10])
        #print( torch.sum(((self.connectivity > 3 ) & (ind_a < ind_b )).int() ))
    
        # hack: 4 atom types for Nbb, CAbb, Cbb, and Obb
        self.atom_types = torch.LongTensor( numpy.arange(self.coords.shape[0]) % 4 )
    
class PDBModelViewer(object):
    def __init__(self, target, score=None):
        self.target = target
        self.view = py3dmol.view(1200, 600)
        
        self.target.observe(lambda c: self.update(), "state")
        self.score = score
        
        self.pdb = None
        
        self.update()
        self.view.zoomTo()
        self.update()
    
    def update(self):
        self.view.clear()
        self.pdb = self.target.traits()["state"].to_pdb(
            self.target.state,
            b= self.score(self.target.coords).atom_scores.numpy() if self.score else None)
        self.view.addModel(self.pdb, "pdb")
        if self.score:
            self.view.setStyle({"sphere" : {"colorscheme" : {"prop":'b',"gradient": 'rwb',"min":1,"max":-1}}})
        else:
            self.view.setStyle({"sphere" : {}})
        
        display(self.view.update())

class LJParams :
    def __init__( self ) :
        self.n_atom_types = 4
        self.radii = torch.Tensor( [ 1.802452, 2.011760, 1.916661, 1.540580 ] )
        self.r_ms = self.radii.view(1,-1)+self.radii.view(-1,1)
        self.r_ms[0,3] = 3.0 #lj_hbond_dis_
        self.r_ms[3,0] = 3.0 #lj_hbond_dis_
        self.dis2sigma = 0.6
        self.m_rm_over_eps = 12 * ( pow(self.dis2sigma,-7) - pow(self.dis2sigma,-13))
        self.b_over_eps = ( 13 * pow(self.dis2sigma,-12) - 14 * pow(self.dis2sigma,-6))
        self.atom_epsilons = torch.Tensor( [ 0.161725, 0.062642, 0.141799, 0.142417 ])
        self.epsilons = torch.sqrt( self.atom_epsilons.view(1,-1) * self.atom_epsilons.view(-1,1) )
        self.lj_atr_fade_begin_x = 5.5
        self.lj_atr_fade_end_x = 6.0
        self.lj_atr_fade_begin_val = self.epsilons * ( pow( self.radii / self.lj_atr_fade_begin_x, 12) + 2 * pow( self.radii / self.lj_atr_fade_begin_x, 6 ))
        self.lj_atr_fade_begin_slope = self.epsilons * ( 1 / self.lj_atr_fade_begin_x ) * ( -12 * pow( self.radii / self.lj_atr_fade_begin_x, 12) + 12 * pow( self.radii / self.lj_atr_fade_begin_x, 6 ))
        
        
class StupidLJScore:
    def __init__(self, connectivity, atom_types, lj_params ):
        self.connectivity = connectivity
        self.atom_types = atom_types
        self.lj_params = lj_params
    def __call__(self, coords):
        self.coords = Variable(torch.Tensor(coords), requires_grad=True)
        ind = Variable(torch.Tensor(numpy.arange(self.coords.shape[0])), requires_grad=False)

        atypes = self.atom_types.view(-1,1)*self.lj_params.n_atom_types + self.atom_types.view(1,-1)
        r_ms = Variable( self.lj_params.r_ms.view(-1)[ atypes.view(-1) ].view(-1,self.coords.shape[0]), requires_grad=False )
        epsilons = Variable( self.lj_params.epsilons.view(-1)[ atypes.view(-1) ].view(-1,self.coords.shape[0]), requires_grad=False )
        #for fading off the atractive component
        x0 = self.lj_params.lj_atr_fade_begin_x 
        y0 = Variable( self.lj_params.lj_atr_fade_begin_val.view(-1)[ atypes.view(-1)].view(-1,self.coords.shape[0]), requires_grad=False )
        sl0 = Variable( self.lj_params.lj_atr_fade_begin_slope.view(-1)[ atypes.view(-1)].view(-1,self.coords.shape[0]), requires_grad=False )
        x1 = self.lj_params.lj_atr_fade_end_x
        y1 = 0
        sl1 = 0

        # for linear damping of the lj repulsive component for distances below 0.6 r_m
        m = self.lj_params.m_rm_over_eps * epsilons / r_ms
        b = self.lj_params.b_over_eps * epsilons
        
        
        ind_a = ind.view((-1, 1))
        ind_b = ind.view((1, -1))
        deltas = self.coords.view((-1, 1, 3)) - self.coords.view((1, -1, 3))

        dist = torch.norm(deltas, 2, -1)

        #regular lj
        fd = (r_ms / dist)
        fd2 = fd * fd
        fd6 = fd2 * fd2 * fd2
        fd12 = fd6 * fd6
        lj_reg = epsilons * (fd12  - 2 * fd6)

        # lj below 0.6 r_m
        lj_linear = m * dist + b
        
        # attractive ramp to zero -- Frank's closed form quick spline solution
        c0 = y0
        c1 = sl0
        c2 = 3*y1-sl1-2*c1-3*c0
        c3 = sl1-2*y1+c1+2*c0
        dhat=(dist-x0)/(x1-x0)
        
        lj_fade_to_0 = (((c3 * dhat) + c2 )*dhat + c1)*dhat + c0

        lj = torch.where(
            dist < self.lj_params.dis2sigma * r_ms,
            lj_linear,
            lj_reg
        )
        
        lj = torch.where(
            dist > x0,
            lj_fade_to_0,
            lj )
        
        lj = torch.where(
            dist > x1,
            Variable( torch.Tensor([0.0]), requires_grad=False),
            lj )
        
        # take only upper triangle + where the number of chemical bonds is 4 or greater
        self.lj = torch.where(
            (ind_a < ind_b) & (self.connectivity > 3),
            #ind_a < ind_b,
            lj,
            Variable(torch.Tensor([0.0]), requires_grad=False)
        )

        self.atom_scores = torch.sum(self.lj.detach(), dim=-1)
        self.total_score = torch.sum(self.lj)
        (self.grads,) = torch.autograd.grad(self.total_score, self.coords)
        print( self.total_score ) 
        return self

class StupidMinimizer:
    def __init__(self, system, scorefn):
        self.system = system
        self.scorefn = scorefn
        
    def fun(self, x):
        coords = x.reshape((-1, 3))
        
        score = self.scorefn(coords)
        return (
            score.total_score.detach().numpy(),
            score.grads.numpy().reshape(-1)
        )
    
    def update_system(self, x):
        self.system.coords[:] = x.reshape(self.system.coords.shape)
        self.system.state = self.system.state
    
    def minimize(self):
        self.result = scipy.optimize.minimize(
            self.fun,
            self.system.coords.reshape(-1),
            jac=True,
            tol=1,
            options = dict(disp=False, maxiter=250),
            callback=self.update_system
        )
        
        return self

In [22]:
start_model = BBModel(state=fetch_pdb("1ubq"))
start_model.compute_connectivity()
#start_view = PDBModelViewer(start_model)

  silent = bool(old_value == new_value)


In [23]:
test_model = BBModel(state=fetch_pdb("1ubq"))
test_model.compute_connectivity()
#print( test_model.atom_types[0:20])
lj_params = LJParams()
test_score = StupidLJScore(test_model.connectivity, test_model.atom_types, lj_params )


  silent = bool(old_value == new_value)


In [24]:
view = PDBModelViewer(test_model, test_score)

Variable containing:
 2145.0037
[torch.FloatTensor of size ()]



Variable containing:
 2145.0037
[torch.FloatTensor of size ()]



In [36]:
eps = 1.6; sig = 2.05; m = -12 * ( 1/pow(0.6,13) - 1/pow(0.6,7)) * eps / sig;
b=eps*(13/pow(0.6,12)-14/pow(0.6,6)); print(b);
testd = 0.6*sig;
lj_reg = eps * ( pow(sig/testd,12) - 2*pow(sig/testd,6));
lj_lin = m * testd + b;
print( lj_reg, lj_lin, lj_reg - lj_lin )

9075.2783469849
666.4427471723111 666.4427471723102 9.094947017729282e-13


view = PDBModelViewer(test_model, test_score)

In [11]:
StupidMinimizer(test_model, test_score).minimize()

<__main__.StupidMinimizer at 0x11a0f2b00>