In [1]:
from CLEAN.infer import infer_pvalue

test_data = "new"
train_data = "split100"
infer_pvalue(train_data, test_data, p_value=1e-5, nk_random=20, report_metrics=True, pretrained=True)

  from .autonotebook import tqdm as notebook_tqdm


The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([392, 128])


100%|██████████| 5242/5242 [00:00<00:00, 27790.19it/s]


Calculating eval distance map, between 392 test ids and 5242 train EC cluster centers


392it [00:00, 1002.13it/s]
100%|██████████| 5242/5242 [00:00<00:00, 38467.94it/s]
20000it [00:15, 1292.98it/s]
100%|██████████| 392/392 [00:09<00:00, 43.54it/s]


############ EC calling results using random chosen 20k samples ############
---------------------------------------------------------------------------
>>> total samples: 392 | total ec: 177 
>>> precision: 0.558 | recall: 0.477| F1: 0.482 | AUC: 0.737 
---------------------------------------------------------------------------


In [1]:
## Similar code to the selection methods, except we just want to extract the raw euclidean distance maps for any pair of train and test data
import torch
from CLEAN.utils import * 
from CLEAN.model import LayerNormNet
from CLEAN.distance_map import *
from CLEAN.evaluate import *
import pandas as pd
import warnings

def get_eval_dist_map(train_data, test_data, pretrained=True, model_name=None):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    dtype = torch.float32
    id_ec_train, ec_id_dict_train = get_ec_id_dict('./data/' + train_data + '.csv')
    id_ec_test, _ = get_ec_id_dict('./data/' + test_data + '.csv')
    # load checkpoints
    # NOTE: change this to LayerNormNet(512, 256, device, dtype) 
    # and rebuild with [python build.py install]
    # if inferencing on model trained with supconH loss
    model = LayerNormNet(512, 128, device, dtype)
    
    if pretrained:
        try:
            checkpoint = torch.load('./data/pretrained/'+ train_data +'.pth', map_location=device)
        except FileNotFoundError as error:
            raise Exception('No pretrained weights for this training data')
    else:
        try:
            checkpoint = torch.load('./data/model/'+ model_name +'.pth', map_location=device)
        except FileNotFoundError as error:
            raise Exception('No model found!')
        
    model.load_state_dict(checkpoint)
    model.eval()
    # load precomputed EC cluster center embeddings if possible
    if train_data == "split70":
        emb_train = torch.load('./data/pretrained/70.pt', map_location=device)
    elif train_data == "split100":
        emb_train = torch.load('./data/pretrained/100.pt', map_location=device)
    else:
        emb_train = model(esm_embedding(ec_id_dict_train, device, dtype))
        
    emb_test = model_embedding_test(id_ec_test, model, device, dtype)
    eval_dist = get_dist_map_test(emb_train, emb_test, ec_id_dict_train, id_ec_test, device, dtype)
    seed_everything()

    return eval_dist

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
eval_dist = get_eval_dist_map(train_data, test_data, pretrained=True)

true_label, all_label = get_true_labels('./data/' + test_data)


The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([392, 128])


100%|██████████| 5242/5242 [00:00<00:00, 29275.38it/s]


Calculating eval distance map, between 392 test ids and 5242 train EC cluster centers


392it [00:00, 1219.90it/s]


In [None]:
# get keys of dict as list
test_ids = list(eval_dist.keys())

# each key in the dictionary is a dictionary, and we want to sort the sub-dictionary based on the values of the keys ascending
sorted_dict = {key: dict(sorted(eval_dist[key].items(), key=lambda item: item[1])) for key in eval_dist}


In [None]:
## Create two numpy arrays, one for the keys and one for the values, of size (len(test_ids), # of keys in the sub-dictionary)
## Go through each key in sorted_dict,
## For each key, go through each key in the sub-dictionary and get the key value pair
## Append the key to the keys array and the value to the values array
import numpy as np

## init 2d np arrays with 0's

## np array of EC_ids (strings)
#EC_ids = np.zeros((len(test_ids), len(sorted_dict[test_ids[0]]))
dists = np.zeros((len(test_ids), len(sorted_dict[test_ids[0]])))
for i, key in enumerate(sorted_dict):
    j = 0
    for k, v in sorted_dict[key].items():
        dists[i][j] = v
        j += 1


In [None]:
import numpy as np
# find first index where element dists[0] is not ascedning
def find_non_ascending_row(arr):
    for i in range(arr.shape[0]):
        if not np.all(np.diff(arr[i]) >= 0):
            return i
    return -1  # Return -1 if all rows are ascending

# Example usage
non_ascending_row_index = find_non_ascending_row(dists)
print(non_ascending_row_index)


In [None]:
# save sorted_dict
import pickle
with open('/home/seyonec/protein-conformal/clean_selection/sorted_dict.pkl', 'wb') as f:
    pickle.dump(sorted_dict, f)

# save dists
with open('/home/seyonec/protein-conformal/clean_selection/dists.pkl', 'wb') as f:
    pickle.dump(dists, f)

# dump true labels
with open('/home/seyonec/protein-conformal/clean_selection/true_labels.pkl', 'wb') as f:
    pickle.dump(true_label, f)