In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
import argparse
import sys
import os
import csv
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import scipy.stats as ss
from utils import *

In [2]:
jax.devices()

[CudaDevice(id=0)]

# Load data to get given lddt list in right order

In [3]:
data_path = "/cluster/tufts/pettilab/shared/structure_comparison_data"

In [4]:
val_aln_path = f"{data_path}/protein_data/given_validation_alignments.npz"
coord_path = f"{data_path}/protein_data/allCACoord.npz"
given_lddt_path = f"{data_path}/protein_data/pairs_validation_lddts.csv"


nh_path = f"{data_path}/blurry_vec/nHot.npz"
tmat_path = f"{data_path}/blurry_vec/transition_mtx.npy"
jbl_path = f"{data_path}/blurry_vec/jaccard_blosum_int.npy"

## check that databases have same keys and same length sequences

In [5]:
#hot_d; protien_name: L by alphabet size or num bins + 1; encode sequence in alphabet or nhots
#coord_d; protein_name: L by 3; encodes 3D coordinates

In [6]:
coord_d = np.load(coord_path)
oh_d = np.load(nh_path)
bad_list = check_keys_and_lengths(oh_d, coord_d)

in hot_d but not coord_d:
set()
in coord_d but not hot_d:
{'d1o7d.2', 'd1o7d.3'}


In [7]:
n2l_d = make_name_to_length_d(coord_d)

In [8]:
bad_list.append('d1e25a_')
print(bad_list)

['d1o7d.2', 'd1o7d.3', 'd1e25a_']


In [9]:
# load in validation pairs, their alignments and lddt of them (precomputed in organize_val_and_train)
val_aln_d = dict(np.load(val_aln_path))
val_aln_d_new = {}
for key, val in val_aln_d.items():
    val_aln_d_new[tuple(key.split(','))] = val

val_aln_d = val_aln_d_new
val_aln_d_new = {}
given_lddt_d = {}

# Open the CSV file for reading
with open(given_lddt_path, mode='r') as file:
    reader = csv.reader(file)
    for row in reader:
        a, b, value = row[0], row[1], float(row[2])  # Convert value to float
        given_lddt_d[(a, b)] = value
        
print(len(given_lddt_d.keys()))
check_keys(given_lddt_d, val_aln_d)

1518
all keys match


[]

In [10]:
for key in val_aln_d.keys():
    if key[0] in bad_list or key[1] in bad_list:
        raise ValueError(f"pair {key} is bad and should not be used")

## Sort for better batching

In [11]:
# sort pairs by length of longer protein
pair_list = sorted(list(val_aln_d.keys())) 
pair_max_length_pairs = [(pair, max(n2l_d[pair[0]], n2l_d[pair[1]])) for pair in pair_list]
sorted_keys = sorted(pair_max_length_pairs, key=lambda x: (x[1],x[0][0],x[0][1]))
sorted_pairs = [key for key, shape in sorted_keys]
pairs = sorted_pairs
given_lddt_list = [given_lddt_d[pair] for pair in pairs]

## Plot results

In [12]:
def plot(lddt_d, mode = "mean"):# Example dictionary (replace this with your actual data)
    data_dict = lddt_d

    # Get all unique x (o) and y (e) values
    x_values = sorted(set(o for o, e in data_dict.keys()))
    y_values = sorted(set(e for o, e in data_dict.keys()))

    # Create an empty 2D array to store the median values
    grid = np.zeros((len(y_values), len(x_values)))

    # Fill the grid with the median of the lists
   
    for (o, e), values in data_dict.items():
        if mode == "mean":
            val = np.mean(values)
        elif mode == "median":
            val = np.median(values)
        elif mode == "spearman_lddt":
            val =  ss.spearmanr(values, given_lddt_list).correlation
        elif mode == "spearman_tm":
            val =  ss.spearmanr(values, tm_list).correlation
        elif mode == "geo_mean":
            val=(ss.spearmanr(values, given_lddt_list).correlation*np.mean(values))**(1/2.0)
        x_idx = x_values.index(o)
        y_idx = y_values.index(e)
        grid[y_idx, x_idx] = val
    # Create the heatmap
    plt.figure(figsize=(5, 3))
    sns.heatmap(grid, xticklabels=x_values, yticklabels=y_values, cmap='Blues', annot=True,annot_kws={"size": 8})

    # Add labels and title
    plt.xlabel('o values (x-axis)')
    plt.ylabel('e values (y-axis)')
    plt.title(f'{mode} heatmap')

    # Show the plot
    plt.show()

In [13]:
def get_max_key_by_spearman(your_dict):
    return max(your_dict, key=lambda key: ss.spearmanr(your_dict[key], given_lddt_list).correlation)
def get_max_key_by_mean(your_dict):
    return max(your_dict, key=lambda key: np.mean(your_dict[key]))
def get_max_key_by_median(your_dict):
    return max(your_dict, key=lambda key: np.median(your_dict[key]))

In [14]:
data_path = "/cluster/tufts/pettilab/shared/structure_comparison_data"

lddt_d_d ={}
for alphabet in ["3Di_3Dn_aa","3Di_3Dn","aa_3Dn"]:
#for alphabet in ["3Di",  "graph_clusters",  "MI"]:
    lddt_d_d[alphabet] = pickle.load(open(f"{data_path}/alphabets/{alphabet}_lddt_grid.pkl", "rb"))

In [15]:
for alphabet in lddt_d_d.keys():
    lddt_d=lddt_d_d[alphabet]
    options = list(lddt_d.keys())
    m_key=get_max_key_by_mean(lddt_d)
    s_key = get_max_key_by_spearman(lddt_d)
    s = ss.spearmanr(lddt_d[s_key], given_lddt_list).correlation
    m = np.mean(lddt_d[m_key])
    #ns = len(get_pairs_to_search(m_key,s_key, options))
    print(f"{alphabet:<15} {s:.4f} {m:.4f} {s_key[-1]} {m_key[-1]}")    

3Di_3Dn_aa      0.9545 0.5642 (0.3, 0.0, 0.1) (0.2, 0.1, 0.1)
3Di_3Dn         0.9508 0.5601 0.6 0.6
aa_3Dn          0.9080 0.5224 0.6 0.6


In [16]:
for alphabet in lddt_d_d.keys():
    lddt_d=lddt_d_d[alphabet]
    options = list(lddt_d.keys())
    m_key=get_max_key_by_mean(lddt_d)
    s_key = get_max_key_by_spearman(lddt_d)
    s = ss.spearmanr(lddt_d[s_key], given_lddt_list).correlation
    m = np.mean(lddt_d[m_key])
    #ns = len(get_pairs_to_search(m_key,s_key, options))
    print(f"{alphabet:<15} {s:.4f} {m:.4f} {s_key[-1]} {m_key[-1]}")  
    for key,ls in lddt_d.items():
        sp=ss.spearmanr(ls, given_lddt_list).correlation
        mp=np.mean(ls)
        if sp>.995*s or mp>.995*m:
            print(f"{alphabet:<15} {sp:.4f} {mp:.4f} {key[-1]} ") 

3Di_3Dn_aa      0.9545 0.5642 (0.3, 0.0, 0.1) (0.2, 0.1, 0.1)
3Di_3Dn_aa      0.9515 0.5469 (0.1, 0.0, 0.3) 
3Di_3Dn_aa      0.9432 0.5626 (0.1, 0.1, 0.2) 
3Di_3Dn_aa      0.9501 0.5630 (0.2, 0.0, 0.2) 
3Di_3Dn_aa      0.9539 0.5607 (0.2, 0.0, 0.2) 
3Di_3Dn_aa      0.9522 0.5446 (0.2, 0.0, 0.2) 
3Di_3Dn_aa      0.9464 0.5633 (0.2, 0.1, 0.1) 
3Di_3Dn_aa      0.9493 0.5642 (0.2, 0.1, 0.1) 
3Di_3Dn_aa      0.9518 0.5604 (0.2, 0.1, 0.1) 
3Di_3Dn_aa      0.9495 0.5625 (0.3, 0.0, 0.1) 
3Di_3Dn_aa      0.9519 0.5624 (0.3, 0.0, 0.1) 
3Di_3Dn_aa      0.9545 0.5590 (0.3, 0.0, 0.1) 
3Di_3Dn_aa      0.9451 0.5615 (0.3, 0.1, 0.0) 
3Di_3Dn_aa      0.9481 0.5634 (0.3, 0.1, 0.0) 
3Di_3Dn_aa      0.9507 0.5626 (0.3, 0.1, 0.0) 
3Di_3Dn_aa      0.9514 0.5572 (0.3, 0.1, 0.0) 
3Di_3Dn_aa      0.9501 0.5614 (0.4, 0.0, 0.0) 
3Di_3Dn_aa      0.9526 0.5595 (0.4, 0.0, 0.0) 
3Di_3Dn_aa      0.9525 0.5523 (0.4, 0.0, 0.0) 
3Di_3Dn         0.9508 0.5601 0.6 0.6
3Di_3Dn         0.9436 0.5592 0.6 
3Di_3Dn         0.9

## Decide what grid to use

In [17]:
def get_pairs_to_search(p1, p2, options):
    a,b = p1
    c,d = p2
    # Get the min and max for ranges
    x_min, x_max = min(a, c), max(a, c)
    y_min, y_max = min(b, d), max(b, d)

    # Filter pairs that satisfy the constraints
    return [(x, y) for x,y in options if x_min <= x <= x_max and y_min <= y <= y_max]

### 3Di-3Dn

In [18]:
# will grid search 3Di_3Dn at .6 only

In [19]:
oew_params = {}
alphabet="3Di_3Dn"    
lddt_d=lddt_d_d[alphabet]
options = list(set([_[:-1] for _ in lddt_d.keys()]))
m_key=get_max_key_by_mean(lddt_d)
s_key = get_max_key_by_spearman(lddt_d)
s = ss.spearmanr(lddt_d[s_key], given_lddt_list).correlation
print(s_key, m_key)
m = np.mean(lddt_d[m_key])
params = [(float(k[0]),float(k[1]), 0.6) for k in get_pairs_to_search(m_key[:-1],s_key[:-1], options)]
ns = len(params)
oew_params[alphabet] = params
print(f"{alphabet:<15} {s:.4f} {m:.4f} {ns}")    


(np.int64(-6), np.float64(-0.5), 0.6) (np.int64(-10), np.float64(-0.5), 0.6)
3Di_3Dn         0.9508 0.5601 3


### aa-3Dn 

In [20]:
# will grid search at .4 and .6

In [21]:
alphabet="aa_3Dn"    
lddt_d=lddt_d_d[alphabet]
options = list(set([_[:-1] for _ in lddt_d.keys()]))
m_key=get_max_key_by_mean(lddt_d)
s_key = get_max_key_by_spearman(lddt_d)
s = ss.spearmanr(lddt_d[s_key], given_lddt_list).correlation
print(s_key, m_key)
m = np.mean(lddt_d[m_key])
params = [(float(k[0]),float(k[1]), 0.6) for k in get_pairs_to_search(m_key[:-1],s_key[:-1], options)]
ns = len(params)
oew_params[alphabet] = params
print(f"{alphabet:<15} {s:.4f} {m:.4f} {ns}")    

(np.int64(-4), np.float64(-0.5), 0.6) (np.int64(-6), np.float64(-0.5), 0.6)
aa_3Dn          0.9080 0.5224 2


In [22]:
lddt_d=lddt_d_d[alphabet]
options = list(lddt_d.keys())
m_key=get_max_key_by_mean(lddt_d)
s_key = get_max_key_by_spearman(lddt_d)
s = ss.spearmanr(lddt_d[s_key], given_lddt_list).correlation
m = np.mean(lddt_d[m_key])
#ns = len(get_pairs_to_search(m_key,s_key, options))
print(f"{alphabet:<15} {s:.4f} {m:.4f} {s_key[-1]} {m_key[-1]}")  
for key,ls in lddt_d.items():
    sp=ss.spearmanr(ls, given_lddt_list).correlation
    mp=np.mean(ls)
    if sp>.995*s or mp>.995*m:
        print(f"{alphabet:<15} {sp:.4f} {mp:.4f} {key} ")
        if key[-1]==0.4:
            oew_params[alphabet].append((float(key[0]),float(key[1]), 0.4))


aa_3Dn          0.9080 0.5224 0.6 0.6
aa_3Dn          0.8974 0.5204 (np.int64(-8), np.float64(-0.5), 0.4) 
aa_3Dn          0.9003 0.5224 (np.int64(-6), np.float64(-0.5), 0.6) 
aa_3Dn          0.9080 0.5182 (np.int64(-4), np.float64(-0.5), 0.6) 
aa_3Dn          0.9056 0.4968 (np.int64(-2), np.float64(-1.0), 0.6) 
aa_3Dn          0.9059 0.5182 (np.int64(-4), np.float64(-0.5), 0.7) 
aa_3Dn          0.9052 0.4914 (np.int64(-2), np.float64(-0.5), 0.7) 


In [23]:
oew_params

{'3Di_3Dn': [(-10.0, -0.5, 0.6), (-6.0, -0.5, 0.6), (-8.0, -0.5, 0.6)],
 'aa_3Dn': [(-6.0, -0.5, 0.6), (-4.0, -0.5, 0.6), (-8.0, -0.5, 0.4)]}

### 3Di 3Dn aa

In [24]:
# will benchmark anything within top .997 of top spearman and top lddt

In [25]:
alphabet = "3Di_3Dn_aa"
oew_params[alphabet] = []
lddt_d=lddt_d_d[alphabet]
options = list(lddt_d.keys())
m_key=get_max_key_by_mean(lddt_d)
s_key = get_max_key_by_spearman(lddt_d)
s = ss.spearmanr(lddt_d[s_key], given_lddt_list).correlation
m = np.mean(lddt_d[m_key])
#ns = len(get_pairs_to_search(m_key,s_key, options))
print(f"{alphabet:<15} {s:.4f} {m:.4f} {s_key[-1]} {m_key[-1]}")  
for key,ls in lddt_d.items():
    sp=ss.spearmanr(ls, given_lddt_list).correlation
    mp=np.mean(ls)
    if sp>.997*s or mp>.997*m:
        print(f"{alphabet:<15} {sp:.4f} {mp:.4f} {key[0]} {key[1]} {key[-1]} ") 
        oew_params[alphabet].append((float(key[0]),float(key[1]), key[-1]))

3Di_3Dn_aa      0.9545 0.5642 (0.3, 0.0, 0.1) (0.2, 0.1, 0.1)
3Di_3Dn_aa      0.9432 0.5626 -6 -0.5 (0.1, 0.1, 0.2) 
3Di_3Dn_aa      0.9501 0.5630 -6 -0.5 (0.2, 0.0, 0.2) 
3Di_3Dn_aa      0.9539 0.5607 -4 -0.5 (0.2, 0.0, 0.2) 
3Di_3Dn_aa      0.9522 0.5446 -2 -0.5 (0.2, 0.0, 0.2) 
3Di_3Dn_aa      0.9464 0.5633 -8 -0.5 (0.2, 0.1, 0.1) 
3Di_3Dn_aa      0.9493 0.5642 -6 -0.5 (0.2, 0.1, 0.1) 
3Di_3Dn_aa      0.9518 0.5604 -4 -0.5 (0.2, 0.1, 0.1) 
3Di_3Dn_aa      0.9519 0.5624 -6 -0.5 (0.3, 0.0, 0.1) 
3Di_3Dn_aa      0.9545 0.5590 -4 -0.5 (0.3, 0.0, 0.1) 
3Di_3Dn_aa      0.9481 0.5634 -8 -0.5 (0.3, 0.1, 0.0) 
3Di_3Dn_aa      0.9507 0.5626 -6 -0.5 (0.3, 0.1, 0.0) 
3Di_3Dn_aa      0.9526 0.5595 -6 -0.5 (0.4, 0.0, 0.0) 
3Di_3Dn_aa      0.9525 0.5523 -4 -0.5 (0.4, 0.0, 0.0) 


In [26]:
oew_params

{'3Di_3Dn': [(-10.0, -0.5, 0.6), (-6.0, -0.5, 0.6), (-8.0, -0.5, 0.6)],
 'aa_3Dn': [(-6.0, -0.5, 0.6), (-4.0, -0.5, 0.6), (-8.0, -0.5, 0.4)],
 '3Di_3Dn_aa': [(-6.0, -0.5, (0.1, 0.1, 0.2)),
  (-6.0, -0.5, (0.2, 0.0, 0.2)),
  (-4.0, -0.5, (0.2, 0.0, 0.2)),
  (-2.0, -0.5, (0.2, 0.0, 0.2)),
  (-8.0, -0.5, (0.2, 0.1, 0.1)),
  (-6.0, -0.5, (0.2, 0.1, 0.1)),
  (-4.0, -0.5, (0.2, 0.1, 0.1)),
  (-6.0, -0.5, (0.3, 0.0, 0.1)),
  (-4.0, -0.5, (0.3, 0.0, 0.1)),
  (-8.0, -0.5, (0.3, 0.1, 0.0)),
  (-6.0, -0.5, (0.3, 0.1, 0.0)),
  (-6.0, -0.5, (0.4, 0.0, 0.0)),
  (-4.0, -0.5, (0.4, 0.0, 0.0))]}

## Write config file for each gap/open and extend that we plan to test

In [31]:
# for two alphabets
name = {}
name["3Di"] = "3Di"
name["aa"] = "aa"
name["3Dn"] = "graph_clusters"

ref_path =f"protein_data/ref_names_no_test.csv"
query_list_dir_path = f"protein_data/validation_queries_by_10"
coord_path = "protein_data/allCACoord.npz"
for alphabet in oew_params.keys():
    if alphabet == "3Di_3Dn_aa": continue
    a1, a2 = alphabet.split('_')
    a1 = name[a1]
    a2 = name[a2]
    p1 = pickle.load(open(f"{data_path}/alphabets/{a1}_karlin_params.pkl", "rb"))
    p2 = pickle.load(open(f"{data_path}/alphabets/{a2}_karlin_params.pkl", "rb"))
    for val in oew_params[alphabet]:
        go,ge,w1 = val
        print(go,ge,w1)
        w2=1-w1
        path_to_config = f"val_search_combos/{alphabet}/{alphabet}_{go}_{ge}_{w1}_config"
        with open(path_to_config, 'w') as file:
            file.write(f"data_path: {data_path}" + '\n')
            file.write(f"coord_d: {coord_path}" + '\n')
            file.write(f"oh_d1: alphabets/{a1}.npz" +'\n')
            file.write(f"blosum1: alphabets/{a1}_blosum.npy" +'\n')
            file.write(f"oh_d2: alphabets/{a2}.npz" +'\n')
            file.write(f"blosum2: alphabets/{a2}_blosum.npy" +'\n')
            file.write(f"gap_open: {go}" + '\n')
            file.write(f"gap_extend: {ge}" + '\n')
            file.write(f"w1: {w1}" + '\n')
            file.write(f"w2: {w2}" + '\n')
            file.write(f"use_two: True" + '\n')
            file.write(f"lam: {p1["lam"]}" + '\n')
            file.write(f"k: {p1["k"]}" + '\n')
            file.write(f"lam2: {p2["lam"]}" + '\n')
            file.write(f"k2: {p2["k"]}" + '\n')
            file.write(f"refs: {ref_path}"+ '\n')
            file.write(f"query_list_dir: {query_list_dir_path}"+ '\n') 

-10.0 -0.5 0.6
-6.0 -0.5 0.6
-8.0 -0.5 0.6
-6.0 -0.5 0.6
-4.0 -0.5 0.6
-8.0 -0.5 0.4


In [32]:
# for three alphabets
ref_path =f"protein_data/ref_names_no_test.csv"
query_list_dir_path = f"protein_data/validation_queries_by_10"
coord_path = "protein_data/allCACoord.npz"
alphabet = "3Di_3Dn_aa"
a1, a2, a3 = alphabet.split('_')
a1 = name[a1]
a2 = name[a2]
a3 = name[a3]
p1 = pickle.load(open(f"{data_path}/alphabets/{a1}_karlin_params.pkl", "rb"))
p2 = pickle.load(open(f"{data_path}/alphabets/{a2}_karlin_params.pkl", "rb"))
p3 = pickle.load(open(f"{data_path}/alphabets/{a3}_karlin_params.pkl", "rb"))

for val in oew_params[alphabet]:
    go,ge,ws = val
    w1 = .2+ws[0]
    w2 = .2+ws[1]
    w3 = .2+ws[2]
    path_to_config = f"val_search_combos/{alphabet}/{alphabet}_{go}_{ge}_{w1:.1f}_{w2:.1f}_{w3:.1f}_config"
    with open(path_to_config, 'w') as file:
        file.write(f"data_path: {data_path}" + '\n')
        file.write(f"coord_d: {coord_path}" + '\n')
        file.write(f"oh_d1: alphabets/{a1}.npz" +'\n')
        file.write(f"blosum1: alphabets/{a1}_blosum.npy" +'\n')
        file.write(f"oh_d2: alphabets/{a2}.npz" +'\n')
        file.write(f"blosum2: alphabets/{a2}_blosum.npy" +'\n')
        file.write(f"oh_d3: alphabets/{a3}.npz" +'\n')
        file.write(f"blosum3: alphabets/{a3}_blosum.npy" +'\n')
        file.write(f"gap_open: {go}" + '\n')
        file.write(f"gap_extend: {ge}" + '\n')
        file.write(f"w1: {w1}" + '\n')
        file.write(f"w2: {w2}" + '\n')
        file.write(f"w3: {w3}" + '\n')
        file.write(f"use_three: True" + '\n')
        file.write(f"lam: {p1["lam"]}" + '\n')
        file.write(f"k: {p1["k"]}" + '\n')
        file.write(f"lam2: {p2["lam"]}" + '\n')
        file.write(f"k2: {p2["k"]}" + '\n')
        file.write(f"lam3: {p3["lam"]}" + '\n')
        file.write(f"k3: {p3["k"]}" + '\n')
        file.write(f"refs: {ref_path}"+ '\n')
        file.write(f"query_list_dir: {query_list_dir_path}"+ '\n') 