# Full pipeline evaluation

This notebook is used to evaluate the entire pipeline. It compares the predictions of the pipeline with the ground truth author and affiliation data, along with predictions made using other strategies.

Ground truth authors and affiliations were cataloged by hand using SHROOM, and are downloaded as Cocina from SDR by the `preprints:download` task (see README.md). This needs to be run prior to running this notebook.

Article plain texts are extracted from the PDFs using the `preprints:clean` task (see README.md). This also needs to be run prior to running this notebook.

In [None]:
# set up project root path for imports
import sys
import os
import pathlib
root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
PROJECT_ROOT = pathlib.Path(root)

# make scripts in scripts/ importable and import util functions
sys.path.insert(1, str(PROJECT_ROOT / 'scripts'))
from notebook_utils import get_preprint_text, get_gold_affiliations, load_predictions

# Load the models
import spacy
ner = spacy.load("en_core_web_trf")
ner.disable_pipes("parser")
textcat = spacy.load(PROJECT_ROOT / 'training' / 'textcat' / 'model-best')


  model.load_state_dict(torch.load(filelike, map_location=device))


In [None]:
# set up data table with columns for gold and predicted affiliations
import pandas as pd
preprints = pd.read_csv(PROJECT_ROOT / 'assets' / 'preprints.csv')
preprints['gold'] = ''

# add the full text and gold affiliations to the data table
for i, row in preprints.iterrows():
    preprint_id = row['OpenAlex ID']
    preprint_text = get_preprint_text(preprint_id)
    preprint_file = PROJECT_ROOT / "assets" / "preprints" / "pdf" / f"{preprint_id}.pdf"
    preprints.at[i, 'gold'] = get_gold_affiliations(preprint_id)
    preprints.at[i, 'text'] = preprint_text
    

# keep only the columns we need
preprints = preprints[['OpenAlex ID', 'DRUID', 'text', 'gold']]

# limit to only rows where we have gold affiliations
preprints = preprints[preprints['gold'] != '']

Preprint text not found for W3091005730
Preprint text not found for W3185060415


In [6]:
from utils import get_affiliation_dict, analyze_pdf_text
from tqdm.notebook import tqdm

# set this and run cell to force re-running predictions
FORCE_RERUN = True

# add a column for predictions
preprints['pred'] = ''

# if we don't have any saved predictions, run prediction for every preprint
predictions = load_predictions()
if not predictions or FORCE_RERUN:
    print("No predictions found, running prediction for all preprints")
    for i, row in tqdm(preprints.iterrows(), total=len(preprints), desc="Predicting"):
        preprint_id = row['OpenAlex ID']
        preprint_file = PROJECT_ROOT / "assets" / "preprints" / "txt" / f"{preprint_id}.txt"
        pdf_text = preprint_file.read_text(encoding='utf-8')
        try:
            result = analyze_pdf_text(pdf_text, textcat, ner)
            affiliations = get_affiliation_dict(result)
        except ValueError as e:
            print(f"Error analyzing {preprint_id}: {e}")
            affiliations = {}
        with (results_path / f"{preprint_id}.json").open(mode="w") as f:
            json.dump(affiliations, f)
    predictions = load_predictions()
else:
    print("Using saved predictions")

# set predictions for each preprint in the data table
for i, row in preprints.iterrows():
    preprint_id = row['OpenAlex ID']
    if preprint_id in predictions:
        preprints.at[i, 'pred'] = predictions[preprint_id]

No predictions found, running prediction for all preprints


Predicting:   0%|          | 0/98 [00:00<?, ?it/s]

Error analyzing W3178821884: No affiliations found in document.
Error analyzing W3116436840: No affiliations found in document.
Error analyzing W4226047880: No affiliations found in document.


In [9]:
# calculate some accuracy statistics for authors
for i, row in preprints.iterrows():
    gold = row.gold
    pred = row.pred
    correct = 0
    total = 0
    for author in gold:
        total += 1
        if author in pred:
            correct += 1
    preprints.at[i, 'authors_accuracy'] = correct / total if total > 0 else (1 if correct == 0 else 0)

author_acc_mean = preprints['authors_accuracy'].mean()
author_acc_1 = preprints[preprints['authors_accuracy'] == 1]
author_acc_0 = preprints[preprints['authors_accuracy'] == 0]
author_acc_mid = preprints[(preprints['authors_accuracy'] > 0) & (preprints['authors_accuracy'] < 1)]

# get some author statistics
print("AUTHORS")
print("  avg accuracy\t\t", f"{author_acc_mean.round(2) * 100}%")
print("  count of 100%\t\t", len(author_acc_1))
print("  count of 1-99%\t", len(author_acc_mid))
print("  count of 0%\t\t", len(author_acc_0))

AUTHORS
  avg accuracy		 49.0%
  count of 100%		 23
  count of 1-99%	 38
  count of 0%		 37


In [10]:
# calculate some accuracy statistics for affiliations
for i, row in preprints.iterrows():
    gold = row.gold
    pred = row.pred
    correct = 0
    total = 0
    for author in gold:
        for affiliation in gold[author]:
            total += 1
            if author in pred and affiliation in pred[author]:
                correct += 1
    preprints.at[i, 'affiliations_accuracy'] = correct / total if total > 0 else (1 if correct == 0 else 0)

affil_acc_mean = preprints['affiliations_accuracy'].mean()
affil_acc_1 = preprints[preprints['affiliations_accuracy'] == 1]
affil_acc_0 = preprints[preprints['affiliations_accuracy'] == 0]
affil_acc_mid = preprints[(preprints['affiliations_accuracy'] > 0) & (preprints['affiliations_accuracy'] < 1)]

# get some affiliation statistics
print("\nAFFILIATIONS")
print("  avg accuracy\t\t", f"{affil_acc_mean.round(2) * 100}%")
print("  count of 100%\t\t", len(affil_acc_1))
print("  count of 1-99%\t", len(affil_acc_mid))
print("  count of 0%\t\t", len(affil_acc_0))



AFFILIATIONS
  avg accuracy		 5.0%
  count of 100%		 0
  count of 1-99%	 13
  count of 0%		 85
