In [1]:
import os
import jsonlines
import pandas as pd

from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

In [2]:
# === Set your hyperparameters here ===

# Choose the dataset to use from ["en", "de"]
lang = "en"
assert lang in ["en", "de"]

# Set to true if splitting the survey variables with sub-questions into separate corpus items
sep_answers = False
assert isinstance(sep_answers, bool)

# The string used to join survey variable information
join_str = "[UNK]"
assert isinstance(join_str, str)

# Evaluate each of the top k values
k_values = [1, 3, 5, 10, 100]
assert isinstance(k_values, list)
all(isinstance(e, int) for e in k_values)

# Embedding similarity function from ["cos_sim", "dot_score"]
score_function = "cos_sim"
assert score_function in ["cos_sim", "dot_score"]

# Any sentence-transformers (https://www.sbert.net/docs/pretrained_models.html)
# or HuggingFace model (https://huggingface.co/models) works 
pretrained_model = "all-MiniLM-L6-v2"

# Set batch size for inference
batch_size = 16
assert isinstance(batch_size, int)

# ================ End ================

In [3]:
def split_variables(row):
    if row not in ["No", "NoSkip"]:
        return [tuple(x.split('-')) for x in row.replace('[','').replace(']','').split(',')]
    else:
        return [row]
    
def get_variables(row):
    return [x[0] for x in row.variables]

def make_label(row, answer, join_str):
    v_id = row["id"] if row["id"] else ""
    v_label = row["label"] if row["label"] else ""
    v_topic = row["topic"] if row["topic"] else ""
    v_question = row["question"] if row["question"] else ""
    v_answer = answer if answer else ""

    label = join_str.join([v_label, v_topic, v_question, v_answer])
    
    return label

def get_labels(df, sep_answers=False, join_str="[UNK]"):
    ids = []
    labels = []

    df.fillna("", inplace=True)

    for i,row in df.iterrows():

        if sep_answers:  # split survey variable answers into separate corpus items
            answers = row["answer"].split(";")
        else:  # do not split survey variable answers
            answers = [row["answer"] if row["answer"] else ""]
            
        for v_answer in answers:
            v_id = row["id"] if row["id"] else ""
            label = make_label(row, v_answer, join_str)
            ids.append(v_id)
            labels.append(label)
    
    return ids, labels

In [4]:
# Load raw data

data_path = '../../data/trial/test/en.tsv' if lang != "de" else '../../data/trial/test/de.tsv'
variables_path = '../../data/trial/vocabulary/en.tsv' if lang != "de" else '../../data/vocabulary/de.tsv'

data_df = pd.read_csv(data_path, sep="\t")
data_df.rename(columns={"is_variable": "label"}, inplace=True)
data_df["variables"] = data_df.variable.apply(lambda x: split_variables(x))

variable_df = pd.read_csv(variables_path, sep="\t")

ids, labels = get_labels(variable_df, sep_answers, join_str)

In [5]:
def make_beir_data(data_df, beir_data_dir):
    queries = {}
    qrels = {}
    for i,row in data_df.iterrows():
        queries[str(i)] = row.text
        rel_labels = []
        if row.variable not in ["No", "NoSkip"]:
            rel_labels = ["v"+x for x in get_variables(row)]
        qrels[str(i)] = rel_labels        

    corpus = {}
    for i,label in enumerate(labels):
        corpus[ids[i]] = label

    beir_qrels_dir = os.path.join(beir_data_dir, "qrels")
    if not os.path.exists(beir_qrels_dir):
        os.makedirs(beir_qrels_dir)

    queries_beir = []
    for k,v in queries.items():
        queries_beir.append({"_id": k, "text": v})

    corpus_beir = []
    for k,v in corpus.items():
        corpus_beir.append({"_id": k, "title": "", "text": v})
        # corpus_beir.append({"_id": k, "title": f"Doc_{str(k)}", "text": v})

    qrels_beir = {"query-id": [], "corpus-id": [], "score": []}

    for k,vals in qrels.items():
        for v in vals:
            qrels_beir["query-id"].append(k)
            qrels_beir["corpus-id"].append(v)
            qrels_beir["score"].append(1)

    df = pd.DataFrame.from_records(qrels_beir)
    df[["query-id", "corpus-id", "score"]].to_csv(os.path.join(beir_data_dir, "qrels", "all.tsv"), index=False, sep="\t")

    with jsonlines.open(os.path.join(beir_data_dir, "queries.jsonl"), "w") as writer:
        writer.write_all(queries_beir)

    with jsonlines.open(os.path.join(beir_data_dir, "corpus.jsonl"), "w") as writer:
        writer.write_all(corpus_beir)
    
    return queries_beir, corpus_beir, qrels_beir

In [6]:
# Load data in BEIR format

beir_data_dir = f"../../data/trial/beir_data/{lang}"
_queries_beir, _corpus_beir, _qrels_beir = make_beir_data(data_df, beir_data_dir)

corpus, queries, qrels = GenericDataLoader(data_folder=beir_data_dir).load(split="all")

100%|██████████| 182/182 [00:00<00:00, 98613.01it/s]


In [7]:
# Initialize retriever model

model = DRES(models.SentenceBERT(pretrained_model), batch_size=batch_size)
retriever = EvaluateRetrieval(model, score_function=score_function, k_values=k_values)

In [8]:
# Evaluate model using multiple metrics

results = retriever.retrieve(corpus, queries, return_sorted=True)

Batches: 100%|██████████| 2/2 [00:00<00:00, 70.60it/s]
Batches: 100%|██████████| 12/12 [00:00<00:00, 152.60it/s]


In [10]:
# Save files

import json

with open("./qrels_dense.json", "w") as fp:
    json.dump(qrels, fp)

with open("./run_dense.json", "w") as fp:
    json.dump(results, fp)

In [23]:
qrels

{'3': {'v311': 1, 'v286': 1},
 '4': {'v726': 1, 'v723': 1},
 '8': {'v237': 1,
  'v238': 1,
  'v239': 1,
  'v240': 1,
  'v241': 1,
  'v242': 1,
  'v243': 1,
  'v247': 1},
 '12': {'v732': 1, 'v734': 1},
 '22': {'v205': 1, 'v204': 1, 'v183': 1},
 '24': {'v223': 1, 'v229': 1, 'v688': 1, 'v668': 1},
 '25': {'v333': 1, 'v336': 1, 'v685': 1},
 '29': {'v180': 1, 'v181': 1, 'v182': 1},
 '35': {'v1909': 1, 'v442': 1, 'v462': 1},
 '36': {'v136': 1, 'v137': 1, 'v138': 1, 'v139': 1, 'v237': 1},
 '37': {'v182': 1, 'v210': 1, 'v213': 1},
 '40': {'v288': 1, 'v287': 1, 'v290': 1, 'v270': 1},
 '43': {'v185': 1, 'v184': 1},
 '46': {'v580': 1, 'v1504': 1, 'v454': 1, 'v623': 1},
 '50': {'v253': 1, 'v255': 1, 'v257': 1, 'v258': 1, 'v270': 1},
 '52': {'v732': 1, 'v734': 1, 'v735': 1},
 '60': {'v328': 1, 'v329': 1},
 '61': {'v272': 1, 'v136': 1, 'v273': 1},
 '63': {'v729': 1, 'v915': 1, 'v731': 1, 'v955': 1},
 '64': {'v266': 1,
  'v261': 1,
  'v262': 1,
  'v263': 1,
  'v267': 1,
  'v269': 1,
  'v265': 1},
 '6

In [16]:
results

{'3': {'v311': 0.661536455154419,
  'v274': 0.41382619738578796,
  'v277': 0.4083348214626312,
  'v275': 0.40753433108329773,
  'v276': 0.38659316301345825,
  'v278': 0.3268101215362549,
  'v1320': 0.32371217012405396,
  'v289': 0.31318366527557373,
  'v325': 0.3082313537597656,
  'v1329': 0.3062721788883209,
  'v258': 0.3048103451728821,
  'v257': 0.3044426441192627,
  'v270': 0.30044472217559814,
  'v271': 0.30041348934173584,
  'v295': 0.29970914125442505,
  'v321': 0.299460232257843,
  'v253': 0.29767948389053345,
  'v286': 0.29756683111190796,
  'v265': 0.2936401963233948,
  'v247': 0.2878466546535492,
  'v290': 0.2863495945930481,
  'v264': 0.2842598259449005,
  'v293': 0.2827727198600769,
  'v1330': 0.26995334029197693,
  'v263': 0.26902416348457336,
  'v273': 0.2687966823577881,
  'v249': 0.2673758268356323,
  'v237': 0.2661028504371643,
  'v272': 0.2660942077636719,
  'v287': 0.2658703923225403,
  'v255': 0.2621314823627472,
  'v243': 0.26099416613578796,
  'v267': 0.259049564