In [7]:
import json

# Default threshold function
def get_f1(thresh_constant, term_freqs, solution):
    # Predict
    predictions = {}
    for doc in term_freqs.keys():
        mean_freq = sum(term_freqs[doc].values()) / len(term_freqs[doc])
        predictions[doc] = [key for key, val in term_freqs[doc].items() if val > (mean_freq + thresh_constant)]
        
    # Get evaluation metrics
    true_pos = 0
    false_pos = 0
    false_neg = 0
    
    for pmid in predictions:
        true_pos += len([pred for pred in predictions[pmid] if pred in solution[pmid]])
        false_pos += len([pred for pred in predictions[pmid] if pred not in solution[pmid]])
        false_neg += len([sol for sol in solution[pmid] if sol not in predictions[pmid]])

    if true_pos == 0:
        precision = 0
        recall = 0
        f1 = 0
    else:
        precision = true_pos / (true_pos + false_pos)
        recall = true_pos / (true_pos + false_neg)
        f1 = (2 * precision * recall) / (precision + recall)
    
    return f1

def train(term_freqs, solution):
    curr_thresh_constant = 0.0
    step_val = 0.001
    f1s = []
    
    f1s.append(get_f1(curr_thresh_constant, term_freqs, solution))
    f1s.append(get_f1(curr_thresh_constant + step_val, term_freqs, solution))
    
    curr_thresh_constant += step_val
    next_thresh_f1 = get_f1(curr_thresh_constant + step_val, term_freqs, solution)
    
    while not (next_thresh_f1 < f1s[-1] and next_thresh_f1 < f1s[-2] and f1s[-1] < f1s[-2]):
        curr_thresh_constant += step_val
        f1s.append(get_f1(curr_thresh_constant, term_freqs, solution))
        next_thresh_f1 = get_f1(curr_thresh_constant + step_val, term_freqs, solution)
    
    return curr_thresh_constant - step_val

def predict(test_freqs, thresh_constant):
    # Test it out
    predictions = {}

    # Predict
    for doc in test_freqs.keys():
        mean_freq = sum(test_freqs[doc].values()) / len(test_freqs[doc])
        predictions[doc] = [key for key, val in test_freqs[doc].items() if val > (mean_freq + thresh_constant)]
    
    return predictions

In [2]:
# Load in term frequencies and partition
with open("./data/term_freqs_rev_3_all_terms.json", "r") as handle:
    temp = json.load(handle)

docs_list = list(temp.keys())
partition = int(len(docs_list) * .8)

train_docs = docs_list[0:partition]
test_docs = docs_list[partition:]

# Load in solution values
solution = {}
docs_list = set(docs_list)
with open("./data/pm_doc_term_counts.csv", "r") as handle:
    for line in handle:
        line = line.strip("\n").split(",")
        if line[0] in docs_list:
            # Only use samples indexed with MeSH terms
            terms = [term for term in line[1:] if term]
            if terms:
                solution[line[0]] = terms
                
# Build training/test data, ensure good solution data is available
# Solution data is not always available because documents may not be
# indexed - even though obviously some of their references have been indexed
train_freqs = {}
for doc in train_docs:
    if doc in solution.keys():
        train_freqs[doc] = temp[doc]

test_freqs = {}
for doc in test_docs:
    if doc in solution.keys():
        test_freqs[doc] = temp[doc]  

In [3]:
# Load in MeSH data
term_names = {}
mean_term_depths = {}
with open("./data/mesh_data.tab", "r") as handle:
    for line in handle:
        line = line.strip("\n").split("\t")
        term_names[line[0]] = line[1]
        mean_depth = 0
        posits = [len(posit.split(".")) for posit in line[4].split(",")]
        mean_term_depths[line[0]] = sum(posits) / len(posits)
            
uids = list(term_names.keys())

In [9]:
thresh_constant = train(train_freqs, solution)
print(f"Learned discrimination threshold constant: {thresh_constant}\n")

preds = predict(test_freqs, thresh_constant)

true_pos = 0
false_pos = 0
false_neg = 0

for pmid in preds:
    true_pos += len([pred for pred in preds[pmid] if pred in solution[pmid]])
    false_pos += len([pred for pred in preds[pmid] if pred not in solution[pmid]])
    false_neg += len([sol for sol in solution[pmid] if sol not in preds[pmid]])

if true_pos == 0:
    mi_precision = 0
    mi_recall = 0
    mi_f1 = 0
else:
    mi_precision = true_pos / (true_pos + false_pos)
    mi_recall = true_pos / (true_pos + false_neg)
    mi_f1 = (2 * mi_precision * mi_recall) / (mi_precision + mi_recall)

print(f"Micro-averaged F1 from test set: {mi_f1}")
print(f"Micro-averaged precision from test set: {mi_precision}")
print(f"Micro-averaged recall from test set: {mi_recall}\n")

ma_ps = []
ma_rs = []
ma_f1s = []

for uid in uids:
    true_pos = 0
    false_pos = 0
    false_neg = 0
    
    for pmid in preds:
        if uid in preds[pmid] and uid in solution[pmid]:
            true_pos += 1
        if uid in preds[pmid] and uid not in solution[pmid]:
            false_pos += 1
        if uid in solution[pmid] and uid not in preds[pmid]:
            false_neg += 1
    
    if true_pos == 0:
        ma_precision = 0
        ma_recall = 0
        ma_f1 = 0
    else:
        ma_precision = true_pos / (true_pos + false_pos)
        ma_recall = true_pos / (true_pos + false_neg)
        ma_f1 = (2 * ma_precision * ma_recall) / (ma_precision + ma_recall)

    if true_pos + false_pos + false_neg > 0:
        ma_ps.append(ma_precision)
        ma_rs.append(ma_recall)
        ma_f1s.append(ma_f1)

ma_f1 = sum(ma_f1s) / len(ma_f1s)
ma_recall = sum(ma_rs) / len(ma_rs)
ma_precision = sum(ma_ps) / len(ma_ps)

print(f"Macro-averaged F1 from test set: {ma_f1}")
print(f"Macro-averaged precision from test set: {ma_precision}")
print(f"Macro-averaged recall from test set: {ma_recall}\n")

eb_ps = []
eb_rs = []
eb_f1s = []

for pmid in preds:
    true_pos = len([pred for pred in preds[pmid] if pred in solution[pmid]])
    false_pos = len([pred for pred in preds[pmid] if pred not in solution[pmid]])
    false_neg = len([sol for sol in solution[pmid] if sol not in preds[pmid]])

    if true_pos == 0:
        eb_precision = 0
        eb_recall = 0
        eb_f1 = 0
    else:
        eb_precision = true_pos / (true_pos + false_pos)
        eb_recall = true_pos / (true_pos + false_neg)
        eb_f1 = (2 * eb_precision * eb_recall) / (eb_precision + eb_recall)

    eb_ps.append(eb_precision)
    eb_rs.append(eb_recall)
    eb_f1s.append(eb_f1)

eb_f1 = sum(eb_f1s) / len(eb_f1s)
eb_recall = sum(eb_rs) / len(eb_rs)
eb_precision = sum(eb_ps) / len(eb_ps)

print(f"Example-based F1 from test set: {eb_f1}")
print(f"Example-based precision from test set: {eb_precision}")
print(f"Example-based recall from test set: {eb_recall}\n")

Learned discrimination threshold constant: 0.009000000000000001

Micro-averaged F1 from test set: 0.4798893772454778
Micro-averaged precision from test set: 0.5024845355903074
Micro-averaged recall from test set: 0.4592388424940755

Macro-averaged F1 from test set: 0.28809422130121976
Macro-averaged precision from test set: 0.38261781501437614
Macro-averaged recall from test set: 0.26782157999696654

Example-based F1 from test set: 0.47100622014450827
Example-based precision from test set: 0.5087594123136442
Example-based recall from test set: 0.48054287510120836



In [15]:
def predict(test_freqs, thresh_constant, solution):
    predictions = {}
    f1s = {}
    f1s_all = {}
    
    # Predict
    for doc in test_freqs.keys():
        mean_freq = sum(test_freqs[doc].values()) / len(test_freqs[doc])
        predictions[doc] = [key for key, val in test_freqs[doc].items() if val > (mean_freq + thresh_constant)]

    for pmid in predictions:
        true_pos = len([pred for pred in predictions[pmid] if pred in solution[pmid]])
        false_pos = len([pred for pred in predictions[pmid] if pred not in solution[pmid]])
        false_neg = len([sol for sol in solution[pmid] if sol not in predictions[pmid]])

        if true_pos == 0:
            precision = 0
            recall = 0
            f1 = 0
        else:
            precision = true_pos / (true_pos + false_pos)
            recall = true_pos / (true_pos + false_neg)
            f1 = (2 * precision * recall) / (precision + recall)
        if len(predictions[pmid]) > 20:
            f1s[pmid] = f1
            
        f1s_all[pmid] = f1
        
    return f1s, f1s_all, predictions
        
f1s, f1s_all, predictions = predict(test_freqs, thresh_constant, solution)

In [16]:
max_f1 = max(f1s.values())
min_f1 = min(f1s.values())

sorted_f1s = sorted([[key, val] for key, val in f1s.items()], key=lambda item: item[1], reverse=True)
max_pmids = sorted_f1s[0:10]
med_pmids = [samp for samp in sorted_f1s if samp[1] > 0.44 and samp[1] < 0.46][0:15]
sorted_f1s_all = sorted([[key, val] for key, val in f1s_all.items()], key=lambda item: item[1], reverse=True)
min_pmids = sorted_f1s_all[-10:]

print(f"Maximum F1: {max_f1}")
print(f"Minimum F1: {min_f1}")

Maximum F1: 0.8
Minimum F1: 0.07692307692307693


In [18]:
max_preds = "; ".join(sorted([term_names[pred] for pred in predictions[max_pmids[0][0]]]))
max_sol = "; ".join(sorted(list(dict.fromkeys([term_names[sol] for sol in solution[max_pmids[0][0]]]))))
print(f"F1 score for {max_pmids[0][0]}: {f1s[max_pmids[0][0]]}")
print(f"Predicted terms ({len(predictions[max_pmids[0][0]])}) for PMID {max_pmids[0][0]}: \n{max_preds}")
print(f"\nActual terms ({len(solution[max_pmids[0][0]])}) applied to PMID {max_pmids[0][0]}: \n{max_sol}")
print(f"\nNumber of MeSH terms applied to all of {max_pmids[0][0]}'s references: {len(test_freqs[max_pmids[0][0]].keys())}\n")

max_preds = "; ".join(sorted([term_names[pred] for pred in predictions[max_pmids[1][0]]]))
max_sol = "; ".join(sorted([term_names[sol] for sol in solution[max_pmids[1][0]]]))
print(f"F1 score for {max_pmids[1][0]}: {f1s[max_pmids[1][0]]}")
print(f"Predicted terms ({len(predictions[max_pmids[1][0]])}) for PMID {max_pmids[1][0]}: \n{max_preds}")
print(f"\nActual terms ({len(solution[max_pmids[1][0]])}) applied to PMID {max_pmids[1][0]}: \n{max_sol}")
print(f"\nNumber of MeSH terms applied to all of {max_pmids[1][0]}'s references: {len(test_freqs[max_pmids[1][0]].keys())}")

F1 score for 25686058: 0.8
Predicted terms (22) for PMID 25686058: 
Adult; Aged; Aged, 80 and over; Angiogenesis Inhibitors; Antibodies, Monoclonal, Humanized; Bevacizumab; Female; Fluorescein Angiography; Glucocorticoids; Humans; Injections; Macular Edema; Male; Middle Aged; Retinal Vein Occlusion; Retrospective Studies; Tomography, Optical Coherence; Treatment Outcome; Triamcinolone Acetonide; Vascular Endothelial Growth Factor A; Visual Acuity; Vitreous Body

Actual terms (23) applied to PMID 25686058: 
Adult; Aged; Aged, 80 and over; Angiogenesis Inhibitors; Bevacizumab; Dose-Response Relationship, Drug; Female; Fluorescein Angiography; Follow-Up Studies; Fundus Oculi; Glucocorticoids; Humans; Intravitreal Injections; Macular Edema; Male; Middle Aged; Retinal Vein Occlusion; Retrospective Studies; Time Factors; Tomography, Optical Coherence; Treatment Outcome; Triamcinolone Acetonide; Visual Acuity

Number of MeSH terms applied to all of 25686058's references: 84

F1 score for 2136

In [22]:
examining = min_pmids[3][0]
min_preds = "; ".join(sorted([term_names[pred] for pred in predictions[examining]]))
min_sol = "; ".join(sorted(list(dict.fromkeys([term_names[sol] for sol in solution[examining]]))))
print(f"F1 score for {examining}: {f1s_all[examining]}")
print(f"Predicted terms ({len(predictions[examining])}) for PMID {examining}: \n{min_preds}")
print(f"\nActual terms ({len(solution[examining])}) applied to PMID {examining}: \n{min_sol}")
print(f"\nNumber of MeSH terms applied to all of {examining}'s references: {len(test_freqs[examining].keys())}\n")

min_preds = "; ".join(sorted([term_names[pred] for pred in predictions[min_pmids[1][0]]]))
min_sol = "; ".join(sorted(list(dict.fromkeys([term_names[sol] for sol in solution[min_pmids[1][0]]]))))
print(f"F1 score for {min_pmids[1][0]}: {f1s_all[min_pmids[1][0]]}")
print(f"Predicted terms ({len(predictions[min_pmids[1][0]])}) for PMID {min_pmids[1][0]}: \n{min_preds}")
print(f"\nActual terms ({len(solution[min_pmids[1][0]])}) applied to PMID {min_pmids[1][0]}: \n{min_sol}")
print(f"\nNumber of MeSH terms applied to all of {min_pmids[1][0]}'s references: {len(test_freqs[min_pmids[1][0]].keys())}\n")

F1 score for 26160390: 0
Predicted terms (4) for PMID 26160390: 
Fungi; Molecular Structure; Plants; Species Specificity

Actual terms (8) applied to PMID 26160390: 
Anti-Infective Agents; Bacteria; Endophytes; Fusarium; Microbial Sensitivity Tests; Opuntia; Pyrrolidinones; Tetrahydronaphthalenes

Number of MeSH terms applied to all of 26160390's references: 65

F1 score for 25609924: 0
Predicted terms (9) for PMID 25609924: 
Animals; Antioxidants; Cholesterol; Flavonoids; Humans; Lipoproteins, LDL; Male; Phenols; Rats

Actual terms (11) applied to PMID 25609924: 
Chromatography, High Pressure Liquid; Chromatography, Reverse-Phase; Gas Chromatography-Mass Spectrometry; Hydroxymethylglutaryl-CoA Reductase Inhibitors; Hypercholesterolemia; Magnoliopsida; Phytotherapy; Plant Extracts; Plant Leaves; Plants, Medicinal; Tandem Mass Spectrometry

Number of MeSH terms applied to all of 25609924's references: 298



In [23]:
curr_pmid = med_pmids[4][0]
med_preds = "; ".join(sorted([term_names[pred] for pred in predictions[curr_pmid]]))
med_sol = "; ".join(sorted(list(dict.fromkeys([term_names[sol] for sol in solution[curr_pmid]]))))
print(f"F1 score for {curr_pmid}: {f1s[curr_pmid]}")
print(f"Predicted terms ({len(predictions[curr_pmid])}) for PMID {curr_pmid}: \n{med_preds}")
print(f"\nActual terms ({len(solution[curr_pmid])}) applied to PMID {curr_pmid}: \n{med_sol}")
print(f"\nNumber of MeSH terms applied to all of {curr_pmid}'s references: {len(test_freqs[curr_pmid].keys())}\n")

med_preds = "; ".join(sorted([term_names[pred] for pred in predictions[med_pmids[1][0]]]))
med_sol = "; ".join(sorted([term_names[sol] for sol in solution[med_pmids[1][0]]]))
print(f"F1 score for {med_pmids[1][0]}: {f1s[med_pmids[1][0]]}")
print(f"Predicted terms ({len(predictions[med_pmids[1][0]])}) for PMID {med_pmids[1][0]}: \n{med_preds}")
print(f"\nActual terms ({len(solution[med_pmids[1][0]])}) applied to PMID {med_pmids[1][0]}: \n{med_sol}")
print(f"\nNumber of MeSH terms applied to all of {med_pmids[1][0]}'s references: {len(test_freqs[med_pmids[1][0]].keys())}")

F1 score for 29351962: 0.45161290322580644
Predicted terms (21) for PMID 29351962: 
Amino Acid Sequence; Bacterial Proteins; DEAD-box RNA Helicases; Endoribonucleases; Escherichia coli; Escherichia coli Proteins; Exoribonucleases; Models, Molecular; Molecular Sequence Data; Multienzyme Complexes; Nucleic Acid Conformation; Polyribonucleotide Nucleotidyltransferase; Protein Binding; Protein Structure, Tertiary; RNA; RNA Helicases; RNA Stability; RNA, Bacterial; RNA, Messenger; Saccharomyces cerevisiae; Saccharomyces cerevisiae Proteins

Actual terms (10) applied to PMID 29351962: 
Bacteria; DEAD-box RNA Helicases; Endoribonucleases; Exosomes; Mitochondria; Multienzyme Complexes; Polyribonucleotide Nucleotidyltransferase; RNA; RNA Helicases; RNA Stability

Number of MeSH terms applied to all of 29351962's references: 271

F1 score for 20170555: 0.4571428571428571
Predicted terms (21) for PMID 20170555: 
Adolescent; Adult; Aged; Antigens, Bacterial; Antitubercular Agents; BCG Vaccine; Chi