In [1]:
from pathlib import Path
from multiprocessing import Pool
from tqdm import tqdm
import pandas as pd
import torch
from scipy.special import softmax
from collections import defaultdict
import numpy as np
import pickle

In [2]:
book_paths = Path('/home/allekim/stonybook-data/guten_hathi_alignment/')

In [3]:
parsed_book_results_paths = list(book_paths.glob('*/*_hathi_batched_results.dt'))

In [4]:
len(parsed_book_results_paths)

2975

In [5]:
def map_subtok_to_tok_idx(offsets):
    idx_mapping = []
    tok_num = -1
    for idx, pair in enumerate(offsets):
        s, e = pair
        if e == 0:
            continue
        if s == 1:
            tok_num += 1
        idx_mapping.append(tok_num)
    return idx_mapping

In [6]:
def score_book(result_path):
    output_path = result_path.parent / 'ocr_batched_detection_results.pkl'
    if output_path.exists():
        return
    hathi_id = result_path.stem.split('_')[0]
    input_path = result_path.parent / '{}_hathi_input.dt'.format(hathi_id)
    aligned_df = pd.read_csv(result_path.parent / 'aligned_toks.csv', converters={'sent_idx2':eval})
    inputs = torch.load(input_path)
    results = torch.load(result_path)
    bad_sents = set()
    for x, y in aligned_df['sent_idx2']:
        if y - x > 1:
            continue
        bad_sents.add(x)
        bad_sents.add(y)
    preda = softmax(results, axis=-1)

    num_errors = 0
    thresholds = [0.8,0.9,0.95,0.99]
    bad_sent_caught = defaultdict(set)
    for idx in range(len(inputs)):
        mask = np.where(np.array(inputs[idx]['attention_mask']) == 1)
        input_ids = np.array(inputs[idx]['input_ids'])[mask][1:-1]
        offsets = np.array(inputs[idx]['offset_mapping'])[mask]
        subtok_to_tok = np.array(map_subtok_to_tok_idx(offsets))
        preds = np.array(preda[idx])[mask][1:-1]
        for thres in thresholds:
            pos_idx = np.where(preds[:,1] > thres)[0]
            if len(pos_idx) > 0:
                bad_sent_caught[thres].add(idx)
    #     pred = np.zeros(len(preds))
    #     pred[pos_idx] = 1
    #     if 1 in pred:
    #         bad_sent_caught.add(idx)
    #         pred_mask = np.where(pred==1)
    #         indices = subtok_to_tok[pred_mask]
    #         ocr_toks = np.array(inputs[idx]['sent'])[indices]
    #         print(inputs[idx]['sent'])
    #         print(confidence[idx][pred_mask])
    #         print(ocr_toks)
    #         print()
#             num_errors += 1
    inp_len = len(inputs)
    bsents = len(bad_sents)
    threshold_results = {}
    for thres in thresholds:
        intersection = bad_sent_caught[thres] & bad_sents
        threshold_results[thres] = (len(bad_sent_caught[thres]), len(intersection))
    results = inp_len, bsents, threshold_results
    with open(output_path, 'wb') as f:
        pickle.dump(results, f, 4)
    return results

In [7]:
with Pool(30) as p:
    list(tqdm(p.imap_unordered(score_book, parsed_book_results_paths), total=len(parsed_book_results_paths)))

100%|██████████| 2975/2975 [1:28:55<00:00,  1.79s/it]  
