In [1]:
import pyrosetta
from pyrosetta import init, rosetta
from pyrosetta.rosetta import core
from pyrosetta.rosetta.core.import_pose import pose_from_file
from pyrosetta.rosetta.core.io.pdb import dump_pdb
from pyrosetta.rosetta.core.io.pdb import add_to_multimodel_pdb
from pyrosetta.rosetta.core.chemical import VariantType
from pyrosetta.io import poses_from_silent
from pyrosetta.rosetta.protocols.rna.denovo import RNA_DeNovoProtocol
pyrosetta.init("-mute all -out:level 0")

import numpy as np
import tempfile
import os
class Arguments:
    def __init__(self): pass

┌──────────────────────────────────────────────────────────────────────────────┐
│                                 PyRosetta-4                                  │
│              Created in JHU by Sergey Lyskov and PyRosetta Team              │
│              (C) Copyright Rosetta Commons Member Institutions               │
│                                                                              │
│ NOTE: USE OF PyRosetta FOR COMMERCIAL PURPOSES REQUIRE PURCHASE OF A LICENSE │
│              See LICENSE.md or email license@uw.edu for details              │
└──────────────────────────────────────────────────────────────────────────────┘
PyRosetta-4 2024 [Rosetta PyRosetta4.Release.python310.linux 2024.08+release.717d2e8232174371f0c672564f23a097062db88a 2024-02-21T10:16:44] retrieved from: http://www.pyrosetta.org


In [2]:
def get_template_pose(seq,type="RNA",**kwargs):
    if type!="RNA":
        seq=seq.upper()
    else:
        seq=seq.lower()
    return pyrosetta.pose_from_sequence(seq,**kwargs)

In [3]:
def perturb_RNA(pose,seqlen=None,da=10,db=10,dg=10,dd=10,de=10,dz=10,dx=5,amplify=1.0):
    if seqlen is None:
        seqlen=len(sample_pose.secstruct())
    oldpose=pose.clone()
    rid=np.random.randint(seqlen)+1
    deltas=np.random.normal(np.zeros(7),np.array([da,db,dg,dd,de,dz,dx]).astype(float))*amplify
    pose.set_alpha(rid,pose.alpha(rid)+deltas[0])
    pose.set_beta(rid,pose.beta(rid)+deltas[1])
    pose.set_gamma(rid,pose.gamma(rid)+deltas[2])
    pose.set_delta(rid,pose.delta(rid)+deltas[3])
    pose.set_epsilon(rid,pose.epsilon(rid)+deltas[4])
    pose.set_zeta(rid,pose.zeta(rid)+deltas[5])
    pose.set_chi(rid,pose.chi(rid)+deltas[6])
    return oldpose,pose

In [4]:
class LoadedPDB(pyrosetta.rosetta.core.pose.Pose):
    def __init__(self,filename):
        if type(filename)==rosetta.core.pose.Pose:
            pose=filename
        else:
            pose=pyrosetta.pose_from_pdb(filename)
        super(LoadedPDB,self).__init__(pose)
        #self.current_residue=0

    def __call__(self,chn=None,idx=None):
        if chn is None: return self
        else:
            if idx is None: return self.residue(int(chn))
            else: return self.residue(self.get_resid(chn,idx))
    def __len__(self): return self.pdb_info().nres()
    def all_phi(self): return np.array([self.phi(r+1) for r in range(len(self))])
    def all_psi(self): return np.array([self.psi(r+1) for r in range(len(self))])
    def all_chi(self): return np.array([self.chi(r+1) for r in range(len(self))])
    def generic_all(self,param):
        '''
        A generic function to get all "param"s (one from each residue).
        For example:
            generic_all("phi") -> Same as all_phi()
            generic_all("alpha") -> Returns a list of all alpha angles (RNA only)
        '''
        fx=getattr(self,param)
        ret=np.array([fx(r+1) for r in range(len(self))])
        return ret
    def all_generic(self,param): return self.generic_all(param)
    def get_resid(self,chn,idx): return self.pdb_info().pdb2pose(chn,int(idx))
    def get_atom(self,rid,aid):
        if type(rid)!=int: rid=rid.seqpos()
        if type(aid)==int: return pyrosetta.AtomID(aid,rid)
        else:
            residue=self.residue(rid)
            return pyrosetta.AtomID(residue.atom_index(aid),rid)
    def get_conformation_parameter(self,rid,par="bond_length",*args):
        conform=self.conformation()
        pass_param=[self.get_atom(rid,ai) for ai in args]
        return getattr(conform,par)(*pass_param)
    
    def __iter__(self):
        self.current_residue=0
        return self
    def __next__(self):
        self.current_residue+=1
        if self.current_residue>len(self): raise StopIteration()
        else: return self.residue(self.current_residue)
    def clone_pose(self): return pyrosetta.rosetta.core.pose.Pose(self)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE=torch.device("cpu")
TORCH_FLOAT=torch.float32
def assign_pose_by_pt(pose,pt_pose):
    # Pose must be LoadedPDB object
    pose_np=pt_pose.detach().cpu().numpy()
    N=0
    for i in range(len(pose.sequence())):
        for atidx,atom in enumerate(pose.residue(i+1).atoms()):
            mycrd=rosetta.numeric.xyzVector_double_t(pose_np[N,0],pose_np[N,1],pose_np[N,2])
            pose.residue(i+1).set_xyz(atidx+1,mycrd)
            #atom.xyz(mycrd)
            #print(atom.xyz())
            N+=1
    return pose
def get_torch_rep(pose):
    if type(pose)==rosetta.core.pose.Pose: pose=LoadedPDB(pose)
    vecs=[]
    for res in pose:
        for atom in res.atoms():
            vecs.append(np.array(atom.xyz()))
    return torch.tensor(np.stack(vecs),dtype=TORCH_FLOAT,device=DEVICE)

In [None]:
def align_poses(ref_struct,pose_struct,max_iter=5000,step_size=1e-1,metric=lambda v1,v2: torch.mean(torch.sum((v1-v2)**2,dim=-1),dim=0),silent=False,printevery=200):
    ref=get_torch_rep(ref_struct)
    pose=get_torch_rep(pose_struct)
    pose.requires_grad_(False)
    ref_origin=torch.sum(ref,dim=0)/len(ref)
    pose_origin=torch.sum(pose,dim=0)/len(ref)
    delta=ref_origin-pose_origin
    pose+=delta
    
    ref.requires_grad_(False)
    pose.requires_grad_(True)
    for it in range(max_iter):
        with torch.no_grad():
            pose_origin=torch.sum(pose,dim=0)/len(ref)
            r_pose=pose-pose_origin
        loss=metric(ref,pose)
        loss.backward()

        #with torch.no_grad(): pose=pose-pose.grad*step_size
        translation=torch.sum(pose.grad,dim=0)/len(pose)
        rotation=torch.sum(torch.cross(r_pose,pose.grad),dim=0)/len(pose)
        rotation_mag=torch.sqrt(torch.sum(rotation**2))
        rotation_norm=rotation/rotation_mag
        
        
        with torch.no_grad():
            rotation_norm=rotation/torch.sqrt(torch.sum(rotation**2))
            pose_del=-translation*step_size
            pose_proj=torch.sum(r_pose*rotation_norm,dim=-1)
            pose_orth=r_pose-pose_proj[:,np.newaxis]*rotation_norm
            x_basis=pose_orth[0]
            x_basis/=torch.sqrt(torch.sum(x_basis**2))
            y_basis=torch.cross(rotation_norm,x_basis)
            #print(torch.sum(x_basis**2),torch.sum(y_basis**2),torch.sum(x_basis*y_basis))
            x_proj=torch.sum(r_pose*x_basis,dim=-1)
            y_proj=torch.sum(r_pose*y_basis,dim=-1)

            angle=rotation_mag*step_size
            #print("Angle:",angle)
            pose_x=x_proj*np.cos(angle)+y_proj*np.sin(angle)
            pose_y=y_proj*np.cos(angle)-x_proj*np.sin(angle)
            pose_z=pose_proj

            pose=torch.mm(pose_x.unsqueeze(1),x_basis.unsqueeze(0))+torch.mm(pose_y.unsqueeze(1),y_basis.unsqueeze(0))+torch.mm(pose_z.unsqueeze(1),rotation_norm.unsqueeze(0))+pose_origin+pose_del
        
        pose.requires_grad_(True)
        #with torch.no_grad(): pose-=pose.grad*step_size
        if (not silent) and (it%printevery==0): print(it,loss.item())
    pose.requires_grad_(False)
    aligned_pose=pose_struct.clone()
    assign_pose_by_pt(aligned_pose,pose)
    return aligned_pose