In [1]:
import json

# Default threshold function
def get_f1(thresh, term_freqs, solution):
    # Predict
    predictions = {}
    for doc in term_freqs.keys():
        predictions[doc] = [key for key, val in term_freqs[doc].items() if val > thresh]
        
    # Get evaluation metrics
    true_pos = 0
    false_pos = 0
    false_neg = 0
    
    for pmid in predictions:
        true_pos += len([pred for pred in predictions[pmid] if pred in solution[pmid]])
        false_pos += len([pred for pred in predictions[pmid] if pred not in solution[pmid]])
        false_neg += len([sol for sol in solution[pmid] if sol not in predictions[pmid]])

    if true_pos == 0:
        precision = 0
        recall = 0
        f1 = 0
    else:
        precision = true_pos / (true_pos + false_pos)
        recall = true_pos / (true_pos + false_neg)
        f1 = (2 * precision * recall) / (precision + recall)
    
    return f1

def train(term_freqs, solution):
    curr_thresh = 0.0
    step_val = 0.001
    f1s = []
    
    f1s.append(get_f1(curr_thresh, term_freqs, solution))
    f1s.append(get_f1(curr_thresh + step_val, term_freqs, solution))
    
    curr_thresh += step_val
    next_thresh_f1 = get_f1(curr_thresh + step_val, term_freqs, solution)
    
    while not (next_thresh_f1 < f1s[-1] and next_thresh_f1 < f1s[-2] and f1s[-1] < f1s[-2]):
        curr_thresh += step_val
        f1s.append(get_f1(curr_thresh, term_freqs, solution))
        next_thresh_f1 = get_f1(curr_thresh + step_val, term_freqs, solution)
    
    return curr_thresh - step_val

def predict(test_freqs, thresh):
    # Test it out
    predictions = {}

    # Predict
    for doc in test_freqs.keys():
        mean_freq = sum(test_freqs[doc].values()) / len(test_freqs[doc])
        if mean_freq < thresh:
            predictions[doc] = [key for key, val in test_freqs[doc].items() if val > thresh]
        else:
            predictions[doc] = [key for key, val in test_freqs[doc].items() if val > mean_freq]
    
    return predictions

In [None]:
# Load in term frequencies and partition
with open("./data/term_freqs_rev_3_all_terms.json", "r") as handle:
    temp = json.load(handle)

docs_list = list(temp.keys())
partition = int(len(docs_list) * .8)

train_docs = docs_list[0:partition]
test_docs = docs_list[partition:]

# Load in solution values
solution = {}
docs_list = set(docs_list)
with open("./data/pm_doc_term_counts.csv", "r") as handle:
    for line in handle:
        line = line.strip("\n").split(",")
        if line[0] in docs_list:
            # Only use samples indexed with MeSH terms
            terms = [term for term in line[1:] if term]
            if terms:
                solution[line[0]] = terms
                
# Build training/test data, ensure good solution data is available
# Solution data is not always available because documents may not be
# indexed - even though obviously some of their references have been indexed
train_freqs = {}
for doc in train_docs:
    if doc in solution.keys():
        train_freqs[doc] = temp[doc]

test_freqs = {}
for doc in test_docs:
    if doc in solution.keys():
        test_freqs[doc] = temp[doc]  

In [None]:
# Load in MeSH data
term_names = {}
mean_term_depths = {}
with open("./data/mesh_data.tab", "r") as handle:
    for line in handle:
        line = line.strip("\n").split("\t")
        term_names[line[0]] = line[1]
        mean_depth = 0
        posits = [len(posit.split(".")) for posit in line[4].split(",")]
        mean_term_depths[line[0]] = sum(posits) / len(posits)
            
uids = list(term_names.keys())

In [None]:
thresh = train(train_freqs, solution)
print(f"Learned discrimination threshold: {thresh}\n")

preds = predict(test_freqs, thresh)

true_pos = 0
false_pos = 0
false_neg = 0

for pmid in preds:
    true_pos += len([pred for pred in preds[pmid] if pred in solution[pmid]])
    false_pos += len([pred for pred in preds[pmid] if pred not in solution[pmid]])
    false_neg += len([sol for sol in solution[pmid] if sol not in preds[pmid]])

if true_pos == 0:
    mi_precision = 0
    mi_recall = 0
    mi_f1 = 0
else:
    mi_precision = true_pos / (true_pos + false_pos)
    mi_recall = true_pos / (true_pos + false_neg)
    mi_f1 = (2 * mi_precision * mi_recall) / (mi_precision + mi_recall)

print(f"Micro-averaged F1 from test set: {mi_f1}")
print(f"Micro-averaged precision from test set: {mi_precision}")
print(f"Micro-averaged recall from test set: {mi_recall}\n")

ma_ps = []
ma_rs = []
ma_f1s = []

for uid in uids:
    true_pos = 0
    false_pos = 0
    false_neg = 0
    
    for pmid in preds:
        if uid in preds[pmid] and uid in solution[pmid]:
            true_pos += 1
        if uid in preds[pmid] and uid not in solution[pmid]:
            false_pos += 1
        if uid in solution[pmid] and uid not in preds[pmid]:
            false_neg += 1
    
    if true_pos == 0:
        ma_precision = 0
        ma_recall = 0
        ma_f1 = 0
    else:
        ma_precision = true_pos / (true_pos + false_pos)
        ma_recall = true_pos / (true_pos + false_neg)
        ma_f1 = (2 * ma_precision * ma_recall) / (ma_precision + ma_recall)

    if true_pos + false_pos + false_neg > 0:
        ma_ps.append(ma_precision)
        ma_rs.append(ma_recall)
        ma_f1s.append(ma_f1)

ma_f1 = sum(ma_f1s) / len(ma_f1s)
ma_recall = sum(ma_rs) / len(ma_rs)
ma_precision = sum(ma_ps) / len(ma_ps)

print(f"Macro-averaged F1 from test set: {ma_f1}")
print(f"Macro-averaged precision from test set: {ma_precision}")
print(f"Macro-averaged recall from test set: {ma_recall}\n")

eb_ps = []
eb_rs = []
eb_f1s = []

for pmid in preds:
    true_pos = len([pred for pred in preds[pmid] if pred in solution[pmid]])
    false_pos = len([pred for pred in preds[pmid] if pred not in solution[pmid]])
    false_neg = len([sol for sol in solution[pmid] if sol not in preds[pmid]])

    if true_pos == 0:
        eb_precision = 0
        eb_recall = 0
        eb_f1 = 0
    else:
        eb_precision = true_pos / (true_pos + false_pos)
        eb_recall = true_pos / (true_pos + false_neg)
        eb_f1 = (2 * eb_precision * eb_recall) / (eb_precision + eb_recall)

    eb_ps.append(eb_precision)
    eb_rs.append(eb_recall)
    eb_f1s.append(eb_f1)

eb_f1 = sum(eb_f1s) / len(eb_f1s)
eb_recall = sum(eb_rs) / len(eb_rs)
eb_precision = sum(eb_ps) / len(eb_ps)

print(f"Example-based F1 from test set: {eb_f1}")
print(f"Example-based precision from test set: {eb_precision}")
print(f"Example-based recall from test set: {eb_recall}\n")