In [50]:
import torch
from helper.model import *
from helper.utils import *
from helper.distance_map import *
from helper.evaluate import *
import pandas as pd
import warnings
def warn(*args, **kwargs):
    pass
warnings.warn = warn

In [51]:
args_train_data = "uniref100_full"
args_test_data = "price_149"
args_model_name = "split100_ensemble/split100_5"  
 
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
dtype = torch.float32
# load id ec from tran and test
id_ec_train, ec_id_dict_train = get_ec_id_dict(
    './data/' + args_train_data + '.csv')
id_ec_test, _ = get_ec_id_dict(
    './data/' + args_test_data + '.csv')

# load model
if False:
    # no model used for pretrained embedding
    model = lambda *args: args[0]
else:
    model = LayerNormNet(512, 128, device, dtype)
    checkpoint = torch.load('./model/' + args_model_name + '.pth')
    model.load_state_dict(checkpoint)
    model.eval()

# compute distance map
emb_train = model(esm_embedding(ec_id_dict_train, device, dtype))
emb_test = model_embedding_test(id_ec_test, model, device, dtype)

100%|██████████| 5242/5242 [03:33<00:00, 24.58it/s] 


In [52]:
eval_dist = get_dist_map_test(
        emb_train, emb_test, ec_id_dict_train, id_ec_test, 
        device, dtype, dot=False)
eval_df = pd.DataFrame.from_dict(eval_dist)

The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([149, 128])


100%|██████████| 5242/5242 [00:11<00:00, 460.05it/s]


Calculating eval distance map, between 149 test ids and 5242 train EC cluster centers


149it [00:00, 1052.10it/s]


In [53]:
def get_pred_probs(out_filename, pred_type="_maxsep"):
    file_name = out_filename+pred_type
    result = open(file_name+'.csv', 'r')
    csvreader = csv.reader(result, delimiter=',')
    pred_probs = []
    for row in csvreader:
        preds_ec_lst = []
        preds_with_dist = row[1:]
        probs = torch.zeros(len(preds_with_dist))
        count = 0
        
        for pred_ec_dist in preds_with_dist:
            # get EC number 3.5.2.6 from EC:3.5.2.6/10.8359
            ec_i = - float(pred_ec_dist.split(":")[1].split("/")[1])
            probs[count] = ec_i
            #preds_ec_lst.append(probs)
            count += 1
        # sigmoid of the negative distances 
        # probs = (1 - torch.exp(-1/probs)) / (1 + torch.exp(-1/probs))
        # probs = probs/torch.sum(probs)
        probs = torch.nn.functional.softmax(probs)
        pred_probs.append(probs)
        
    return pred_probs


def get_pred_dist(out_filename, pred_type="_maxsep"):
    file_name = out_filename+pred_type
    result = open(file_name+'.csv', 'r')
    csvreader = csv.reader(result, delimiter=',')
    pred_probs = []
    for row in csvreader:
        preds_ec_lst = []
        preds_with_dist = row[1:]
        probs = torch.zeros(len(preds_with_dist))
        count = 0
        
        for pred_ec_dist in preds_with_dist:
            # get EC number 3.5.2.6 from EC:3.5.2.6/10.8359
            ec_i = float(pred_ec_dist.split(":")[1].split("/")[1])
            probs[count] = ec_i
            #preds_ec_lst.append(probs)
            count += 1
        # sigmoid of the negative distances 
       
        
        pred_probs.append(probs)
    return pred_probs

In [54]:
out_filename = './eval/' + args_test_data
write_max_sep_choices(eval_df, out_filename, first_grad=True, use_max_grad=False)
pred_label = get_pred_labels(out_filename, pred_type='_maxsep')
pred_probs= get_pred_probs(out_filename, pred_type='_maxsep')

In [55]:
def get_ec_pos_dict(mlb, true_label, pred_label):
    ec_list = []
    pos_list = []
    for i in range(len(true_label)):
        ec_list += list(mlb.inverse_transform(mlb.transform([true_label[i]]))[0])
        pos_list += list(np.nonzero(mlb.transform([true_label[i]]))[1])
    for i in range(len(pred_label)):
        ec_list += list(mlb.inverse_transform(mlb.transform([pred_label[i]]))[0])
        pos_list += list(np.nonzero(mlb.transform([pred_label[i]]))[1])
    label_pos_dict = {}
    for i in range(len(ec_list)):
        ec, pos = ec_list[i], pos_list[i]
        label_pos_dict[ec] = pos
        
    return label_pos_dict

In [56]:
eval_df = pd.DataFrame.from_dict(eval_dist)
# write the top 10 closest EC to _top10.csv
out_filename = './eval/' + args_test_data
# _ = write_top10_choices(eval_df, out_filename)
# maximum separation results
write_max_sep_choices(eval_df, out_filename,
                        first_grad=True,
                        use_max_grad=False)
# get preds and true labels
pred_label = get_pred_labels(out_filename, pred_type='_maxsep')
true_label, all_label = get_true_labels('./data/'+args_test_data)
pre, rec, f1, roc, acc = get_eval_metrics(pred_label, true_label, all_label)
print("############ EC calling results using maximum separation ############")
print('-' * 75)
print(f'>>> total samples: {len(true_label)} | total ec {len(all_label)} |\n'
        f'precision | recall | F1 | AUC | accuracy' )
print( f'{pre:.5} , {rec:.5} , {f1:.5} , {roc:.7} , {acc:.5}')
print('-' * 75)

############ EC calling results using maximum separation ############
---------------------------------------------------------------------------
>>> total samples: 149 | total ec 56 |
precision | recall | F1 | AUC | accuracy
0.51864 , 0.44737 , 0.46279 , 0.7234345 , 0.45638
---------------------------------------------------------------------------


In [57]:
eval_df = pd.DataFrame.from_dict(eval_dist)
# write the top 10 closest EC to _top10.csv
out_filename = './eval/' + args_test_data
# _ = write_top10_choices(eval_df, out_filename)
# maximum separation results
write_max_sep_choices(eval_df, out_filename, first_grad=True, use_max_grad=False)
# get preds and true labels
pred_label = get_pred_labels(out_filename, pred_type='_maxsep')
pred_probs = get_pred_probs(out_filename, pred_type='_maxsep')
true_label, all_label = get_true_labels('./data/'+args_test_data)
pre, rec, f1, roc, acc = get_eval_metrics_new(
    pred_label, pred_probs, true_label, all_label)
print("############ Maximum separation w correct AUC ############")
print('-' * 75)
print(f'>>> total samples: {len(true_label)} | total ec {len(all_label)} |\n'
        f'precision | recall | F1 | AUC | accuracy' )
print( f'{pre:.5} , {rec:.5} , {f1:.5} , {roc:.7} , {acc:.5}')
print('-' * 75)

############ Maximum separation w correct AUC ############
---------------------------------------------------------------------------
>>> total samples: 149 | total ec 56 |
precision | recall | F1 | AUC | accuracy
0.51864 , 0.44737 , 0.46279 , 0.7234575 , 0.45638
---------------------------------------------------------------------------
