In [13]:
# Load processed sorted EC clusters and raw distances

import pickle as pkl
import numpy as np


with open('/home/seyonec/protein-conformal/clean_selection/sorted_dict.pkl', 'rb') as f:
    sorted_dict = pkl.load(f)

with open('/home/seyonec/protein-conformal/clean_selection/dists.pkl', 'rb') as f:
    dists = pkl.load(f)

with open('/home/seyonec/protein-conformal/clean_selection/true_labels.pkl', 'rb') as f:
    labels = pkl.load(f)

len(sorted_dict), len(dists), len(labels)

(392, 392, 392)

In [14]:
txt = '3.5.4.3'
len(txt.split('.'))

4

In [25]:
from protein_conformal.utils import scope_hierarchical_loss
def get_clean_dict(sorted_ec_dist, dists, labels):
    """

    sorted_ec_dist: a dictionary where each key is a query protein with EC value key,
                    and each value is a dictionary of EC cluster center values and their 
                    euclidean distances to the query protein.

    dists: a 2D numpy array of distances between each test protein embedding and each EC cluster center embedding

    labels: a list of true EC labels, where each element in the list is a list of valid EC labels for the test protein.
            We want to make sure that we check against all the valid EC labels for each test protein, and take the minimum
            loss against the EC cluster centers for each valid EC label.
    """

    num_train_clusters = len(dists[0])
    near_ids = []
    min_sim = np.min(dists)
    max_sim = np.max(dists)

    for i, key in enumerate(sorted_dict):
        #test_id = test_df.loc[true_test_idcs[i], id]
        test_ec = labels[i] # could be list of EC numbers if multiple valid.
        #print(test_ec)
        ec_cluster_centers = [k for k in sorted_ec_dist[key].keys()]
        exact_loss = [[scope_hierarchical_loss(poss_ec, ec_cluster_centers[j]) for poss_ec in test_ec] for j in range(num_train_clusters)]
        # grab the 2nd element in the tuple belonging to each element of exact_loss as mask_exact
        #mask_exact = [x[1] for x in exact_loss]
        loss = [min([y[0] for y in x]) for x in exact_loss]
        mask_exact = [x == 4 for x in loss] ## if loss is 4, then it is an exact match
        
        # define mask_partial as 1 for any element of loss that is <=1 (tolerate retrieving homolog with diff family but same superfamily)
        mask_partial = [l <= 1 for l in loss]
        
        # create a row of size len(lookup_df) where each element is the sum of all entries in S_i until that index
        sum = np.cumsum(dists[i])
        norm_sim = (dists[i] - min_sim) / (max_sim - min_sim) # convert similarities into a probability space (0, 1) based on (min_sim, max_sim)
        #mask_exact = [test_sccs == lookup_df.loc[lookup_idcs[j], 'sccs'] for j in I[i]]

        sum_norm_s_i = np.cumsum(norm_sim)
        near_ids.append({
            'EC_id': key,
            'test_ec': test_ec,
            'EC_centroids': ec_cluster_centers,
            #'meta_query': meta_query,
            'loss' : loss,
            'exact': mask_exact,
            'partial': mask_partial,
            'S_i': dists[i],
            'Sum_i' : sum,
            'Norm_S_i' : norm_sim,
            'Sum_Norm_S_i': sum_norm_s_i,
        })
    return near_ids

In [26]:
clean_dict = get_clean_dict(sorted_dict, dists, labels)

In [27]:
clean_dict[60]['S_i']

array([ 4.35125303,  4.76616192,  5.1906805 , ..., 22.18528366,
       22.23024178, 22.46767235])

In [28]:
np.save('clean_new_v_ec_cluster.npy', clean_dict)
