In [10]:
import pandas as pd
import os
import math
from collections import Counter

import regex as re
from tqdm.notebook import tqdm

In [11]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../src")
import harmonize

In [2]:
# Load word frequency statistics for control features
word_freq = Counter()
with open("../data/wikitext-2_train_vocab.txt", "r") as f:
    for line in f:
        token, freq = line.strip().split("\t")
        word_freq[token] = int(freq)

In [35]:
def merge_model_results(log_target_code=None):
    mismatches = Counter()
    final_df = []
    
    models = [f for f in os.listdir("../data/model_results") if not f.startswith(".")]
    for m in tqdm(models, desc="Harmonizing models"):
        #tqdm.write("Harmonizing results for " + m)
        test_corpus = [f for f in os.listdir("../data/model_results/" + m) if not f.startswith(".")]
        for tc in test_corpus:
            # DEV
            if tc != "dundee":
                continue
                
            test_files = [f for f in os.listdir("../data/model_results/" + m + "/" + tc) if not f.startswith(".")]
            
            for tf in tqdm(test_files, desc="Test files"):
                if tf == "UNKS":
                    print("TODO: UNKS")
                    continue
                
                try:
                    tf = tf.split("_")
                    test_filename = tf[0]
                    model_architecture = tf[1]
                    training_data = tf[2]
                    seed = tf[3].replace(".csv", "")
                except:
                    print(tf)
                
                # Special handling for the Dundee corpus
                if tc == "dundee":
                    gold_test_filename = test_filename.replace("wrdp", "") + "_avg"
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + gold_test_filename + ".txt", sep="\t", names=["word", "surprisal"])
                    gold_standard.insert(0, 'code', range(0,len(gold_standard)))
                    gold_standard["code"] = gold_standard["code"] + int(test_filename.replace("tx", "").replace("wrdp", "")) * 10000
                else:
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + test_filename + ".txt", sep="\t")
                                    
                model_results = "_".join([test_filename, model_architecture, training_data, seed])
                model_path = "/".join(["../data/model_results", m, tc, model_results])
                model_results = pd.read_csv(model_path+".csv", sep="\t")
                
                # PTB-de-process model results
                model_results.token = model_results.token.str.replace("-LRB-", "(")
                model_results.token = model_results.token.str.replace("-RRB-", ")")
            
                # TODO: EOL Handleing
                
                model_results = [tuple(x)[2:4] for x in model_results.values.tolist()]
                gold_standard = [tuple(x) for x in gold_standard.values.tolist()]
                
                harmonized_results, mismatches_i = harmonize.harmonize_rows(gold_standard, model_results,
                                                                            log_target_code=log_target_code)
                mismatches += mismatches_i
                
                result = [tuple((x[2], x[0], x[1], x[4], tc, model_architecture, training_data, seed, len(x[0]), math.log(word_freq[x[0]]+1))) for x in harmonized_results]
                final_df.extend(result)
                
                break
                
        # DEV
        break
                
    df = pd.DataFrame(final_df)
    df.columns = ["code", "word", "surprisal", "psychometric", "corpus", "model", "training", "seed", "len", "freq"]
    df.head()
    df.to_csv("../data/harmonized_results.csv")
    return df, mismatches

df, mismatches = merge_model_results(mismatches.most_common(1)[0][0])

HBox(children=(FloatProgress(value=0.0, description='Harmonizing models', max=4.0, style=ProgressStyle(descrip…

HBox(children=(FloatProgress(value=0.0, description='Test files', max=80.0, style=ProgressStyle(description_wi…

CONTRACTION MATCH Dubai's Dubai ['Dubai', "'s"] ('Dubai', "'s")
CONTRACTION MATCH Dubai's Dubai ['Dubai', "'s"] ('Dubai', "'s")
	142592	glass	glass
	142593	and	and
	142594	steel	steel
	142595	Dubai's	Dubai
CONTRACTION MATCH Dubai's Dubai ['Dubai', "'s"] ('Dubai', "'s")
	142595	Dubai's	Dubai's
**	142596	go-getting,	's
**	142596	go-getting	's
	142597	onward-thrusting	UNK-LC-DASH
	142598	character,	,
	142598	character	,
	142598	character	UNK-LC-DASH
	142599	the	character
	142599	the	,
	142599	the	the
	142600	service	service




In [17]:
from copy import copy
mismatches_old = copy(mismatches)
mismatches_old.most_common()

[(60394, 3), (20516, 1), (20515, 1)]