In [1]:
# -*- coding: utf-8 -*-
"""
Atomic radii and hydrophobicity values

*) radii adapted from parameters use in MSMS
*) hydrophobic contribution for amino acid residue atoms and backbone atoms
    are based on values obtained by Ghose and Crippen
*) values for N- and C-termini not included:
    N: -0.4748 (N-term)
    C: -0.1703 (C-term)
    OT1/2: -0.1962

Description v.0.1 (19.01.2010)
radii column #0: adjusted vdw radii
radii column #1: "all atom" Radius set
                 based on the "vdw.hydrogen" set of radii used in MidasPlus;
                 http://www.cgl.ucsf.edu/chimera/1.1700/docs/UsersGuide/midas/vdwtables.html
                 assignments are made according to standard atom and residue names.
HC  column #2: HC values for HP calculation 1986
HC  column #3: HC values for HP calculation 1989
HC  column #4: HC values for HP calculation 1998
"""
radii = {
    ("ALA", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("ALA", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("ALA", " CB "): [2.17, 1.25, 0.6594, 0.6483, 0.6420, 5],
    ("ALA", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("ALA", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("ARG", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("ARG", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("ARG", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("ARG", " CG "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("ARG", " CD "): [2.23, 1.25, -0.0867, 0.0466, 0.2196, 5],
    ("ARG", " NE "): [1.60, 1.45, -0.0886, -0.0455, -0.4204, 1],
    ("ARG", " CZ "): [2.10, 1.50, -0.2692, 0.1847, 0.1388, 0],
    ("ARG", " NH1"): [1.60, 1.45, -0.2817, -0.2894, -0.5851, 1],
    ("ARG", " NH2"): [1.60, 1.45, -0.2817, -0.2894, -0.5851, 1],
    ("ARG", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("ARG", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("ASN", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("ASN", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("ASN", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("ASN", " CG "): [2.10, 1.50, -0.1703, 0.0709, -0.1002, 0],
    ("ASN", " OD1"): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("ASN", " ND2"): [1.60, 1.45, -0.5992, -0.7048, -0.7185, 1],
    ("ASN", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("ASN", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("ASP", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("ASP", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("ASP", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("ASP", " CG "): [2.10, 1.50, -0.1703, 0.0709, -0.1002, 0],
    ("ASP", " OD1"): [1.60, 1.35, -0.1962, -0.2587, -0.1216, 2],
    ("ASP", " OD2"): [1.60, 1.35, -0.1962, -0.2587, -0.1216, 2],
    ("ASP", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("ASP", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("CYS", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("CYS", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("CYS", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("CYS", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("CYS", " CB "): [2.23, 1.25, -0.0867, 0.0466, 0.2196, 5],
    ("CYS", " SG "): [1.89, 1.35, 0.6449, 0.4008, 0.5110, 5],
    ("GLU", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("GLU", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("GLU", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("GLU", " CG "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("GLU", " CD "): [2.10, 1.50, -0.1703, 0.0709, -0.1002, 0],
    ("GLU", " OE1"): [1.60, 1.35, -0.1962, -0.2587, -0.1216, 2],
    ("GLU", " OE2"): [1.60, 1.35, -0.1962, -0.2587, -0.1216, 2],
    ("GLU", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("GLU", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("GLN", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("GLN", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("GLN", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("GLN", " CG "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("GLN", " CD "): [2.10, 1.50, -0.1703, 0.0709, -0.1002, 0],
    ("GLN", " OE1"): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("GLN", " NE2"): [1.60, 1.45, -0.5992, -0.7048, -0.7185, 1],
    ("GLN", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("GLN", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("GLY", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("GLY", " CA "): [2.23, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("GLY", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("GLY", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("HIS", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("HIS", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("HIS", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("HIS", " CG "): [2.10, 1.70, 0.3345, 0.1600, 0.1492, 4],
    ("HIS", " CD2"): [2.10, 1.70, 0.3952, 0.1569, 0.2578, 4],
    ("HIS", " ND1"): [1.60, 1.70, -0.0210, 0.0938, 0.0223, 3],
    ("HIS", " CE1"): [2.10, 1.70, -0.3509, 0.2027, 0.4154, 4],
    ("HIS", " NE2"): [1.60, 1.70, 0.3493, 0.4198, 0.1259, 7],
    ("HIS", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("HIS", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("ILE", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("ILE", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("ILE", " CB "): [2.37, 1.25, 0.1514, 0.0785, 0.0660, 5],
    ("ILE", " CG2"): [2.17, 1.25, 0.6594, 0.6483, 0.6420, 5],
    ("ILE", " CG1"): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("ILE", " CD1"): [2.17, 1.25, 0.6594, 0.6483, 0.6420, 5],
    ("ILE", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("ILE", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("LEU", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("LEU", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("LEU", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("LEU", " CG "): [2.37, 1.25, 0.1514, 0.0785, 0.0660, 5],
    ("LEU", " CD1"): [2.17, 1.70, 0.6594, 0.6483, 0.6420, 5],
    ("LEU", " CD2"): [2.17, 1.70, 0.6594, 0.6483, 0.6420, 5],
    ("LEU", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("LEU", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("LYS", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("LYS", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("LYS", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("LYS", " CG "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("LYS", " CD "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("LYS", " CE "): [2.23, 1.25, -0.0867, 0.0466, 0.2196, 5],
    ("LYS", " NZ "): [1.60, 1.25, -0.4748, -0.5333, -0.7499, 1],
    ("LYS", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("LYS", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("MET", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("MET", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("MET", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("MET", " CG "): [2.23, 1.25, -0.0867, 0.0466, 0.2196, 5],
    ("MET", " SD "): [1.89, 1.35, 1.0339, 0.6145, 0.5906, 5],
    ("MET", " CE "): [2.17, 1.25, 0.1460, 0.2430, 0.4143, 5],
    ("MET", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("MET", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("PHE", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("PHE", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("PHE", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("PHE", " CG "): [2.10, 1.70, 0.3345, 0.1600, 0.1492, 4],
    ("PHE", " CD1"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("PHE", " CD2"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("PHE", " CE1"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("PHE", " CE2"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("PHE", " CZ "): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("PHE", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("PHE", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("PRO", " N  "): [1.80, 1.45, 0.3990, 0.3954, 0.0132, 0],
    ("PRO", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("PRO", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("PRO", " CG "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("PRO", " CD "): [2.23, 1.25, -0.0867, 0.0466, 0.2196, 5],
    ("PRO", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("PRO", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("SER", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("SER", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("SER", " CB "): [2.23, 1.25, -0.0867, 0.0466, 0.2196, 0],
    ("SER", " OG "): [1.60, 1.35, -0.4220, -0.1858, -0.4603, 3],
    ("SER", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("SER", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("THR", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("THR", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("THR", " CB "): [2.37, 1.25, -0.5156, -0.0792, 0.0536, 0],
    ("THR", " OG1"): [1.60, 1.35, -0.4220, -0.1858, -0.4603, 3],
    ("THR", " CG2"): [2.17, 1.25, 0.6594, 0.6483, 0.6420, 5],
    ("THR", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("THR", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("TRP", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("TRP", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("TRP", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 0],
    ("TRP", " CG "): [2.10, 1.70, 0.3345, 0.1600, 0.1492, 4],
    ("TRP", " CD2"): [2.10, 1.70, 0.3345, 0.1600, 0.1492, 4],
    ("TRP", " CE2"): [2.10, 1.70, 0.2455, -0.2782, 0.2813, 4],
    ("TRP", " CE3"): [2.10, 1.70, 0.3147, 0.3411, 0.3050, 4],
    ("TRP", " CD1"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TRP", " NE1"): [1.60, 1.70, -0.1946, -0.4366, -0.2660, 6],
    ("TRP", " CZ2"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TRP", " CZ3"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TRP", " CH2"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TRP", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("TRP", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("TYR", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("TYR", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("TYR", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 4],
    ("TYR", " CG "): [2.10, 1.70, 0.3345, 0.1600, 0.1492, 0],
    ("TYR", " CD1"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TYR", " CE1"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TYR", " CD2"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TYR", " CE2"): [2.10, 1.70, 0.3174, 0.3411, 0.3050, 4],
    ("TYR", " CZ "): [2.10, 1.70, -0.1153, -0.1033, 0.1539, 4],
    ("TYR", " OH "): [1.60, 1.35, 0.1509, 0.1600, -0.1163, 3],
    ("TYR", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("TYR", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("VAL", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],
    ("VAL", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("VAL", " CB "): [2.37, 1.25, 0.1514, 0.0785, 0.0660, 5],
    ("VAL", " CG1"): [2.17, 1.25, 0.6594, 0.6483, 0.6420, 5],
    ("VAL", " CG2"): [2.17, 1.25, 0.6594, 0.6483, 0.6420, 5],
    ("VAL", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("VAL", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    ("MSE", " N  "): [1.80, 1.45, -0.2289, -0.3788, -0.6149, 1],  # selenomethionine
    ("MSE", " CA "): [2.37, 1.25, -0.4621, -0.1012, 0.3663, 0],
    ("MSE", " CB "): [2.23, 1.25, 0.4616, 0.3963, 0.4562, 5],
    ("MSE", " CG "): [2.23, 1.25, -0.0867, 0.0466, 0.2196, 5],
    ("MSE", " CE "): [2.17, 1.25, 0.1460, 0.2430, 0.4143, 5],
    ("MSE", "SE  "): [1.89, 1.35, 1.0339, 0.6145, 0.5906, 5],
    ("MSE", " C  "): [1.70, 1.50, -0.1703, 0.0709, -0.1002, 4],
    ("MSE", " O  "): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 2],
    (None, " OXT"): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 0],
    (None, "XE  "): [2.16, 2.16, 0.0, 0.0, 0.0, 0],  # Xenon
    (None, "CA  "): [2.00, 2.00, 0.0, 0.0, 0.0, 0],  # Calcium
    (None, "ZN  "): [2.00, 2.00, 0.0, 0.0, 0.0, 0],  # Zinc
    (None, " C2"): [1.70, 1.70, 0.0, 0.0, 0.0, 0],
    (None, " C4"): [1.70, 1.70, 0.0, 0.0, 0.0, 0],
    (None, " C5"): [1.70, 1.70, 0.0, 0.0, 0.0, 0],
    (None, " C6"): [1.70, 1.70, 0.0, 0.0, 0.0, 0],
    (None, " C8"): [1.70, 1.70, 0.0, 0.0, 0.0, 0],
    (None, " C "): [1.70, 1.25, 0.0, 0.0, 0.0, 0],
    (None, " C"): [1.70, 1.25, 0.0, 0.0, 0.0, 0],  # all C's if C2-C8 not found
    (None, " N4"): [1.60, 1.45, 0.0, 0.0, 0.0, 0],
    (None, " N"): [1.60, 1.70, 0.0, 0.0, 0.0, 0],  # all N's if NX not found
    (None, " O"): [1.60, 1.35, 0.0, 0.0, 0.0, 0],  # all remaining O's
    (None, " P"): [2.00, 1.70, 0.0, 0.0, 0.0, 0],  # all remaining P's
    (None, " S"): [2.00, 1.70, 0.0, 0.0, 0.0, 0],  # all remaining S's
    (None, " OT1"): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 0],  # Terminal Oxygen 1
    (None, " OT2"): [1.60, 1.35, -0.1729, -0.3514, -0.0233, 0],  # Terminal Oxygen 2
    (None, " H"): [1.00, 1.00, 0.0, 0.0, 0.0, 0],  # Hydrogen
}

charges = {
    ("ARG", " NE "): 1.0 / 3.0,
    ("ARG", " NH1"): 1.0 / 3.0,
    ("ARG", " NH2"): 1.0 / 3.0,
    ("ASP", " OD1"): -0.5,
    ("ASP", " OD2"): -0.5,
    ("GLU", " OE1"): -0.5,
    ("GLU", " OE2"): -0.5,
    ("HIS", " ND1"): 1.0 / 6.0,
    ("HIS", " NE2"): 1.0 / 6.0,
    ("LYS", " NZ "): 1.0,
    # (None, " N  "): -0.084,  # peptide bond
    # (None, " C  "): 0.506,
    # (None, " O  "): -0.422,
    (None, " OXT"): -1.0,  # C-terminal carboxylate
    (None, " OT1"): -0.5,
    (None, " OT2"): -0.5,
    (None, "MG  "): 2.0,  # Magnesium
    (None, "CA  "): 2.0,  # Calcium
    (None, "ZN  "): 2.0,  # Zinc
}


In [2]:
# -*- coding: utf-8 -*-
"""
class definitions for molecular structures (in PDB format)
"""
from io import TextIOWrapper  # necessary to check for file objects
from logging import getLogger


logger = getLogger(__name__)


class PDBAtom:
    """
    Class for a single atom in a (protein) structure.
    """

    def __init__(self, line=None):
        # extracts information from a line of PDB formatted text,
        # no error checking is performed.
        self.het = False
        self.atom_number = 0
        self.name = " X "
        self.alternate = ""
        self.residue = "UNK"
        self.chain = "X"
        self.residue_number = 0
        self.insertion = " "
        self.x = 0.0
        self.y = 0.0
        self.z = 0.0
        self.occupancy = 0.0
        self.bfactor = 0.0

        self.segi = " "
        self.element = " "
        self.crg = "  "

        # other possible atom parameters
        self.keep_old_bfactor = self.bfactor
        self.HP = 0.0
        self.radius = 2.0
        self.charge = 0.0
        self.formal_charge = 0.0
        self.par1 = 0.0
        self.par2 = 0.0
        self.pseudo_id = 0

        # property dictionary: {"prop1": value,"prop2": value"}
        self.prop_dic = {}

        if line is not None:  # read data from a PDB ATOM or HETATM line
            self.add_line(line)

    def add_line(self, line):
        """
        Read atom data from a line of PDB-formatted text,
        assumes line to start with 'ATOM' or 'HETATM'
        """
        self.het = line[:6] == "HETATM"
        self.atom_number = int(line[6:11])
        self.name = line[12:16]
        self.alternate = line[16]
        self.residue = line[17:20]
        self.chain = line[21]
        self.residue_number = line[22:26]
        self.insertion = line[26]
        self.x = float(line[30:38])
        self.y = float(line[38:46])
        self.z = float(line[46:54])

        # in some PDB files, everything beyond the coordinates is missing
        try:
            self.occupancy = float(line[54:60])
        except ValueError:
            self.occupancy = 1.0

        try:
            self.bfactor = float(line[60:66])
        except ValueError:
            self.bfactor = 20.0

        self.segi = line[71:76]
        self.element = line[76:78]
        self.crg = line[78:80]
        self.keep_old_bfactor = self.bfactor

    def set_prop(self, prop, value=None):
        """
        Set property 'prop' of atom to value.
        """
        if hasattr(self, prop):
            setattr(self, prop, value)

    def add_prop_dic(self, prop, value=None):
        """
        Add or update custom property dictionary.
        """
        self.prop_dic.update({prop: value})

    def _print_prop_info(self):
        u = []
        u.extend(
            [
                "keep_old_bfactor",
                "radius          ",
                "charge          ",
                "formal_charge   ",
                "par1            ",
                "par2            ",
                "HP              ",
                "pseudo_id       ",
            ]
        )
        u.extend([key for key in self.prop_dic.keys()])
        u = ["%-16s" % (str(i)) for i in u]
        print(" ".join(u))

    def print_prop(self, form="table"):
        """
        Output non-standard properties in various formats
        form: 'table'
              'catalophore'
              'csv'
        """
        if form == "table":
            print(self.__str__())
            print("Properties:")
            print("-----------")
            print(".keep_old_bfactor:   %s" % self.keep_old_bfactor)
            print(".radius          :   %s" % self.radius)
            print(".charge          :   %s" % self.charge)
            print(".formal_charge   :   %s" % self.formal_charge)
            print(".par1            :   %s" % self.par1)
            print(".par2            :   %s" % self.par2)
            print(".HP              :   %s" % self.HP)
            print(".pseudo_id       :   %s" % self.pseudo_id)
            print("-------------------------")
            print("Additional: .prop_dic")
            for key, value in self.prop_dic.items():
                print("%s: %s" % (key, value))

        elif form == "catalophore":
            print(
                ",".join(
                    [
                        str(a)
                        for a in [
                            self.atom_number,
                            self.residue_number,
                            self.x,
                            self.y,
                            self.z,
                            self.keep_old_bfactor,
                            self.HP,
                        ]
                    ]
                )
            )

        else:
            u = [
                int(self.atom_number),  # point_id
                int(self.residue_number),  # cav_id
                float(self.x),  # x
                float(self.y),  # y
                float(self.z),  # z
            ]
            u.extend(
                [
                    self.keep_old_bfactor,
                    self.radius,
                    self.charge,
                    self.formal_charge,
                    self.par1,
                    self.par2,
                    self.HP,
                    self.pseudo_id,
                ]
            )
            u.extend([value for value in self.prop_dic.values()])
            u = ["%s" % (str(i)) for i in u]
            print(",".join(u))

    def __str__(self):
        return self.get_pdbstr()

    def get_pdbstr(self, prop=None, prop_key=None):
        """
        Return a PDB-formatted line.
        The B-factor column can be replaced by the value in 'prop' or an entry from 'prop_dic'
        (if prop == 'prop_dic')

        If property of dictionary key is not found, the B-factor column is set to -1.
        :param prop: property of PDBAtom object or 'prop_dic'
        :param prop_key: key in PDBAtom.prop_dic
        :return: PDB-formatted line
        """
        if self.het:
            start = "HETATM"
        else:
            start = "ATOM  "

        if prop is None or prop == "bfactor":  # use the value im atom.bfactor
            value = self.bfactor
        elif prop == "prop_dic":  # use a value from 'prop_dic'
            value = self.prop_dic.get(prop_key, -1.0)
        else:  # use another property
            if hasattr(self, prop):
                value = getattr(self, prop)
            else:
                value = -1.0

        return (
            "%-6s%5d %4s%1s%3s %1s%4d%1s   %8.3f%8.3f%8.3f%6.2f%6.2f      %-4s%2s%2s"
            % (
                start,
                int(self.atom_number),
                str(self.name),
                str(self.alternate),
                str(self.residue),
                str(self.chain),
                int(self.residue_number),
                str(self.insertion),
                float(self.x),
                float(self.y),
                float(self.z),
                float(self.occupancy),
                value,  # float(self.bfactor),  replacement value for the B-factor
                str(self.segi),
                str(self.element),
                str(self.crg),
            )
        )


class PDBStructure:
    """
    Class for a protein structure which is read from a file-like,
    PDB-formatted object.
    """

    def __init__(self, file_obj=None):
        """
        reads a PDB file line per line. Non-atom lines are stored in a
        separate list.

        file_obj can be a file object or a list (tuple) of PDB lines
        """
        self.remarks = []  # list that contains non-atom lines
        self.atom = []  # list of instances of PdbAtom
        self.water = []  # water coordinates

        # The following attributes are currently not needed
        # self.deleted_atom = []      # as atom[], but for deleted atoms (H, HETATM,...)
        # self.dummy = []             # dummy coordinates

        # # pseudo center atoms (CavBase-like pseudo atoms)
        # self.pseudo_aliphatic = []  # aliphatic pseudo atoms
        # self.pseudo_pi = []         # pi pseudo atoms
        # self.pseudo_donor = []      # donor atoms
        # self.pseudo_acceptor = []   # acceptor atoms
        # self.pseudo_DON_ACC = []    # donor_acceptor atoms

        if file_obj is not None:
            if (
                isinstance(file_obj, list)
                or isinstance(file_obj, tuple)
                or isinstance(file_obj, TextIOWrapper)
            ):  # 'file'
                tmp = file_obj
            elif isinstance(file_obj, str):
                tmp = file_obj.split("\n")  # assuming PyMOL pdbstr
            else:
                return

            for line in tmp:
                if self._is_atom(line):
                    self.atom.append(PDBAtom(line))
                else:
                    self.remarks.append(line)

            self.add_water()

    @staticmethod
    def _is_atom(line):
        """Checks whether the line contains atom data."""
        return line[:4] == "ATOM" or line[:6] == "HETATM"

    def __str__(self):
        return self.get_pdbstr()

    def get_pdbstr(self, prop=None, prop_key=None):
        """
        Return pdbstr of atoms
        """
        return (
            "\n".join([i.get_pdbstr(prop=prop, prop_key=prop_key) for i in self.atom])
            + "\n"
        )

    def add_water(self, atoms=None, resn_list=("H2O", "HOH", "WAT")):
        """
        Put water atoms to separate atoms list called water.
        atoms: list of atom objects
        resn_list: ("H2O","HOH","WAT") residue names for water molecules.
        """
        if atoms is None:
            atoms = self.atom

        for atom in atoms:
            if atom.residue in resn_list:
                self.water.append(atom)

    def change_bfac(
        self, atom_list=None, from_prop="par1", prop_key=None, to_prop="bfactor"
    ):
        """
        Change the bfac field from old property to new property.
        atomlist: list of PdbAtom objects (standard: self.atom)

        from_prop: "par1",
                   "par2",
                   "radius"
                   "charge"
                   "occupancy"
                   ...         other properties have to be added
                   "prop_dic" additional property dictionary accessible by
                              "prop_key" value


                   "restore"   restore old b-factor value
        to_prop:   "bfactor"   currently only b-factor can be switched
                               old b-factor is still in "keep_old_bfactor"
                               (KG: This is obviously not true, anymore. See code!)
        """
        if atom_list is None:
            atom_list = self.atom
        elif hasattr(
            self, atom_list
        ):  # in case the PDBStructure object has an attribute 'atom_list'
            atom_list = getattr(self, atom_list)

        if len(atom_list) == 0:
            logger.warning("No atoms in selection.")
            return

        if from_prop == "prop_dic":  # read from property dictionary
            if prop_key:
                if atom_list[0].prop_dic.get(prop_key) is not None:
                    if hasattr(atom_list[0], to_prop):  # check atom[0]
                        for atom in atom_list:
                            setattr(atom, to_prop, atom.prop_dic.get(prop_key))
                    else:
                        logger.warning("No %s property found in atom.", str(to_prop))
                else:
                    logger.warning(
                        "No %s property found in atom prop_dic.", str(prop_key)
                    )
                    return

        elif from_prop == "restore":  # restore original B-factor
            for atom in atom_list:
                atom.bfactor = atom.keep_old_bfactor

        elif hasattr(atom_list[0], from_prop) and hasattr(  # check atom[0]
            atom_list[0], to_prop
        ):
            for atom in atom_list:
                setattr(atom, to_prop, getattr(atom, from_prop))

    def assign_rad_hp_chg(
        self, HP_dict=None, HP_column=2, rad_dict=None, rad_column=0, chg_dict=None
    ):
        """
        Assign radii and hydrophobicity values to atoms in a structure
        :param chg_dict:
        :param HP_dict:
        :param HP_column:
        :param rad_dict:
        :param rad_column:
        :return:
        """
        if rad_dict is None:
            rad_dict = radii  # internal radius dictionary
        if HP_dict is None:
            HP_dict = radii  # internal radius/HP dictionary
        if chg_dict is None:
            chg_dict = charges

        for atom in self.atom:
            queries = [  # atom queries with decreasing specificity
                (atom.residue, atom.name),
                (None, atom.name),
                (None, atom.name[:3]),
                (None, atom.name[:2]),
            ]

            # set atom radii
            for query in queries:
                value = rad_dict.get(query)
                if value:
                    atom.radius = value[rad_column]
                    break
            else:
                logger.warning(
                    "ATOM %s %s %s unknown. Radius is set to %.1f"
                    % (atom.atom_number, atom.residue, atom.name, 2.0)
                )
                atom.radius = 2.0

            # set HC values
            for query in queries:
                value = HP_dict.get(query)
                if value:
                    atom.HP = value[HP_column]
                    break
            else:
                logger.warning(
                    "ATOM %s %s %s unknown. HC value is set to %.1f"
                    % (atom.atom_number, atom.residue, atom.name, 0.0)
                )
                atom.HP = 0.0

            # set partial charges
            for query in queries:
                value = chg_dict.get(query)
                if value:
                    atom.charge = value
                    break
            else:
                atom.charge = 0.0


In [3]:
# -*- coding: utf-8 -*-
"""
function definitions for LigSite and cavity detection
"""
import itertools
import math

import numpy as np
from numpy.lib.stride_tricks import as_strided



# 32-bit precision provides speed-up in HP annotation
FP_DTYPE = np.float32


def setup_grid(coords, cushion, d_grid, init, dtype):
    """
    Setup and initialize grid using atom coordinates
    """
    # min and max in cartesian coordinates
    min_coords = coords.min(axis=0)
    max_coords = coords.max(axis=0)
    # min and max in grid coordinates
    min_grid = np.floor((min_coords - cushion) / d_grid)
    max_grid = np.ceil((max_coords + cushion) / d_grid)

    origin = min_grid * d_grid
    extent = (max_grid - min_grid + 1).astype(int)

    return Grid(origin=origin, extent=extent, d=d_grid, init=init, dtype=dtype)


def mask_grid(
    grid, coords, radii, radius_factor, probe_radius, softness, protein_flag, soft_flag
):
    """
    Mask grid using atom coordinates and radii
    """
    # atom coordinates in grid units
    grid_coords = (coords - grid.origin) / grid.d
    # atom radii in grid units
    grid_radii = (radii * radius_factor + probe_radius) / grid.d
    r1sq = grid_radii**2  # squared outer, soft radius
    r2sq = (grid_radii - softness) ** 2  # squared inner, hard radius

    # origin and space-diagonal coordinates of the sub-grids around the atoms
    sg_start = np.clip(
        np.floor(grid_coords - grid_radii.reshape(-1, 1)).astype(int),
        (0, 0, 0),
        grid.extent,
    )
    sg_end = np.clip(
        np.ceil(grid_coords + grid_radii.reshape(-1, 1)).astype(int) + 1,
        (0, 0, 0),
        grid.extent,
    )

    for i in range(grid_coords.shape[0]):  # loop over all atom coordinates
        x_start, y_start, z_start = sg_start[i]
        x_end, y_end, z_end = sg_end[i]
        sub_grid = grid.get_subgrid(x_start, y_start, z_start, x_end, y_end, z_end)

        x, y, z = np.ogrid[x_start:x_end, y_start:y_end, z_start:z_end]
        dist = (
            (x - grid_coords[i, 0]) ** 2
            + (y - grid_coords[i, 1]) ** 2
            + (z - grid_coords[i, 2]) ** 2
        )

        # grid points within the inner, hard radius
        mask_hard = dist < r2sq[i]
        sub_grid[mask_hard] = protein_flag

        # grid points between inner and outer radius
        mask_soft = (dist < r1sq[i]) & (sub_grid != protein_flag)
        sub_grid[mask_soft] = soft_flag


def do_ligsite(masked_grid, protein_flag):
    """
    Run ligsite algorithm.
    Calculate LigSite scores for a masked grid
    """
    grid = masked_grid.get_grid()
    nx = masked_grid.nx
    ny = masked_grid.ny
    nz = masked_grid.nz
    offset = masked_grid.offset

    # analyze along the x-direction
    for iy in range(ny):
        for iz in range(nz):
            line = grid[:, iy, iz]
            _analyze_line(line, protein_flag)

    # analyze along the y-direction
    for ix in range(nx):
        for iz in range(nz):
            line = grid[ix, :, iz]
            _analyze_line(line, protein_flag)

    # analyze along the z-direction
    for ix in range(nx):
        for iy in range(ny):
            line = grid[ix, iy, :]
            _analyze_line(line, protein_flag)

    # space-diagonal 1 (1,1,1)
    off = offset(1, 1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(nx - ix, ny - iy, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(nx - ix, ny, nz - iz)
            line = as_strided(grid[ix, :, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(1, ny):
        for iz in range(1, nz):
            length = min(nx, ny - iy, nz - iz)
            line = as_strided(grid[:, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    # space-diagonal 2 (-1,1,1)
    off = offset(-1, 1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(ix + 1, ny - iy, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(ix + 1, ny, nz - iz)
            line = as_strided(grid[ix, :, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(1, ny):
        for iz in range(1, nz):
            length = min(nx, ny - iy, nz - iz)
            line = as_strided(grid[::-1, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    # space-diagonal 3 (1,-1,1)
    off = offset(1, -1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(nx - ix, iy + 1, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(nx - ix, ny, nz - iz)
            line = as_strided(grid[ix, ::-1, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(ny - 1):
        for iz in range(1, nz):
            length = min(nx, iy + 1, nz - iz)
            line = as_strided(grid[:, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    # space-diagonal 4 (-1,-1,1) equiv. to (1,1,-1)
    off = offset(-1, -1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(ix + 1, iy + 1, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(ix + 1, ny, nz - iz)
            line = as_strided(grid[ix, ::-1, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(ny - 1):
        for iz in range(1, nz):
            length = min(nx, iy + 1, nz - iz)
            line = as_strided(grid[::-1, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)


def _analyze_line(line, protein_flag):
    """
    Analyze line in grid.
    Analyze a line from the masked grid and increment
    all ligsite scores for points between two protein grid points

    assumes "line" to be an integer or float numpy array

    protein_flag ... value to denote protein grid points
    """
    # indices of the protein grid points
    ind = np.argwhere(line == protein_flag).flatten()
    if ind.shape[0] > 0:  # if there are any protein points in the line
        start = int(ind[0]) + 1
        end = int(ind[-1])
        if (end - start) > 0:
            # part of the line between the first and the last protein grid point
            tmp = line[start:end]
            # increment all non-protein grid points in that part
            tmp[tmp > protein_flag] += 1


def find_cavities(ligsite_grid, cutoff, gap, min_size, max_size, radius=1.4, vol_res=3.):
    """
    Find cavities in a LigSite grid using sets
    """
    cavities = []  # list of cavity objects to be returned

    d_grid = ligsite_grid.d  # get grid spacing
    origin = ligsite_grid.origin  # get origin of the grid in cartesian coordinates
    grid = ligsite_grid.get_grid()  # get underlying grid array

    # array of indices grid points above threshold
    indices = np.argwhere(grid >= cutoff)

    # set of those indices
    ind_set = {tuple(row) for row in indices}

    # get relative indices around (but not including) the origin
    environment = np.array(
        [
            xyz
            for xyz in itertools.product(range(-1 - gap, 2 + gap), repeat=3)
            if xyz != (0, 0, 0)
        ]
    ).reshape(1, -1, 3)

    while ind_set:
        points = []  # indices of a new cavity

        seeds = [
            ind_set.pop(),
        ]
        while seeds:
            # add the point(s) to the growing cavity
            points.extend(seeds)

            # create a set of neighbours
            neighbour_coords = (
                np.array(seeds).reshape(-1, 1, 3) + environment
            ).reshape(-1, 3)
            neighbours = {tuple(row) for row in neighbour_coords}

            # set of neighbours that qualify as cavity points
            in_cavity = ind_set.intersection(neighbours)

            # remove them from the set of indices
            ind_set.difference_update(in_cavity)

            # these points are the 'seeds' for the next iteration
            seeds = list(in_cavity)

        # now 'points' contains a complete cavity
        if min_size <= len(points) <= max_size:
            cavities.append(_construct_cavity(points, grid, d_grid, origin, radius, vol_res))

    # sort cavities according to their size in descending order and generate PDBStructure objects
    cavities.sort(key=lambda x: x.size, reverse=True)
    _gen_cav_pdb(cavities)

    return cavities


def find_cavities_dist(ligsite_grid, cutoff, max_dist, min_size, max_size, radius=1.4, vol_res=3.):
    """
    Find cavities in a LigSite grid, using a distance cutoff
    """
    cavities = []  # list of cavity objects to be returned

    
    d_grid = ligsite_grid.d  # get grid spacing
    origin = ligsite_grid.origin  # get origin of the grid in cartesian coordinates
    grid = ligsite_grid.get_grid()  # get underlying grid array

    # array of indices with LigSite scores above cutoff
    indices = np.argwhere(grid >= cutoff)
    n_ind = indices.shape[0]
    # flag for cavity search, True ... grid point not yet analyzed
    flag = np.ones(n_ind, dtype=bool)

    # pairwise squared distances in grid coordinates
    norm = (indices * indices).sum(axis=1).reshape(1, -1)
    dist = (norm + norm.T) - 2 * np.dot(indices, indices.T)

    # squared distance cutoff in grid coordinates
    sq_max = max_dist**2 / d_grid**2

    for i in range(n_ind):  # step through the array by lines
        if flag[i]:  # if we have not yet dealt with this grid point
            seeds = [
                i,
            ]
            flag[i] = False
            points = []

            while seeds:
                tmp_list = []
                for j in seeds:
                    ind_j = tuple(indices[j])
                    points.append(ind_j)
                    neigh = np.argwhere((dist[j] <= sq_max) & flag)[:, 0]
                    flag[neigh] = False
                    tmp_list.extend(neigh)
                seeds = tmp_list

            # now 'points' contains a complete cavity
            if min_size <= len(points) <= max_size:
                cavities.append(_construct_cavity(points, grid, d_grid, origin, radius, vol_res))

    # sort cavities according to their size in descending order and generate PDBStructure objects
    cavities.sort(key=lambda x: x.size, reverse=True)
    _gen_cav_pdb(cavities)

    return cavities


def _construct_cavity(points, grid, d_grid, origin, radius=1.4, vol_res=3.):
    """
    construct a cavity from the assembled grid points
    """
    points = np.array(points)  # grid indices
    coords = points.astype(FP_DTYPE) * d_grid + origin  # 3D coordinates
    # create new cavity object
    cav = Cavity(points, coords, d_grid)
    # estimate cavity volume
    cav.volume = estimate_volume(coords, radius, d_grid / vol_res)
    # add LigSite scores to annotations dictionary
    cav.annotations["LIG"] = grid[points[:, 0], points[:, 1], points[:, 2]]

    return cav


def _gen_cav_pdb(cavities):
    """
    generate PDBStructure objects for the cavities
    """
    n_at = 1
    for n_cav, cav in enumerate(cavities):
        cav.pdb = PDBStructure()
        if n_cav >= 9999:  # just in case that there are more than 9999 cavities
            n_cav -= 9999

        for n_point, xyz in enumerate(cav.coords):
            # generate a dummy atom
            atom = PDBAtom(
                f"HETATM{n_at:5d}  XP  CAV X   1       3.500  53.900  22.400  1.00  7.00"
            )
            # set the actual parameters
            atom.x, atom.y, atom.z = xyz
            atom.residue_number = n_cav + 1
            # store the LigSite score in the B-factor column
            atom.bfactor = cav.annotations["LIG"][n_point]
            # set the element to "X" (necessary for reading with Yasara)
            atom.element = "X"
            atom.keep_old_bfactor = atom.bfactor
            atom.add_prop_dic("LIG", atom.bfactor)
            cav.pdb.atom.append(atom)
            n_at += 1
            if n_at > 99999:
                n_at -= 99999


class Cavity:
    """
    definition of a cavity object
    """

    def __init__(self, points, coords, d_grid):
        self.size = len(points)
        self.volume = 0.
        self.d_grid = d_grid
        self.points = points  # cavity points in grid coordinates
        self.coords = coords  # cartesian cavity coordinates

        self.annotations = dict()  # cavity point annotations
        self.pdb = None  # cavity as a PDBStructure object

        self.shaped = False  # indicate, whether the cavity was shaped

    def __str__(self):
        return self.get_pdbstr()

    def get_pdbstr(self, annotation=None):
        header = self._get_pdb_header()
        if annotation is None:  # use the standard B-factor column
            return header + self.pdb.get_pdbstr()
        else:  # replace the B-factor column with annotation value
            return header + self.pdb.get_pdbstr("prop_dic", annotation)

    def _get_pdb_header(self):
        header = [
            "REMARK",
            f"REMARK number of grid_points:{self.size:7d}",
            f"REMARK approximate cavity volume:{self.volume:7.0f} Angs.**3",
        ]
        return "\n".join(header) + "\n"

    def write(self, file_obj, file_format="pdb", annotation=None):
        if file_format == "pdb":
            file_obj.write(self.get_pdbstr(annotation))
        elif file_format == "csv":
            scores = [
                self.annotations[key] for key in annotation if key in self.annotations
            ]
            for i, (x, y, z) in enumerate(self.coords):
                data = [x, y, z]
                data.extend([score[i] for score in scores])
                line = ", ".join([f"{item:f}" for item in data])
                file_obj.write(f"{line}\n")

    def annotate(self, atom_coords, chg, hp, c1=1.0, c2=4.0):
        """
        Annotate cavity points with Coulomb and hydrophobic potentials
        """
        cp, lp = calculate_cp_lp(self.coords, atom_coords, chg, hp, c1, c2)
        self.annotations["CP"] = cp
        self.annotations["HP"] = lp

        # write value to PDBatom properties 'CP' and 'HP'
        for i, (cp_value, lp_value) in enumerate(zip(cp, lp)):
            self.pdb.atom[i].set_prop("HP", lp_value)
            self.pdb.atom[i].add_prop_dic("HP", lp_value)
            self.pdb.atom[i].add_prop_dic("CP", cp_value)

    def shape(self, shaping_obj, d_max):
        """
        Shape a cavity based on points in 'shaping_obj' and a maximum distance of 'd_max'.
        The function modifies the original cavity!

        :param d_max: maximum distance of a cavity point from any point in 'shaping_obj' to be retained
        :param shaping_obj: Nx3 numpy array, coordinates of 'shaping_obj'
        :return: number of retained cavity points
        """
        flags = crop_pointcloud(self.coords, shaping_obj.astype(FP_DTYPE), d_max)

        num_retained = np.sum(flags)  # number of retained points
        self.size = num_retained

        self.points = self.points[flags]
        self.coords = self.coords[flags]

        for key, value in self.annotations.items():
            self.annotations[key] = self.annotations[key][flags]

        atom = self.pdb.atom
        self.pdb.atom = [atom[i] for i, flag in enumerate(flags) if flag]

        self.shaped = True

        return num_retained


def calc_dist(cloud_1, cloud_2, do_sqrt=False):
    """
    calculate the distance matrix for two point clouds
    :param cloud_1: N x k numpy array
    :param cloud_2: M x k numpy array
           (k ... dimension of the clouds)
    :param do_sqrt: if True return the square root
    :return: N x M numpy array with euclidean distances
    """
    norm_1 = (cloud_1 * cloud_1).sum(axis=1).reshape(-1, 1)
    norm_2 = (cloud_2 * cloud_2).sum(axis=1).reshape(1, -1)
    dist = (norm_1 + norm_2) - 2.0 * np.dot(cloud_1, cloud_2.T)

    if do_sqrt:
        return np.sqrt(dist, out=dist)
    else:
        return dist


def crop_pointcloud(cloud_1, cloud_2, d_max=2.0):
    """
    crop one point cloud based on the points of another point cloud,
    all points in 'big', which are closer than 'd_max' to any point in 'small',
    are retained

    :param cloud_1: N x k numpy array
    :param cloud_2: M x k numpy array
                  (k ... dimension of the clouds)
    :param d_max: maximum distance for a point in 'cloud_1' to be retained
    :return: boolean array with shape (N,)
             True ... point is retained
             False ... point is deleted
    """
    diff = calc_dist(cloud_1, cloud_2)  # calculate all pairwise distances
    min_dist = diff.min(axis=1)  # get the minimum distance

    return min_dist <= d_max**2


def calculate_cp_lp(points, atom_coords, chg, hp, c1=1.0, c2=4.0):
    """
    Calculate the Coulomb and hydrophobic potential at points using the coordinates
    in 'atom_coords' and the (partial) charges and atomic hydrophobicity values in 'chg' and 'hp'.
    Uses Fermi-weighting for calculating the hydrophobic potential and a relative epsilon of 'r'.

    :param points: points at the which the potentials will be calculated (M, 3)
    :param atom_coords: protein coordinates (N, 3)
    :param chg: (partial) charges assigned to the protein atoms (N,)
    :param hp: hydrophobicity parameters assigned to the protein atoms (N,)
    :param c1: parameter for Fermi-weighting
    :param c2: parameter for Fermi-weighting
    :return: cp, lp (M,)
    """
    # calculate squared distances
    dist_sq = calc_dist(points, atom_coords, do_sqrt=False)

    # Coulomb potential, relative permittivity = distance
    # cp = 557.0 * sum(chg / dist**2)
    cp = 557.0 * np.sum(np.divide(chg[np.newaxis, :], dist_sq), axis=1)

    # calculate distances, dist_sq is overwritten!
    dist = np.sqrt(dist_sq, out=dist_sq)

    # fermi_wt = (math.exp(-c1 * c2) + 1.0) / (np.exp(c1 * (dist - c2)) + 1.0)
    # calculation of the weighting term is spilt-up and uses
    # the 'out' parameter to prevent creation of a new array at each step
    fermi_wt = dist  # COPY if 'dist' is used later
    np.subtract(fermi_wt, c2, out=fermi_wt)  # dist - c2
    np.multiply(fermi_wt, c1, out=fermi_wt)  # c1 * (dist - c2)
    if FP_DTYPE == np.float32:
        # clipping is necessary to prevent overflow in np.exp
        np.clip(fermi_wt, -103.9, 88.7, out=fermi_wt)
    np.exp(fermi_wt, out=fermi_wt)  # np.exp(c1 * (dist - c2))
    np.add(fermi_wt, 1.0, out=fermi_wt)  # np.exp(c1 * (dist - c2)) + 1.0
    # (math.exp(-c1 * c2) + 1.0) / (np.exp(c1 * (dist - c2)) + 1.0)
    np.divide(math.exp(-c1 * c2) + 1.0, fermi_wt, out=fermi_wt)

    # sum of weights, clipped to 1.e-6
    # MUST be calculated first, because the next step overwrites 'fermi_wt'
    sum_wt = np.clip(np.sum(fermi_wt, axis=1), 1.0e-6, np.inf)
    # weighted sum of atom contributions
    lp = np.sum(np.multiply(fermi_wt, hp[np.newaxis, :], out=fermi_wt), axis=1)

    np.divide(lp, sum_wt, out=lp)

    return cp, lp


def estimate_volume(points, radius: float, d_grid: float) -> float:
    """
    estimate the volume of a point cloud
    point ... 3D coordinates of the cloud
    radius ... radius of the sphere placed at each point
    d_grid ... grid spacing

    The object is placed on a regular grid; voxels within the spheres placed at the points are
    counted; the estimated volume is n_voxels * d_grid**3
    """
    radii = np.ones(points.shape[0], dtype=FP_DTYPE) * radius
    grid = setup_grid(points, radius + d_grid, d_grid, 0, np.int8)
    mask_grid(grid, points, radii, 1.0, 0.0, 0.0, 1, 0)
    return np.sum(grid.get_grid()) * d_grid ** 3


class Grid:
    """
    Class for a grid object.

    Extension (# of grid points) given by a tuple of
    length 3. Default initialization with value "0".

    USAGE: Grid(origin, extend, init=0)
              extend: tuple:(x,y,z)  grid extension number of grid points
                              in (x , y , z)
              init:        :initialization value of grid point (standard=0)
    """

    def __init__(
        self, origin=(0.0, 0.0, 0.0), extent=(50, 50, 50), d=0.7, init=0, dtype=np.int32
    ):
        self.nx, self.ny, self.nz = extent
        self.extent = extent
        self.origin = np.array(origin, dtype=FP_DTYPE)
        self.d = d
        self._grid = np.zeros(extent, dtype=dtype)
        if init != 0:
            self._grid = init

    def get_subgrid(self, x0, y0, z0, x1, y1, z1):
        """
        Return a view of a sub-grid of the internal np.array
        :param x0: start x-index
        :param y0: start y-index
        :param z0: start z-index
        :param x1: end x-index
        :param y1: end y-index
        :param z1: end z-index
        :return: reference to the sub-grid
        """
        return self._grid[x0:x1, y0:y1, z0:z1]

    def get_grid(self):
        """
        Return a view of the complete grid object
        :return: reference to the internal np.array
        """
        return self._grid

    def get_value(self, indices):
        """
        Return value of grid point
        :param indices: tuple of indices (i, j, k); cannot be a np.array
        :return: value at grid point
        """
        return self._grid[indices]

    def set_value(self, indices, value):
        """
        Set value of a grid point
        :param indices: tuple of indices (i, j, k); cannot be a np.array
        :param value: new value
        """
        self._grid[indices] = value

    def is_valid_index(self, index):
        """
        Check whether 'index' contains valid indices of the grid.
        :param index: tuple of indices (i, j, k); cannot be an np.array
        :return: True/False
        """
        x, y, z = index
        return 0 <= x < self.nx and 0 <= y < self.ny and 0 <= z < self.nz

    def coordinates(self, index):
        """
        Return the cartesian coordinates of a grid point
        :param index: indices (i, j, k), tuple, list, np.array
        :return: np.array with cartesian coordinates
        """
        return self.origin + np.array(index, dtype=FP_DTYPE) * self.d

    def offset(self, dx, dy, dz):
        """
        Return the offset in bytes
        :param dx: increment in x
        :param dy: increment in y
        :param dz: increment in z
        :return: offset, to be used for as_strided
        """
        return ((dx * self.ny + dy) * self.nz + dz) * self._grid.itemsize


In [4]:
# -*- coding: utf-8 -*-
"""
function definitions for LigSite and cavity detection
"""
import itertools
import math

import numpy as np
from numpy.lib.stride_tricks import as_strided


# 32-bit precision provides speed-up in HP annotation
FP_DTYPE = np.float32


def setup_grid(coords, cushion, d_grid, init, dtype):
    """
    Setup and initialize grid using atom coordinates
    """
    # min and max in cartesian coordinates
    min_coords = coords.min(axis=0)
    max_coords = coords.max(axis=0)
    # min and max in grid coordinates
    min_grid = np.floor((min_coords - cushion) / d_grid)
    max_grid = np.ceil((max_coords + cushion) / d_grid)

    origin = min_grid * d_grid
    extent = (max_grid - min_grid + 1).astype(int)

    return Grid(origin=origin, extent=extent, d=d_grid, init=init, dtype=dtype)


def mask_grid(
    grid, coords, radii, radius_factor, probe_radius, softness, protein_flag, soft_flag
):
    """
    Mask grid using atom coordinates and radii
    """
    # atom coordinates in grid units
    grid_coords = (coords - grid.origin) / grid.d
    # atom radii in grid units
    grid_radii = (radii * radius_factor + probe_radius) / grid.d
    r1sq = grid_radii**2  # squared outer, soft radius
    r2sq = (grid_radii - softness) ** 2  # squared inner, hard radius

    # origin and space-diagonal coordinates of the sub-grids around the atoms
    sg_start = np.clip(
        np.floor(grid_coords - grid_radii.reshape(-1, 1)).astype(int),
        (0, 0, 0),
        grid.extent,
    )
    sg_end = np.clip(
        np.ceil(grid_coords + grid_radii.reshape(-1, 1)).astype(int) + 1,
        (0, 0, 0),
        grid.extent,
    )

    for i in range(grid_coords.shape[0]):  # loop over all atom coordinates
        x_start, y_start, z_start = sg_start[i]
        x_end, y_end, z_end = sg_end[i]
        sub_grid = grid.get_subgrid(x_start, y_start, z_start, x_end, y_end, z_end)

        x, y, z = np.ogrid[x_start:x_end, y_start:y_end, z_start:z_end]
        dist = (
            (x - grid_coords[i, 0]) ** 2
            + (y - grid_coords[i, 1]) ** 2
            + (z - grid_coords[i, 2]) ** 2
        )

        # grid points within the inner, hard radius
        mask_hard = dist < r2sq[i]
        sub_grid[mask_hard] = protein_flag

        # grid points between inner and outer radius
        mask_soft = (dist < r1sq[i]) & (sub_grid != protein_flag)
        sub_grid[mask_soft] = soft_flag


def do_ligsite(masked_grid, protein_flag):
    """
    Run ligsite algorithm.
    Calculate LigSite scores for a masked grid
    """
    grid = masked_grid.get_grid()
    nx = masked_grid.nx
    ny = masked_grid.ny
    nz = masked_grid.nz
    offset = masked_grid.offset

    # analyze along the x-direction
    for iy in range(ny):
        for iz in range(nz):
            line = grid[:, iy, iz]
            _analyze_line(line, protein_flag)

    # analyze along the y-direction
    for ix in range(nx):
        for iz in range(nz):
            line = grid[ix, :, iz]
            _analyze_line(line, protein_flag)

    # analyze along the z-direction
    for ix in range(nx):
        for iy in range(ny):
            line = grid[ix, iy, :]
            _analyze_line(line, protein_flag)

    # space-diagonal 1 (1,1,1)
    off = offset(1, 1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(nx - ix, ny - iy, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(nx - ix, ny, nz - iz)
            line = as_strided(grid[ix, :, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(1, ny):
        for iz in range(1, nz):
            length = min(nx, ny - iy, nz - iz)
            line = as_strided(grid[:, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    # space-diagonal 2 (-1,1,1)
    off = offset(-1, 1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(ix + 1, ny - iy, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(ix + 1, ny, nz - iz)
            line = as_strided(grid[ix, :, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(1, ny):
        for iz in range(1, nz):
            length = min(nx, ny - iy, nz - iz)
            line = as_strided(grid[::-1, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    # space-diagonal 3 (1,-1,1)
    off = offset(1, -1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(nx - ix, iy + 1, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(nx - ix, ny, nz - iz)
            line = as_strided(grid[ix, ::-1, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(ny - 1):
        for iz in range(1, nz):
            length = min(nx, iy + 1, nz - iz)
            line = as_strided(grid[:, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    # space-diagonal 4 (-1,-1,1) equiv. to (1,1,-1)
    off = offset(-1, -1, 1)
    for ix in range(nx):
        for iy in range(ny):
            length = min(ix + 1, iy + 1, nz)
            line = as_strided(grid[ix, iy, :], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for ix in range(nx):
        for iz in range(1, nz):
            length = min(ix + 1, ny, nz - iz)
            line = as_strided(grid[ix, ::-1, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)

    for iy in range(ny - 1):
        for iz in range(1, nz):
            length = min(nx, iy + 1, nz - iz)
            line = as_strided(grid[::-1, iy, iz], shape=(length,), strides=(off,))
            _analyze_line(line, protein_flag)


def _analyze_line(line, protein_flag):
    """
    Analyze line in grid.
    Analyze a line from the masked grid and increment
    all ligsite scores for points between two protein grid points

    assumes "line" to be an integer or float numpy array

    protein_flag ... value to denote protein grid points
    """
    # indices of the protein grid points
    ind = np.argwhere(line == protein_flag).flatten()
    if ind.shape[0] > 0:  # if there are any protein points in the line
        start = int(ind[0]) + 1
        end = int(ind[-1])
        if (end - start) > 0:
            # part of the line between the first and the last protein grid point
            tmp = line[start:end]
            # increment all non-protein grid points in that part
            tmp[tmp > protein_flag] += 1


def find_cavities(ligsite_grid, cutoff, gap, min_size, max_size, radius=1.4, vol_res=3.):
    """
    Find cavities in a LigSite grid using sets
    """
    cavities = []  # list of cavity objects to be returned

    super(Grid, ligsite_grid).__init__
    d_grid = ligsite_grid.d  # get grid spacing
    origin = ligsite_grid.origin  # get origin of the grid in cartesian coordinates
    grid = ligsite_grid.get_grid()  # get underlying grid array

    # array of indices grid points above threshold
    indices = np.argwhere(grid >= cutoff)

    # set of those indices
    ind_set = {tuple(row) for row in indices}

    # get relative indices around (but not including) the origin
    environment = np.array(
        [
            xyz
            for xyz in itertools.product(range(-1 - gap, 2 + gap), repeat=3)
            if xyz != (0, 0, 0)
        ]
    ).reshape(1, -1, 3)

    while ind_set:
        points = []  # indices of a new cavity

        seeds = [
            ind_set.pop(),
        ]
        while seeds:
            # add the point(s) to the growing cavity
            points.extend(seeds)

            # create a set of neighbours
            neighbour_coords = (
                np.array(seeds).reshape(-1, 1, 3) + environment
            ).reshape(-1, 3)
            neighbours = {tuple(row) for row in neighbour_coords}

            # set of neighbours that qualify as cavity points
            in_cavity = ind_set.intersection(neighbours)

            # remove them from the set of indices
            ind_set.difference_update(in_cavity)

            # these points are the 'seeds' for the next iteration
            seeds = list(in_cavity)

        # now 'points' contains a complete cavity
        if min_size <= len(points) <= max_size:
            cavities.append(_construct_cavity(points, grid, d_grid, origin, radius, vol_res))

    # sort cavities according to their size in descending order and generate PDBStructure objects
    cavities.sort(key=lambda x: x.size, reverse=True)
    _gen_cav_pdb(cavities)

    return cavities


def find_cavities_dist(ligsite_grid, cutoff, max_dist, min_size, max_size, radius=1.4, vol_res=3.):
    """
    Find cavities in a LigSite grid, using a distance cutoff
    """
    cavities = []  # list of cavity objects to be returned

    d_grid = ligsite_grid.d  # get grid spacing
    origin = ligsite_grid.origin  # get origin of the grid in cartesian coordinates
    grid = ligsite_grid.get_grid()  # get underlying grid array

    # array of indices with LigSite scores above cutoff
    indices = np.argwhere(grid >= cutoff)
    n_ind = indices.shape[0]
    # flag for cavity search, True ... grid point not yet analyzed
    flag = np.ones(n_ind, dtype=bool)

    # pairwise squared distances in grid coordinates
    norm = (indices * indices).sum(axis=1).reshape(1, -1)
    dist = (norm + norm.T) - 2 * np.dot(indices, indices.T)

    # squared distance cutoff in grid coordinates
    sq_max = max_dist**2 / d_grid**2

    for i in range(n_ind):  # step through the array by lines
        if flag[i]:  # if we have not yet dealt with this grid point
            seeds = [
                i,
            ]
            flag[i] = False
            points = []

            while seeds:
                tmp_list = []
                for j in seeds:
                    ind_j = tuple(indices[j])
                    points.append(ind_j)
                    neigh = np.argwhere((dist[j] <= sq_max) & flag)[:, 0]
                    flag[neigh] = False
                    tmp_list.extend(neigh)
                seeds = tmp_list

            # now 'points' contains a complete cavity
            if min_size <= len(points) <= max_size:
                cavities.append(_construct_cavity(points, grid, d_grid, origin, radius, vol_res))

    # sort cavities according to their size in descending order and generate PDBStructure objects
    cavities.sort(key=lambda x: x.size, reverse=True)
    _gen_cav_pdb(cavities)

    return cavities


def _construct_cavity(points, grid, d_grid, origin, radius=1.4, vol_res=3.):
    """
    construct a cavity from the assembled grid points
    """
    points = np.array(points)  # grid indices
    coords = points.astype(FP_DTYPE) * d_grid + origin  # 3D coordinates
    # create new cavity object
    cav = Cavity(points, coords, d_grid)
    # estimate cavity volume
    cav.volume = estimate_volume(coords, radius, d_grid / vol_res)
    # add LigSite scores to annotations dictionary
    cav.annotations["LIG"] = grid[points[:, 0], points[:, 1], points[:, 2]]

    return cav


def _gen_cav_pdb(cavities):
    """
    generate PDBStructure objects for the cavities
    """
    n_at = 1
    for n_cav, cav in enumerate(cavities):
        cav.pdb = PDBStructure()
        if n_cav >= 9999:  # just in case that there are more than 9999 cavities
            n_cav -= 9999

        for n_point, xyz in enumerate(cav.coords):
            # generate a dummy atom
            atom = PDBAtom(
                f"HETATM{n_at:5d}  XP  CAV X   1       3.500  53.900  22.400  1.00  7.00"
            )
            # set the actual parameters
            atom.x, atom.y, atom.z = xyz
            atom.residue_number = n_cav + 1
            # store the LigSite score in the B-factor column
            atom.bfactor = cav.annotations["LIG"][n_point]
            # set the element to "X" (necessary for reading with Yasara)
            atom.element = "X"
            atom.keep_old_bfactor = atom.bfactor
            atom.add_prop_dic("LIG", atom.bfactor)
            cav.pdb.atom.append(atom)
            n_at += 1
            if n_at > 99999:
                n_at -= 99999


class Cavity:
    """
    definition of a cavity object
    """

    def __init__(self, points, coords, d_grid):
        self.size = len(points)
        self.volume = 0.
        self.d_grid = d_grid
        self.points = points  # cavity points in grid coordinates
        self.coords = coords  # cartesian cavity coordinates

        self.annotations = dict()  # cavity point annotations
        self.pdb = None  # cavity as a PDBStructure object

        self.shaped = False  # indicate, whether the cavity was shaped

    def __str__(self):
        return self.get_pdbstr()

    def get_pdbstr(self, annotation=None):
        header = self._get_pdb_header()
        if annotation is None:  # use the standard B-factor column
            return header + self.pdb.get_pdbstr()
        else:  # replace the B-factor column with annotation value
            return header + self.pdb.get_pdbstr("prop_dic", annotation)

    def _get_pdb_header(self):
        header = [
            "REMARK",
            f"REMARK number of grid_points:{self.size:7d}",
            f"REMARK approximate cavity volume:{self.volume:7.0f} Angs.**3",
        ]
        return "\n".join(header) + "\n"

    def write(self, file_obj, file_format="pdb", annotation=None):
        if file_format == "pdb":
            file_obj.write(self.get_pdbstr(annotation))
        elif file_format == "csv":
            scores = [
                self.annotations[key] for key in annotation if key in self.annotations
            ]
            for i, (x, y, z) in enumerate(self.coords):
                data = [x, y, z]
                data.extend([score[i] for score in scores])
                line = ", ".join([f"{item:f}" for item in data])
                file_obj.write(f"{line}\n")

    def annotate(self, atom_coords, chg, hp, c1=1.0, c2=4.0):
        """
        Annotate cavity points with Coulomb and hydrophobic potentials
        """
        cp, lp = calculate_cp_lp(self.coords, atom_coords, chg, hp, c1, c2)
        self.annotations["CP"] = cp
        self.annotations["HP"] = lp

        # write value to PDBatom properties 'CP' and 'HP'
        for i, (cp_value, lp_value) in enumerate(zip(cp, lp)):
            self.pdb.atom[i].set_prop("HP", lp_value)
            self.pdb.atom[i].add_prop_dic("HP", lp_value)
            self.pdb.atom[i].add_prop_dic("CP", cp_value)

    def shape(self, shaping_obj, d_max):
        """
        Shape a cavity based on points in 'shaping_obj' and a maximum distance of 'd_max'.
        The function modifies the original cavity!

        :param d_max: maximum distance of a cavity point from any point in 'shaping_obj' to be retained
        :param shaping_obj: Nx3 numpy array, coordinates of 'shaping_obj'
        :return: number of retained cavity points
        """
        flags = crop_pointcloud(self.coords, shaping_obj.astype(FP_DTYPE), d_max)

        num_retained = np.sum(flags)  # number of retained points
        self.size = num_retained

        self.points = self.points[flags]
        self.coords = self.coords[flags]

        for key, value in self.annotations.items():
            self.annotations[key] = self.annotations[key][flags]

        atom = self.pdb.atom
        self.pdb.atom = [atom[i] for i, flag in enumerate(flags) if flag]

        self.shaped = True

        return num_retained


def calc_dist(cloud_1, cloud_2, do_sqrt=False):
    """
    calculate the distance matrix for two point clouds
    :param cloud_1: N x k numpy array
    :param cloud_2: M x k numpy array
           (k ... dimension of the clouds)
    :param do_sqrt: if True return the square root
    :return: N x M numpy array with euclidean distances
    """
    norm_1 = (cloud_1 * cloud_1).sum(axis=1).reshape(-1, 1)
    norm_2 = (cloud_2 * cloud_2).sum(axis=1).reshape(1, -1)
    dist = (norm_1 + norm_2) - 2.0 * np.dot(cloud_1, cloud_2.T)

    if do_sqrt:
        return np.sqrt(dist, out=dist)
    else:
        return dist


def crop_pointcloud(cloud_1, cloud_2, d_max=2.0):
    """
    crop one point cloud based on the points of another point cloud,
    all points in 'big', which are closer than 'd_max' to any point in 'small',
    are retained

    :param cloud_1: N x k numpy array
    :param cloud_2: M x k numpy array
                  (k ... dimension of the clouds)
    :param d_max: maximum distance for a point in 'cloud_1' to be retained
    :return: boolean array with shape (N,)
             True ... point is retained
             False ... point is deleted
    """
    diff = calc_dist(cloud_1, cloud_2)  # calculate all pairwise distances
    min_dist = diff.min(axis=1)  # get the minimum distance

    return min_dist <= d_max**2


def calculate_cp_lp(points, atom_coords, chg, hp, c1=1.0, c2=4.0):
    """
    Calculate the Coulomb and hydrophobic potential at points using the coordinates
    in 'atom_coords' and the (partial) charges and atomic hydrophobicity values in 'chg' and 'hp'.
    Uses Fermi-weighting for calculating the hydrophobic potential and a relative epsilon of 'r'.

    :param points: points at the which the potentials will be calculated (M, 3)
    :param atom_coords: protein coordinates (N, 3)
    :param chg: (partial) charges assigned to the protein atoms (N,)
    :param hp: hydrophobicity parameters assigned to the protein atoms (N,)
    :param c1: parameter for Fermi-weighting
    :param c2: parameter for Fermi-weighting
    :return: cp, lp (M,)
    """
    # calculate squared distances
    dist_sq = calc_dist(points, atom_coords, do_sqrt=False)

    # Coulomb potential, relative permittivity = distance
    # cp = 557.0 * sum(chg / dist**2)
    cp = 557.0 * np.sum(np.divide(chg[np.newaxis, :], dist_sq), axis=1)

    # calculate distances, dist_sq is overwritten!
    dist = np.sqrt(dist_sq, out=dist_sq)

    # fermi_wt = (math.exp(-c1 * c2) + 1.0) / (np.exp(c1 * (dist - c2)) + 1.0)
    # calculation of the weighting term is spilt-up and uses
    # the 'out' parameter to prevent creation of a new array at each step
    fermi_wt = dist  # COPY if 'dist' is used later
    np.subtract(fermi_wt, c2, out=fermi_wt)  # dist - c2
    np.multiply(fermi_wt, c1, out=fermi_wt)  # c1 * (dist - c2)
    if FP_DTYPE == np.float32:
        # clipping is necessary to prevent overflow in np.exp
        np.clip(fermi_wt, -103.9, 88.7, out=fermi_wt)
    np.exp(fermi_wt, out=fermi_wt)  # np.exp(c1 * (dist - c2))
    np.add(fermi_wt, 1.0, out=fermi_wt)  # np.exp(c1 * (dist - c2)) + 1.0
    # (math.exp(-c1 * c2) + 1.0) / (np.exp(c1 * (dist - c2)) + 1.0)
    np.divide(math.exp(-c1 * c2) + 1.0, fermi_wt, out=fermi_wt)

    # sum of weights, clipped to 1.e-6
    # MUST be calculated first, because the next step overwrites 'fermi_wt'
    sum_wt = np.clip(np.sum(fermi_wt, axis=1), 1.0e-6, np.inf)
    # weighted sum of atom contributions
    lp = np.sum(np.multiply(fermi_wt, hp[np.newaxis, :], out=fermi_wt), axis=1)

    np.divide(lp, sum_wt, out=lp)

    return cp, lp


def estimate_volume(points, radius: float, d_grid: float) -> float:
    """
    estimate the volume of a point cloud
    point ... 3D coordinates of the cloud
    radius ... radius of the sphere placed at each point
    d_grid ... grid spacing

    The object is placed on a regular grid; voxels within the spheres placed at the points are
    counted; the estimated volume is n_voxels * d_grid**3
    """
    radii = np.ones(points.shape[0], dtype=FP_DTYPE) * radius
    grid = setup_grid(points, radius + d_grid, d_grid, 0, np.int8)
    mask_grid(grid, points, radii, 1.0, 0.0, 0.0, 1, 0)
    return np.sum(grid.get_grid()) * d_grid ** 3


class Grid:
    """
    Class for a grid object.

    Extension (# of grid points) given by a tuple of
    length 3. Default initialization with value "0".

    USAGE: Grid(origin, extend, init=0)
              extend: tuple:(x,y,z)  grid extension number of grid points
                              in (x , y , z)
              init:        :initialization value of grid point (standard=0)
    """

    def __init__(
        self, origin=(0.0, 0.0, 0.0), extent=(50, 50, 50), d=0.7, init=0, dtype=np.int32
    ):
        self.nx, self.ny, self.nz = extent
        self.extent = extent
        self.origin = np.array(origin, dtype=FP_DTYPE)
        self.d = d
        self._grid = np.zeros(extent, dtype=dtype)
        if init != 0:
            self._grid = init

    def get_subgrid(self, x0, y0, z0, x1, y1, z1):
        """
        Return a view of a sub-grid of the internal np.array
        :param x0: start x-index
        :param y0: start y-index
        :param z0: start z-index
        :param x1: end x-index
        :param y1: end y-index
        :param z1: end z-index
        :return: reference to the sub-grid
        """
        return self._grid[x0:x1, y0:y1, z0:z1]

    def get_grid(self):
        """
        Return a view of the complete grid object
        :return: reference to the internal np.array
        """
        return self._grid

    def get_value(self, indices):
        """
        Return value of grid point
        :param indices: tuple of indices (i, j, k); cannot be a np.array
        :return: value at grid point
        """
        return self._grid[indices]

    def set_value(self, indices, value):
        """
        Set value of a grid point
        :param indices: tuple of indices (i, j, k); cannot be a np.array
        :param value: new value
        """
        self._grid[indices] = value

    def is_valid_index(self, index):
        """
        Check whether 'index' contains valid indices of the grid.
        :param index: tuple of indices (i, j, k); cannot be an np.array
        :return: True/False
        """
        x, y, z = index
        return 0 <= x < self.nx and 0 <= y < self.ny and 0 <= z < self.nz

    def coordinates(self, index):
        """
        Return the cartesian coordinates of a grid point
        :param index: indices (i, j, k), tuple, list, np.array
        :return: np.array with cartesian coordinates
        """
        return self.origin + np.array(index, dtype=FP_DTYPE) * self.d

    def offset(self, dx, dy, dz):
        """
        Return the offset in bytes
        :param dx: increment in x
        :param dy: increment in y
        :param dz: increment in z
        :return: offset, to be used for as_strided
        """
        return ((dx * self.ny + dy) * self.nz + dz) * self._grid.itemsize


In [5]:
# -*- coding: utf-8 -*-
"""
class definitions for CavFind
"""
import os
import re
import time
from datetime import datetime
from logging import getLogger

import numpy as np



logger = getLogger(__name__)

default_settings = {
    # PDB interpretation
    "keep_hydrogens": False,
    "keep_hetatms": False,
    "keep_waters": False,
    "resn_water": ("HOH", "WAT", "H2O"),
    "alternate": "A",  # set to 'all' to include all alternate conformations
    # Grid setup
    "grid_spacing": 0.7,
    "probe_radius": 1.4,
    "softness": 0.5,
    "cushion": 0.0,
    "radii": "UA",
    "radius_factor": 1.0,
    # LigSite
    "ligsite_cutoff": 5,
    "gap": 0,
    "original_ligsite": False,
    # should be approx. sqrt(3)*grid_spacing for comparable results
    "max_dist": 1.22,
    # resolution for volume estimation, grid_spacing divided by this number
    "vol_resol": 3.0,
    # 'old_cavfind': False,
    "min_size": 4,
    "max_size": 99999,
    "annotate": True,
    # 'min_volume': 4 * 0.7 ** 3,
    # 'max_volume': 99999 * 0.7 ** 3,
    "split_files": False,
    # Annotation and Cropping
    "shape_dmax": 3.0,
    "shape_object": "sele",
    "cp_limit": 25.0,
}

tool_tips = {
    # PDB interpretation
    "keep_hydrogens": "Keep H-atoms for LigSite calculation (default: False).",
    "keep_hetatms": "Keep hetero-components for LigSite calculation (default: False).",
    "keep_waters": "Keep water molecules for LigSite calculation (default: False).",
    "resn_water": "Residue names for 'water'.",
    "alternate": "alternate conformation to be kept (default: 'A'), "
    + "use 'all' to keep all alternates.",
    # Grid setup
    "grid_spacing": "Grid spacing in Angs. (default: 0.7).",
    "probe_radius": "Probe size in Angs. (default: 1.4).",
    "softness": "Soft shell in grid units (default: 0.5).",
    "cushion": "Cushion around the coordinates (default: 0.).",
    "radii": "Atomic radii to use for: 'UA' (united-atoms, default), 'AA' (all-atom).",
    "radius_factor": "Multiplicative factor for atom radii (default: 1.0).",
    # LigSite
    "ligsite_cutoff": "Cutoff value for the LigSite algorithm (default: 5).",
    "gap": "Maximum gap for cavity coalescence (default: 0).",
    "original_ligsite": "Find cavities using a distance cutoff.",
    "max_dist": "Maximum distance used in cavity detection (default: 1.22)",
    "vol_resol": "Resolution for volume estimation, higher is more accurate (default: 3.0).",
    # 'old_cavfind': "Use the old GrowFromSeed algorithm (default: False).",
    "min_size": "Minimum size of a cavity in grid points (default: 4).",
    "max_size": "Maximum size of a cavity in grid points (default: 99999).",
    "annotate": "Annotate cavities with Coulomb potential and hydrobicity. Disable for large structures.",
    # 'min_volume': "Minimum size of a cavity in Angs.**3.",
    # 'max_volume': "Maximum size of a cavity in Angs.**3.",
    "split_files": "Generate separate files for each cavity (default: False).",
    # Annotation and Cropping
    "shape_dmax": "Maximum distance of a grid point to the shaping object (default: 3.0).",
    "shape_object": "PyMOL object/selection used for cavity shaping (default: sele).",
    "cp_limit": "Limit for displaying the Coulomb potential",
}

param_types = {
    # types of setting parameters:
    # 0 ... boolean
    # 1 ... integer
    # 2 ... float
    # 3 ... str
    # 4 ... tuple
    # PDB interpretation
    "keep_hydrogens": 0,
    "keep_hetatms": 0,
    "keep_waters": 0,
    "resn_water": 4,
    "alternate": 3,
    # Grid setup
    "grid_spacing": 2,
    "probe_radius": 2,
    "softness": 2,
    "cushion": 2,
    "radii": 3,
    "radius_factor": 2,
    # LigSite
    "ligsite_cutoff": 1,
    "gap": 1,
    "original_ligsite": 0,
    "max_dist": 2,
    "vol_resol": 2,
    # 'old_cavfind': 0,
    "min_size": 1,
    "max_size": 1,
    "annotate": 0,
    # 'min_volume': 2,
    # 'max_volume': 2,
    "split_files": 0,
    "shape_dmax": 2,
    "shape_object": 3,
    "cp_limit": 2,
}

# global parameters
GRID_DTYPE = np.int8  # numerical type of the grid object
PROTEIN_FLAG = -100  # flag for grid points occupied by the protein
SOFT_FLAG = -99  # flag for grid points in the 'soft shell' around protein atoms


class CavFind:
    """
    CavFind object
    """

    def __init__(self, name, obj_name, settings=None):
        self.name = name
        self.obj_name = obj_name
        self.time_of_creation = datetime.now()

        if settings is not None:
            self.settings = settings.copy()
        else:
            self.settings = default_settings.copy()

        # structure for which cavities should be calculated (PDBStructure)
        self.pdb = None

        self.coords = None  # atom coordinates as numpy array
        self.radii = None  # atom radii as numpy array
        self.hp = None  # hydrophobicity parameters as numpy array
        self.charges = None  # partial charges

        self.grid = None  # grid object for LigSite algorithm
        self.cavities = None  # cavities detected in structure

    def struct_from_pdb(self, file_obj):
        """
        import a structure, assuming PDB-formatted data
        clean the structure according to setting parameters:
        'keep_hydrogens', 'keep_hetatms', 'keep_waters', 'alternates'

        :param file_obj: PDB data, file, list/tuple of lines, PyMOL pdbstr
        """
        self.pdb = PDBStructure(file_obj)

        keep_hydrogens = self.settings["keep_hydrogens"]
        keep_hetatms = self.settings["keep_hetatms"]
        keep_waters = self.settings["keep_waters"]
        alternate = self.settings["alternate"]
        resn_water = self.settings["resn_water"]

        if (
            not keep_hydrogens
            or not keep_hetatms
            or not keep_waters
            or alternate != "all"
        ):  # remove at least some atoms
            keep_atoms = []
            for atom in self.pdb.atom:

                alt_flag = atom.alternate in (alternate, " ", "") or alternate == "all"
                if not alt_flag:  # only keep atoms with the correct alternate flag
                    continue

                is_hydrogen = re.match("^[0-9H]", atom.name.strip()) is not None
                is_water = atom.residue in resn_water

                if not is_hydrogen and not atom.het:
                    # is not a hydrogen, not a HETATM and has the 'correct' alternate flag
                    keep_atoms.append(atom)
                    continue

                if is_water and not is_hydrogen and keep_waters:
                    # is a water, not a hydrogen and waters should be kept
                    keep_atoms.append(atom)
                    continue

                if atom.het and not is_hydrogen and not is_water and keep_hetatms:
                    # is a HETATM, not a hydrogen and HETATMs should be kept
                    keep_atoms.append(atom)
                    continue

                if is_hydrogen and keep_hydrogens:
                    # is a hydrogen and hydrogen should be kept
                    keep_atoms.append(atom)

            self.pdb.atom = keep_atoms

        rad_column = {"UA": 0, "AA": 1}.get(self.settings["radii"], 0)
        self.pdb.assign_rad_hp_chg(rad_column=rad_column)

        # Generate numpy arrays for coordinates, radii and hydrophobicity parameters
        coords = []
        radii = []
        hp = []
        charges = []
        for atom in self.pdb.atom:
            coords.append((atom.x, atom.y, atom.z))
            radii.append(atom.radius)
            hp.append(atom.HP)
            charges.append(atom.charge)
        self.coords = np.array(coords, dtype=FP_DTYPE)
        self.radii = np.array(radii, dtype=FP_DTYPE)
        self.hp = np.array(hp, dtype=FP_DTYPE)
        self.charges = np.array(charges, dtype=FP_DTYPE)
        logger.info(f"Total charge: {np.sum(self.charges):2f}")

    def run(self, rerun=False, progress_callback=None):
        """
        perform a LigSite calculation and find cavities
        if re_run is True, rerun cavity detection on existing grid,
        e.g. with different 'ligsite_cutoff', 'gap_parameter' or 'min/max_size'
        old cavities are deleted!
        """
        if not rerun:
            # perform a complete Ligsite calculation
            # setup and mask the grid
            self._emit_message(f"{self.name}:\nsetup grid", progress_callback)

            # setup and mask the grid
            start = time.perf_counter()
            self.grid = setup_grid(
                self.coords,
                self.settings["cushion"],
                self.settings["grid_spacing"],
                0,
                GRID_DTYPE,
            )
            mask_grid(
                self.grid,
                self.coords,
                self.radii,
                self.settings["radius_factor"],
                self.settings["probe_radius"],
                self.settings["softness"],
                PROTEIN_FLAG,
                SOFT_FLAG,
            )
            stop = time.perf_counter()
            logger.info(f"{self.name}: {stop - start:.2f} sec.")

            # run the LigSite algorithm and find cavities
            self._emit_message(
                f"{self.name}:\nanalyse grid\n({np.prod(self.grid.extent)} points)",
                progress_callback,
            )

            start = time.perf_counter()
            do_ligsite(self.grid, PROTEIN_FLAG)
            stop = time.perf_counter()
            logger.info(f"{self.name}: {stop - start:.2f} sec.")

        # detect cavities in the grid (entry point in case of re-run)
        super(CavFind, self).__init__
        self._emit_message(
            f"{self.name}:\ndetect cavities\n(cutoff: {self.settings['ligsite_cutoff']})",
            progress_callback,
        )

        start = time.perf_counter()
        if self.settings["original_ligsite"]:
            self.cavities = find_cavities_dist(
                self.grid,
                self.settings["ligsite_cutoff"],
                self.settings["max_dist"],
                self.settings["min_size"],
                self.settings["max_size"],
                self.settings["probe_radius"],
                self.settings["vol_resol"],
            )
        else:
            self.cavities = find_cavities(
                self.grid,
                self.settings["ligsite_cutoff"],
                self.settings["gap"],
                self.settings["min_size"],
                self.settings["max_size"],
                self.settings["probe_radius"],
                self.settings["vol_resol"],
            )
        stop = time.perf_counter()
        logger.info(f"{self.name}: {stop - start:.2f} sec.")

        if self.settings["annotate"]:
            # annotate cavities
            self._emit_message(
                f"{self.name}:\nannotate\n({len(self.cavities)} cavities)",
                progress_callback,
            )

            start = time.perf_counter()
            for cav in self.cavities:
                cav.annotate(self.coords, self.charges, self.hp)
            stop = time.perf_counter()
            logger.info(f"{self.name}: {stop - start:.2f} sec.")
        else:
            logger.info(f"{self.name}: {len(self.cavities)} cavities detected.")

        return self  # necessary for asynchronous execution

    @staticmethod
    def _emit_message(message, progress_callback):
        if progress_callback:
            progress_callback.emit(message)
        logger.info(message.replace("\n", " "))

    def _get_pdb_header(self):
        """
        Generate REMARK lines for PDB-output,
        containing the settings used in the calculation
        :return: string with REMARK lines
        """
        header = [
            "REMARK",
            f"REMARK     alternate flag: {self.settings['alternate']}",
        ]

        if not self.settings["keep_hydrogens"]:
            header.append("REMARK                     *** hydrogen atoms excluded")
        if not self.settings["keep_waters"]:
            header.append("REMARK                     *** water molecules excluded")
        if not self.settings["keep_hetatms"]:
            header.append("REMARK                     *** hetero atoms excluded")

        header.extend(
            [
                "REMARK",
                f"REMARK       grid spacing:{self.settings['grid_spacing']:8.3f}",
                f"REMARK            cushion:{self.settings['cushion']:7.2f}",
                f"REMARK       probe radius:{self.settings['probe_radius']:7.2f}",
                f"REMARK           softness:{self.settings['softness']:7.2f}",
                f"REMARK      radius factor:{self.settings['radius_factor']:7.2f}",
                "REMARK",
                f"REMARK      cavity cutoff:{self.settings['ligsite_cutoff']:7.2f}",
                f"REMARK      gap parameter:{self.settings['gap']:7.2f}",
                f"REMARK   min. cavity size:{self.settings['min_size']:7d} points",
                f"REMARK   max. cavity size:{self.settings['max_size']:7d} points",
                "REMARK",
                "",
            ]
        )

        return "\n".join(header)

    def write_cavities(
        self,
        filename="cav.pdb",
        file_format="pdb",
        indices=None,
        annotation=None,
        header="",
    ):
        """

        :param indices:
        :param filename:
        :param file_format:
        :param annotation:
        :param header:
        :return:
        """
        # build list of annotations
        annotation_list = []
        if isinstance(annotation, list) or isinstance(annotation, tuple):
            annotation_list = list(annotation)
        elif isinstance(annotation, str):
            annotation_list = [
                item.strip() for item in annotation.split(",") if item.strip()
            ]

        if "LIG" not in annotation_list:  # always save the LigSite score
            annotation_list.append("LIG")

        # remove unavailable annotations
        annotation_list = [
            key for key in annotation_list if key in self.cavities[0].annotations
        ]

        basename, extension = os.path.splitext(filename)

        if file_format == "pdb":
            if extension != ".pdb":
                basename = filename
                extension = ".pdb"
            pdb_header = self._get_pdb_header()

            for annotation in annotation_list:
                annotation_header = f"REMARK Annotation: {annotation}\n"
                # write one file with all selected cavities included
                if annotation == "LIG":  # naming compatibility with CavFind
                    fname = f"{basename}{extension}"
                else:
                    fname = f"{basename}_{annotation}{extension}"

                with open(fname, "w") as file:
                    file.write(header)
                    file.write(pdb_header)
                    file.write(annotation_header)
                    for i, cav in enumerate(self.cavities):
                        # either write all cavities or those in 'indices'
                        if indices is None or i in indices:
                            cav.write(file, file_format, annotation)
                    file.write("END\n")
                    logger.info("Cavities written to: {:s}".format(fname))

                if self.settings["split_files"]:
                    # write a separate file for each selected cavity
                    for i, cav in enumerate(self.cavities):
                        # either write all cavities or those in 'indices'
                        if indices is None or i in indices:
                            if annotation == "LIG":  # naming compatibility with CavFind
                                fname = f"{basename}_{i + 1:04d}{extension}"
                            else:
                                fname = (
                                    f"{basename}_{annotation}_{i + 1:04d}{extension}"
                                )

                            with open(fname, "w") as file:
                                file.write(header)
                                file.write(pdb_header)
                                file.write(annotation_header)
                                cav.write(file, file_format, annotation)
                                file.write("END\n")
                                logger.info(f"Cavity #{i + 1:d} written to: {fname}")

        # separate file for each cavity with all selected annotations included
        elif file_format == "csv":
            if extension != ".csv":
                basename = filename
                extension = ".csv"

            content = "# columns: x, y, z, " + ", ".join(annotation_list) + "\n"

            for i, cav in enumerate(self.cavities):
                # either write all cavities or those in 'indices'
                if indices is None or i in indices:
                    fname = f"{basename}_{i + 1:04d}{extension}"
                    with open(fname, "w") as file:
                        file.write(header)
                        file.write(f"# number of grid points:{cav.size:6d}\n")
                        file.write(content)
                        cav.write(file, "csv", annotation_list)
                        logger.info(f"Cavity #{i + 1} written to: {fname}")


In [11]:
a = CavFind('name', 'obj_name')
with open('path/to/file.pdb', 'r') as file:
    pdb_string = file.read()
a.struct_from_pdb(pdb_string)
a.run()
a.write_cavities()
