In [1]:
import pandas as pd
import os
import re
import math
from collections import Counter

from tqdm.notebook import tqdm

In [2]:
# Load word frequency statistics for control features
word_freq = Counter()
with open("../data/wikitext-2_train_vocab.txt", "r") as f:
    for line in f:
        token, freq = line.strip().split("\t")
        word_freq[token] = int(freq)

In [149]:
# Harmonize lists of <(word, int),(word, int)> pairs
# Discards pairs where words do not match

punct_at_end_re = re.compile(r"\W+$")
punct_at_start_re = re.compile(r"^\W+")

contractions_re = re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T|not|NOT|'[sS]|'[mM]|'[dD]|')")

mismatches = Counter()
biggest_mismatch_code, _ = mismatches_old.most_common(1)[0]

def harmonize_rows(ref, d):
    result = []
    curr_d = d.pop(0)
    curr_ref = ref.pop(0)
    
    to_print = 0
    mismatched = [0, None]
    while len(d) > 10:
        model_token, surprisal = curr_d
        code, rt_token, rt = curr_ref
        
        if punct_at_start_re.search(rt_token):
            to_print = 10
            
        if to_print > 0:
            #print("\t", code, model_token, surprisal, rt_token)
            to_print -= 1
            if to_print == 0:
                pass
                #print("=======")
                
        if code > biggest_mismatch_code - 5 and code < biggest_mismatch_code + 5:
            print("\t", code if code != biggest_mismatch_code else "**" + str(code), model_token, surprisal, rt_token)
                
        if mismatched[0] == 5:
            mismatches[mismatched[1]] += 1
        
        #print(curr_d[2] + " " + curr_ref[0])
        if model_token == rt_token:
            #print("===" + curr_d[2] + "-" + curr_ref[0])
            result.append(curr_d + curr_ref)
            curr_d = d.pop(0)
            curr_ref = ref.pop(0)
            mismatched = [0, None]
        else:
            if mismatched[0] == 0:
                mismatched = [1, code]
            else:
                mismatched[0] += 1
            # If current token is unked, then pop both
            if "UNK" in model_token:
                curr_d = d.pop(0)
                curr_ref = ref.pop(0)
            # If current ref has trailing punctuation, remove and re-check
            elif punct_at_end_re.search(rt_token):
                rt_token_new = punct_at_end_re.sub("", rt_token)
                curr_ref = (code, rt_token_new, rt)

                # If next model token(s) are that punctuation, drop those tokens
                match = punct_at_end_re.findall(rt_token)[0]
                while match.startswith(d[0][0]):
                    match = match[len(d[0][0]):]
                    d.pop(0)
            # If current ref has leading punctuation, remove and re-check
            elif punct_at_start_re.search(rt_token):
                rt_token_new = punct_at_start_re.sub("", rt_token)
                curr_ref = (code, rt_token_new, rt)

                # If current model token(s) are that punctuation, drop those tokens
                match = punct_at_start_re.findall(rt_token)[0]
                while match.startswith(model_token):
                    match = match[len(model_token):]
                    model_token, surprisal = d.pop(0)

                curr_d = (model_token, surprisal)
            # de-tokenize PTB splits of contractions
            elif contractions_re.search(rt_token):
                ideal_tokenized_form = tuple(contractions_re.sub(r"\1 \2", rt_token).split(" "))

                # Make sure we have the expanded form here in the model tokenization
                future_model_tokens = d[:len(ideal_tokenized_form) - 1]
                model_token_full = [model_token] + [tok for tok, _ in future_model_tokens]
                if ideal_tokenized_form == tuple(model_token_full):
                    # Build a little synthetic `curr_d`, `curr_ref` by adding surprisals
                    curr_d = ("".join(model_token_full), surprisal + sum(surp for _, surp in future_model_tokens))
                    d = d[len(future_model_tokens):]
                else:
                    curr_d = d.pop(0)
            # If current ref has leading punctuation, pop both
            elif not rt_token.isalpha():
                curr_ref = ref.pop(0)
                curr_d = d.pop(0)
            #If the current word is the end of a line
            elif "EOL" in rt_token:
                curr_ref = ref.pop(0)
                curr_d = d.pop(0)
            else:
                curr_d = d.pop(0)

    return result

In [148]:
def merge_model_results():
    
    final_df = []
    
    models = [f for f in os.listdir("../data/model_results") if not f.startswith(".")]
    for m in tqdm(models, desc="Harmonizing models"):
        tqdm.write("Harmonizing results for " + m)
        test_corpus = [f for f in os.listdir("../data/model_results/" + m) if not f.startswith(".")]
        for tc in test_corpus:
            test_files = [f for f in os.listdir("../data/model_results/" + m + "/" + tc) if not f.startswith(".")]
            
            for tf in tqdm(test_files, desc="Test files"):
                if tf == "UNKS":
                    print("TODO: UNKS")
                    continue
                
                try:
                    tf = tf.split("_")
                    test_filename = tf[0]
                    model_architecture = tf[1]
                    training_data = tf[2]
                    seed = tf[3].replace(".csv", "")
                except:
                    print(tf)
                    
                if tc == "dundee":
                    continue
                
                # Special handling for the Dundee corpus
                if tc == "dundee":
                    gold_test_filename = test_filename.replace("wrdp", "") + "_avg"
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + gold_test_filename + ".txt", sep="\t", names=["word", "surprisal"])
                    gold_standard.insert(0, 'code', range(0,len(gold_standard)))
                    gold_standard["code"] = gold_standard["code"] + int(test_filename.replace("tx", "").replace("wrdp", "")) * 10000
                else:
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + test_filename + ".txt", sep="\t")
                                    
                model_results = "_".join([test_filename, model_architecture, training_data, seed])
                model_path = "/".join(["../data/model_results", m, tc, model_results])
                model_results = pd.read_csv(model_path+".csv", sep="\t")
                
                # PTB-de-process model results
                model_results.token = model_results.token.str.replace("-LRB-", "(")
                model_results.token = model_results.token.str.replace("-RRB-", ")")
            
                # TODO: EOL Handleing
                
                model_results = [tuple(x)[2:4] for x in model_results.values.tolist()]
                gold_standard = [tuple(x) for x in gold_standard.values.tolist()]
                
                harmonized_results = harmonize_rows(gold_standard, model_results)
                
                result = [tuple((x[2], x[0], x[1], x[4], tc, model_architecture, training_data, seed, len(x[0]), math.log(word_freq[x[0]]+1))) for x in harmonized_results]
                final_df.extend(result)
                
    df = pd.DataFrame(final_df)
    df.columns = ["code", "word", "surprisal", "psychometric", "corpus", "model", "training", "seed", "len", "freq"]
    df.head()
    df.to_csv("../data/harmonized_results.csv")
    return df

df = merge_model_results()

HBox(children=(FloatProgress(value=0.0, description='Harmonizing models', max=4.0, style=ProgressStyle(descrip…

Harmonizing results for 5gram


HBox(children=(FloatProgress(value=0.0, description='Test files', max=80.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=4.0, style=ProgressStyle(description_wid…

	 50502 make 7.891324061672032 make
	 50503 my 11.031233966474128 my
	 50504 hair 9.035856975741627 hair
	 50505 like 11.489323788952227 like
	 **50506 Elvis 18.36795988068078 Elvis's?'
	 **50506 Elvis 18.36795988068078 Elvis's
	 50507 's 8.84148195957268 I
	 50507 ? 14.337447935463047 I
	 50507 ' 4.130624192383613 I
	 50507 I 10.713245580325928 I
	 50508 asked 8.177121540433163 asked.
	 50508 asked 8.177121540433163 asked
	 50509 The 2.6239870675945296 The
	 50510 barber 24.2958996791218 barber,
	 50510 barber 24.2958996791218 barber
	 50502 make 9.068704210480652 make
	 50503 my 12.934874773941024 my
	 50504 hair 9.483685675121167 hair
	 50505 like 11.226122102907862 like
	 **50506 Elvis 20.85481548613818 Elvis's?'
	 **50506 Elvis 20.85481548613818 Elvis's
	 50507 's 8.35127796323083 I
	 50507 ? 14.868813485445399 I
	 50507 ' 5.101830475959195 I
	 50507 I 12.910367591482274 I
	 50508 asked 9.389981133860633 asked.
	 50508 asked 9.389981133860633 asked
	 50509 The 2.5940555494832624 T

HBox(children=(FloatProgress(value=0.0, description='Test files', max=180.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=9.0, style=ProgressStyle(description_wid…

	 50502 make 2.53153 make
	 50503 my 5.08552 my
	 50504 hair 7.62017 hair
	 50505 like 5.27123 like
	 **50506 Elvis 10.7144 Elvis's?'
	 **50506 Elvis 10.7144 Elvis's
	 50507 's 11.8409 I
	 50507 ? 2.91505 I
	 50507 ' 6.04245 I
	 50507 I 7.416860000000001 I
	 50508 asked 6.267119999999999 asked.
	 50508 asked 6.267119999999999 asked
	 50509 The 1.8348099999999998 The
	 50510 UNK-LC-er 7.82015 barber,
	 50502 make 2.88469 make
	 50503 my 4.447690000000001 my
	 50504 hair 8.193760000000001 hair
	 50505 like 6.21475 like
	 **50506 Elvis 10.5518 Elvis's?'
	 **50506 Elvis 10.5518 Elvis's
	 50507 's 10.0458 I
	 50507 ? 0.23756 I
	 50507 ' 6.964010000000001 I
	 50507 I 5.5833900000000005 I
	 50508 asked 6.827189999999999 asked.
	 50508 asked 6.827189999999999 asked
	 50509 The 1.8960400000000002 The
	 50510 UNK-LC-er 8.2993 barber,
	 50502 make 3.76663 make
	 50503 my 4.3631 my
	 50504 hair 6.84824 hair
	 50505 like 5.4942 like
	 **50506 Elvis 8.980789999999999 Elvis's?'
	 **50506 Elvis 8.9807

KeyboardInterrupt: 

In [146]:
from copy import copy
mismatches_old = copy(mismatches)
mismatches_old.most_common()

[(50506, 24),
 (60391, 24),
 (70189, 24),
 (10776, 20),
 (70248, 16),
 (70140, 14),
 (20277, 8),
 (20515, 8),
 (40199, 8),
 (40745, 8),
 (50428, 8),
 (60633, 8),
 (60735, 8),
 (60945, 8),
 (70153, 8),
 (70164, 8),
 (70247, 8),
 (70595, 8),
 (70643, 8),
 (80142, 8),
 (100302, 8),
 (100311, 8),
 (100321, 8),
 (20516, 6),
 (10878, 5),
 (10879, 5),
 (10880, 5),
 (10881, 5),
 (10882, 5)]