In [None]:
import numpy as np
import mdtraj as md
import pickle
from endstate_rew.system import generate_molecule
from endstate_rew.constant import zinc_systems
import seaborn as sns
import seaborn_image as isns
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from os import path
from glob import glob
from matplotlib.offsetbox import (TextArea, DrawingArea, OffsetImage,
                                  AnnotationBbox)
#from PIL import Image
from rdkit.Chem import AllChem
import math
import os

In [None]:
def save_mol_pic(zinc_id:str, ff:str):
    from rdkit import Chem
    from rdkit.Chem.Draw import IPythonConsole
    from rdkit.Chem import Draw
    IPythonConsole.drawOptions.addAtomIndices = True
    #IPythonConsole.molSize = 500,500
    from rdkit.Chem.Draw import rdMolDraw2D
    
    name, smiles = zinc_systems[zinc_id]
    # generate openff Molecule
    mol = generate_molecule(name=name, forcefield=ff, base='../../data/hipen_data')
    # convert openff object to rdkit mol object
    mol_rd = mol.to_rdkit()
    
    if zinc_id == 4:
        # NOTE: FIXME: this is a temporary workaround to fix the wrong indexing in rdkit
        # when using the RemoveHs() function
        mol_draw = Chem.RWMol(mol_rd)
        # remove all explicit H atoms, except the ones on the ring (for correct indexing)
        for run in range(1,13):
            n_atoms = mol_draw.GetNumAtoms()
            mol_draw.RemoveAtom(n_atoms-1)
    else:
        # remove explicit H atoms
        mol_draw = Chem.RemoveHs(mol_rd) 

    # get 2D representation
    AllChem.Compute2DCoords(mol_draw)    
    d = rdMolDraw2D.MolDraw2DCairo(1500, 1000)
    d.drawOptions().fixedFontSize = 90
    d.drawOptions().fixedBondLength = 110
    d.drawOptions().annotationFontScale = 0.7
    d.drawOptions().addAtomIndices = True

    d.DrawMolecule(mol_draw)
    d.FinishDrawing()
    d.WriteDrawingText(f'{name}_{ff}.png') 

###########################################################################################################################################################################
###########################################################################################################################################################################

def get_traj(samples:str, name: str, ff: str):
    
    # depending on endstate, get correct label
    if samples == 'mm':
        endstate = '0.0000'
    elif samples == 'qml':
        endstate = '1.0000'
        
    # get pickle files for traj
    # globals()['pickle_file_%s' %samples] = glob(f'/data/shared/projects/endstate_rew/{name}/sampling_{ff}/run*/{name}_samples_5000_steps_1000_lamb_{endstate}.pickle')
    globals()['pickle_file_%s' %samples] = glob(f'/home/stkaczyk/endstate_rew/notebooks/TMP_SWI/{name}/switching_{ff}/{name}_samples_5000_steps_1000_lamb_{endstate}_nr_switches_2_switching_length_3.pickle')
    
    #list for collecting sampling data
    coordinates = []

    # generate traj instance only if at least one pickle file exists
    if not len(globals()['pickle_file_%s' %samples]) == 0:
        for run in globals()['pickle_file_%s' %samples]:
            
            # load pickle file
            coord = pickle.load(open(run, "rb"))
            # check, if sampling data is complete (MODIFY IF NR OF SAMPLING STEPS != 5000)
            if len(coord) == 2: #change to 5000
        
                # remove first 1k samples
                coordinates.extend(coord[1000:])

                # load topology from pdb file
                top = md.load("mol.pdb").topology

                # NOTE: the reason why this function needs a smiles string is because it 
                # has to generate a pdb file from which mdtraj reads the topology
                # this is not very elegant # FIXME: try to load topology directly
                    
                # generate trajectory instance
                globals()['traj_%s' %samples] = md.Trajectory(xyz = coordinates, topology=top)
                return globals()['traj_%s' %samples]  
            else:
                print(f'{run} file contains incomplete sampling data')
            
###########################################################################################################################################################################
###########################################################################################################################################################################


def get_indices(rot_bond, rot_bond_list:list, bonds:list):
    
    print(f'---------- Investigating bond nr {rot_bond} ----------')

    # get indices of both atoms forming an rotatable bond
    atom_1_idx = (rot_bond_list[rot_bond]).atom1_index
    atom_2_idx = (rot_bond_list[rot_bond]).atom2_index
    print(f'{atom_1_idx=}')
    print(f'{rot_bond_list[rot_bond].atom1.element=}')
    print(f'{atom_2_idx=}')
    print(f'{rot_bond_list[rot_bond].atom2.element=}')
    # create lists to collect neighbors of atom_1 and atom_2
    neighbors1 = []
    neighbors2 = []
    
    # find neighbors of atoms forming the rotatable bond and add to index list (if heavy atom torsion)
    for bond in bonds:
        
        # get neighbors of atom_1 (of rotatable bond)
        # check, if atom_1 (of rotatable bond) is the first atom in the current bond
        if bond.atom1_index == atom_1_idx:
            
            # make sure, that neighboring atom is not an hydrogen, nor atom_2
            if not bond.atom2.element.name == 'hydrogen' and not bond.atom2_index == atom_2_idx:
                neighbors1.append(bond.atom2_index)
        
        # check, if atom_1 (of rotatable bond) is the second atom in the current bond  
        elif bond.atom2_index == atom_1_idx:
            
            # make sure, that neighboring atom is not an hydrogen, nor atom_2
            if not bond.atom1.element.name == 'hydrogen' and not bond.atom1_index == atom_2_idx:
                neighbors1.append(bond.atom1_index)
            
        # get neighbors of atom_2 (of rotatable bond)    
        # check, if atom_2 (of rotatable bond) is the first atom in the current bond
        if bond.atom1_index == atom_2_idx:
            
            #make sure, that neighboring atom is not an hydrogen, nor atom_1
            if not bond.atom2.element.name == 'hydrogen' and not bond.atom2_index == atom_1_idx:
                neighbors2.append(bond.atom2_index) 
            
        # check, if atom_2 (of rotatable bond) is the second atom in the current bond       
        elif bond.atom2_index == atom_2_idx:
            
            # make sure, that neighboring atom is not an hydrogen, nor atom_1
            if not bond.atom1.element.name == 'hydrogen' and not bond.atom1_index == atom_1_idx:
                neighbors2.append(bond.atom1_index)

    # check, if both atoms forming the rotatable bond have neighbors
    if len(neighbors1) > 0 and len(neighbors2) > 0:
        
        # list for final atom indices defining torsion
        indices = [[neighbors1[0], atom_1_idx, atom_2_idx, neighbors2[0]]]
        return indices
    
    else:
        
        print(f'No heavy atom torsions found for bond {rot_bond}')
        indices = []
        return indices

###########################################################################################################################################################################
###########################################################################################################################################################################

def vis_torsions(zinc_id: int, ff:str):
    ############################################ LOAD MOLECULE AND GET BOND INFO ##########################################################################################

    # get zinc_id(name of the zinc system) and smiles
    name, smiles = zinc_systems[zinc_id]

    print(  f'################################## SYSTEM {name} ##################################')

    # generate mol from smiles string
    mol = generate_molecule(forcefield=ff, name=name)

    # write mol as pdb
    mol.to_file("mol.pdb", file_format="pdb")

    # get all bonds
    bonds = mol.bonds

    # get all rotatable bonds
    rot_bond_list = mol.find_rotatable_bonds()
    print(len(rot_bond_list),'rotatable bonds found.')
    print(f'{rot_bond_list=}')
    
    ################################################## GET HEAVY ATOM TORSIONS ##########################################################################################

    # create list for collecting bond nr, which have heavy atom torsion profile
    torsions = []
    all_indices = []
    plotting = False
    
    for rot_bond in range(len(rot_bond_list)):
        
        # get atom indices of all rotatable bonds 
        indices = get_indices(rot_bond= rot_bond, rot_bond_list=rot_bond_list, bonds=bonds)
        print(indices)
        
        # compute dihedrals only if heavy atom torsion was found for rotatable bond
        if len(indices) > 0:
            print(f'Dihedrals are computed for bond nr {rot_bond}')
            # add bond nr to list
            torsions.append(rot_bond)
            all_indices.extend(indices)
            
            # check if traj data can be retrieved
            globals()['traj_mm_%s' %rot_bond] = get_traj(samples='mm', name=name, ff=ff)
            globals()['traj_qml_%s' %rot_bond] = get_traj(samples='qml', name=name, ff=ff)
            print(globals()['traj_mm_%s' %rot_bond])
            
            if globals()['traj_mm_%s' %rot_bond] and globals()['traj_qml_%s' %rot_bond]:
                globals()['data_mm_%s' %rot_bond] = md.compute_dihedrals(globals()['traj_mm_%s' %rot_bond], indices, periodic=True, opt=True) #* 180.0 / np.pi
                globals()['data_qml_%s' %rot_bond] = md.compute_dihedrals(globals()['traj_qml_%s' %rot_bond], indices, periodic=True, opt=True)# * 180.0 / np.pi
                plotting = True
            else:
                print(f'Trajectory data cannot be found for {name}')
        else:
            print(f'No dihedrals will be computed for bond nr {rot_bond}')
    
    ################################################## PLOT TORSION PROFILES ##########################################################################################
    
    if plotting:
        # generate molecule picture  
        save_mol_pic(zinc_id = zinc_id, ff = ff)
        
        # counter for addressing axis
        counter = 0
        
        # create corresponding nr of subplots
        fig, axs = plt.subplots(len(torsions)+1, 1, figsize = (8, len(torsions)*2+6), dpi = 400)
        fig.suptitle(f'Torsion profile of {name} ({ff})', fontsize = 13, weight= 'bold')
        
        # flip the image, so it is displayed correctly 
        image = np.flipud(mpimg.imread(f'/home/stkaczyk/endstate_rew/notebooks/torsion_profiles/{name}_{ff}.png'))
        
        # plot the molecule image on the first axis
        axs[0].imshow(image)
        axs[0].axis('off')

        # set counter to 1
        counter += 1
        # counter for atom indices
        idx_counter = 0
        
        # iterate over all torsions and plot results
        for torsion in torsions:
            # add atom indices as plot title
            axs[counter].set_title(f'Torsion {all_indices[idx_counter]}') 
            #sns.set(font_scale = 2)
            sns.histplot(
                ax = axs[counter],
                data = {'mm samples': globals()['data_mm_%s' %torsion].squeeze(), 'qml samples': globals()['data_qml_%s' %torsion].squeeze()},
                bins = 100, #not sure how many bins to use
                kde = True,
                alpha = 0.5,
                stat="density"
            )
            # adjust axis labelling
            unit = np.arange(-np.pi, np.pi+np.pi/4, step = (1/4*np.pi))
            axs[counter].set(xlim=(-np.pi,np.pi))
            axs[counter].set_xticks(unit, ['-π', '-3π/4', '-π/2', '-π/4', '0', 'π/4', 'π/2', '3π/4', 'π'])
            axs[counter].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
            counter += 1
            idx_counter += 1
        axs[-1].set_xlabel('Dihedral angle')

        plt.tight_layout()
        #plt.show()
        #plt.close()

        if not path.isdir(f"torsion_profiles_{ff}"):
            os.makedirs(f"torsion_profiles_{ff}")
        plt.savefig(f"torsion_profiles_{ff}/{name}_{ff}.png")
    else:
        print(f'No torsion profile can be generated for {name}')

In [None]:
vis_torsions(1, 'charmmff')