In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from compare_wandb import load_full_df, get_test_pdbs, load_run_dict

import wandb
api = wandb.Api()
entity = 'proteins'
project = 'iclr2021-rebuttal'

import boto3
import os
s3 = boto3.client("s3")
s3_bucket = "proteindata"

import torch

# load metatest pdbs
with open('metatest_fams.txt', 'r') as f:
    lines = f.readlines()
    metatest_pdbs = [l.strip() for l in lines]

from mogwai import models
from mogwai.utils.functional import apc
model_name = 'factored_attention'
fatt_model = models.get(model_name)
gremlin_model = models.get('gremlin')

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm

from scipy.stats import spearmanr
from scipy.stats import pearsonr

import pickle as pkl

import h5py

In [3]:
head_sweep_runs = {
    'fatt-metatest-head-sweep-512': 'bxnkt0uq',
    'fatt-metatest-head-256': 'xuofwjtc',
    'fatt-metatest-head-sweep-64-kqe7or39': 'kqe7or39',
    'fatt-metatest-head-sweep-128': '32emd6ri',
    'fatt-metatest-head-sweep-8-32': '8yi6a4w5',
}

dict_of_dfs = load_run_dict(head_sweep_runs)
fatt_df = pd.concat(list(dict_of_dfs.values()))

100%|██████████| 748/748 [00:00<00:00, 32494.45it/s]

bxnkt0uq has 748 runs



100%|██████████| 748/748 [00:00<00:00, 35895.92it/s]

xuofwjtc has 748 runs



100%|██████████| 748/748 [00:00<00:00, 34070.78it/s]

kqe7or39 has 748 runs



100%|██████████| 748/748 [00:00<00:00, 34669.42it/s]

32emd6ri has 748 runs



  0%|          | 0/2233 [00:00<?, ?it/s]

8yi6a4w5 has 2233 runs


100%|██████████| 2233/2233 [00:03<00:00, 603.97it/s]


In [4]:
gremlin_runs = {'gremlin': 'dbuvl02g'}
gremlin_df_dict = load_run_dict(gremlin_runs)
gremlin_df = pd.concat(list(gremlin_df_dict.values()))

100%|██████████| 748/748 [00:00<00:00, 36760.67it/s]

dbuvl02g has 748 runs





In [5]:
def download_statedict(run_id, dest='fatt.h5'):
    run = api.run(f"{entity}/{project}/{run_id}")
    key = os.path.join("iclr-2021-factored-attention", *run.path, "model_state_dict.h5")
    with open(dest, 'wb') as f:
        s3.download_fileobj(s3_bucket, key, f)
    return dest

def get_fatt_run_id(pdb, attention_head_size, num_attention_heads, df):
    run_id = df[(df['pdb']==pdb) & (df['num_attention_heads']==num_attention_heads)]['run_id'].values
    return run_id[0]

def get_gremlin_run_id(pdb, df):
    run_id = df[df['pdb']==pdb]['run_id'].values
    return run_id[0]

def get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5'):
    run_id = get_gremlin_run_id(pdb, gremlin_df)
    f_statedict = download_statedict(run_id, dest=dest)
    statedict = torch.load(f_statedict)
    return statedict

def get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=32, dest='fatt.h5'):
    run_id = get_fatt_run_id(pdb, attention_head_size, num_attention_heads, fatt_df)
    f_statedict = download_statedict(run_id, dest=f'fatt_{num_attention_heads}.h5')
    statedict = torch.load(f_statedict)
    return statedict


In [21]:
f =  h5py.File('fatt_8.h5', 'r')
print(f.items())

OSError: Unable to open file (file signature not found)

In [19]:
!ls

bert-results	  gremlin-contacts	       __pycache__
compare_wandb.py  gremlin.h5		       randomize_value_weights.py
data		  gremlin-results	       README.md
download_data.sh  iclr2021-rebuttal	       requirements.txt
fatt_128.h5	  launch_jobs.sh	       sampled_pdbs.txt
fatt_16.h5	  LICENSE		       sweep_df.pkl
fatt_256.h5	  loggers.py		       train.py
fatt_32.h5	  main_figs.ipynb	       venv
fatt_512.h5	  Makefile		       wandb
fatt_64.h5	  metatest_fams.txt	       wandb-configs
fatt_8.h5	  mogwai		       weight_correlations.ipynb
fatt-pl-results   num_head_sweep_df.pkl        weight_correlations.py
fatt_transfer.py  paper_figs_no_weights.ipynb


In [6]:
gremlin_statedict = get_gremlin_statedict('2bfw_1_A', gremlin_df)
fatt_statedict = get_fatt_statedict('2bfw_1_A', 32, fatt_df)

q = fatt_statedict['query']
k = fatt_statedict['key']
v = fatt_statedict['value']

w = gremlin_statedict['weight']

In [7]:
print(q.shape, k.shape, v.shape, w.shape)

torch.Size([196, 32, 32]) torch.Size([196, 32, 32]) torch.Size([32, 20, 20]) torch.Size([196, 20, 196, 20])


In [8]:
def get_msa_hparams(pdb, df):
    pdb = df[df['pdb']==pdb]
    hparam_dict = pdb.to_dict()
    msa_length = int(list(hparam_dict['msa_length'].values())[0])
    num_seqs = int(list(hparam_dict['num_seqs'].values())[0])
    return {'msa_length': msa_length, 'num_seqs': num_seqs}

In [9]:

def get_correlations(w, w_fatt, L, make_plot=False):
    idx = np.triu_indices(L, 1)

    fatt_w_no_diag = w_fatt.detach()[idx[0], :, idx[1], :]
    gremlin_w_no_diag = w[idx[0], :, idx[1], :]

    fatt_w_compare_idx = torch.flatten(fatt_w_no_diag)
    gremlin_w_compare_idx = torch.flatten(gremlin_w_no_diag)
    w_spearman = spearmanr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    w_pearson = pearsonr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    
    if make_plot:
        # plotting correlation
        subset = np.random.choice(len(gremlin_w_compare_idx), size=100000) # plotting the whole thing takes time
        plt.title(f'Spearman: {w_spearman:.2f} Pearson: {w_pearson:.2f} PDB: {pdb}')
        plt.xlabel('gremlin w')
        plt.ylabel('fatt w')
        plt.xlim(-2, 2); plt.ylim(-2, 2)
        plt.scatter(gremlin_w_compare_idx[subset], fatt_w_compare_idx[subset], s=1)
        plt.show()
    return w_spearman, w_pearson

In [10]:
from mogwai.utils.functional import apc
from mogwai.metrics import precisions_in_range

def get_info(num_attention_heads, 
             pdb,
             attention_head_size,
             fatt_df=fatt_df, 
             gremlin_df=gremlin_df):
    fatt_statedict = get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=attention_head_size, dest=f'fatt_{num_attention_heads}.h5')
    gremlin_statedict = get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5')
    
    hparams = get_msa_hparams(pdb, df=fatt_df)
    msa_length = hparams['msa_length']
    num_seqs = hparams['num_seqs']
    hparams['attention_head_size'] = attention_head_size
    hparams['num_attention_heads'] = num_attention_heads
    # initialize a matrix, which will be overriden when we load later.
    hparams['true_contacts'] = torch.ones([hparams['msa_length'], hparams['msa_length']])
    model = fatt_model(**hparams)
    model.load_state_dict(fatt_statedict)
    
    w_fatt = model.compute_mrf_weight()
    w = gremlin_statedict['weight']
    
    metrics = {}
    metrics['pdb'] = pdb
    metrics['msa_length'] = msa_length
    metrics['num_seqs'] = num_seqs
    metrics['attention_head_size'] = attention_head_size
    metrics['num_attention_heads'] = num_attention_heads
    
    predictions = apc(model.get_contacts())
    targets = model._true_contacts
    
    short = precisions_in_range(predictions, targets, minsep=6, maxsep=13)
    for k, v in short.items():
        metrics[f'short_{k}'] = float(v.squeeze())
    medium = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in medium.items():
        metrics[f'medium_{k}'] = float(v.squeeze())
    long = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in long.items():
        metrics[f'long_{k}'] = float(v.squeeze())
    wspearman, wpearson = get_correlations(w, w_fatt, msa_length)
    metrics['w_spearman'] = wspearman
    metrics['w_pearson'] = wpearson
#     metrics['predicted_contacts_apc'] = predictions
#     metrics['true_contacts'] = targets
    return metrics

In [11]:
from functools import partial
from multiprocessing import Pool

info = []
num_heads_sweep = [8, 16, 32, 64, 128, 256, 512]
pdb = '2bfw_1_A'

def get_info_map(num_attention_heads):
    return get_info(num_attention_heads, attention_head_size=32, pdb=pdb)

with Pool() as p:
    new_info = p.map(get_info_map, num_heads_sweep)




IndexError: index 0 is out of bounds for axis 0 with size 0

In [15]:
info = []

for pdb in tqdm(metatest_pdbs[0:5]):
    for num_attention_heads in num_heads_sweep:
        info.append(get_info(num_attention_heads, attention_head_size=32, pdb=pdb))
        
sweep_df = pd.DataFrame.from_records(info)
sweep_df

100%|██████████| 5/5 [06:27<00:00, 77.46s/it] 


Unnamed: 0,pdb,msa_length,num_seqs,attention_head_size,num_attention_heads,short_auc,short_pr_at_l,short_pr_at_l_2,short_pr_at_l_5,short_pr_at_l_10,...,medium_pr_at_l_2,medium_pr_at_l_5,medium_pr_at_l_10,long_auc,long_pr_at_l,long_pr_at_l_2,long_pr_at_l_5,long_pr_at_l_10,w_spearman,w_pearson
0,6alm_1_A,338,2179,32,8,0.142447,0.091716,0.147929,0.19403,0.242424,...,0.029586,0.044776,0.060606,0.033325,0.026627,0.029586,0.044776,0.060606,0.070647,0.137098
1,6alm_1_A,338,2179,32,16,0.150456,0.10355,0.130178,0.208955,0.272727,...,0.029586,0.074627,0.121212,0.045219,0.023669,0.029586,0.074627,0.121212,0.098218,0.193777
2,6alm_1_A,338,2179,32,32,0.140956,0.091716,0.112426,0.208955,0.30303,...,0.04142,0.074627,0.151515,0.05844,0.032544,0.04142,0.074627,0.151515,0.136595,0.266383
3,6alm_1_A,338,2179,32,64,0.163534,0.094675,0.147929,0.223881,0.333333,...,0.053254,0.119403,0.151515,0.068318,0.038462,0.053254,0.119403,0.151515,0.182637,0.35056
4,6alm_1_A,338,2179,32,128,0.192881,0.118343,0.171598,0.283582,0.424242,...,0.053254,0.134328,0.181818,0.071363,0.032544,0.053254,0.134328,0.181818,0.239467,0.442238
5,6alm_1_A,338,2179,32,256,0.216074,0.130178,0.183432,0.298507,0.454545,...,0.065089,0.134328,0.212121,0.083589,0.038462,0.065089,0.134328,0.212121,0.300808,0.536104
6,6alm_1_A,338,2179,32,512,0.201583,0.130178,0.177515,0.253731,0.393939,...,0.071006,0.149254,0.181818,0.086667,0.044379,0.071006,0.149254,0.181818,0.379835,0.64053
7,3eud_1_A,99,22284,32,8,0.371543,0.20202,0.346939,0.578947,0.777778,...,0.326531,0.473684,0.666667,0.342777,0.222222,0.326531,0.473684,0.666667,0.0771,0.165929
8,3eud_1_A,99,22284,32,16,0.360457,0.191919,0.285714,0.526316,0.888889,...,0.387755,0.526316,0.555556,0.380981,0.252525,0.387755,0.526316,0.555556,0.104053,0.215534
9,3eud_1_A,99,22284,32,32,0.381963,0.232323,0.326531,0.578947,0.777778,...,0.44898,0.684211,0.777778,0.490119,0.30303,0.44898,0.684211,0.777778,0.140257,0.277972


In [16]:
from functools import partial
from multiprocessing import Pool

num_heads_sweep = [8, 16, 32, 64, 128, 256, 512]

info = []
try:
    for pdb in tqdm(metatest_pdbs[1:2]):
        get_info_map = partial(get_info, attention_head_size=32, pdb=pdb)
        with Pool() as p:
            new_info = p.map(get_info_map, num_heads_sweep)
        for i in new_info:
            info.append(i)
finally:
    print(pdb)
    sweep_df = pd.DataFrame.from_records(info)
    with open('num_head_sweep_df.pkl', 'wb') as f:
        pkl.dump(sweep_df, f)

  1%|          | 6/748 [07:46<16:02:25, 77.82s/it] 


KeyboardInterrupt: 

In [28]:
get_info_map = partial(get_info, attention_head_size=32, pdb='2bfw_1_A')

In [29]:
get_info_map(8)


In [34]:
print(get_info(num_attention_heads=8, attention_head_size=32, pdb='2bfw_1_A'))