In [5]:
import json

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Input, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

from numba import jit

import numpy as np
import model

In [2]:
# Load in term frequencies and partition
with open("../data/term_freqs_rev_3_all_terms.json", "r") as handle:
    temp = json.load(handle)

docs_list = ["21364592", "19432821", "21749731"]
    
solution = {}
with open("../data/pm_doc_term_counts.csv", "r") as handle:
    for line in handle:
        line = line.strip("\n").split(",")
        if line[0] in docs_list:
            solution[line[0]] = line[1:]

In [4]:
# built dataset
uids = []
term_names = {}
with open("../data/mesh_data.tab", "r") as handle:
    for line in handle:
        line = line.strip("\n").split("\t")
        uids.append(line[0])
        term_names[line[0]] = line[1]

x = []
y = []
for doc in docs_list:
    row = []
    for uid in uids:
        if uid in temp[doc].keys():
            # truncate to save space
            row.append(float(str(temp[doc][uid])[:6]))
        else:
            row.append(0)
    row = np.array(row)
    x.append(row)

    row = []
    for uid in uids:
        if uid in solution[doc]:
            row.append(1)
        else:
            row.append(0)
    y.append(row)
    
x = np.array(x)
y = np.array(y)

In [6]:
weights_fp = "weights.current_best.hdf5"
mod = model.get_model(1600)
mod.load_weights(weights_fp)

In [7]:
y_hat = mod.predict(x)

In [14]:
threshold = 0.33
#predictions = {}
predictions = {doc:[] for doc in docs_list}
#print(y_hat.shape[0])
for row in range(y_hat.shape[0]):
    predictions[docs_list[row]] = [uids[idx] for idx, val in enumerate(y_hat[row]) if val > threshold]
    


In [18]:
for doc in predictions:
    preds = "; ".join([term_names[uid] for uid in predictions[doc]])
    print(f"{doc}: {preds}\n")

21364592: Adult; Aged; Antimetabolites, Antineoplastic; Antineoplastic Combined Chemotherapy Protocols; Aryl Hydrocarbon Hydroxylases; Carcinoma, Non-Small-Cell Lung; China; Drug Combinations; Female; Fluorouracil; Tegafur; Genotype; Humans; Male; Middle Aged; Oxonic Acid; Polymorphism, Genetic; Stomach Neoplasms; Treatment Outcome; Polymorphism, Single Nucleotide; Asian Continental Ancestry Group; Cytochrome P-450 CYP2A6

19432821: Adenocarcinoma; Aged; Animals; Cell Movement; Cytoskeletal Proteins; Female; Humans; Immunohistochemistry; Male; Microfilament Proteins; Middle Aged; Neoplasm Invasiveness; Neoplasm Transplantation; Pancreatic Neoplasms; Prognosis; Biomarkers, Tumor; Homeodomain Proteins; Carcinoma, Pancreatic Ductal; RNA, Small Interfering; Cell Line, Tumor; Mice

21749731: Amino Acid Sequence; Binding Sites; Enzyme Stability; Escherichia coli; Glucose; Kinetics; Ligands; Models, Molecular; Protein Binding; Protein Structure, Tertiary; Dimerization; Catalytic Domain; Bioca

In [None]:
@jit
def count_metrics(y, y_hat, threshold):
    true_pos = 0
    false_pos = 0
    false_neg = 0

    n_cols = len(y[0])
    n_rows = len(y)

    for row in range(n_rows):
        for col in range(n_cols):
            y_hat_val = y_hat[row][col]
            y_val = y[row][col]
            
            if y_hat_val > threshold and y_val == 1.0:
                true_pos += 1
            elif y_hat_val <= threshold and y_val == 1.0:
                false_neg += 1
            elif y_hat_val > threshold and y_val == 0.0:
                false_pos += 1
            
    return true_pos, false_pos, false_neg

def test(weights_fp, logger, threshold, mod):
    mod.load_weights(weights_fp)

    test_ids = []
    with open("test_ids", "r") as handle:
        for line in handle:
            test_ids.append(line.strip("\n"))

    true_pos = 0
    false_pos = 0
    false_neg = 0
    
    batch_size = 16
    test_gen = DataGen(test_ids, batch_size)

    for batch in tqdm(test_gen):
        x = batch[0]
        y = batch[1]

        y_hat = mod.predict_on_batch(x)
        
        tp_temp, fp_temp, fn_temp = count_metrics(y, y_hat, threshold)
        true_pos += tp_temp
        false_pos += fp_temp
        false_neg += fn_temp
           
    if true_pos > 0:
        precision = true_pos / (true_pos + false_pos)
        recall = true_pos / (true_pos + false_neg)
        f1 = (2 * precision * recall) / (precision + recall)
    else:
        f1 = 0