In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from compare_wandb import load_full_df, get_test_pdbs, load_run_dict

import wandb
api = wandb.Api()
entity = 'proteins'
project = 'iclr2021-rebuttal'

import boto3
import os
s3 = boto3.client("s3")
s3_bucket = "proteindata"

import torch

# load metatest pdbs
with open('metatest_fams.txt', 'r') as f:
    lines = f.readlines()
    metatest_pdbs = [l.strip() for l in lines]

from mogwai import models
from mogwai.utils.functional import apc
model_name = 'factored_attention'
fatt_model = models.get(model_name)
gremlin_model = models.get('gremlin')

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm

from scipy.stats import spearmanr
from scipy.stats import pearsonr

import pickle as pkl

import h5py

In [3]:
# plt.rcParams.update({
#     "text.usetex": True,
#     "font.family": "serif",
#     "font.serif": ["Times"],
#     "font.size": 8,
# })

In [4]:
head_sweep_runs = {
    'fatt-metatest-head-sweep-512': 'bxnkt0uq',
    'fatt-metatest-head-256': 'xuofwjtc',
    'fatt-metatest-head-sweep-64-kqe7or39': 'kqe7or39',
    'fatt-metatest-head-sweep-128': '32emd6ri',
    'fatt-metatest-head-sweep-8-32': '8yi6a4w5',
}

dict_of_dfs = load_run_dict(head_sweep_runs)
fatt_df = pd.concat(list(dict_of_dfs.values()))

100%|██████████| 748/748 [00:00<00:00, 32669.05it/s]

bxnkt0uq has 748 runs



100%|██████████| 748/748 [00:00<00:00, 33597.19it/s]

xuofwjtc has 748 runs



100%|██████████| 748/748 [00:00<00:00, 36077.55it/s]

kqe7or39 has 748 runs



100%|██████████| 748/748 [00:00<00:00, 34378.03it/s]

32emd6ri has 748 runs



  0%|          | 0/2233 [00:00<?, ?it/s]

8yi6a4w5 has 2233 runs


100%|██████████| 2233/2233 [00:03<00:00, 592.58it/s]


In [5]:
gremlin_runs = {'gremlin': 'dbuvl02g'}
gremlin_df_dict = load_run_dict(gremlin_runs)
gremlin_df = pd.concat(list(gremlin_df_dict.values()))

100%|██████████| 748/748 [00:00<00:00, 34646.83it/s]

dbuvl02g has 748 runs





In [6]:
def download_statedict(run_id, dest='fatt.h5'):
    run = api.run(f"{entity}/{project}/{run_id}")
    key = os.path.join("iclr-2021-factored-attention", *run.path, "model_state_dict.h5")
    with open(dest, 'wb') as f:
        s3.download_fileobj(s3_bucket, key, f)
    return dest

def get_fatt_run_id(pdb, attention_head_size, num_attention_heads, df):
    run_id = df[(df['pdb']==pdb) & (df['num_attention_heads']==num_attention_heads)]['run_id'].values
    return run_id[0]

def get_gremlin_run_id(pdb, df):
    run_id = df[df['pdb']==pdb]['run_id'].values
    return run_id[0]

def get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5'):
    run_id = get_gremlin_run_id(pdb, gremlin_df)
    f_statedict = download_statedict(run_id, dest=dest)
    statedict = torch.load(f_statedict)
    return statedict

def get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=32, dest='fatt.h5'):
    run_id = get_fatt_run_id(pdb, attention_head_size, num_attention_heads, fatt_df)
    f_statedict = download_statedict(run_id, dest=f'fatt_{num_attention_heads}.h5')
    statedict = torch.load(f_statedict)
    return statedict


In [7]:
!ls

bert-results			     Makefile
compare_wandb.py		     metatest_fams.txt
contact_maps_2bfw.pdf		     mogwai
contact_potts_3n2a_1_A_p_at_L_5.pdf  num_heads_sweep_df.pkl
contact_potts_3n2a_1_A_p_at_L.pdf    paper_figs_no_weights.ipynb
data				     plot_contacts_for_run_fig2.ipynb
download_data.sh		     potts_3n2a_1_A.pdf
fatt.h5				     predicted_contacts.npy
fatt-pl-results			     protbert_test_proteins.txt
fatt_transfer.py		     __pycache__
fig2_2bfw_contacts.pdf		     randomize_value_weights.py
fig2_2bfw_contacts.png		     README.md
gremlin-contacts		     requirements.txt
gremlin.h5			     roshan_contact_maps.ipynb
gremlin-results			     sampled_pdbs.txt
head_size_sweep.ipynb		     spagonerincon.npy
iclr2021-rebuttal		     train.py
launch_jobs.sh			     true_contacts.npy
launch_shards.sh		     venv
launch_weight_comparison.py	     wandb
LICENSE				     wandb-configs
loggers.py			     weight_correlations.ipynb
main_figs.ipynb			     weight_correlations.py


In [8]:
gremlin_statedict = get_gremlin_statedict('2bfw_1_A', gremlin_df)
fatt_statedict = get_fatt_statedict('2bfw_1_A', 32, fatt_df)

q = fatt_statedict['query']
k = fatt_statedict['key']
v = fatt_statedict['value']

w = gremlin_statedict['weight']

In [9]:
print(q.shape, k.shape, v.shape, w.shape)

torch.Size([196, 32, 32]) torch.Size([196, 32, 32]) torch.Size([32, 20, 20]) torch.Size([196, 20, 196, 20])


In [10]:
def get_msa_hparams(pdb, df):
    pdb = df[df['pdb']==pdb]
    hparam_dict = pdb.to_dict()
    msa_length = int(list(hparam_dict['msa_length'].values())[0])
    num_seqs = int(list(hparam_dict['num_seqs'].values())[0])
    return {'msa_length': msa_length, 'num_seqs': num_seqs}

In [11]:

def get_correlations(w, w_fatt, L, make_plot=False):
    idx = np.triu_indices(L, 1)

    fatt_w_no_diag = w_fatt.detach()[idx[0], :, idx[1], :]
    gremlin_w_no_diag = w[idx[0], :, idx[1], :]

    fatt_w_compare_idx = torch.flatten(fatt_w_no_diag)
    gremlin_w_compare_idx = torch.flatten(gremlin_w_no_diag)
    w_spearman = spearmanr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    w_pearson = pearsonr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    
    if make_plot:
        # plotting correlation
        subset = np.random.choice(len(gremlin_w_compare_idx), size=100000) # plotting the whole thing takes time
        plt.title(f'Spearman: {w_spearman:.2f} Pearson: {w_pearson:.2f} PDB: {pdb}')
        plt.xlabel('gremlin w')
        plt.ylabel('fatt w')
        plt.xlim(-2, 2); plt.ylim(-2, 2)
        plt.scatter(gremlin_w_compare_idx[subset], fatt_w_compare_idx[subset], s=1)
        plt.show()
    return w_spearman, w_pearson

In [12]:
from mogwai.utils.functional import apc
from mogwai.metrics import precisions_in_range

def get_info(num_attention_heads, 
             pdb,
             attention_head_size,
             fatt_df=fatt_df, 
             gremlin_df=gremlin_df):
    fatt_statedict = get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=attention_head_size, dest=f'fatt_{num_attention_heads}.h5')
    gremlin_statedict = get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5')
    
    hparams = get_msa_hparams(pdb, df=fatt_df)
    msa_length = hparams['msa_length']
    num_seqs = hparams['num_seqs']
    hparams['attention_head_size'] = attention_head_size
    hparams['num_attention_heads'] = num_attention_heads
    # initialize a matrix, which will be overriden when we load later.
    hparams['true_contacts'] = torch.ones([hparams['msa_length'], hparams['msa_length']])
    model = fatt_model(**hparams)
    model.load_state_dict(fatt_statedict)
    
    w_fatt = model.compute_mrf_weight()
    w = gremlin_statedict['weight']
    
    metrics = {}
    metrics['pdb'] = pdb
    metrics['msa_length'] = msa_length
    metrics['num_seqs'] = num_seqs
    metrics['attention_head_size'] = attention_head_size
    metrics['num_attention_heads'] = num_attention_heads
    
    predictions = apc(model.get_contacts())
    targets = model._true_contacts
    
    short = precisions_in_range(predictions, targets, minsep=6, maxsep=13)
    for k, v in short.items():
        metrics[f'short_{k}'] = float(v.squeeze())
    medium = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in medium.items():
        metrics[f'medium_{k}'] = float(v.squeeze())
    long = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in long.items():
        metrics[f'long_{k}'] = float(v.squeeze())
#     wspearman, wpearson = get_correlations(w, w_fatt, msa_length)
#     metrics['w_spearman'] = wspearman
#     metrics['w_pearson'] = wpearson
#     metrics['predicted_contacts_apc'] = predictions
#     metrics['true_contacts'] = targets
    return metrics

In [13]:
from functools import partial
from multiprocessing import Pool

info = []
num_heads_sweep = [8, 16, 32, 64, 128, 256, 512]
pdb = '2bfw_1_A'

def get_info_map(num_attention_heads):
    return get_info(num_attention_heads, attention_head_size=32, pdb=pdb)

with Pool() as p:
    new_info = p.map(get_info_map, num_heads_sweep)




IndexError: index 0 is out of bounds for axis 0 with size 0

In [None]:
info = []

for pdb in tqdm(metatest_pdbs[0:5]):
    for num_attention_heads in num_heads_sweep:
        info.append(get_info(num_attention_heads, attention_head_size=32, pdb=pdb))
        
sweep_df = pd.DataFrame.from_records(info)
sweep_df

In [None]:
from functools import partial
from multiprocessing import Pool

num_heads_sweep = [8, 16, 32, 64, 128, 256, 512]

info = []
try:
    for pdb in tqdm(metatest_pdbs[1:2]):
        get_info_map = partial(get_info, attention_head_size=32, pdb=pdb)
        with Pool() as p:
            new_info = p.map(get_info_map, num_heads_sweep)
        for i in new_info:
            info.append(i)
finally:
    print(pdb)
    sweep_df = pd.DataFrame.from_records(info)
    with open('num_head_sweep_df.pkl', 'wb') as f:
        pkl.dump(sweep_df, f)

In [None]:
fatt_df

In [None]:
info = []

for index, row in tqdm(fatt_df.iterrows()):
    info.append(get_info(num_attention_heads=row['num_attention_heads'], 
                         attention_head_size=row['attention_head_size'], 
                         pdb=row['pdb']))
    if index > 10:
        break

In [None]:
row['num_seqs']

In [None]:
n = fatt_df.shape[0]
fatt_df[:n//10]

In [None]:
with open('num_heads_sweep_df.pkl', 'rb') as f:
    sweep_df = pkl.load(f)
    

In [None]:
import seaborn as sns
sweep_df

In [None]:
plt.subplots(dpi=600)
sns.violinplot(x="num_attention_heads", 
               y="w_spearman", 
               data=sweep_df)
plt.title('spearman')

In [None]:
plt.subplots(dpi=600)
sns.violinplot(x="num_attention_heads", 
               y="w_pearson", 
               data=sweep_df)
plt.title('pearson')

In [None]:
# fig, ax = plt.subplots(figsize=(7., 3.5), dpi=600, ncols=3, sharey=True)

melted_df = pd.melt(sweep_df, id_vars=['num_attention_heads'], value_vars=['short_pr_at_l_5', 'medium_pr_at_l_5', 'long_pr_at_l_5'])
# plt.subplots(dpi=600)
sns.catplot(x="variable", 
            y="value", 
            hue="num_attention_heads",
            data=melted_df, 
            kind="violin",
            height=10, 
            aspect=2)
plt.title('short medium long L/5')

In [None]:
melted_df = pd.melt(sweep_df, id_vars=['num_attention_heads'], value_vars=['short_pr_at_l', 'medium_pr_at_l', 'long_pr_at_l'])
# plt.subplots(dpi=600)
sns.catplot(x="variable", 
            y="value", 
            hue="num_attention_heads",
            data=melted_df, 
            kind="violin",
            height=10, 
            aspect=2)
plt.title('short medium long L')