In [362]:
import numpy as np
import pandas as pd

class topo2rest():
   def __init__(self, ifile:str, temps:list = [300.0, 500.0], nreps:int = 20, kappa:str = 'off'):
      '''Convert processed topology to REST2/3 input topology
         E_tot = gamma*E^{pp} + sqrt(gamma)*E^{pw} + E^{ww}
         gamma = T_0/T_i
         for REST2/3 bonds and angles are not scaled per the REST2 paper,
         however, for REST2 LJ parameters epsilon_i is scaled by epsilon_i*gamma
         for tempered atoms and all others are unmodified. 
         In the case of REST3, with the additon of sqrt(gamma)*kappa*E^{pw} which differs 
         from sqrt(gamma)*E^{pw}, we produced the combination rule 2 nonbonded terms
         involving protein-water interactions to override the pw nonbonded interactions
         as including gamma*kappa with each hot atom epsilon_i would result in the 
         incorrect form of: E_tot = gamma*kappa*E^{pp} + sqrt(gamma*kappa)*E^{pw} + E^{ww}:
         rather than the correct form: 
         E_tot = gamma*kappa*E^{pp} + sqrt(gamma)*kappa*E^{pw} + E^{ww}
         input
         ifile = inputtopology.top
         temps = ["lower temp":float, "upper temp":float ]; temperature range of replicas
         kappa:float = kappa scaling; if not equal to 1 REST3 implementation active
         hot_m = hot molecule; [0] for most systems will select the protein'''

      self.nreps = nreps
      self._kappa = None
      self.hot_m = None
      self.nmol = None
      self.molecules = None
      # self.get_scaled_charges = {}
      self.tempreps = self.compute_temperatures(temps)
      self.lambdai = self.compute_lambda()
      self.sections = {}
      self.sections_out = {}
      self.scaled_dihedrals = {}
      self.scaled_dihedral_types = {}
      self.scaled_atomtypes = {}
      self.scaled_charges={}
      self.scaled_nonb={}
      self.hard_order_sections = [ "defaults", "atomtypes", "nonbond_params", "bondtypes", \
                                   "constrainttypes", "angletypes", "dihedraltypes", "molecules", "moleculetype"]
      with open(ifile) as topo:
         self.readfile = topo.readlines()
      self.gather_param_sections()
      self.kappa = kappa
      
      
   def compute_lambda(self):
      return [ self.tempreps[0]/Ti for Ti in self.tempreps ]
   
   def compute_temperatures(self, temp_range:list):
      from numpy import log, exp
      tlow, thigh = temp_range
      temps = []
      for i in range(self.nreps):
         temps.append(tlow*exp((i)*log(thigh/tlow)/(self.nreps-1)))
      return np.array(temps)
   
   @property
   def kappa(self):
      return self._kappa
   
   @kappa.setter
   def kappa(self, kappa_set:str):
      def create_kappa(temp):
         '''Hardcoded curve fit to kappa values for 16 replicas to interpolate kappa for any number of replicas'''
         from numpy import log,round
         ab = np.array([0.1448285 , 0.33109532])
         return round(ab[0]*log(temp/ab[1]),3)
         
      assert isinstance(kappa_set,str)
      if kappa_set == 'on':
         arr = np.zeros(self.nreps)
         print(np.where(self.tempreps<=330.0)[0])
         kappa_off = np.where(self.tempreps<=330.0)[0][-1]
         print(kappa_off)
         arr[:kappa_off+1] = 1.0
         print(arr)
         arr[kappa_off+1:] = create_kappa(self.tempreps[kappa_off+1:])
         self._kappa = {i:j for i,j in zip(self.lambdai,arr)}
      elif kappa_set == 'off':
         self._kappa = {i:1.0 for i in self.lambdai}
      else: print("kappa must either be set to \'on\' or \'off\' !")
      
   def get_scaled_charges(self, hot_m:list = [0]):
      
      # charges_new = {}

      for hot in hot_m :
         a=[i.split() if len(i.split()) == 8 else i.split()[:-3] for i in self.sections['moleculetype'][hot]['atoms'] if len(i.split()) != 0]

         for l in self.lambdai:
            c=[]
            
            for i in a:
               
               s_charge=float(i[6])*np.sqrt(l) # scalling with sqrt(lambda)
               stringout = f'{i[0]:>6} {"s"+i[1]:>10} {i[2]:>6} {i[3]:>6} {i[4]:>6} {i[5]:>6} {s_charge:>10.4f} {float(i[7]):>10.2f}\n'
               c.append(stringout)
               

            self.scaled_charges[l]=self.sections['moleculetype'].copy()
            self.scaled_charges[l][hot]['atoms'] = c

   def populate_out(self, lambdai:int=1.0):
      a=self.sections['defaults']+['\n']
      # a=a+['\n']+self.sections['bondtypes']
      a=a+['[ atomtypes ]\n']+self.scaled_atomtypes[lambdai]+['\n']
      a=a+['[ nonbond_params ]\n']+self.scaled_nonb[lambdai]+['\n']
      a=a+self.sections['bondtypes']+['\n']
      a=a+self.sections['constrainttypes']+['\n']
      a=a+self.sections['angletypes']+['\n']
      a=a+['[ dihedraltypes ]\n']+self.scaled_dihedral_types[lambdai]+['\n']
      # a=a+
      self.sections_out[lambdai]=a+['\n']+self.sections['bondtypes']
   
   def parse_section(self,trunks):
      first_round = True
      output = []
      for line in self.readfile[trunks:]:
         if "[" not in line and first_round!=True and len(line.split()) != 0:
            output.append(line)
         elif ";" == line[0]: 
            #output.append(line)
            continue
         elif "[" in line and first_round!=True: 
            break
         first_round = False
      return output

   def get_molecule_atomtypes(self,molecule:int=0):
      import pandas as pd
      from numpy import sqrt
      b=[]
      for i in self.sections['moleculetype'][molecule]['atoms']:
         if len(i.split())>1 and ';' not in i.split()[0] and '[' not in i :
            b.append(i.split()[:2])
      dataset = np.array(b,dtype=object)
      # Need to grab unique atoms
      atomtypes = dataset[:,1]
      return np.unique(atomtypes)
   
   def get_scale_nonbonded(self):
      from numpy import sqrt, float32
   
      def compute_se(e1:float,e2:float,s1:float,s2:float):
         eij = sqrt(float32(e1)*float32(e2))
         sij = 0.5*(float32(s1)+float32(s2))
         return eij,sij
      
      hot_atoms = np.array([i for j in self.hot_m for i in self.get_molecule_atomtypes(j) ])
      
      atomtypes_new = {}
      nonbonded_new = {}

      atomtypes = [i.split()[:-2] if i.split()[-2] == ';' else i.split()[:-1] if i.split()[-1] == ';' \
                               else i.split() for i in self.sections['atomtypes'] if ';' not in i[:3] if '[' not in i[:3] \
                               if '\n' not in i[:3]]
      #[print(len(at_),at_) for at_ in atomtypes]
      for lambdai in self.lambdai:
         at_out = []
         for _at in atomtypes:
            e_scaled = float32(lambdai) * float32(_at[-1])
            stringout = f'{_at[0]:<12} {_at[1]:<6} {_at[2]:>6} {_at[3]:>8}{_at[4]:^5}{_at[5]:>11}{_at[6]:>13}\n'
            at_out.append(stringout)
            stringout = f'{"s"+_at[0]:<12} {_at[1]:<6} {_at[2]:>6} {_at[3]:>8}{_at[4]:^5}{_at[5]:>11}{e_scaled:>13.5e}\n'
            at_out.append(stringout)
         atomtypes_new[lambdai] = at_out

         nbp_out = []
         for nbp in self.sections['nonbond_params']:
            nbp_ = nbp.split()
            if '[' in nbp_[0] or '\n' in nbp[:3]:
               continue
            elif ';' in nbp_[0]:
               #nbp_out.append(nbp)
               continue
            elif len(nbp.split()) == 5:
               nbp_out.append(nbp)
               check_hot = np.isin(np.array(nbp_[:2]), hot_atoms)
               if check_hot.all():
                  eps_scaled = float32(lambdai) * float32(nbp_[4])
                  stringout = f'{"s"+nbp_[0]:>5} {"s"+nbp_[1]:>4} {nbp_[2]:>5} {nbp_[3]:>10} {eps_scaled:>8.4f}\n'
                  nbp_out.append(stringout)
               elif check_hot.any():
                  if self.kappa[lambdai] == 1.0:
                     eps_scaled = lambdai * float32(nbp_[4])
                  elif check_hot.any():
                     eps_scaled = float32(self.kappa[lambdai]) * sqrt(float32(lambdai)) * float32(nbp_[4])
                  idx_ = np.where(check_hot)[0][0]
                  if idx_: stringout = f'{nbp_[0]:>5} {"s"+nbp_[1]:>4} {nbp_[2]:>5} {nbp_[3]:>10} {eps_scaled:>8.4f}\n'
                  else: stringout = f'{"s"+nbp_[0]:>5} {nbp_[1]:>4} {nbp_[2]:>5} {nbp_[3]:>10} {eps_scaled:>8.4f}\n'
                  nbp_out.append(stringout)
            else: print(f'here is the problem line: {nbp} {nbp_[0]}')

         if self.kappa[lambdai] != 1.00:
            print(f'kappa {self.kappa[lambdai]} for lambda {lambdai}')
            at_in = np.loadtxt(at_out,dtype=object)
            df_at = pd.DataFrame(at_in, columns=['name','atnum','mass','charge','ptype','sigma','epsilon'], dtype=object)            
            atom1_ = df_at[df_at['name'].str.contains('s')]
            atom2_ = df_at[~df_at['name'].str.contains('s')]

            for _, atom1 in atom1_.iterrows():
               for _, atom2 in atom2_.iterrows():
                  e1 = float32(atom1.epsilon)
                  e2 = float32(atom2.epsilon)
                  s1 = float32(atom1.sigma)
                  s2 = float32(atom2.sigma)
                  funct = str(1)
                  eij, sij = compute_se(e1,e2,s1,s2)
                  scaled_eij = float32(self.kappa[lambdai]) * sqrt(float32(lambdai)) * eij
                  stringout = f'{atom1.name:>5} {atom2.name:>4} {funct:^5} {sij:>10.4f} {scaled_eij:>8.4f}'
                  nbp_out.append(stringout)
         nonbonded_new[lambdai] = nbp_out
      self.scaled_nonb = nonbonded_new
      self.scaled_atomtypes = atomtypes_new
      pass
   
   def show_molecule_names(self):
      try:
         molecules = [" ".join(i.split()[:-1]) for i in self.sections['molecules'] if '[' not in i and ';' not in i.split()[0]] 
         molecules = {num: i for num,i in enumerate(molecules)}
         print("\n".join([f'{key}: {mol}' for key,mol in zip(molecules.keys(),molecules.values())]))
         if self.molecules == None: self.molecules = molecules
         if self.nmol == None: self.nmol = max(self.molecules.keys())
      except:
         print("Do you have molecules?")

   def get_molecule_atoms(self):
      import pandas as pd
      b=[]
      for i in self.sections['atoms']:
         i.split()
         if len(i.split())>1 and i.split()[0]!=';' and i.split()[0]!='[' :
            b.append(i.split())
   
   def _moleculetype_sub(self,linestart:int):
      first_round = True
      output = []
      for line in self.readfile[linestart:]:
         if "[" not in line and ';' not in line[:3]:
            output.append(line)
         elif ";" == line[0]:
            #output.append(line)
            continue
         elif "[" in line and first_round!=True: 
            break
         first_round = False
      return output 
   
   def identify_moltype_sections(self,trunks:int):
      section_start = []
      for i, line in enumerate(self.readfile[trunks:]):
         if '[' in line and 'moleculetype' not in line and 'system' not in line:
            section_start.append(i+trunks)
         elif i != 0 and 'moleculetype' in line or 'system' in line:
            break
      return section_start

   def parse_moleculetypes(self,trunks:int):
      first_round = True
      output = {}
      sections = self.identify_moltype_sections(trunks)
      output['header'] = self.readfile[trunks:trunks+3]
      for section in sections:
         section_ = self.readfile[section].split()[1]
         output[section_] = self._moleculetype_sub(section)
      return output
   
   def get_scale_dehedrals_(self, hot_m:list = [0]):

      dihedrals_new = {}
      dihedral_types_new = {}
      dihedral_types = [" ".join(i.split()[:-2]) if i.split()[-2] == ';' else i.split()[:-1] if i.split()[-1] == ';' \
                               else i.split() for i in self.sections['dihedraltypes'] if ';' not in i[:3] if '[' not in i[:3] \
                               if '\n' not in i[:3]]
      for hot in hot_m:
         dihedrals = self.sections['moleculetype'][hot]['dihedrals']
         for lambdai in self.lambdai:
            dih_new = []
            dih_types_new = []
            for dihedral in dihedrals:
               dls_ = dihedral.split()
               dls_ = [entry for entry in dls_ if len(entry) != 0]
               if len(dls_) == 5:
                  dih_new.append(dihedral)
               elif len(dls_) == 8:
                  Kscaled = float(dls_[6])*lambdai
                  stringout = f'{dls_[0]:>5} {dls_[1]:>5} {dls_[2]:>5} {dls_[3]:>5} {dls_[4]:^9}{dls_[6]:<10}{Kscaled:<10.3f}{dls_[7]}\n'
                  dih_new.append(stringout)
               else: 
                  print(f'Warning: incorrect parsing of dihedrals section\n expecting 5 or 8 columns found {len(dls_)}\n'+dihedral)
            dihedrals_new[lambdai] = dih_new
         self.scaled_dihedrals[hot] = dihedrals_new

      for lambdai in self.lambdai:
         for dihedraltype in dihedral_types:
            dtls_ = dihedraltype.split()
            Kscaled = float(dtls_[6])*lambdai
            if len(dtls_) == 8:
               dih_types_new.append(dihedraltype)
               check_atom_X = np.isin(np.array(dtls_[:4]), np.array(['X']))
               if sum(check_atom_X) == 0:
                  stringout = f'{"s"+dtls_[0]:>5} {"s"+dtls_[1]:>5} {"s"+dtls_[2]:>5} {"s"+dtls_[3]:>5} {dtls_[4]:^9}{dtls_[6]:<10}{Kscaled:<10.3f}{dtls_[7]}\n'
                  dih_types_new.append(stringout)
               elif dtls_[0] == 'X':
                  if sum(check_atom_X) == 1:
                     stringout = f'{dtls_[0]:>5} {"s"+dtls_[1]:>5} {"s"+dtls_[2]:>5} {"s"+dtls_[3]:>5} {dtls_[4]:^9}{dtls_[6]:<10}{Kscaled:<10.3f}{dtls_[7]}\n' 
                     dih_types_new.append(stringout)
                  elif dtls_[3] == 'X':
                     stringout = f'{dtls_[0]:>5} {"s"+dtls_[1]:>5} {"s"+dtls_[2]:>5} {dtls_[3]:>5} {dtls_[4]:^9}{dtls_[6]:<10}{Kscaled:<10.3f}{dtls_[7]}\n' 
                     dih_types_new.append(stringout)
                  elif dtls_[1] == 'X':
                     stringout = f'{dtls_[0]:>5} {dtls_[1]:>5} {"s"+dtls_[2]:>5} {"s"+dtls_[3]:>5} {dtls_[4]:^9}{dtls_[6]:<10}{Kscaled:<10.3f}{dtls_[7]}\n' 
                     dih_types_new.append(stringout)
                  else: print(f'In dihedraltype: something slipped through\n{dihedraltype}')
                  
               else: print(f'warning: parameter not found for {dtls_[:4]}')
            else: print('warning: dihedraltype not processed:\n '+dihedraltype)
         dihedral_types_new[lambdai] = dih_types_new
      self.scaled_dihedral_types = dihedral_types_new


   def gather_param_sections(self):
      for section in self.hard_order_sections[:-1]:
         is_select = [ i for i, line in enumerate(self.readfile) if section in line ]
         in_select = [f' [ {section} ] \n']
         for i in is_select:
            in_select += self.parse_section(i)
         self.sections[section]=in_select
      section = self.hard_order_sections[-1]
      is_select = [ i for i, line in enumerate(self.readfile) if section in line ] 
      moltype_dict = {} 
      for i in range(len(is_select)):
         moltype_dict[i] = self.parse_moleculetypes(is_select[i])
      self.sections[section] = moltype_dict
         

In [363]:
test = topo2rest('../async_topo/processed.top',kappa='on',nreps=8)

[0 1]
1
[1. 1. 0. 0. 0. 0. 0. 0.]


In [364]:
test.hot_m=[0]

In [365]:
test.show_molecule_names()

0: Protein_chain_A
1: L47
2: SOL
3: NA


In [366]:
test.kappa

{1.0: 1.0,
 0.9296239874987812: 1.0,
 0.8642007581331342: 1.007,
 0.8033817547751941: 1.018,
 0.7468429503578841: 1.028,
 0.6942831215470505: 1.039,
 0.64542224390567: 1.05,
 0.6: 1.06}

In [367]:
test.tempreps

array([300.        , 322.71112195, 347.14156077, 373.42147518,
       401.69087739, 432.10037907, 464.81199375, 500.        ])

In [368]:
test.lambdai

[1.0,
 0.9296239874987812,
 0.8642007581331342,
 0.8033817547751941,
 0.7468429503578841,
 0.6942831215470505,
 0.64542224390567,
 0.6]

In [369]:
test.get_scale_nonbonded()
test.get_scaled_charges()
test.get_scale_dehedrals_()

kappa 1.007 for lambda 0.8642007581331342
kappa 1.018 for lambda 0.8033817547751941
kappa 1.028 for lambda 0.7468429503578841
kappa 1.039 for lambda 0.6942831215470505
kappa 1.05 for lambda 0.64542224390567
kappa 1.06 for lambda 0.6
 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 5 or 8 columns found 0


 expecting 

In [370]:
print(test.scaled_charges.keys())
print(test.scaled_atomtypes.keys())
print(test.scaled_dihedral_types.keys())
print(test.scaled_nonb.keys())



dict_keys([1.0, 0.9296239874987812, 0.8642007581331342, 0.8033817547751941, 0.7468429503578841, 0.6942831215470505, 0.64542224390567, 0.6])
dict_keys([1.0, 0.9296239874987812, 0.8642007581331342, 0.8033817547751941, 0.7468429503578841, 0.6942831215470505, 0.64542224390567, 0.6])
dict_keys([1.0, 0.9296239874987812, 0.8642007581331342, 0.8033817547751941, 0.7468429503578841, 0.6942831215470505, 0.64542224390567, 0.6])
dict_keys([1.0, 0.9296239874987812, 0.8642007581331342, 0.8033817547751941, 0.7468429503578841, 0.6942831215470505, 0.64542224390567, 0.6])


In [371]:
test.scaled_atomtypes[1.0]

['Br           35      79.90   0.0000  A  0.00000e+00  0.00000e+00\n',
 'sBr          35      79.90   0.0000  A  0.00000e+00  0.00000e+00\n',
 'C            6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'sC           6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'C6           6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'sC6          6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'C5           6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'sC5          6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'CA           6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'sCA          6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'CB           6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'sCB          6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'CC           6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'sCC          6       12.01   0.0000  A  3.39967e-01  3.59824e-01\n',
 'CK  

In [372]:
test.populate_out()

In [378]:
test.scaled_charges[1.0].keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

In [374]:
for l in test.lambdai :
    f = open(f'./scalled_toplogies/topol-{l}.top','w+')

    [f.write(i) for i in test.sections_out[1.0]]

    f.close()