## Data harmonization

This notebook harmonizes (aligns) human reading time data from multiple corpora with language model surprisal data. This is necessary because the various reading time corpora and language models all deploy different preprocessing/tokenization procedures.

This notebook reads from the following directories:

- `data/model_results/<architecture>/<corpus>_<architecture>_<training-corpus>_<seed>.csv` contains language model surprisal data for language model `<architecture>`, trained on `<training-corpus>` with seed `<seed>` and evaluated on a reading time corpus `<corpus>`
- `data/human_rts/<corpus>/*.txt` contains human reading time data, as produced by the script `scripts/average_sprs_script.Rmd`

This notebook writes to the file `data/harmonized_data.csv` the following table:

| code | word | surprisal | psychometric | corpus | model | training | seed | len | freq |
| -- | -- | -- | -- | -- | -- | -- | -- | -- | -- |
| unique integer identifying a token within a test corpus | token | model surprisal estimate | human psychometric (e.g. firstpass RT, SPR measure) for token | test corpus | model architecture name | model training corpus | model training seed | token length | token log-frequency |

This can then be read by the main analysis notebook, `notebooks/analysis.Rmd`.

In [1]:
import pandas as pd
import os
import math
from collections import Counter

import regex as re
from tqdm.notebook import tqdm

In [2]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../src")
import harmonize

In [3]:
# Load word frequency statistics for control features
word_freq = Counter()
with open("../data/wikitext-2_train_vocab.txt", "r") as f:
    for line in f:
        token, freq = line.strip().split("\t")
        word_freq[token] = int(freq)

In [6]:
mismatches = Counter()

In [27]:
def merge_model_results(log_target_code=None):
    mismatches = Counter()
    final_df = []
    
    models = [f for f in os.listdir("../data/model_results") if not f.startswith(".")]
    for m in tqdm(models, desc="Harmonizing models"):
        #tqdm.write("Harmonizing results for " + m)
        test_corpus = [f for f in os.listdir("../data/model_results/" + m) if not f.startswith(".")]
        for tc in test_corpus:
            test_files = [f for f in os.listdir("../data/model_results/" + m + "/" + tc) if not f.startswith(".")]
            
            for tf in tqdm(test_files, desc="Test files"):
                if tf == "UNKS":
                    print("TODO: UNKS")
                    continue
                
                try:
                    tf = tf.split("_")
                    test_filename = tf[0]
                    model_architecture = tf[1]
                    training_data = tf[2]
                    seed = tf[3].replace(".csv", "")
                except:
                    print(tf)
                
                model_results = "_".join([test_filename, model_architecture, training_data, seed])
                model_path = "/".join(["../data/model_results", m, tc, model_results])
                model_results = pd.read_csv(model_path+".csv", sep="\t")
                
                # Special handling for the Dundee corpus
                if tc == "dundee":
                    gold_test_filename = test_filename.replace("wrdp", "") + "_avg"
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + gold_test_filename + ".txt", sep="\t",
                                                names=["word", "surprisal"])
                    gold_standard.insert(0, 'code', range(0,len(gold_standard)))
                    gold_standard["code"] = gold_standard["code"] + int(test_filename.replace("tx", "").replace("wrdp", "")) * 10000
                elif tc == "bnc-brown":
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + test_filename + ".csv", sep=",",
                                                usecols=["code", "word", "psychometric"])
                    
                    # Model tokenization differs a little bit here.
                    model_results.token = model_results.token.str.replace("``", '"')
                    model_results.token = model_results.token.str.replace("''", '"')
                elif tc == "natural-stories":
                    gold_standard = pd.read_csv("../data/human_rts/" + tc + "/" + test_filename + ".txt", sep="\t")
                
                # PTB-de-process model results
                model_results.token = model_results.token.str.replace("-LRB-", "(")
                model_results.token = model_results.token.str.replace("-RRB-", ")")
                model_results.token = model_results.token.str.replace("-LSB-", "[")
                model_results.token = model_results.token.str.replace("-RSB-", "]")
                
                model_results = [tuple(x)[2:4] for x in model_results.values.tolist()]
                gold_standard = [tuple(x) for x in gold_standard.values.tolist()]
                
                harmonized_results, mismatches_i = harmonize.harmonize_rows(gold_standard, model_results,
                                                                            log_target_code=log_target_code)
                mismatches += mismatches_i
                
                result = [tuple((x[2], x[0], x[1], x[4], tc, model_architecture, training_data, seed, len(x[0]), math.log(word_freq[x[0]]+1))) for x in harmonized_results]
                final_df.extend(result)
                
    df = pd.DataFrame(final_df)
    df.columns = ["code", "word", "surprisal", "psychometric", "corpus", "model", "training", "seed", "len", "freq"]
    df.head()
    df.to_csv("../data/harmonized_results.csv")
    return df, mismatches

non_issue_mismatches = [
    ## dundee
    # comma-separated list (lots of punctuation joins)
    172179, 50578, 180516, 91736, 102630,
    
    # lotsa punct
    200203, 160452, 180430,
    
    # lotsa unk
    60765, 60766, 21970,
    
    # punct + unk
    131873, 51548, 142006, 182019, 112194,
    
    ## natural stories
    # lotsa unk
    60391,
    # unk+punct
    60676,
    
    # bnc-brown
    # punct
    20165, 20178, 26533, 29273,
    # punct and unks
    32117, 17107,
]
for code in non_issue_mismatches:
    del mismatches[code]
try:
    to_log = mismatches.most_common(1)[0][0]
except:
    to_log = None
df, mismatches = merge_model_results(to_log)

HBox(children=(FloatProgress(value=0.0, description='Harmonizing models', max=4.0, style=ProgressStyle(descrip…

HBox(children=(FloatProgress(value=0.0, description='Test files', max=80.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=4.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=4.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=180.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=9.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=9.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=180.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=9.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=10.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=140.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=7.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='Test files', max=7.0, style=ProgressStyle(description_wid…





In [28]:
mismatches

Counter({172179: 29,
         131873: 14,
         50578: 24,
         51548: 14,
         160452: 8,
         180516: 24,
         60766: 6,
         180430: 8,
         182019: 8,
         200203: 20,
         112194: 8,
         142006: 14,
         21970: 8,
         102630: 8,
         60765: 14,
         91736: 14,
         60391: 13,
         60676: 8,
         26533: 30,
         32117: 30,
         17107: 9,
         29273: 9})

In [29]:
df

Unnamed: 0,code,word,surprisal,psychometric,corpus,model,training,seed,len,freq
0,140000,Lesley,25.244897,286.375000,dundee,5gram,bllip-md,1111,6,0.000000
1,140002,a,5.052998,224.000000,dundee,5gram,bllip-md,1111,1,10.495626
2,140003,51-year-old,12.149018,334.300000,dundee,5gram,bllip-md,1111,11,0.000000
3,140004,from,8.795982,198.000000,dundee,5gram,bllip-md,1111,4,9.130214
4,140006,explains,12.865601,252.875000,dundee,5gram,bllip-md,1111,8,3.988984
...,...,...,...,...,...,...,...,...,...,...
1691369,35749,been,3.947763,284.739583,bnc-brown,gpt-2,bllip-xs-gptbpe,1581807512,4,8.090709
1691370,35750,in,5.606664,288.440417,bnc-brown,gpt-2,bllip-xs-gptbpe,1581807512,2,10.714040
1691371,35751,precisely,16.519783,324.473750,bnc-brown,gpt-2,bllip-xs-gptbpe,1581807512,9,2.564949
1691372,35752,such,11.342088,325.680833,bnc-brown,gpt-2,bllip-xs-gptbpe,1581807512,4,7.368340
