In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
import argparse
import sys
import os
import csv
import matplotlib.pyplot as plt
import seaborn as sns
from utils import *

## Load data

In [3]:
data = "/cluster/tufts/pettilab/shared/structure_comparison_data"

In [4]:
pairs_path = f"{data}/train_test_val/pairs_validation.csv"
coord_path = f"{data}/alphabets_blosum_coordinates/allCACoord.npz"
oh_path = f"{data}/alphabets_blosum_coordinates/nH_oh.npz"
blosum_path = f"{data}/alphabets_blosum_coordinates/nH_mat.npy"
#oh2_path = f"{data}/alphabets_blosum_coordinates/3Di.npz"
#blosum2_path = f"{data}/alphabets_blosum_coordinates/mat3di.npy"

In [5]:
coord_d = np.load(coord_path)
oh_d = np.load(oh_path)
blosum = np.load(blosum_path)[:-1,:-1].astype(float)
#oh_d2 = np.load(oh2_path)
#blosum2 = np.load(blosum2_path)[:-1,:-1].astype(float)

In [6]:
pairs = []
with open(pairs_path, mode='r') as file:
    csv_reader = csv.reader(file)
    for row in csv_reader:
        #print(row)
        pairs.append((row[1],row[2]))
pairs = pairs[1:]
print(len(pairs))

2900


## Sort for better batching

In [7]:
# how long is each protein
name_to_length_d ={}
for key in oh_d.keys():
    name_to_length_d[key]= oh_d[key].shape[0]

In [8]:
# check that lengths are the same
if False:
    name_to_length_d2 ={}
    for key in oh_d2.keys():
        name_to_length_d2[key]= oh_d2[key].shape[0]

    for key in name_to_length_d.keys():
        if key in bad_list:
            continue
        if name_to_length_d[key]!=name_to_length_d2[key]:
            print("ISSUE")
            print(key)
            print(name_to_length_d[key],name_to_length_d2[key])
            bad_list.append(key)

In [9]:
# sort pairs by length of longer protein
pair_max_length_pairs = [(pair, max(name_to_length_d[pair[0]], name_to_length_d[pair[1]])) for pair in pairs]
sorted_keys = sorted(pair_max_length_pairs, key=lambda x: x[1])
sorted_pairs = [key for key, shape in sorted_keys]
pairs = sorted_pairs

## Run alignment benchmark

In [10]:
def run_in_batches(long_list, batch_size, params):
    result = []
    for i in range(0, len(long_list), batch_size):
        batch = long_list[i:i + batch_size]  # Get the current batch
        result.extend(run_batch(batch, params))   # Process and extend results
        print(f"finished batch {i}")
    return result

def run_batch(pairs, params):

     # compute max length of any protein
    names=[item for tup in pairs for item in tup]
    max_len = max([name_to_length_d[name] for name in names])
    pad_to = int(jnp.where(max_len < 1, 1, 2 ** jnp.ceil(jnp.log2(max_len))))
    print(pad_to)


    # get oh and coords for left and right part of pairs, padded
    lefts, left_lengths = pad_and_stack_manual([oh_d[pair[0]] for pair in pairs],pad_to = pad_to)
    rights, right_lengths = pad_and_stack_manual([oh_d[pair[1]] for pair in pairs], pad_to = pad_to)
    
    if params["use_two"]:
        lefts2, _ = pad_and_stack_manual([oh_d2[pair[0]] for pair in pairs],pad_to = pad_to)
        rights2, _ = pad_and_stack_manual([oh_d2[pair[1]] for pair in pairs], pad_to = pad_to)

    left_coords, _ = pad_and_stack_manual([coord_d[pair[0]] for pair in pairs],pad_to = pad_to)
    right_coords, _ = pad_and_stack_manual([coord_d[pair[1]] for pair in pairs], pad_to = pad_to)

    # make similarity matrices
    sim_tensor = vv_sim_mtx(lefts, rights, blosum)
    if params["use_two"]:
        sim_tensor *= params["w1"]
        sim_tensor += params["w2"]*vv_sim_mtx(lefts2, rights2, blosum2)

    # align (gap, open, temp)
    length_pairs = jnp.column_stack([left_lengths, right_lengths])
    aln_tensor = v_aln_w_sw(sim_tensor, length_pairs, params["gap_extend"], params["gap_open"],params["temp"])

    # compute lddts 
    lddts = vv_lddt(left_coords, right_coords, aln_tensor, jnp.sum((aln_tensor>0.95).astype(int), axis = [-2,-1]), jnp.array(left_lengths))

    return lddts

In [11]:
params = {}
params["gap_extend"] = -1.0
params["gap_open"] = -10.0
params["temp"] = 0.01
params["use_two"]= False
params["w1"] =0.5
params["w2"] =0.5

In [None]:
%%time
# faster with batch size 200 if you have sufficient memory for it
lddts=run_in_batches(pairs, 50, params)



128
finished batch 0
256
finished batch 50
256
finished batch 100
256
finished batch 150
256
finished batch 200
256
finished batch 250
256
finished batch 300
256
finished batch 350
256
finished batch 400
256
finished batch 450
256
finished batch 500
256
finished batch 550
256
finished batch 600
256
finished batch 650
256
finished batch 700
256
finished batch 750
256
finished batch 800
256
finished batch 850
256
finished batch 900
256
finished batch 950
256
finished batch 1000
256
finished batch 1050
512
finished batch 1100
512
finished batch 1150
512
finished batch 1200
512
finished batch 1250
512
finished batch 1300
512
finished batch 1350
512
finished batch 1400
512
finished batch 1450
512
finished batch 1500
512
finished batch 1550
512
finished batch 1600
512
finished batch 1650
512
finished batch 1700
512
finished batch 1750
512
finished batch 1800
512
finished batch 1850
512
finished batch 1900
512
finished batch 1950
512
finished batch 2000
512
finished batch 2050
512
finished ba

## Mini-grid search 

In [None]:
def plot(lddt_d):# Example dictionary (replace this with your actual data)
    data_dict = lddt_d

    # Get all unique x (o) and y (e) values
    x_values = sorted(set(o for o, e in data_dict.keys()))
    y_values = sorted(set(e for o, e in data_dict.keys()))

    # Create an empty 2D array to store the median values
    grid = np.zeros((len(y_values), len(x_values)))

    # Fill the grid with the median of the lists
    for (o, e), values in data_dict.items():
        median_value = np.mean(values)
        x_idx = x_values.index(o)
        y_idx = y_values.index(e)
        grid[y_idx, x_idx] = median_value

    # Create the heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(grid, xticklabels=x_values, yticklabels=y_values, cmap='coolwarm', annot=True)

    # Add labels and title
    plt.xlabel('o values (x-axis)')
    plt.ylabel('e values (y-axis)')
    plt.title('Mean Value Heatmap')

    # Show the plot
    plt.show()

### Just n-hot

In [None]:
%%time
lddt_d = {}
open_choices = [-15.0,-10.0,-5.0, -1.0]
extend_choices =[-1.0,-2.0,-3.0]

params = {}
params["temp"] = 0.01
params["use_two"]= False
params["w1"] =0.0
params["w2"] =0.0

for o in open_choices:
    for e in extend_choices:
        params["gap_extend"] = e
        params["gap_open"] = o
        lddt_d[(o,e)] = run_in_batches(pairs, 100, params)


In [None]:
plot(lddt_d)

In [14]:
#to do with spearman correlation with the lddts from the dali alignments and with TM scores

### 3di and nhot with equal weights (in progress)

In [15]:
%%time
lddt_d = {}
open_choices = [-15.0,-10.0,-5.0, -1.0]
extend_choices =[-1.0,-2.0,-3.0]

params = {}
params["temp"] = 0.01
params["use_two"]= True
params["w1"] = 1.0
params["w2"] = 1.0

for o in open_choices:
    for e in extend_choices:
        params["gap_extend"] = e
        params["gap_open"] = o
        lddt_d[(o,e)] = run_in_batches(pairs, 200, params)

256


ValueError: index can't contain negative values

In [16]:
plot(lddt_d)

ValueError: zero-size array to reduction operation fmin which has no identity

<Figure size 576x432 with 0 Axes>