In [1]:
import torch
from torch.autograd import Variable

In [2]:
import numpy
import scipy.optimize

In [3]:
import requests
import traitlets

In [4]:
import py3dmol

In [5]:
import pdb_parsing

def fetch_pdb(pdbid):
    return requests.get("https://files.rcsb.org/download/%s.pdb" % str.upper(pdbid)).text

In [6]:
class FixedNamedAtomSystem(traitlets.TraitType):
    def __init__(self, atoms):
        self.atoms = atoms
    
    @property
    def dtype(self):
        return numpy.dtype([(n, "f4", 3) for n in self.atoms])
    
    def from_pdb(self, pdb):
        atoms = pdb_parsing.parse_pdb(pdb)
        atoms = atoms[atoms.apply(lambda r: r["atomn"] in self.atoms, axis=1)]
        for t in ("model", "chain"):
            assert atoms[t].nunique() == 1
        resi = atoms["resi"].unique()
        assert numpy.all(resi == numpy.arange(resi[0], resi[-1] + 1))
        
        atoms = atoms.set_index(["resi", "atomn"])
        
        result = numpy.empty_like(resi, dtype=self.dtype)
        
        for i, ri in enumerate(resi):
            for a in result.dtype.names:
                result[i][a] = atoms.loc[ri, a][["x", "y", "z"]].values
                
        return result
    
    def to_pdb(self, value, b = None):
        atom_records = numpy.zeros((len(value), 4), dtype=pdb_parsing.atom_record_dtype)

        atom_records["resn"] = "CEN"
        atom_records["chain"] = "X"
        atom_records["resi"] = numpy.arange(len(value)).reshape((-1, 1))

        for i, n in enumerate(value.dtype.names):
            atom_records[:,i]["atomn"] = n
            atom_records[:,i]["x"] = value[n][:,0]
            atom_records[:,i]["y"] = value[n][:,1]
            atom_records[:,i]["z"] = value[n][:,2]
            
        atom_records = atom_records.ravel()
        atom_records["atomi"] = numpy.arange(len(atom_records))
        if b is not None:
            atom_records["b"] = b
        
        return pdb_parsing.to_pdb(atom_records.ravel())
        
    def validate(self, obj, value):
        if isinstance(value, str):
            value = self.from_pdb(value)
            
        value = numpy.array(value, dtype=self.dtype, copy=False)
        assert value.ndim == 1
        
        return value

In [7]:
class BBModel(traitlets.HasTraits):
    state = FixedNamedAtomSystem(("N", "C", "CA", "O"))
    
    @property
    def coords(self):
        return self.state.view(dtype="f4").reshape((-1, 3))
    
    
class PDBModelViewer(object):
    def __init__(self, target, score=None):
        self.target = target
        self.view = py3dmol.view(1200, 600)
        
        self.target.observe(lambda c: self.update(), "state")
        self.score = score
        
        self.pdb = None
        
        self.update()
        self.view.zoomTo()
        self.update()
    
    def update(self):
        self.view.clear()
        self.pdb = self.target.traits()["state"].to_pdb(
            self.target.state,
            b= self.score(self.target.coords).atom_scores.numpy() if self.score else None)
        self.view.addModel(self.pdb, "pdb")
        if self.score:
            self.view.setStyle({"sphere" : {"colorscheme" : {"prop":'b',"gradient": 'rwb',"min":1,"max":-1}}})
        else:
            self.view.setStyle({"sphere" : {}})
        
        display(self.view.update())

class StupidLJScore:
    def __init__(self, r_m = 1.0, epsilon = 1.0):
        self.r_m = r_m
        self.epsilon = epsilon
        
    def __call__(self, coords):
        self.coords = Variable(torch.Tensor(coords), requires_grad=True)
        ind = Variable(torch.Tensor(numpy.arange(self.coords.shape[0])), requires_grad=False)

        ind_a = ind.view((-1, 1))
        ind_b = ind.view((1, -1))
        deltas = self.coords.view((-1, 1, 3)) - self.coords.view((1, -1, 3))

        dist = torch.norm(deltas, 2, -1)

        fd = (self.r_m / dist)
        fd2 = fd * fd
        fd6 = fd2 * fd2 * fd2
        fd12 = fd6 * fd6
        lj = self.epsilon * (fd12  - 3 * fd6)

        self.lj = torch.where(
            ind_a != ind_b,
            lj,
            Variable(torch.Tensor([0.0]), requires_grad=False)
        )


        self.atom_scores = torch.sum(self.lj.detach(), dim=-1)
        self.total_score = torch.sum(self.lj)
        (self.grads,) = torch.autograd.grad(self.total_score, self.coords)
        
        return self

class StupidMinimizer:
    def __init__(self, system, scorefn):
        self.system = system
        self.scorefn = scorefn
        
    def fun(self, x):
        coords = x.reshape((-1, 3))
        
        score = self.scorefn(coords)
        return (
            score.total_score.detach().numpy(),
            score.grads.numpy().reshape(-1)
        )
    
    def update_system(self, x):
        self.system.coords[:] = x.reshape(self.system.coords.shape)
        self.system.state = self.system.state
    
    def minimize(self):
        self.result = scipy.optimize.minimize(
            self.fun,
            self.system.coords.reshape(-1),
            jac=True,
            tol=1,
            options = dict(disp=False, maxiter=250),
            callback=self.update_system
        )
        
        return self

In [8]:
start_model = BBModel(state=fetch_pdb("1ubq"))
start_view = PDBModelViewer(start_model)

  silent = bool(old_value == new_value)


In [18]:
test_model = BBModel(state=fetch_pdb("1ubq"))
test_score = StupidLJScore(r_m=10)
view = PDBModelViewer(test_model, test_score)

  silent = bool(old_value == new_value)


In [19]:
StupidMinimizer(test_model, test_score).minimize()

<__main__.StupidMinimizer at 0x7f713df2cef0>