<a href="https://colab.research.google.com/github/rvanasa/deep-antibody/blob/master/contacts_docked.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Google Drive setup

from IPython.display import clear_output
from google.colab import drive
drive.mount('/gdrive')
clear_output()

In [None]:
#@title Workspace setup

!pip install -q biopython pdb-tools

from IPython.display import clear_output, display
clear_output()

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import requests
import Bio
import Bio.PDB
from Bio.PDB import DSSP

contact_buffer = 4
contact_window_size = contact_buffer * 2 + 1

amino_acids = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLU', 'GLN', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', '???']
oneletters = 'ARNDCEQGHILKMFPSTWYV'
ssletters = 'HBEGITS'

parser = Bio.PDB.PDBParser(QUIET=True)

def parse(ident, cache_dir=None):
  cache_dir = cache_dir.rstrip('/') + '/' if cache_dir else ''
  filename = ident if ident.endswith('pdb') else f'{cache_dir}{ident}.pdb'
  if '.' not in ident and not os.path.exists(filename):
    !wget -N https://files.rcsb.org/download/{ident}.pdb
    !mv {ident}.pdb {cache_dir}
  return parser.get_structure(ident, filename)


def create_seq(rs):
  return ''.join(oneletters[amino_acids.index(r)] if r in amino_acids else 'X' for r in rs)


def cmd(command):
  if not isinstance(command, str):
    for c in command:
      cmd(c)
  elif os.system(command):
    raise Exception(f'Non-zero exit code in command: $ {command}')

In [None]:
#@title Contact point calculation

import numpy.linalg as lin

residue_contact_margin = 5.5 # Empirical

def compute_contacts(filename, bname, bmodel, aname, amodel, keys, contact_dist_threshold):
  hk, lk, ak = keys

  structure = parse(filename)
  print(bname, bmodel, hk, lk, '::', aname, amodel, ak)

  def trim_residues(residues):
    return [r for r in residues if r.resname in amino_acids]

  data = []
  for model in structure.get_models():

    chain_map = {chain.id: chain for chain in model.get_chains()}

    H = chain_map.get(hk)
    L = chain_map.get(lk)
    A = chain_map[ak]
    
    ca = A

    a_res = trim_residues(ca)
    a_coords = [np.array([a.coord for a in r]) for r in a_res]
    a_centers = np.array([a.mean(axis=0) for a in a_coords])

    ct = 0
    for btype, cb in (('H', H), ('L', L)):
      if cb is None:
        continue
      
      b_res = trim_residues(cb)
      b_coords = [np.array([a.coord for a in r]) for r in b_res]
      b_centers = np.array([a.mean(axis=0) for a in b_coords])

      norms = lin.norm(a_centers[:, None] - b_centers, axis=2)
      locs = np.argwhere(norms <= contact_dist_threshold + residue_contact_margin * 2)
      if len(locs):
        for an, bn in locs:
          min_dist = np.min(lin.norm(a_coords[an][:, None] - b_coords[bn], axis=2))
          if min_dist <= contact_dist_threshold:
            data.append({
                'BFile': f'{bname}.pdb',
                'BModel': bmodel,
                'BType': btype,
                'BKey': cb.id,
                'BIndex': bn,
                'BResidue': b_res[bn].resname,
                'AFile': f'{aname}.pdb',
                'AModel': amodel,
                'AKey': ca.id,
                'AIndex': an,
                'AResidue': a_res[an].resname,
                'Distance': min_dist,
            })
            ct += 1

    print(model.id, ct)

  # if not data:
  #   return

  df = pd.DataFrame(data)
  return df

In [None]:
#@title Contact point visualization

def plot_contacts(df_all, exponent):
  if not len(df_all):
    print('Empty dataframe')
    return

  for (bfile, bmodel, afile, amodel), df in df_all.groupby(['BFile', 'BModel', 'AFile', 'AModel']):

    fig, (ax, *xs) = plt.subplots(1, 1 + len(df.BKey.unique()), figsize=(16, 4))

    ax.set_title(f'{bfile} {bmodel} :: {afile} {amodel}')
    ax.set_xlabel('Antigen')
    df.AIndex.hist(bins=40, weights=1 / df.Distance ** exponent, ax=ax)

    for (bkey, dfg), x in zip(df.groupby('BKey'), xs):
      x.set_title(bkey)
      x.set_xlabel('Antigen')
      x.hexbin(dfg.AIndex, dfg.BIndex, gridsize=30)
    
    plt.show()

In [None]:
!rm -f ./outputs && ln -s "/gdrive/Shared drives/TA(CO)^2 Re-Epitoping/Data/Hex" ./outputs

In [None]:
!wget -N https://raw.githubusercontent.com/rvanasa/deep-antibody/master/thera_collection.zip
!unzip -nq thera_collection.zip
clear_output()

df_cov = pd.read_csv('cov_preprocessed.csv')
dfdx = pd.read_csv('docked_preprocessed.csv')
dfdx = dfdx[dfdx.File.isin(df_cov.File)]
print(list(dfdx.File.unique()))

[]


In [None]:
ensemble_contact_dist_threshold = 3

!wget -N https://raw.githubusercontent.com/rvanasa/deep-antibody/master/thera_collection.zip
!unzip -nq thera_collection.zip
clear_output()

dft = pd.read_csv('thera_prioritized.csv')
dft = pd.concat([dfdx, dft]).reset_index()#.sort_values('File')

afiles = sorted(dfdx.File.unique())
bfiles = sorted(dft.File.unique())

# completed = np.zeros((len(afiles), len(bfiles)))

for i, brow in dft.iterrows():
  bfile, bmodel, hk, lk = brow[['File', 'Model', 'HKey', 'LKey']]
  bname = bfile.replace('.pdb', '')

  print(bname, bmodel, ''.join([hk, lk]))

  for j, arow in dft.iterrows():
    afile, amodel, ak = arow[['File', 'Model', 'AKey']]
    aname = afile.replace('.pdb', '')
    keys = ''.join([hk, lk, ak])

    parts = [bname, bmodel, aname, amodel, keys]
    part_str = '_'.join(str(s) for s in parts)

    # assert hk != lk != ak
    if not (hk != lk != ak):
        # print('Key collision')
        continue
    
    hex_path = f'outputs/{part_str}_Docked_Hex.pdb'
    if os.path.isfile(hex_path):
      # if afile in afiles:
      #   completed[afiles.index(afile), bfiles.index(bfile)] = 1

      csv_path = hex_path.replace('.pdb', '.csv')
      csv_path = csv_path[:csv_path.rindex('/')] + '/Contacts' + csv_path[csv_path.rindex('/'):]
      if os.path.isfile(csv_path):
        continue

      print(hex_path)
      cmd([
          f'rm -f input.pdb && pdb_tidy {hex_path} | pdb_reatom | pdb_reres > input.pdb',
      ])
      df = compute_contacts('input.pdb', *parts, ensemble_contact_dist_threshold)
      df.round(4).to_csv(csv_path, index=False)

      clear_output()
      plot_contacts(df, 2)

clear_output()
print('Done')

# import seaborn as sns
# fig, ax = plt.subplots(figsize=(20, 6))
# sns.heatmap(completed, ax=ax, cmap='Blues', xticklabels=bfiles, yticklabels=afiles, cbar=False)
# plt.show()

Done


In [None]:
# cmd([
#     f'rm -f input.pdb && pdb_tidy outputs/3gbm_0_6w41_0_HLC_Docked_Hex.pdb | pdb_reatom | pdb_reres > input.pdb',
# ])

# df = compute_contacts('input.pdb', '6w41', 0, '6w41', 0, ['H', 'L', 'C'])
# clear_output()
# display(df)
# plot_contacts(df, 0)
# plot_contacts(df, 3)