In [1]:
import csv
import numpy as np
import os
import sys

sys.path.append(os.path.abspath('../benchmarking/'))
from utils import *

In [2]:
data_path = "/cluster/tufts/pettilab/shared/structure_comparison_data"

## which proteins can be used as queries (according to hierarchical condition)

In [3]:
def load_csv_to_list(file_path):
    with open(file_path, mode='r', newline='') as csvfile:
        return [item for row in csv.reader(csvfile) for item in row]

def list_to_csv(string_list, csv_file_path):
    # Open the CSV file in write mode
    with open(csv_file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        
        # Write each string as a separate row
        for string in string_list:
            writer.writerow([string])
            
def split_names_to_csv(names, output_prefix, chunk_size=10):
    for i in range(0, len(names), chunk_size):
        chunk = names[i:i + chunk_size]
        with open(f'{output_prefix}/query_list_{i // chunk_size}.csv', 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerows([[name] for name in chunk])  # Write each name in a new row

In [4]:
# make val queries for search by taking the intersection of queryProts.txt and validation.csv
data_path = "/cluster/tufts/pettilab/shared/structure_comparison_data"
all_val = load_csv_to_list(f"{data_path}/protein_data/validation.csv")
all_test = load_csv_to_list(f"{data_path}/protein_data/test.csv")

#these satisfy hierarchical condition
all_query = load_csv_to_list(f"{data_path}/protein_data/queryProts.txt") 
val_query = list(set(all_val).intersection(all_query))
test_query = list(set(all_test).intersection(all_query))

print("num val queries: ", len(val_query))
print("num test queries: ",len(test_query))

list_to_csv(val_query, f"{data_path}/protein_data/validation_queries.csv")
list_to_csv(test_query, f"{data_path}/protein_data/test_queries.csv")# forgot to remove 'd1o7d.2' so later manually removed from query_list_26

num val queries:  93
num test queries:  1442


## make two reference sets: all and no test

In [5]:
all_prot_names = load_csv_to_list(f"{data_path}/protein_data/validation.csv") + load_csv_to_list(f"{data_path}/protein_data/test.csv") + load_csv_to_list(f"{data_path}/protein_data/train.csv")
not_test_prot_names = load_csv_to_list(f"{data_path}/protein_data/validation.csv") + load_csv_to_list(f"{data_path}/protein_data/train.csv")
bad_list = ['d1e25a_','d1o7d.2','d1o7d.3']
for b in bad_list:
    all_prot_names.remove(b)
    try:
        not_test_prot_names.remove(b)
    except:
        print(f"didn't find {b} in not_test")
print("total proteins: ", len(all_prot_names))
print("total no test proteins: ", len(not_test_prot_names))

list_to_csv(all_prot_names, f"{data_path}/protein_data/ref_names.csv")
list_to_csv(not_test_prot_names, f"{data_path}/protein_data/ref_names_no_test.csv")

didn't find d1o7d.2 in not_test
total proteins:  11208
total no test proteins:  6504


## validation pairs will be any pairs that do not include a validation query

In [6]:
pairs_path = f"{data_path}/protein_data/pairs_validation.csv"
coord_path = f"{data_path}/protein_data/allCACoord.npz"

In [7]:
coord_d = np.load(coord_path)
n2l_d = make_name_to_length_d(coord_d)
bad_list = ['d1e25a_','d1o7d.2','d1o7d.3']

val_alignments = {}
first = True
with open(pairs_path, mode='r') as file:
    csv_reader = csv.reader(file)
    for row in csv_reader:
        pair = (row[1],row[2])
        if first:
            first = False
            continue    
        elif pair[0] in bad_list or pair[1] in bad_list:
            continue
        elif n2l_d[pair[0]]>512 or n2l_d[pair[1]]>512:
            continue
        elif pair[0] in val_query or pair[1] in val_query:
            continue
        else:
            val_alignments[f"{pair[0]},{pair[1]}"] = [int(i) for i in row[-1].strip('[]').split()]
           
print(len(val_alignments))
np.savez(f"{data_path}/protein_data/given_validation_alignments.npz", **val_alignments)

1518


## compute LDDT for the given alignments for validation pairs

In [8]:
def aln_list_to_tensor(l1, l2, aln_list):
    tensor = jnp.zeros(shape= (l1,l2))
    pairs = [(index, value) for index, value in enumerate(aln_list) if value != -1]
    indices = jnp.array([i for i, _ in pairs])
    values = jnp.array([j for _, j in pairs])
    tensor = tensor.at[indices, values].set(1)
    return tensor

In [9]:
%%time
aln_tensors = []
lddts = {}
for pair, aln in val_alignments.items():
    p1,p2 = pair.split(',')
    aln_tensor = aln_list_to_tensor(n2l_d[p1],n2l_d[p2], aln)
    lddts[pair] = lddt2(coord_d[p1], coord_d[p2], aln_tensor, jnp.sum((aln_tensor>0.95).astype(int)), coord_d[p1].shape[0]).item()

E1106 20:59:35.808946 2568860 hlo_lexer.cc:443] Failed to parse int literal: 26363099083674555315


CPU times: user 7min 6s, sys: 29.5 s, total: 7min 35s
Wall time: 10min 7s


In [10]:
len(lddts.keys())

1518

In [11]:
with open(f"{data_path}/protein_data/pairs_validation_lddts.csv", mode='w', newline='') as file:
    writer = csv.writer(file)
    for pair, value in lddts.items():
        a,b = pair.split(',')
        print(a,b,value)
        writer.writerow([a, b, value])

d1b0nb_ d1g2ya_ 0.8013308048248291
d1g2ya_ d1b0nb_ 0.7751798033714294
d1hw1a2 d2hs5a2 0.5899147391319275
d2hs5a2 d1hw1a2 0.6547322273254395
d1lf6a1 d1nc5a_ 0.3808688223361969
d1lf6a1 d2ahfa_ 0.3974071741104126
d1lf6a1 d3p2ca_ 0.5302639603614807
d1lf6a1 d2jg0a_ 0.47540396451950073
d1lf6a1 d1dl2a_ 0.37020981311798096
d1lf6a1 d2ri9a_ 0.36775654554367065
d1lf6a1 d1x9da1 0.3592124283313751
d1lf6a1 d1nxca_ 0.36536672711372375
d1nc5a_ d1lf6a1 0.4290983974933624
d1nc5a_ d2ahfa_ 0.6005796790122986
d1nc5a_ d2jg0a_ 0.40370798110961914
d1nc5a_ d1dl2a_ 0.4703938066959381
d1nc5a_ d2ri9a_ 0.45808184146881104
d1nc5a_ d1x9da1 0.4781067371368408
d1nc5a_ d1nxca_ 0.4696991443634033
d1nc5a_ d2g0da_ 0.35819679498672485
d2ahfa_ d1lf6a1 0.41260358691215515
d2ahfa_ d1nc5a_ 0.554099440574646
d2ahfa_ d3p2ca_ 0.38981685042381287
d2ahfa_ d2jg0a_ 0.41408637166023254
d2ahfa_ d1dl2a_ 0.4597906768321991
d2ahfa_ d2ri9a_ 0.43303534388542175
d2ahfa_ d1x9da1 0.44406890869140625
d2ahfa_ d1nxca_ 0.43582645058631897
d2ahfa_ 