In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from compare_wandb import load_full_df, get_test_pdbs, load_run_dict

import wandb
api = wandb.Api()
entity = 'proteins'
project = 'iclr2021-rebuttal'

import boto3
import os
s3 = boto3.client("s3")
s3_bucket = "proteindata"

import torch

# load metatest pdbs
with open('metatest_fams.txt', 'r') as f:
    lines = f.readlines()
    metatest_pdbs = [l.strip() for l in lines]

from mogwai import models
from mogwai.utils.functional import apc
model_name = 'factored_attention'
fatt_model = models.get(model_name)
gremlin_model = models.get('gremlin')

import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm

from scipy.stats import spearmanr
from scipy.stats import pearsonr


In [3]:
head_sweep_runs = {
    'fatt-metatest-head-sweep-512': 'bxnkt0uq',
    'fatt-metatest-head-256': 'xuofwjtc',
    'fatt-metatest-head-sweep-64-kqe7or39': 'kqe7or39',
    'fatt-metatest-head-sweep-128': '32emd6ri',
    'fatt-metatest-head-sweep-8-32': '8yi6a4w5',
}

dict_of_dfs = load_run_dict(head_sweep_runs)
fatt_df = pd.concat(list(dict_of_dfs.values()))

100%|██████████| 748/748 [00:00<00:00, 34552.96it/s]

bxnkt0uq has 748 runs



100%|██████████| 748/748 [00:00<00:00, 33531.84it/s]

xuofwjtc has 748 runs



100%|██████████| 748/748 [00:00<00:00, 32964.59it/s]

kqe7or39 has 748 runs



100%|██████████| 748/748 [00:00<00:00, 33895.93it/s]

32emd6ri has 748 runs



  0%|          | 0/2233 [00:00<?, ?it/s]

8yi6a4w5 has 2233 runs


100%|██████████| 2233/2233 [00:03<00:00, 675.85it/s]


In [4]:
gremlin_runs = {'gremlin': 'dbuvl02g'}
gremlin_df_dict = load_run_dict(gremlin_runs)
gremlin_df = pd.concat(list(gremlin_df_dict.values()))

100%|██████████| 748/748 [00:00<00:00, 34989.96it/s]

dbuvl02g has 748 runs





In [5]:
def download_statedict(run_id, dest='fatt.h5'):
    run = api.run(f"{entity}/{project}/{run_id}")
    key = os.path.join("iclr-2021-factored-attention", *run.path, "model_state_dict.h5")
    with open(dest, 'wb') as f:
        s3.download_fileobj(s3_bucket, key, f)
    return dest

def get_fatt_run_id(pdb, attention_head_size, num_attention_heads, df):
    run_id = df[(df['pdb']==pdb) & (df['num_attention_heads']==num_attention_heads)]['run_id'].values
    return run_id[0]

def get_gremlin_run_id(pdb, df):
    run_id = df[df['pdb']==pdb]['run_id'].values
    return run_id[0]

def get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5'):
    run_id = get_gremlin_run_id(pdb, gremlin_df)
    f_statedict = download_statedict(run_id, dest=dest)
    statedict = torch.load(f_statedict)
    return statedict

def get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=32, dest='fatt.h5'):
    run_id = get_fatt_run_id(pdb, attention_head_size, num_attention_heads, fatt_df)
    f_statedict = download_statedict(run_id, dest=dest)
    statedict = torch.load(f_statedict)
    return statedict


In [6]:
gremlin_statedict = get_gremlin_statedict('2bfw_1_A', gremlin_df)
fatt_statedict = get_fatt_statedict('2bfw_1_A', 32, fatt_df)

q = fatt_statedict['query']
k = fatt_statedict['key']
v = fatt_statedict['value']

w = gremlin_statedict['weight']

In [7]:
print(q.shape, k.shape, v.shape, w.shape)

torch.Size([196, 32, 32]) torch.Size([196, 32, 32]) torch.Size([32, 20, 20]) torch.Size([196, 20, 196, 20])


In [8]:
def get_msa_hparams(pdb, df):
    pdb = df[df['pdb']==pdb]
    hparam_dict = pdb.to_dict()
    msa_length = int(list(hparam_dict['msa_length'].values())[0])
    num_seqs = int(list(hparam_dict['num_seqs'].values())[0])
    return {'msa_length': msa_length, 'num_seqs': num_seqs}

In [9]:

def get_correlations(w, w_fatt, L, make_plot=False):
    idx = np.triu_indices(L, 1)

    fatt_w_no_diag = w_fatt.detach()[idx[0], :, idx[1], :]
    gremlin_w_no_diag = w[idx[0], :, idx[1], :]

    fatt_w_compare_idx = torch.flatten(fatt_w_no_diag)
    gremlin_w_compare_idx = torch.flatten(gremlin_w_no_diag)
    w_spearman = spearmanr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    w_pearson = pearsonr(fatt_w_compare_idx, gremlin_w_compare_idx)[0]
    
    if make_plot:
        # plotting correlation
        subset = np.random.choice(len(gremlin_w_compare_idx), size=100000) # plotting the whole thing takes time
        plt.title(f'Spearman: {w_spearman:.2f} Pearson: {w_pearson:.2f} PDB: {pdb}')
        plt.xlabel('gremlin w')
        plt.ylabel('fatt w')
        plt.xlim(-2, 2); plt.ylim(-2, 2)
        plt.scatter(gremlin_w_compare_idx[subset], fatt_w_compare_idx[subset], s=1)
        plt.show()
    return w_spearman, w_pearson

In [49]:
from mogwai.utils.functional import apc
from mogwai.metrics import precisions_in_range

def get_info(attention_head_size, 
             num_attention_heads, 
             pdb, 
             fatt_df=fatt_df, 
             gremlin_df=gremlin_df):
    fatt_statedict = get_fatt_statedict(pdb, num_attention_heads, fatt_df, attention_head_size=attention_head_size, dest='fatt.h5')
    gremlin_statedict = get_gremlin_statedict(pdb, gremlin_df, dest='gremlin.h5')

    hparams = get_msa_hparams(pdb, df=fatt_df)
    msa_length = hparams['msa_length']
    num_seqs = hparams['num_seqs']
    hparams['attention_head_size'] = attention_head_size
    hparams['num_attention_heads'] = num_attention_heads
    # initialize a matrix, which will be overriden when we load later.
    hparams['true_contacts'] = torch.ones([hparams['msa_length'], hparams['msa_length']])
    model = fatt_model(**hparams)
    model.load_state_dict(torch.load('fatt.h5'))
    
    w_fatt = model.compute_mrf_weight()
    w = gremlin_statedict['weight']
    
    metrics = {}
    metrics['pdb'] = pdb
    metrics['msa_length'] = msa_length
    metrics['num_seqs'] = num_seqs
    metrics['attention_head_size'] = attention_head_size
    metrics['num_attention_heads'] = num_attention_heads
    
    predictions = apc(model.get_contacts())
    targets = model._true_contacts
    
    short = precisions_in_range(predictions, targets, minsep=6, maxsep=13)
    for k, v in short.items():
        metrics[f'short_{k}'] = float(v.squeeze())
    medium = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in medium.items():
        metrics[f'medium_{k}'] = float(v.squeeze())
    long = precisions_in_range(predictions, targets, minsep=13, maxsep=25)
    for k, v in long.items():
        metrics[f'long_{k}'] = float(v.squeeze())
    wspearman, wpearson = get_correlations(w, w_fatt, msa_length)
    metrics['w_spearman'] = wspearman
    metrics['w_pearson'] = wpearson
#     metrics['predicted_contacts_apc'] = predictions
#     metrics['true_contacts'] = targets
    return metrics

In [50]:

info = []
for pdb in tqdm(metatest_pdbs[2:3]):
    for num_attention_heads in [8, 16, 32, 64, 128, 256, 512]:
        info.append(get_info(attention_head_size=32, num_attention_heads=num_attention_heads, pdb=pdb))

100%|██████████| 1/1 [01:24<00:00, 84.04s/it]


In [51]:
import pickle as pkl

sweep_df = pd.DataFrame.from_records(info)

with open('num_head_sweep_df.pkl', 'wb') as f:
    pkl.dump(sweep_df, f)



In [52]:
sweep_df

Unnamed: 0,pdb,msa_length,num_seqs,attention_head_size,num_attention_heads,short_auc,short_pr_at_l,short_pr_at_l_2,short_pr_at_l_5,short_pr_at_l_10,...,medium_pr_at_l_2,medium_pr_at_l_5,medium_pr_at_l_10,long_auc,long_pr_at_l,long_pr_at_l_2,long_pr_at_l_5,long_pr_at_l_10,w_spearman,w_pearson
0,1rkt_1_B,200,24747,32,8,0.176282,0.1,0.14,0.25,0.45,...,0.16,0.275,0.4,0.177962,0.085,0.16,0.275,0.4,0.086878,0.145674
1,1rkt_1_B,200,24747,32,16,0.185198,0.1,0.15,0.275,0.5,...,0.18,0.35,0.5,0.216202,0.11,0.18,0.35,0.5,0.108326,0.185643
2,1rkt_1_B,200,24747,32,32,0.209821,0.105,0.17,0.325,0.5,...,0.2,0.375,0.65,0.251353,0.13,0.2,0.375,0.65,0.140836,0.235991
3,1rkt_1_B,200,24747,32,64,0.227103,0.11,0.19,0.35,0.55,...,0.21,0.45,0.75,0.266514,0.11,0.21,0.45,0.75,0.183175,0.297988
4,1rkt_1_B,200,24747,32,128,0.214597,0.11,0.17,0.325,0.55,...,0.21,0.425,0.75,0.273587,0.135,0.21,0.425,0.75,0.233324,0.362816
5,1rkt_1_B,200,24747,32,256,0.25599,0.125,0.21,0.375,0.65,...,0.23,0.475,0.7,0.285829,0.14,0.23,0.475,0.7,0.289185,0.430172
6,1rkt_1_B,200,24747,32,512,0.247069,0.12,0.2,0.375,0.65,...,0.25,0.5,0.75,0.303419,0.145,0.25,0.5,0.75,0.345778,0.488394
