In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from compare_wandb import load_full_df, get_test_pdbs, load_run_dict

import wandb
api = wandb.Api()
entity = 'proteins'
project = 'iclr2021-rebuttal'

import boto3
import os
s3 = boto3.client("s3")
s3_bucket = "proteindata"

import torch

# load metatest pdbs
with open('metatest_fams.txt', 'r') as f:
    lines = f.readlines()
    metatest_pdbs = [l.strip() for l in lines]

from mogwai import models
from mogwai.utils.functional import apc
model_name = 'factored_attention'
fatt_model = models.get(model_name)
gremlin_model = models.get('gremlin')

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm

from scipy.stats import spearmanr
from scipy.stats import pearsonr

import pickle as pkl

import h5py

In [3]:
head_sweep_runs = {
    'fatt-metatest-head-sweep-512': 'bxnkt0uq',
    'fatt-metatest-head-256': 'xuofwjtc',
    'fatt-metatest-head-sweep-64-kqe7or39': 'kqe7or39',
    'fatt-metatest-head-sweep-128': '32emd6ri',
    'fatt-metatest-head-sweep-8-32': '8yi6a4w5',
}

dict_of_dfs = load_run_dict(head_sweep_runs)
fatt_df = pd.concat(list(dict_of_dfs.values()))

100%|██████████| 748/748 [00:00<00:00, 32564.95it/s]

bxnkt0uq has 748 runs



100%|██████████| 748/748 [00:00<00:00, 33427.87it/s]

xuofwjtc has 748 runs



100%|██████████| 748/748 [00:00<00:00, 35868.42it/s]

kqe7or39 has 748 runs



100%|██████████| 748/748 [00:00<00:00, 34037.14it/s]

32emd6ri has 748 runs



  0%|          | 0/2233 [00:00<?, ?it/s]

8yi6a4w5 has 2233 runs


100%|██████████| 2233/2233 [00:03<00:00, 684.03it/s]


In [4]:
gremlin_runs = {'gremlin': 'dbuvl02g'}
gremlin_df_dict = load_run_dict(gremlin_runs)
gremlin_df = pd.concat(list(gremlin_df_dict.values()))

100%|██████████| 748/748 [00:00<00:00, 34226.89it/s]

dbuvl02g has 748 runs





In [5]:
def download_statedict(run_id, dest='fatt.h5'):
    run = api.run(f"{entity}/{project}/{run_id}")
    key = os.path.join("iclr-2021-factored-attention", *run.path, "model_state_dict.h5")
    with open(dest, 'wb') as f:
        s3.download_fileobj(s3_bucket, key, f)
    return dest

def get_fatt_run_id(pdb, attention_head_size, num_attention_heads, df):
    run_id = df[(df['pdb']==pdb) & (df['num_attention_heads']==num_attention_heads)]['run_id'].values
    return run_id[0]

def get_gremlin_run_id(pdb, df):
    run_id = df[df['pdb']==pdb]['run_id'].values
    return run_id[0]

def get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5'):
    run_id = get_gremlin_run_id(pdb, gremlin_df)
    f_statedict = download_statedict(run_id, dest=dest)
    statedict = torch.load(f_statedict)
    return statedict

def get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=32, dest='fatt.h5'):
    run_id = get_fatt_run_id(pdb, attention_head_size, num_attention_heads, fatt_df)
    f_statedict = download_statedict(run_id, dest=f'fatt_{num_attention_heads}.h5')
    statedict = torch.load(f_statedict)
    return statedict


In [7]:
!ls

bert-results		     metatest_fams.txt
compare_wandb.py	     mogwai
contact_maps_2bfw.pdf	     paper_figs_no_weights.ipynb
contact_potts_3n2a_1_A.pdf   plot_contacts_for_run_fig2.ipynb
data			     potts_3n2a_1_A.pdf
download_data.sh	     predicted_contacts.npy
fatt.h5			     protbert_test_proteins.txt
fatt-pl-results		     __pycache__
fatt_transfer.py	     randomize_value_weights.py
fig2_2bfw_contacts.pdf	     README.md
fig2_2bfw_contacts.png	     requirements.txt
gremlin-contacts	     roshan_contact_maps.ipynb
gremlin.h5		     sampled_pdbs.txt
gremlin-results		     spagonerincon.npy
iclr2021-rebuttal	     train.py
launch_jobs.sh		     true_contacts.npy
launch_weight_comparison.py  venv
LICENSE			     wandb
loggers.py		     wandb-configs
main_figs.ipynb		     weight_correlations.ipynb
Makefile		     weight_correlations.py


In [8]:
gremlin_statedict = get_gremlin_statedict('2bfw_1_A', gremlin_df)
fatt_statedict = get_fatt_statedict('2bfw_1_A', 32, fatt_df)

q = fatt_statedict['query']
k = fatt_statedict['key']
v = fatt_statedict['value']

w = gremlin_statedict['weight']

In [9]:
print(q.shape, k.shape, v.shape, w.shape)

torch.Size([196, 32, 32]) torch.Size([196, 32, 32]) torch.Size([32, 20, 20]) torch.Size([196, 20, 196, 20])


In [10]:
def get_msa_hparams(pdb, df):
    pdb = df[df['pdb']==pdb]
    hparam_dict = pdb.to_dict()
    msa_length = int(list(hparam_dict['msa_length'].values())[0])
    num_seqs = int(list(hparam_dict['num_seqs'].values())[0])
    return {'msa_length': msa_length, 'num_seqs': num_seqs}

In [11]:

def get_correlations(w, w_fatt, L, make_plot=False):
    idx = np.triu_indices(L, 1)

    fatt_w_no_diag = w_fatt.detach()[idx[0], :, idx[1], :]
    gremlin_w_no_diag = w[idx[0], :, idx[1], :]

    fatt_w_compare_idx = torch.flatten(fatt_w_no_diag)
    gremlin_w_compare_idx = torch.flatten(gremlin_w_no_diag)
    w_spearman = spearmanr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    w_pearson = pearsonr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    
    if make_plot:
        # plotting correlation
        subset = np.random.choice(len(gremlin_w_compare_idx), size=100000) # plotting the whole thing takes time
        plt.title(f'Spearman: {w_spearman:.2f} Pearson: {w_pearson:.2f} PDB: {pdb}')
        plt.xlabel('gremlin w')
        plt.ylabel('fatt w')
        plt.xlim(-2, 2); plt.ylim(-2, 2)
        plt.scatter(gremlin_w_compare_idx[subset], fatt_w_compare_idx[subset], s=1)
        plt.show()
    return w_spearman, w_pearson

In [24]:
from mogwai.utils.functional import apc
from mogwai.metrics import precisions_in_range

def get_info(num_attention_heads, 
             pdb,
             attention_head_size,
             fatt_df=fatt_df, 
             gremlin_df=gremlin_df):
    fatt_statedict = get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=attention_head_size, dest=f'fatt_{num_attention_heads}.h5')
    gremlin_statedict = get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5')
    
    hparams = get_msa_hparams(pdb, df=fatt_df)
    msa_length = hparams['msa_length']
    num_seqs = hparams['num_seqs']
    hparams['attention_head_size'] = attention_head_size
    hparams['num_attention_heads'] = num_attention_heads
    # initialize a matrix, which will be overriden when we load later.
    hparams['true_contacts'] = torch.ones([hparams['msa_length'], hparams['msa_length']])
    model = fatt_model(**hparams)
    model.load_state_dict(fatt_statedict)
    
    w_fatt = model.compute_mrf_weight()
    w = gremlin_statedict['weight']
    
    metrics = {}
    metrics['pdb'] = pdb
    metrics['msa_length'] = msa_length
    metrics['num_seqs'] = num_seqs
    metrics['attention_head_size'] = attention_head_size
    metrics['num_attention_heads'] = num_attention_heads
    
    predictions = apc(model.get_contacts())
    targets = model._true_contacts
    
    short = precisions_in_range(predictions, targets, minsep=6, maxsep=13)
    for k, v in short.items():
        metrics[f'short_{k}'] = float(v.squeeze())
    medium = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in medium.items():
        metrics[f'medium_{k}'] = float(v.squeeze())
    long = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in long.items():
        metrics[f'long_{k}'] = float(v.squeeze())
#     wspearman, wpearson = get_correlations(w, w_fatt, msa_length)
#     metrics['w_spearman'] = wspearman
#     metrics['w_pearson'] = wpearson
#     metrics['predicted_contacts_apc'] = predictions
#     metrics['true_contacts'] = targets
    return metrics

In [None]:
from functools import partial
from multiprocessing import Pool

info = []
num_heads_sweep = [8, 16, 32, 64, 128, 256, 512]
pdb = '2bfw_1_A'

def get_info_map(num_attention_heads):
    return get_info(num_attention_heads, attention_head_size=32, pdb=pdb)

with Pool() as p:
    new_info = p.map(get_info_map, num_heads_sweep)




In [None]:
info = []

for pdb in tqdm(metatest_pdbs[0:5]):
    for num_attention_heads in num_heads_sweep:
        info.append(get_info(num_attention_heads, attention_head_size=32, pdb=pdb))
        
sweep_df = pd.DataFrame.from_records(info)
sweep_df

In [None]:
from functools import partial
from multiprocessing import Pool

num_heads_sweep = [8, 16, 32, 64, 128, 256, 512]

info = []
try:
    for pdb in tqdm(metatest_pdbs[1:2]):
        get_info_map = partial(get_info, attention_head_size=32, pdb=pdb)
        with Pool() as p:
            new_info = p.map(get_info_map, num_heads_sweep)
        for i in new_info:
            info.append(i)
finally:
    print(pdb)
    sweep_df = pd.DataFrame.from_records(info)
    with open('num_head_sweep_df.pkl', 'wb') as f:
        pkl.dump(sweep_df, f)

In [12]:
fatt_df

Unnamed: 0,sweep_name,model,pdb,msa_length,pdb_idx,num_seqs,run_state,pr_at_L,pr_at_L_apc,pr_at_L_5,pr_at_L_5_apc,auc,auc_apc,use_bias,run_id,attention_head_size,num_attention_heads,log_num_seqs
0,fatt-metatest-head-sweep-512,factored_attention,2w3o_1_A,100.0,747,17351.0,finished,0.460000,0.690000,0.850000,0.950000,0.841365,0.917629,True,jlap14s0,32,512,9.761405
1,fatt-metatest-head-sweep-512,factored_attention,1eqz_1_B,108.0,746,4997.0,finished,0.074074,0.166667,0.095238,0.285714,0.115649,0.358059,True,rqn7bxru,32,512,8.516593
2,fatt-metatest-head-sweep-512,factored_attention,3no0_1_A,276.0,745,9408.0,finished,0.612319,0.684783,0.945455,0.909091,0.882672,0.902000,True,p55s7hrh,32,512,9.149316
3,fatt-metatest-head-sweep-512,factored_attention,1xju_1_A,156.0,744,2110.0,finished,0.358974,0.500000,0.483871,0.548387,0.523101,0.545718,True,s1xrufs4,32,512,7.654443
4,fatt-metatest-head-sweep-512,factored_attention,4ew5_1_B,97.0,743,2254.0,finished,0.154639,0.257732,0.315789,0.421053,0.348058,0.483167,True,rlwj19ts,32,512,7.720462
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2228,fatt-metatest-head-sweep-8-32,factored_attention,1zli_1_A,306.0,4,2169.0,finished,0.307190,0.300654,0.606557,0.639344,0.564446,0.573865,True,q63hntpn,32,8,7.682022
2229,fatt-metatest-head-sweep-8-32,factored_attention,1u3j_1_B,99.0,3,1167.0,finished,0.242424,0.292929,0.578947,0.631579,0.520355,0.566619,True,r4v88yge,32,8,7.062192
2230,fatt-metatest-head-sweep-8-32,factored_attention,3uv1_1_B,190.0,5,2175.0,finished,0.100000,0.142105,0.210526,0.236842,0.197405,0.223467,True,sw1mcdz6,32,8,7.684784
2231,fatt-metatest-head-sweep-8-32,factored_attention,6alm_1_A,338.0,0,2179.0,finished,0.130178,0.159763,0.238806,0.253731,0.217803,0.254888,True,yqeeidut,32,8,7.686621


In [26]:
info = []

for index, row in tqdm(fatt_df.iterrows()):
    info.append(get_info(num_attention_heads=row['num_attention_heads'], 
                         attention_head_size=row['attention_head_size'], 
                         pdb=row['pdb']))
    if index > 10:
        break

11it [01:29,  8.11s/it]


In [20]:
row['num_seqs']

17351.0

In [33]:
n = fatt_df.shape[0]
fatt_df[:n//10]

Unnamed: 0,sweep_name,model,pdb,msa_length,pdb_idx,num_seqs,run_state,pr_at_L,pr_at_L_apc,pr_at_L_5,pr_at_L_5_apc,auc,auc_apc,use_bias,run_id,attention_head_size,num_attention_heads,log_num_seqs
0,fatt-metatest-head-sweep-512,factored_attention,2w3o_1_A,100.0,747,17351.0,finished,0.460000,0.690000,0.850000,0.950000,0.841365,0.917629,True,jlap14s0,32,512,9.761405
1,fatt-metatest-head-sweep-512,factored_attention,1eqz_1_B,108.0,746,4997.0,finished,0.074074,0.166667,0.095238,0.285714,0.115649,0.358059,True,rqn7bxru,32,512,8.516593
2,fatt-metatest-head-sweep-512,factored_attention,3no0_1_A,276.0,745,9408.0,finished,0.612319,0.684783,0.945455,0.909091,0.882672,0.902000,True,p55s7hrh,32,512,9.149316
3,fatt-metatest-head-sweep-512,factored_attention,1xju_1_A,156.0,744,2110.0,finished,0.358974,0.500000,0.483871,0.548387,0.523101,0.545718,True,s1xrufs4,32,512,7.654443
4,fatt-metatest-head-sweep-512,factored_attention,4ew5_1_B,97.0,743,2254.0,finished,0.154639,0.257732,0.315789,0.421053,0.348058,0.483167,True,rlwj19ts,32,512,7.720462
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
516,fatt-metatest-head-sweep-512,factored_attention,3rnk_1_B,104.0,231,5326.0,finished,0.278846,0.519231,0.850000,0.850000,0.717754,0.827431,True,ufvy48ir,32,512,8.580356
517,fatt-metatest-head-sweep-512,factored_attention,3rv0_1_A,239.0,230,8649.0,finished,0.158996,0.221757,0.276596,0.382979,0.248720,0.303282,True,vp90isjw,32,512,9.065199
518,fatt-metatest-head-sweep-512,factored_attention,1sgj_1_A,231.0,229,7248.0,finished,0.601732,0.688312,0.739130,0.782609,0.735768,0.766259,True,v4gmm13m,32,512,8.888481
519,fatt-metatest-head-sweep-512,factored_attention,4n2i_1_A,662.0,228,522.0,finished,0.075529,0.175227,0.219697,0.356061,0.216997,0.344392,True,ew6lftom,32,512,6.257668
