# 3D comparison

## Setup

In [16]:
#! pip install py3Dmol

In [17]:
import io

import numpy as np
import pandas as pd
import py3Dmol
from Bio import PDB

import bin.params as p

In [18]:
CLUSTERING_CHAINS = p.CLUSTERING_CHAINS
CHAINS = p.CHAINS

In [19]:
EXPERIMENT_SETTINGS = f'{CLUSTERING_CHAINS}_{CHAINS}'
EXP_STRING = f'(scheme={p.FINAL_NUMBERING_SCHEME}, clustering={CLUSTERING_CHAINS}, chains={CHAINS}, dataset=TEST)'
LOW = 10
HIGH = 50
ANGLE = -40
ZOOM = 1
SEQUENCE_ID = '7N8I:H'
#SEQUENCE_ID = '5DT1:H'
MODEL_NAME = 'lco_cont_window_r4_all_H_randomForestN30'

**Load prediction data:**

In [20]:
data = pd.read_csv(f'{p.DATA_DIR}/csv/test/results_{EXPERIMENT_SETTINGS}.csv', index_col=0)
data['buried'] = data['sasa'] <= LOW
data['exposed'] = data['sasa'] >= HIGH
data['overpred'] = (data['buried'] == True) & (data['prediction'] >= HIGH)
data['underpred'] = (data['exposed'] == True) & (data['prediction'] <= LOW)
data.head(n=6)

Unnamed: 0,scenario,sequence_id,position,sasa,prediction,error,abs_error,buried,exposed,overpred,underpred
0,lco_whole_sequence_all_H_BLknnwholeseqn3,6LCS:H,1,,31.7,,,False,False,False,False
1,lco_whole_sequence_all_H_BLknnwholeseqn3,6LDV:H,1,,,,,False,False,False,False
2,lco_whole_sequence_all_H_BLknnwholeseqn3,6LDW:H,1,,,,,False,False,False,False
3,lco_whole_sequence_all_H_BLknnwholeseqn3,6LDX:H,1,100.0,,,,False,True,False,False
4,lco_whole_sequence_all_H_BLknnwholeseqn3,6LDY:H,1,,,,,False,False,False,False
5,lco_whole_sequence_all_H_BLknnwholeseqn3,6LRA:H,1,80.2,65.533333,-14.666667,14.666667,False,True,False,False


**Load FreeSASA data:**

In [21]:
s = pd.read_csv(f'{p.DATA_DIR}/csv/sasa_aligned/sasa_{SEQUENCE_ID[-1]}.csv', index_col=0)
s.head(n=2)

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,141,142,143,143A,144,145,146,147,148,149
12E8:H,100.0,36.0,50.1,4.4,51.9,3.5,28.8,,68.3,71.2,...,28.5,13.3,1.4,,30.0,2.2,19.3,4.3,14.3,75.0
15C8:H,100.0,23.3,51.7,4.5,54.8,5.1,27.2,,65.8,80.3,...,51.2,12.0,2.5,,46.4,4.6,26.7,7.4,18.3,61.4


In [22]:
non_nans = (s.count(axis=1) - s.isnull().sum(axis=1)).sort_values(ascending=False)
non_nans.head(n=1)

5DT1:H    120
dtype: int64

**Load PDB for the structure:**

In [23]:
with open(f'{p.DATA_DIR}/pdb/incremental/{SEQUENCE_ID[:4].lower()}.pdb') as f: pdb_raw = f.read()

**Utility functions:**

In [24]:
def show_pdb(pdb, 
             show_sidechains = True, 
             color_map : dict = dict()):
    
    view = py3Dmol.view(width=800, height=600, js='https://cdnjs.cloudflare.com/ajax/libs/3Dmol/1.8.0/3Dmol.js')
    view.addModelsAsFrames(pdb)
    style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}
    if show_sidechains:
        style['stick'] = {'colorscheme': {'prop': 'b', 'map': color_map}}
    view.addSurface(py3Dmol.SAS, {'opacity': 0.4, 'color': 'grey'})
    view.setStyle({'model': -1}, style)
    return view.zoomTo()

In [25]:
atom_types = [
    'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
    'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
    'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
    'CZ3', 'NZ', 'OXT'
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types)  # := 37.

In [26]:
def overwrite_b_factors(pdb_str: str, chain_id: str, residue_data, vmin, vmax) -> str:
    """Overwrites the B-factors in pdb_str with contents of bfactors array.
    Args:
    pdb_str: An input PDB string.
    residue_data: A numpy array with shape [n_residues]
    vmin: Minimum value in residue_data used for normalization
    vmax: Maximum value in residue_data used for normalization
    Returns:
    A new PDB string with the B-factors replaced.
    """
    bands = np.ceil(((np.clip(residue_data, vmin, vmax) - vmin) / vmax * 100) // 20)
    bfactors = np.repeat(bands[:, np.newaxis], atom_type_num, axis=1)
    
    if bfactors.shape[-1] != atom_type_num:
        raise ValueError(f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.')

    parser = PDB.PDBParser(QUIET=True)
    handle = io.StringIO(pdb_str)
    structure = parser.get_structure('', handle)

    curr_resid = ('', '', '')
    idx = -1
    for atom in structure.get_atoms():
        #print(str(atom))
        if atom.get_full_id()[2] != chain_id:
            continue
        atom_resid = atom.parent.get_id()
        if atom_resid != curr_resid:
            idx += 1
            if idx >= bfactors.shape[0]:
                raise ValueError(f'Index into bfactors exceeds number of residues. B-factors shape: {bfactors.shape}, idx: {idx}.')
        curr_resid = atom_resid
        atom.bfactor = bfactors[idx, atom_order['CA']]

    new_pdb = io.StringIO()
    pdb_io = PDB.PDBIO()
    pdb_io.set_structure(structure)
    pdb_io.save(new_pdb)
    return new_pdb.getvalue()

---

## A: CDR viz

In [44]:
def cdrnumber(i):
    # @param i: aho position
    # return number of CDR
    # 0 for non-cdr position, 1 for CDR1,
    # 2 for CDR2, 3 for CDR3
    ii = int(i[:-1]) if not i[-1].isnumeric() else int(i)
    # we return multiples of 20 since overwrite_b_factors_function divides value by 20
    if 27 <= ii <= 40:
        return 40
    elif 58 <= ii <= 68:
        return 60
    elif 107 <= ii <= 138:
        return 80
    return 0

color_map = {
    0: 'white', 1: 'white',
    2: 'red',
    3: 'blue',
    4: 'green'
}

cdr = s.loc[SEQUENCE_ID].transpose().dropna().to_frame()
cdr['index'] = cdr.index
cdr['cdr'] = cdr['index'].apply(cdrnumber)
#sasa_default.index = range(1, len(sasa_default)+1)
cdr.drop(columns=['index', SEQUENCE_ID], inplace=True)
cdr.columns = [SEQUENCE_ID]
cdr = cdr[SEQUENCE_ID]
print(list(cdr))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 40, 40, 40, 40, 40, 40, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 60, 60, 60, 60, 60, 60, 60, 60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [47]:
pdb_cdr = overwrite_b_factors(pdb_raw, SEQUENCE_ID[-1], cdr, vmin=0, vmax=100)
pdb_cdr = '\n'.join([line for line in pdb_cdr.split('\n') if len(line) < 22 or line[21] == 'H'])
view = show_pdb(pdb_cdr, color_map=color_map)
view.zoom(ZOOM); view.rotate(ANGLE)

  bfactors = np.repeat(bands[:, np.newaxis], atom_type_num, axis=1)


<py3Dmol.view at 0x2ad9076c5af0>

In [48]:
view.png()

---

## B: FreeSASA viz

In [49]:
color_map = { 0: '#0000ff', 1: '#6666ff', 2: '#ffffff', 3: '#ff6666', 4: '#ff0000', 5: '#aa0000' }
sasa_default = s.loc[SEQUENCE_ID].transpose().dropna(); #sasa_default.index = range(1, len(sasa_default) + 1); 
sasa_default

1      100.0
2       36.6
3       45.8
4        4.2
5       42.3
       ...  
144     38.9
145      0.6
146     29.4
147      6.0
148     72.8
Name: 7N8I:H, Length: 121, dtype: float64

In [50]:
pdb_sasa_de = overwrite_b_factors(pdb_raw, SEQUENCE_ID[-1], sasa_default, vmin=0, vmax=100)
pdb_sasa_de = '\n'.join([line for line in pdb_sasa_de.split('\n') if len(line) < 22 or line[21] == 'H'])
view = show_pdb(pdb_sasa_de, color_map=color_map)
view.zoom(ZOOM); view.rotate(ANGLE)

  bfactors = np.repeat(bands[:, np.newaxis], atom_type_num, axis=1)


<py3Dmol.view at 0x2ad90ca2dfa0>

In [33]:
view.png()

---

## C: Model SASA viz

In [34]:
color_map = { 0: '#0000ff', 1: '#6666ff', 2: '#ffffff', 3: '#ff6666', 4: '#ff0000', 5: '#aa0000' }
sasa_predicted = data[ (data['scenario'] == MODEL_NAME) & (data['sequence_id'] == SEQUENCE_ID) ][['position', 'prediction']]
sasa_predicted.index = sasa_predicted['position']
sasa_predicted = sasa_predicted.drop(columns='position').dropna()
#sasa_predicted.index = range(1, len(sasa_predicted) + 1)
sasa_predicted = sasa_predicted['prediction']
sasa_predicted

position
1      94.039122
2      27.535804
3      51.750840
4       5.149331
5      50.947158
         ...    
144    26.757445
145     0.908832
146    39.480453
147     7.094025
148    79.977219
Name: prediction, Length: 121, dtype: float64

In [35]:
pdb_sasa_pr = overwrite_b_factors(pdb_raw, SEQUENCE_ID[-1], sasa_predicted, vmin=0, vmax=100)
pdb_sasa_pr = '\n'.join([line for line in pdb_sasa_pr.split('\n') if len(line) < 22 or line[21] == 'H'])
view = show_pdb(pdb_sasa_pr, color_map=color_map)
view.zoom(ZOOM); view.rotate(ANGLE)

  bfactors = np.repeat(bands[:, np.newaxis], atom_type_num, axis=1)


<py3Dmol.view at 0x2ad908d74d30>

In [36]:
view.png()