In [1]:
import pandas as pd
import os
import re
import math
from collections import Counter

from tqdm.notebook import tqdm

In [2]:
# Load word frequency statistics for control features
word_freq = Counter()
with open("../data/wikitext-2_train_vocab.txt", "r") as f:
    for line in f:
        token, freq = line.strip().split("\t")
        word_freq[token] = int(freq)

In [48]:
contractions_re.search("o'clock")

<re.Match object; span=(0, 2), match="o'">

In [61]:
# Harmonize lists of <(word, int),(word, int)> pairs
# Discards pairs where words do not match

punct_at_end_re = re.compile(r"\W+$")
punct_at_start_re = re.compile(r"^\W+")

contractions_re = re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T|not|NOT|'[sS]|'[mM]|'[dD])")

mismatches = Counter()
biggest_mismatch_code, _ = mismatches_old.most_common(2)[-1]

def harmonize_rows(ref, d):
    result = []
    curr_d = d.pop(0)
    curr_ref = ref.pop(0)
    
    def process_detokenized(curr_d, curr_ref):
        """
        Remove detokenized trailing and leading punctuation from around the current matched data--reference pair.
        NB updates `ref`, `d` in-place.
        
        Returns:
            updated: True if the surrounding context was updated.
            curr_d: updated curr_d pointer
            curr_ref: updated curr_ref pointer
        """
        model_token, surprisal = curr_d
        code, rt_token, rt = curr_ref
        
        changed = True
        
        
        if punct_at_start_re.search(rt_token) and punct_at_end_re.search(rt_token):
            rt_token_new = punct_at_start_re.sub("", punct_at_end_re.sub("", rt_token))
            curr_ref = (code, rt_token_new, rt)

            # If current model token(s) are that punctuation, drop those tokens
            match = punct_at_start_re.findall(rt_token)[0]
            while match.startswith(model_token):
                match = match[len(model_token):]
                model_token, surprisal = d.pop(0)

            curr_d = (model_token, surprisal)

            # If next model token(s) are that punctuation, drop those tokens
            match = punct_at_end_re.findall(rt_token)[0]
            while match.startswith(d[0][0]):
                match = match[len(d[0][0]):]
                d.pop(0)
        # If current ref has trailing punctuation, remove and re-check
        elif punct_at_end_re.search(rt_token):
            rt_token_new = punct_at_end_re.sub("", rt_token)
            curr_ref = (code, rt_token_new, rt)

            # If next model token(s) are that punctuation, drop those tokens
            match = punct_at_end_re.findall(rt_token)[0]
            while match.startswith(d[0][0]):
                match = match[len(d[0][0]):]
                d.pop(0)
        # If current ref has leading punctuation, remove and re-check
        elif punct_at_start_re.search(rt_token):
            rt_token_new = punct_at_start_re.sub("", rt_token)
            curr_ref = (code, rt_token_new, rt)

            # If current model token(s) are that punctuation, drop those tokens
            match = punct_at_start_re.findall(rt_token)[0]
            while match.startswith(model_token):
                match = match[len(model_token):]
                model_token, surprisal = d.pop(0)

            curr_d = (model_token, surprisal)
        # Process PTB detokenized contractions
        elif contractions_re.search(rt_token):
            ideal_tokenized_form = tuple(contractions_re.sub(r"\1 \2", rt_token).split(" "))

            # Make sure we have the expanded form here in the model tokenization
            future_model_tokens = d[:len(ideal_tokenized_form) - 1]
            model_token_full = [model_token] + [tok for tok, _ in future_model_tokens]
            if ideal_tokenized_form == tuple(model_token_full):
                # Build a little synthetic `curr_d`, `curr_ref` by adding surprisals
                curr_d = ("".join(model_token_full), surprisal + sum(surp for _, surp in future_model_tokens))
                # Pop off those tokens from the queue now that they've been absorbed
                for _ in range(len(future_model_tokens) - 1):
                    d.pop(0)
            else:
                curr_d = d.pop(0)
        else:
            changed = False
            
        return changed, curr_d, curr_ref
    
    
    to_print = 0
    mismatched = [0, None]
    while len(d) > 10:
        model_token, surprisal = curr_d
        code, rt_token, rt = curr_ref
        
        if punct_at_start_re.search(rt_token):
            to_print = 10
            
        if to_print > 0:
            #print("\t", code, model_token, surprisal, rt_token)
            to_print -= 1
            if to_print == 0:
                pass
                #print("=======")
                
        if code > biggest_mismatch_code - 5 and code < biggest_mismatch_code + 5:
            print("\t", code if code != biggest_mismatch_code else "**" + str(code), model_token, surprisal, rt_token)
                
        if mismatched[0] == 5:
            mismatches[mismatched[1]] += 1
        
        #print(curr_d[2] + " " + curr_ref[0])
        if model_token == rt_token:
            #print("===" + curr_d[2] + "-" + curr_ref[0])
            result.append(curr_d + curr_ref)
            curr_d = d.pop(0)
            curr_ref = ref.pop(0)
            mismatched = [0, None]
        else:
            if mismatched[0] == 0:
                mismatched = [1, code]
            else:
                mismatched[0] += 1
            # If current token is unked, then pop both
            if "UNK" in model_token:
                # Before we pop, handle possible punctuation on the rt_token
                process_detokenized(curr_d, curr_ref)
                curr_d = d.pop(0)
                curr_ref = ref.pop(0)
                
            # Check for possible punctuation modifications
            else:
                updated, curr_d, curr_ref = process_detokenized(curr_d, curr_ref)
                if updated:
                    # Punctuation context was modified -- run back through the loop.
                    continue
                # If current ref has leading punctuation, pop both
                elif not rt_token.isalpha():
                    curr_ref = ref.pop(0)
                    curr_d = d.pop(0)
                #If the current word is the end of a line
                elif "EOL" in rt_token:
                    curr_ref = ref.pop(0)
                    curr_d = d.pop(0)
                else:
                    curr_d = d.pop(0)

    return result

In [64]:
def merge_model_results():
    
    final_df = []
    
    models = [f for f in os.listdir("../data/model_results") if not f.startswith(".")]
    for m in tqdm(models, desc="Harmonizing models"):
        #tqdm.write("Harmonizing results for " + m)
        test_corpus = [f for f in os.listdir("../data/model_results/" + m) if not f.startswith(".")]
        for tc in test_corpus:
            test_files = [f for f in os.listdir("../data/model_results/" + m + "/" + tc) if not f.startswith(".")]
            
            for tf in tqdm(test_files, desc="Test files"):
                if tf == "UNKS":
                    print("TODO: UNKS")
                    continue
                
                try:
                    tf = tf.split("_")
                    test_filename = tf[0]
                    model_architecture = tf[1]
                    training_data = tf[2]
                    seed = tf[3].replace(".csv", "")
                except:
                    print(tf)
                
                # Special handling for the Dundee corpus
                if tc == "dundee":
                    gold_test_filename = test_filename.replace("wrdp", "") + "_avg"
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + gold_test_filename + ".txt", sep="\t", names=["word", "surprisal"])
                    gold_standard.insert(0, 'code', range(0,len(gold_standard)))
                    gold_standard["code"] = gold_standard["code"] + int(test_filename.replace("tx", "").replace("wrdp", "")) * 10000
                else:
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + test_filename + ".txt", sep="\t")
                                    
                model_results = "_".join([test_filename, model_architecture, training_data, seed])
                model_path = "/".join(["../data/model_results", m, tc, model_results])
                model_results = pd.read_csv(model_path+".csv", sep="\t")
                
                # PTB-de-process model results
                model_results.token = model_results.token.str.replace("-LRB-", "(")
                model_results.token = model_results.token.str.replace("-RRB-", ")")
            
                # TODO: EOL Handleing
                
                model_results = [tuple(x)[2:4] for x in model_results.values.tolist()]
                gold_standard = [tuple(x) for x in gold_standard.values.tolist()]
                
                harmonized_results = harmonize_rows(gold_standard, model_results)
                
                result = [tuple((x[2], x[0], x[1], x[4], tc, model_architecture, training_data, seed, len(x[0]), math.log(word_freq[x[0]]+1))) for x in harmonized_results]
                final_df.extend(result)
                
    df = pd.DataFrame(final_df)
    df.columns = ["code", "word", "surprisal", "psychometric", "corpus", "model", "training", "seed", "len", "freq"]
    df.head()
    df.to_csv("../data/harmonized_results.csv")
    return df

df = merge_model_results()

HBox(children=(FloatProgress(value=0.0, description='Harmonizing models', max=4.0, style=ProgressStyle(descrip…

HBox(children=(FloatProgress(value=0.0, description='Test files', max=80.0, style=ProgressStyle(description_wi…

	 10815 are 4.989645686576483 are
	 10816 managed 12.16162562596682 managed
	 10817 well 11.032294735942669 well.
	 10817 well 11.032294735942669 well
	 10818 The 2.6143962363012534 The
	 **10819 itineraries 23.277708461407077 itineraries
	 10820 are 4.409351497001103 are
	 10821 staggered 18.280159535978814 staggered,
	 10821 staggered 18.280159535978814 staggered
	 10822 so 8.724991273728696 so
	 10823 that 3.6133377111871794 that
	 10815 are 5.990439327539865 are
	 10816 managed 13.062634923593324 managed
	 10817 well 13.52512985985983 well.
	 10817 well 13.52512985985983 well
	 10818 The 2.6239870675945296 The
	 **10819 itineraries 24.2958996791218 itineraries
	 10820 are 8.780729047411162 are
	 10821 staggered 18.263566571897076 staggered,
	 10821 staggered 18.263566571897076 staggered
	 10822 so 8.532761293339076 so
	 10823 that 3.660728319538267 that
	 10815 are 8.709829815971526 are
	 10816 managed 14.409655363666534 managed
	 10817 well 13.032860778582856 well.
	 10817 well 13

HBox(children=(FloatProgress(value=0.0, description='Test files', max=4.0, style=ProgressStyle(description_wid…

	 10815 for 8.085041562034883 for
	 10816 him 8.533129478492526 him.
	 10816 him 8.533129478492526 him
	 10817 Through 11.345982538323064 Through
	 10818 the 3.025154724181375 the
	 **10819 door 11.6839989265989 door
	 10820 at 7.198121859863185 at
	 10821 this 6.315596459367128 this
	 10822 moment 7.852679586889052 moment,
	 10822 moment 7.852679586889052 moment
	 10823 the 5.780729681919688 the
	 10815 for 4.975892782868375 for
	 10816 him 9.317530544477405 him.
	 10816 him 9.317530544477405 him
	 10817 Through 12.189891261407421 Through
	 10818 the 7.309594791161437 the
	 **10819 door 11.994864680055725 door
	 10820 at 6.857649617283381 at
	 10821 this 6.5426875405254314 this
	 10822 moment 16.779938316450767 moment,
	 10822 moment 16.779938316450767 moment
	 10823 the 4.603129249070879 the
	 10815 for 6.704265053147996 for
	 10816 him 8.691044339315543 him.
	 10816 him 8.691044339315543 him
	 10817 Through 11.320327199529212 Through
	 10818 the 2.5000317903162967 the
	 **10819 door

HBox(children=(FloatProgress(value=0.0, description='Test files', max=180.0, style=ProgressStyle(description_w…

	 10815 are 1.31036 are
	 10816 managed 8.58905 managed
	 10817 well 6.432980000000001 well.
	 10817 well 6.432980000000001 well
	 10818 The 1.94433 The
	 **10819 itineraries 18.6649 itineraries
	 10820 are 1.9314099999999998 are
	 10821 staggered 11.1542 staggered,
	 10821 staggered 11.1542 staggered
	 10822 so 5.73721 so
	 10823 that 1.41741 that
	 10815 are 1.46861 are
	 10816 managed 9.03858 managed
	 10817 well 7.57768 well.
	 10817 well 7.57768 well
	 10818 The 1.9422099999999998 The
	 **10819 itineraries 14.9307 itineraries
	 10820 are 10.2443 are
	 10821 staggered 12.9162 staggered,
	 10821 staggered 12.9162 staggered
	 10822 so 5.813359999999999 so
	 10823 that 0.755184 that
	 10815 are 1.32726 are
	 10816 managed 11.5906 managed
	 10817 well 7.99488 well.
	 10817 well 7.99488 well
	 10818 The 1.8960400000000002 The
	 **10819 UNK-LC-s 6.76508 itineraries
	 10820 are 1.58479 are
	 10821 staggered 12.0464 staggered,
	 10821 staggered 12.0464 staggered
	 10822 so 7.19947 so
	 108

HBox(children=(FloatProgress(value=0.0, description='Test files', max=9.0, style=ProgressStyle(description_wid…

	 10815 for 4.24944 for
	 10816 him 5.12664 him.
	 10816 him 5.12664 him
	 10817 Through 8.620619999999999 Through
	 10818 the 2.20007 the
	 **10819 door 9.53625 door
	 10820 at 6.51611 at
	 10821 this 5.92406 this
	 10822 moment 6.7252 moment,
	 10822 moment 6.7252 moment
	 10823 the 1.35589 the
	 10815 for 4.46002 for
	 10816 him 4.67466 him.
	 10816 him 4.67466 him
	 10817 Through 8.688989999999999 Through
	 10818 the 2.08802 the
	 **10819 door 8.8551 door
	 10820 at 8.112960000000001 at
	 10821 this 5.280130000000001 this
	 10822 moment 7.1075 moment,
	 10822 moment 7.1075 moment
	 10823 the 1.4933299999999998 the
	 10815 for 4.5141599999999995 for
	 10816 him 3.7459800000000003 him.
	 10816 him 3.7459800000000003 him
	 10817 Through 8.1828 Through
	 10818 the 1.2917299999999998 the
	 **10819 door 10.1763 door
	 10820 at 7.28285 at
	 10821 this 4.40284 this
	 10822 moment 6.05649 moment,
	 10822 moment 6.05649 moment
	 10823 the 1.83495 the
	 10815 for 4.26962 for
	 10816 him 4.200

HBox(children=(FloatProgress(value=0.0, description='Test files', max=180.0, style=ProgressStyle(description_w…

	 10815 are 2.430396 are
	 10816 managed 8.221906 managed
	 10817 well 6.750433 well.
	 10817 well 6.750433 well
	 10818 The 0.0 The
	 **10819 itineraries 14.887201999999998 itineraries
	 10820 are 3.285192 are
	 10821 staggered 11.897801 staggered,
	 10821 staggered 11.897801 staggered
	 10822 so 4.4835199999999995 so
	 10823 that 2.533829 that
	 10815 are 1.8869189999999998 are
	 10816 managed 7.347428 managed
	 10817 well 9.477857 well.
	 10817 well 9.477857 well
	 10818 The 0.0 The
	 **10819 itineraries 14.348993 itineraries
	 10820 are 2.9902919999999997 are
	 10821 staggered 11.189746000000001 staggered,
	 10821 staggered 11.189746000000001 staggered
	 10822 so 5.179259 so
	 10823 that 2.5191689999999998 that
	 10815 are 3.4603230000000003 are
	 10816 managed 10.292397 managed
	 10817 well 9.782263 well.
	 10817 well 9.782263 well
	 10818 The 0.0 The
	 **10819 UNK-LC-s 5.599352 itineraries
	 10820 are 4.107126999999999 are
	 10821 staggered 10.846631 staggered,
	 10821 staggered 

HBox(children=(FloatProgress(value=0.0, description='Test files', max=9.0, style=ProgressStyle(description_wid…

	 10815 for 4.414386 for
	 10816 him 4.665321 him.
	 10816 him 4.665321 him
	 10817 Through 0.0 Through
	 10818 the 1.082107 the
	 **10819 door 8.885702 door
	 10820 at 4.449256 at
	 10821 this 4.5606290000000005 this
	 10822 moment 4.772809 moment,
	 10822 moment 4.772809 moment
	 10823 the 2.5990029999999997 the
	 10815 for 3.016497 for
	 10816 him 5.83325 him.
	 10816 him 5.83325 him
	 10817 Through 0.0 Through
	 10818 the 1.015254 the
	 **10819 door 6.158196 door
	 10820 at 5.302896 at
	 10821 this 4.2847599999999995 this
	 10822 moment 4.946548 moment,
	 10822 moment 4.946548 moment
	 10823 the 2.513059 the
	 10815 for 4.800268 for
	 10816 him 5.096895 him.
	 10816 him 5.096895 him
	 10817 Through 0.0 Through
	 10818 the 1.170036 the
	 **10819 door 7.426614 door
	 10820 at 4.4859599999999995 at
	 10821 this 4.647754 this
	 10822 moment 4.959256 moment,
	 10822 moment 4.959256 moment
	 10823 the 2.6752740000000004 the
	 10815 for 3.339613 for
	 10816 him 4.527211 him.
	 10816 him 4

HBox(children=(FloatProgress(value=0.0, description='Test files', max=140.0, style=ProgressStyle(description_w…

	 10815 are 1.3525383472442627 are
	 10816 managed 12.693967819213867 managed
	 10817 well 8.556056022644043 well.
	 10817 well 8.556056022644043 well
	 10818 The 0.0 The
	 **10819 itineraries 22.67350959777832 itineraries
	 10820 are 1.99642550945282 are
	 10821 staggered 16.241657257080078 staggered,
	 10821 staggered 16.241657257080078 staggered
	 10822 so 5.404982089996338 so
	 10823 that 3.235528945922852 that
	 10815 are 2.1147537231445312 are
	 10816 managed 13.210108757019045 managed
	 10817 well 3.252722978591919 well.
	 10817 well 3.252722978591919 well
	 10818 The 0.0 The
	 **10819 itineraries 25.73154878616333 itineraries
	 10820 are 3.0941650867462163 are
	 10821 staggered 11.782011032104492 staggered,
	 10821 staggered 11.782011032104492 staggered
	 10822 so 3.051880359649658 so
	 10823 that 2.767695665359497 that
	 10815 are 2.1234374046325684 are
	 10816 managed 13.389694213867188 managed
	 10817 well 14.573084831237793 well.
	 10817 well 14.573084831237793 well
	 10818

HBox(children=(FloatProgress(value=0.0, description='Test files', max=7.0, style=ProgressStyle(description_wid…

	 10815 for 8.735595703125 for
	 10816 him 1.119791865348816 him.
	 10816 him 1.119791865348816 him
	 10817 Through 0.0 Through
	 10818 the 1.6268385648727417 the
	 **10819 door 16.033885955810547 door
	 10820 at 7.239103317260742 at
	 10821 this 6.783059120178223 this
	 10822 moment 2.4769375324249268 moment,
	 10822 moment 2.4769375324249268 moment
	 10823 the 2.1590681076049805 the
	 10815 for 1.9955505132675169 for
	 10816 him 2.22986364364624 him.
	 10816 him 2.22986364364624 him
	 10817 Through 0.0 Through
	 10818 the 2.1855244636535645 the
	 **10819 door 10.929271697998047 door
	 10820 at 6.911056041717528 at
	 10821 this 6.597612857818604 this
	 10822 moment 6.033666133880615 moment,
	 10822 moment 6.033666133880615 moment
	 10823 the 2.6941044330596924 the
	 10815 for 3.7969083786010738 for
	 10816 him 5.870589256286621 him.
	 10816 him 5.870589256286621 him
	 10817 Through 0.0 Through
	 10818 the 2.443039894104004 the
	 **10819 door 11.040163040161133 door
	 10820 at 6.780903

In [63]:
from copy import copy
mismatches_old = copy(mismatches)
mismatches_old.most_common()

[(60394, 20),
 (20515, 8),
 (20516, 6),
 (10878, 5),
 (10879, 5),
 (10880, 5),
 (10881, 5),
 (10882, 5)]