# ProVis: Attention Visualizer for Proteins

In [1]:
import io
import urllib

import torch
from Bio.Data import SCOPData
from Bio.PDB import PDBParser, PPBuilder
from tape import TAPETokenizer, ProteinBertModel
import nglview

attn_color = [0.937, .522, 0.212]



In [2]:
def get_structure(pdb_id):
    resource = urllib.request.urlopen(f'https://files.rcsb.org/download/{pdb_id}.pdb')
    content = resource.read().decode('utf8')
    handle = io.StringIO(content)
    parser = PDBParser(QUIET=True)
    return parser.get_structure(pdb_id, handle)

In [3]:
def get_attn_data(chain, layer, head, min_attn, start_index=0, end_index=None, max_seq_len=1024):

    tokens = []
    coords = []
    for res in chain:
        t = SCOPData.protein_letters_3to1.get(res.get_resname(), "X")
        tokens += t
        if t == 'X':
            coord = None
        else:
            coord = res['CA'].coord.tolist()
        coords.append(coord)      
    last_non_x = None
    for i in reversed(range(len(tokens))):
        if tokens[i] != 'X':
            last_non_x = i
            break
    assert last_non_x is not None
    tokens = tokens[:last_non_x + 1]
    coords = coords[:last_non_x + 1]    
    
    tokenizer = TAPETokenizer()
    model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)

    if max_seq_len:
        tokens = tokens[:max_seq_len - 2]  # Account for SEP, CLS tokens (added in next step)
    token_idxs = tokenizer.encode(tokens).tolist()
    if max_seq_len:
        assert len(token_idxs) == min(len(tokens) + 2, max_seq_len)
    else:
        assert len(token_idxs) == len(tokens) + 2

    inputs = torch.tensor(token_idxs).unsqueeze(0)
    with torch.no_grad():
        attns = model(inputs)[-1]
        # Remove attention from <CLS> (first) and <SEP> (last) token
    attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
    attns = torch.stack([attn.squeeze(0) for attn in attns])
    attn = attns[layer, head]
    if end_index is None:
        end_index = len(tokens)
    attn_data = []
    for i in range(start_index, end_index):
        for j in range(i, end_index):
            # Currently non-directional: shows max of two attns
            a = max(attn[i, j].item(), attn[j, i].item())
            if a is not None and a >= min_attn:
                attn_data.append((a, coords[i], coords[j]))
    return attn_data

### Visualize head 7-1 (targets binding sites)

In [4]:
# Example for head 7-1 (targets binding sites)
pdb_id = '7HVP'
chain_ids = None # All chains
layer = 7
head = 1
min_attn = 0.1
attn_scale = .9

layer_zero_indexed = layer - 1
head_zero_indexed = head - 1

structure = get_structure(pdb_id)
view = nglview.show_biopython(structure)
view.stage.set_parameters(**{
    "backgroundColor": "black",
    "fogNear": 50, "fogFar": 100,
})

models = list(structure.get_models())
if len(models) > 1:
    print('Warning:', len(models), 'models. Using first one')
prot_model = models[0]

if chain_ids is None:
    chain_ids = [chain.id for chain in prot_model]
for chain_id in chain_ids: 
    print('Loading chain', chain_id)
    chain = prot_model[chain_id]    
    attn_data = get_attn_data(chain, layer_zero_indexed, head_zero_indexed, min_attn)
    for att, coords_from, coords_to in attn_data:
        view.shape.add_cylinder(coords_from, coords_to, attn_color, att * attn_scale) 
        
view

Loading chain A
Loading chain B
Loading chain C


NGLWidget()

### Visualize head 12-4 (targets contact maps)

In [5]:
# Example for head 12-4 (targets contact maps)
pdb_id = '2KC7'
chain_ids = None # All chains
layer = 12
head = 4
min_attn = 0.2
attn_scale = .5

layer_zero_indexed = layer - 1
head_zero_indexed = head - 1

structure = get_structure(pdb_id)
view2 = nglview.show_biopython(structure)
view2.stage.set_parameters(**{
    "backgroundColor": "black",
    "fogNear": 50, "fogFar": 100,
})

models = list(structure.get_models())
if len(models) > 1:
    print('Warning:', len(models), 'models. Using first one')
prot_model = models[0]

if chain_ids is None:
    chain_ids = [chain.id for chain in prot_model]
for chain_id in chain_ids: 
    print('Loading chain', chain_id)
    chain = prot_model[chain_id]    
    attn_data = get_attn_data(chain, layer_zero_indexed, head_zero_indexed, min_attn)
    for att, coords_from, coords_to in attn_data:
        view2.shape.add_cylinder(coords_from, coords_to, attn_color, att * attn_scale) 
        
view2

# To save: view2.download_image(filename="testing.png")

Loading chain A


NGLWidget()