In [2]:
import sys
sys.path.append("../")

import os

In [3]:
import json
from collections import defaultdict
from pathlib import Path
import datasets
import pandas as pd
import numpy as np

In [4]:
import datasets
from transformers import AutoConfig, AutoTokenizer

In [13]:
def get_statistics(name, corpus, query_splits, tokenizer_name="distilbert-base-uncased"):

    def count_tokens(dataset, field):
        counts = dataset.map(lambda _: {"n_toks": len(tokenizer(_[field])["input_ids"])})
        return np.sum(counts["n_toks"])

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    if corpus:
        corpus = datasets.load_from_disk(corpus)
        nD = len(corpus)
        nDToks = count_tokens(corpus, "text")
    else:
        nD = None
        nDToks = None
        
    queries = {}
    nQ = 0
    nQToks = 0
    for split, path in query_splits.items():
        query = datasets.load_from_disk(path)
        t = count_tokens(query, "query")
        print("\t", split, "nQ:", len(query), "nQToks:", t)
        
        nQ += len(query)
        nQToks += t
    
    return {
        "name": name,
        "nQ": nQ,
        "nD": nD,
        "nQToks": nQToks,
        "nDToks": nDToks
    }



d_stats = []
d_stats.append(get_statistics("TREC-ToT dev", corpus_path, trec_queries))
d_stats.append(get_statistics("RedditTest", corpus_path, reddit_queries))
d_stats = pd.DataFrame(d_stats)
d_stats["avgQLen"] = d_stats["nQToks"] / d_stats["nQ"]
d_stats["avgDLen"] = d_stats["nDToks"] / d_stats["nD"]
d_stats


Loading cached processed dataset at ../datasets/ToT/trec-corpus/cache-192251ea7652b11c.arrow
Loading cached processed dataset at ../datasets/ToT/trec-dev/cache-097ef614d33a8d0f.arrow
Loading cached processed dataset at ../datasets/ToT/trec-corpus/cache-192251ea7652b11c.arrow


	 trec-dev nQ: 150 nQToks: 25384


  0%|          | 0/933 [00:00<?, ?ex/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (920 > 512). Running this sequence through the model will result in indexing errors


	 reddit-test nQ: 933 nQToks: 145325


Unnamed: 0,name,nQ,nD,nQToks,nDToks,avgQLen,avgDLen
0,TREC-ToT dev,150,231852,25384,154600886,169.226667,666.808507
1,RedditTest,933,231852,145325,154600886,155.760986,666.808507


In [8]:
from eval_run import evaluate, compute_and_merge_mrr_cut

In [9]:
DEST_PATH = Path("../datasets/ToT")
trec_tot_dataset_path = Path("/Users/sam/workspaces/trec-tot-repos/trec-tot/datasets/TREC-TOT/public/")
trec_tot_dataset_path_private = Path("/Users/sam/workspaces/trec-tot-repos/trec-tot/datasets/TREC-TOT/private/")
reddit_path = Path("/Users/sam/workspaces/tomt-data/trec_dataset/Movies/")


In [10]:
qrels = {}
splits = {"dev", "test", "train"}

trec_queries = {}

def conv_q_trec_to_tev(q):
    return {
        "query_id": q["id"],
        "query": q["title"] + ".\n" + q["text"],
        "positive_passages": [],
        "negative_passages": [],
    }


def conv_d_trec_to_tev(d):
    return {
        "docid": d["doc_id"],
        "title": d["page_title"],
        "text": d["text"]
    }

def serial_dict(l):
    d = defaultdict(list)
    for _ in l:
        for k, v in _.items():
            d[k].append(v)
    return d
    

    
for split in splits:
    queries = []
    qrel = {}
    q_path = trec_tot_dataset_path if split != "test" else trec_tot_dataset_path_private
    with open(q_path / split / "queries.jsonl") as reader:
        for line in reader:
            q = json.loads(line)
            queries.append(conv_q_trec_to_tev(q))

            qrel[q["id"]] = {
                q["wikipedia_id"] : 1
            }
    
    qrels[f"trec-{split}"] = qrel
    trec_dset = datasets.Dataset.from_dict(serial_dict(queries))
    # trec_dset.save_to_disk(DEST_PATH / f"trec-{split}")

    

splits = {"validation", "test", "train"}
for split in splits:
    queries = []
    qrel = {}
    with open(reddit_path / f"{split}.jsonl") as reader:
        for line in reader:
            q = json.loads(line)
            queries.append(conv_q_trec_to_tev(q))
            qrel[q["id"]] = {
                q["wikipedia_id"] : 1
            }
    
    qrels[f"reddit-{split}"] = qrel

    trec_dset = datasets.Dataset.from_dict(serial_dict(queries))
    # trec_dset.save_to_disk(DEST_PATH / f"reddit-{split}")



In [42]:
with open(trec_tot_dataset_path / "corpus.jsonl") as reader:
    corpus = []
    for line in reader:
        corpus.append(conv_d_trec_to_tev(json.loads(line)))

    trec_corpus = datasets.Dataset.from_dict(serial_dict(corpus))
    # trec_corpus.save_to_disk(DEST_PATH / "trec-corpus")


In [8]:
corpus_path = DEST_PATH / "trec-corpus"
trec_queries = {
    #"trec-train": DEST_PATH / "trec-train",
    "trec-dev": DEST_PATH / "trec-dev",
    #"trec-test": DEST_PATH / "trec-test",
}

reddit_queries = {
    "reddit-test": DEST_PATH / "reddit-test"
}


get_statistics(

In [6]:
def evaluate_runs(runs_folder: Path, metrics):
    results = {}
    for split, qrel in qrels.items():
        run = defaultdict(dict)
        with open(runs_folder / f"tot-{split}.run") as reader:
            for line in reader:
                qid, doc_id, score = line.split()
                run[qid][doc_id] = float(score)
    
        mrr_cut = [_ for _ in metrics if _.startswith("recip_rank_cut_")]
        metrics = [_ for _ in metrics if _ not in mrr_cut]
        assert len(qrel) == len(run)
        assert [_ in run for _ in qrel]
        eval_res_queries = evaluate(run, qrel, metrics)
        # compute the MRR@K by cutting off the run at K, because trec_eval doesn't support @K
        agg, eval_res_queries = compute_and_merge_mrr_cut(run, qrel, mrr_cut, eval_res_queries)
    
        eval_res = {}
        for metric, values in agg.items():
            m, s = (np.mean(values), np.std(values))
            eval_res[metric] = m
        results[split] = eval_res

    return results


In [11]:
MODELS = {
    "DPR":  Path("./gathered_results/dpr_hs_db_3_runs/"),
    "TAS-B (0s)":  Path("./gathered_results/tas_b_zeroshot_runs/"), 
    "MVRL": Path("./gathered_results_manual/MVRL_new_tot/"),
    "CLDRD": Path("./gathered_results/cldrd_runs/")
}

METRICS = {"recip_rank", "ndcg_cut_10", "ndcg_cut_1000", "recall_1000"}

sel_splits = {"reddit-test", "trec-dev"}

result_rows = []
for model, run_folder in MODELS.items():
    row = {"model": model}
    for sp, sp_res in evaluate_runs(run_folder, METRICS).items():
        if sp not in sel_splits: continue
        for m, v in sp_res.items():
            row[f"{sp}-{m}"] = v
    result_rows.append(row)


results = pd.DataFrame(result_rows)

In [13]:
results.to_clipboard(index=False)

In [12]:
results

Unnamed: 0,model,trec-dev-recip_rank,trec-dev-recall_1000,trec-dev-ndcg_cut_10,trec-dev-ndcg_cut_1000,reddit-test-recip_rank,reddit-test-recall_1000,reddit-test-ndcg_cut_10,reddit-test-ndcg_cut_1000
0,DPR,0.041916,0.34,0.040119,0.08261,0.026075,0.359057,0.028314,0.072324
1,TAS-B (0s),0.064457,0.453333,0.068048,0.118594,0.077386,0.524116,0.088931,0.146284
2,MVRL,0.033553,0.193333,0.035671,0.055117,0.024213,0.275456,0.027088,0.058864
3,CLDRD,0.057906,0.393333,0.059676,0.103649,0.056972,0.44373,0.063539,0.115164
