# match the 2 data sources

This notebook takes tfds `wikihow` dataset and matches it with our manually made pre-processed data for rouge eval.

In [1]:
from pegasus.data import all_datasets
from pathlib import Path

In [3]:
input_pattern = "tfds:wikihow/all"
split = "test"
#for split in ["train", "validation", "test"]:
    
data = all_datasets.get_dataset(input_pattern + "-" + split, shuffle_files=False)

In [28]:
# compare the results of our preprocessing with tfds
dir = Path("./data")
dir_src = dir / "test.source"
dir_tgt = dir / "test.target"
src = []
tgt = []
# with open(dir_src, 'rt') as f: src = f.read().splitlines() 
# with open(dir_tgt, 'rt') as f: tgt = f.read().splitlines()
with open(dir_src, 'rt') as f: src = [l.rstrip() for l in f]
with open(dir_tgt, 'rt') as f: tgt = [l.rstrip() for l in f]
assert len(src) == len(tgt), f"must have the same number of records, but got src={len(src)}, tgt={len(tgt)}"


In [29]:
import difflib
def str_compare(a, b):
    """ 
    If strings are mismatched, print the diff with context
    Returns true if strings match, false otherwise
    adapted from https://stackoverflow.com/a/17904977/9201239
    """
    
    match = True
    if len(a) != len(b):
        print(f"length mismatch: a={len(a)}, b={len(b)}")
        
    def context(s, i):
        start = i-10
        end   = i+10
        if start < 0: start = 0
        if end > len(s)-1: end = len(s)-1
        return s[start:end]
        
    for i, s in enumerate(difflib.ndiff(a, b)):
        if s[0] == ' ': 
            continue          
        elif s[0] == '-':
            match = False
            print(f'Delete "{s[-1]}" from position {i}, ctx=[{context(a, i)}]')
        elif s[0] == '+':
            match = False
            print(f'Add "{s[-1]}" to position {i}, ctx=[{context(a, i)}')
            
    return match

In [48]:
# corresponding ids in our dataset
#idmap = [4909-1, 3164-1, 2920-1, 4541-1, 1781-1]

# -1 is due to the silly mismatch between unix tools and python - as I had to use grep to find the right record
idmap = [1045-1, 3090-1, 2634-1, 4169-1]

# odd cases:
# id 2920-1 corresponding to
# test_articles/HowtoInstallKasperskyAntivirusonaSmartphone.txt
# is inconsistent newline wise in the tfds (some extra new lines that shouldn't be there)
# it looks that it happens if there is no period in the summary line - it injects 2 new lines there.
#
# id 4541-1 
# test_articles/HowtoRefuseAnnoyingSalespeople3.txt
# they capitalize quoted text: Tell them “No calls, emails only.”
# whereas ours is:             Tell them “no calls, emails only.”

In [49]:
def preview(s):
    slen = len(s)
    end = 200 if 200 < len(s)-1 else len(s)
    return s[:end]

data_iter = iter(data)
for id in idmap:
    d = next(data_iter)
    id_ok = True
    tf_src = d['inputs'].numpy().decode()
    hf_src = src[id].replace("<n>", "\n")
    if not str_compare(tf_src, hf_src):
        id_ok = False
        print(f"\nmismatching src texts\ntf=[{preview(tf_src)}]\nhf=[{preview(hf_src)}]")
    
    tf_tgt = d['targets'].numpy().decode()
    hf_tgt = tgt[id].replace("<n>", "\n")
    if not str_compare(tf_tgt, hf_tgt):
        id_ok = False
        
        print(f"\nmismatching tgt texts\ntf=[{preview(tf_tgt)}]\nhf=[{preview(hf_tgt)}]")

    if id_ok:
        print(f"✓ {id}")

✓ 1044
✓ 3089
✓ 2633
✓ 4168
