In [None]:
import time
from datetime import datetime
from collections import defaultdict
from typing import List, Dict, Union, Tuple
import copy
import os
import csv
import json
import jsonlines
from tqdm import tqdm
import math
import random
import numpy as np
import pandas as pd
import pytrec_eval
import torch
from torch import Tensor
from sentence_transformers import models, losses
from sentence_transformers import SentenceTransformer
from beir.retrieval import models
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.search.sparse import SparseSearch


PYTREC_METRIC_MAPPING = {"map": "map_cut", "rprec": "Rprec", "p": "P", "r": "recall"}


In [None]:
def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [None]:
def load_variables(path):
    assert ".json" == path[-5:]
    variables = {}
    with open(path, "r") as fp:
        variables = json.load(fp)
    return variables


def make_label(meta, sep=" [UNK] "):
    label = []
    for k,v in meta.items():
        if isinstance(v, list):
            v = sep.join(v)
        if ";" in v:
            v = sep.join(v.split(";"))
        label.append(v)
    return sep.join(label)


def get_instance_variables(df, variables):
    instance_variables = {k:make_label(v) for rd in df.iloc[0]["research_data"].split(";") for k,v in variables[rd].items()}
    id_mapping = {k:str(i) for i,(k,_) in enumerate(instance_variables.items())}
    return instance_variables, id_mapping


def get_qrels(df, mapping):
    qrels = {"query-id": [], "corpus-id": [], "score": [], "uuid": []}
    qrels_with_unk = []

    for i in range(df.shape[0]):
        row = df.iloc[i]
        uuid = row["uuid"]
        is_variable = row["is_variable"]

        if is_variable == 1:
            vs = row["variable"].split(";") if ";" in row["variable"] else [row["variable"]]
            for v in vs:
                if v == "unk":
                    qrels_with_unk.append(i)
                    continue
                elif v not in mapping:
                    print(f"{v} not in mapping!")
                    continue
                
                _id = mapping[v]
                qrels["query-id"].append(str(i))
                qrels["corpus-id"].append(str(_id))
                qrels["score"].append(1)
                qrels["uuid"].append(uuid)
    return qrels


def get_corpus(df, ivariables, mapping):
    corpus = []

    for uuid,v in ivariables.items():
        _id = mapping[uuid]
        instance = {"_id": _id, "title": "", "text": v, "uuid": uuid}
        corpus.append(instance)
    return corpus


def get_queries(df):
    queries = []

    for i in range(df.shape[0]):
        row = df.iloc[i]
        uuid = row["uuid"]
        text = row["sentence"]

        query = {"_id": str(i), "text": text, "uuid": uuid}
        queries.append(query)
    return queries


def save_jsonl(data, path):
    with jsonlines.open(path, "w") as writer:
        writer.write_all(data)


def save_files(queries, corpus, qrels, data_dir):
    qrels_dir = os.path.join(data_dir, "qrels")
    if not os.path.exists(qrels_dir):
        os.makedirs(qrels_dir)

    queries_path = os.path.join(data_dir, "queries.jsonl")
    save_jsonl(queries, queries_path)
    corpus_path = os.path.join(data_dir, "corpus.jsonl")
    save_jsonl(corpus, corpus_path)

    qrels_path = os.path.join(qrels_dir, "all.tsv")
    qrels_df = pd.DataFrame(qrels)
    qrels_df.to_csv(qrels_path, index=False, sep="\t")

In [None]:
class GenericDataLoader:
    
    def __init__(self, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl", 
                 qrels_folder: str = "qrels", qrels_file: str = ""):
        self.corpus = {}
        self.queries = {}
        self.qrels = {}

        self.query_corpus_mapping = {}
        
        if prefix:
            query_file = prefix + "-" + query_file
            qrels_folder = prefix + "-" + qrels_folder

        self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
        self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
        self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
        self.qrels_file = qrels_file
    
    @staticmethod
    def check(fIn: str, ext: str):
        if not os.path.exists(fIn):
            raise ValueError("File {} not present! Please provide accurate file.".format(fIn))
        
        if not fIn.endswith(ext):
            raise ValueError("File {} must be present with extension {}".format(fIn, ext))

    def load_custom(self) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:

        self.check(fIn=self.corpus_file, ext="jsonl")
        self.check(fIn=self.query_file, ext="jsonl")
        self.check(fIn=self.qrels_file, ext="tsv")
        
        if not len(self.corpus):
            self._load_corpus()
        
        if not len(self.queries):
            self._load_queries()
        
        if os.path.exists(self.qrels_file):
            self._load_qrels()
            self.queries = {qid: self.queries[qid] for qid in self.qrels}
        
        return self.corpus, self.queries, self.qrels

    def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
        
        self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
        self.check(fIn=self.corpus_file, ext="jsonl")
        self.check(fIn=self.query_file, ext="jsonl")
        self.check(fIn=self.qrels_file, ext="tsv")
        
        if not len(self.corpus):
            self._load_corpus()
        
        if not len(self.queries):
            self._load_queries()
        
        if os.path.exists(self.qrels_file):
            self._load_qrels()
            self.queries = {qid: self.queries[qid] for qid in self.qrels}

        # Re-format
        self.corpus = {v["uuid"]: {"text": v["text"], "title": v["title"], "_id": k} for k,v in self.corpus.items()}
        
        return self.corpus, self.queries, self.qrels
    
    def load_corpus(self) -> Dict[str, Dict[str, str]]:
        
        self.check(fIn=self.corpus_file, ext="jsonl")

        if not len(self.corpus):
            self._load_corpus()

        return self.corpus
    
    def _load_corpus(self):
    
        with open(self.corpus_file, encoding='utf8') as fIn:
            for line in fIn:
                line = json.loads(line)
                self.corpus[line.get("_id")] = {
                    "text": line.get("text"),
                    "title": line.get("title"),
                    "uuid": line.get("uuid"),
                }
    
    def _load_queries(self):
        
        with open(self.query_file, encoding='utf8') as fIn:
            for line in fIn:
                line = json.loads(line)
                self.queries[line.get("uuid")] = line.get("text")
        
    def _load_qrels(self):
        
        reader = csv.reader(open(self.qrels_file, encoding="utf-8"), 
                            delimiter="\t", quoting=csv.QUOTE_MINIMAL)
        next(reader)
        
        for id, row in enumerate(reader):
            query_id, corpus_id, score = row[3], row[1], int(row[2])
            
            if query_id not in self.qrels:
                self.qrels[query_id] = {self.corpus[corpus_id]['uuid']: score}
            else:
                self.qrels[query_id][self.corpus[corpus_id]['uuid']] = score

In [None]:
class MultiDatasetDataLoader:
    def __init__(self, datasets, batch_size_pairs, batch_size_triplets=None, dataset_size_temp=-1):
        self.allow_swap = True
        self.batch_size_pairs = batch_size_pairs
        self.batch_size_triplets = batch_size_pairs if batch_size_triplets is None else batch_size_triplets

        # Compute dataset weights
        self.dataset_lengths = list(map(len, datasets))
        self.dataset_lengths_sum = sum(self.dataset_lengths)

        weights = []
        if dataset_size_temp > 0:  # Scale probability with dataset size
            for dataset in datasets:
                prob = len(dataset) / self.dataset_lengths_sum
                weights.append(max(1, int(math.pow(prob, 1 / dataset_size_temp) * 1000)))
        else:  # Equal weighting of all datasets
            weights = [100] * len(datasets)

        self.dataset_idx = []
        self.dataset_idx_pointer = 0

        for idx, weight in enumerate(weights):
            self.dataset_idx.extend([idx] * weight)
        random.shuffle(self.dataset_idx)

        self.datasets = []
        for dataset in datasets:
            random.shuffle(dataset)
            self.datasets.append({
                'elements': dataset,
                'pointer': 0,
            })

    def __iter__(self):
        for _ in range(int(self.__len__())):
            # Select dataset
            if self.dataset_idx_pointer >= len(self.dataset_idx):
                self.dataset_idx_pointer = 0
                random.shuffle(self.dataset_idx)

            dataset_idx = self.dataset_idx[self.dataset_idx_pointer]
            self.dataset_idx_pointer += 1

            # Select batch from this dataset
            dataset = self.datasets[dataset_idx]
            batch_size = self.batch_size_pairs if len(dataset['elements'][0].texts) == 2 else self.batch_size_triplets

            batch = []
            texts_in_batch = set()
            guid_in_batch = set()
            while len(batch) < batch_size:
                example = dataset['elements'][dataset['pointer']]

                valid_example = True
                # First check if one of the texts in already in the batch
                for text in example.texts:
                    text_norm = text.strip().lower()
                    if text_norm in texts_in_batch:
                        valid_example = False

                    texts_in_batch.add(text_norm)

                # If the example has a guid, check if guid is in batch
                if example.guid is not None:
                    valid_example = valid_example and example.guid not in guid_in_batch
                    guid_in_batch.add(example.guid)


                if valid_example:
                    if self.allow_swap and random.random() > 0.5:
                        example.texts[0], example.texts[1] = example.texts[1], example.texts[0]

                    batch.append(example)

                dataset['pointer'] += 1
                if dataset['pointer'] >= len(dataset['elements']):
                    dataset['pointer'] = 0
                    random.shuffle(dataset['elements'])

            yield self.collate_fn(batch) if self.collate_fn is not None else batch

    def __len__(self):
        return int(self.dataset_lengths_sum / self.batch_size_pairs)

In [None]:
class SentenceBERT:
    def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwargs):
        self.sep = sep
        
        if isinstance(model_path, str):
            self.q_model = SentenceTransformer(model_path)
            self.doc_model = self.q_model
        
        elif isinstance(model_path, tuple):
            self.q_model = SentenceTransformer(model_path[0])
            self.doc_model = SentenceTransformer(model_path[1])

        elif isinstance(model_path, SentenceTransformer):
            self.q_model = model_path
            self.doc_model = self.q_model
    
    def encode_queries(self, queries: List[str], batch_size: int = 16, show_progress_bar: bool = False, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        self.q_model.show_progress_bar = False
        return self.q_model.encode(queries, batch_size=batch_size, **kwargs)
    
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, show_progress_bar: bool = False, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        self.q_model.show_progress_bar = False
        sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs)


def get_mean_scores(all_scores):
    scores = defaultdict(list)

    for name,_scores in all_scores.items():
        n = _scores["n"]
        for s in _scores["score"].items():
            scores[s[0]].extend([s[-1]]*n)

    mean_scores = {}

    for name,_scores in scores.items():
        avg = sum(_scores)/len(_scores)
        mean_scores[name] = avg

    return mean_scores


def get_metrics(metrics_str):
    metrics = {}
    for m in metrics_str:
        parts = m.split("@")
        if len(parts) == 2:
            name = ""
            ks = None

            try:
                name, ks = parts[0].lower(), parts[1]
            except:
                print(f"Invalid format for metric: {m}")

            if name in PYTREC_METRIC_MAPPING:
                _metric = PYTREC_METRIC_MAPPING[name]+"."+ks
                metrics[name] = _metric

        elif len(parts) == 1:
            name = m.lower()

            if name in PYTREC_METRIC_MAPPING:
                _metric = PYTREC_METRIC_MAPPING[name]
                metrics[name] = _metric
    return metrics


def _evaluate(
    qrels: Dict[str, Dict[str, int]], 
    results: Dict[str, Dict[str, float]], 
    measures: Dict[str, str]
    ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:
        
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, [k for _,k in measures.items()])
        scores = evaluator.evaluate(results)

        metrics = {_k:0.0 for _k,__ in scores[list(scores.keys())[0]].items()}

        for _,query_scores in scores.items():
            for score_name,score in query_scores.items():
                metrics[score_name] += score

        for k,v in metrics.items():
            metrics[k] = round(v/len(scores),5)

        return metrics


def split_variables(row):
    if row not in ["No", "NoSkip"]:
        return [tuple(x.split('-')) for x in row.replace('[','').replace(']','').split(',')]
    else:
        return [row]
    
    
def get_variables(row):
    return [x[0] for x in row.variables]


def drop_unk_variables(df):
    new_rows = []
    
    for i in range(df.shape[0]):
        row = copy.deepcopy(df.iloc[i])
        variables = row["variable"].split(";") if isinstance(row["variable"], str) else []
        
        vrds = ";".join([v for v in variables if v != "unk"])
        
        if vrds:
            row["variables"] = vrds
            new_rows.append(row)
    
    return pd.DataFrame(new_rows)


def make_label2(row, answer, join_str):
    v_id = row["id"] if row["id"] else ""
    v_label = row["label"] if row["label"] else ""
    v_topic = row["topic"] if row["topic"] else ""
    v_question = row["question"] if row["question"] else ""
    v_answer = answer if answer else ""

    label = join_str.join([v_label, v_topic, v_question, v_answer])
    
    return label


def get_labels(df, sep_answers=False, join_str="[UNK]"):
    ids = []
    labels = []

    df.fillna("", inplace=True)

    for i,row in df.iterrows():

        if sep_answers:  # split survey variable answers into separate corpus items
            answers = row["answer"].split(";")
        else:  # do not split survey variable answers
            answers = [row["answer"] if row["answer"] else ""]
            
        for v_answer in answers:
            v_id = row["id"] if row["id"] else ""
            label = make_label2(row, v_answer, join_str)
            ids.append(v_id)
            labels.append(label)
    
    return ids, labels


def make_beir_data(data_df, beir_data_dir, labels, ids):
    queries = {}
    qrels = {}
    for i,row in data_df.iterrows():
        queries[str(i)] = row.text
        rel_labels = []
        if row.variable not in ["No", "NoSkip"]:
            rel_labels = ["v"+x for x in get_variables(row)]
        qrels[str(i)] = rel_labels

    corpus = {}
    for i,label in enumerate(labels):
        corpus[ids[i]] = label

    beir_qrels_dir = os.path.join(beir_data_dir, "qrels")
    if not os.path.exists(beir_qrels_dir):
        os.makedirs(beir_qrels_dir)

    queries_beir = []
    for k,v in queries.items():
        queries_beir.append({"_id": k, "text": v})

    corpus_beir = []
    for k,v in corpus.items():
        corpus_beir.append({"_id": k, "title": "", "text": v})

    qrels_beir = {"query-id": [], "corpus-id": [], "score": []}

    for k,vals in qrels.items():
        for v in vals:
            qrels_beir["query-id"].append(k)
            qrels_beir["corpus-id"].append(v)
            qrels_beir["score"].append(1)

    df = pd.DataFrame.from_records(qrels_beir)
    df[["query-id", "corpus-id", "score"]].to_csv(os.path.join(beir_data_dir, "qrels", "all.tsv"), index=False, sep="\t")

    with jsonlines.open(os.path.join(beir_data_dir, "queries.jsonl"), "w") as writer:
        writer.write_all(queries_beir)

    with jsonlines.open(os.path.join(beir_data_dir, "corpus.jsonl"), "w") as writer:
        writer.write_all(corpus_beir)
    
    return queries_beir, corpus_beir, qrels_beir


def load_retriever(model_path, batch_size, score_function, k_values, retriever_type="dense", show_progress_bar=False):
    if isinstance(model_path, str) or isinstance(model_path, SentenceTransformer) or model_path == "BM25":
        if retriever_type == "dense":
            model = DRES(SentenceBERT(model_path), batch_size=batch_size, show_progress_bar=show_progress_bar)
        elif "BM25-" in retriever_type:
            language = retriever_type.split("-")[-1]
            model = BM25(index_name="svident", hostname="localhost", language=language, initialize=True, number_of_shards=1)
        elif "sparse":
            model = SparseSearch(models.SPARTA(model_path), batch_size=batch_size)
        else:
            raise Exception(f"Unknown retriever type: {retriever_type}")
    else:
        raise Exception(f"Unknown model type for model: {model_path}")

    retriever = EvaluateRetrieval(model, score_function=score_function, k_values=k_values)
    return retriever


def reduce_precision(results, p=5):
    return {k1:{k2:round(v2,p) for k2,v2 in v1.items()} for k1,v1 in results.items()}


def eval(df, variables, retriever, metrics):
    all_results = {}
    all_scores = {}
    all_qrels = {}

    print("Iterating over groups...")
    for name, group in tqdm(df.groupby("research_data")):
        ivariables, mapping = get_instance_variables(group, variables)

        if ivariables:

            queries = get_queries(group)
            corpus = get_corpus(group, ivariables, mapping)
            qrels = get_qrels(group, mapping)
            
            data_dir = os.path.join(".", "temp", "beir", name)
            save_files(queries, corpus, qrels, data_dir)
            
            corpus, queries, qrels = GenericDataLoader(data_folder=data_dir).load(split="all")

            results = retriever.retrieve(corpus, queries, return_sorted=True)
            results = reduce_precision(results)
            all_results[name] = results

            score = _evaluate(qrels, results, metrics)
            all_scores[name] = {"score": score, "n": group.shape[0]}

    return all_results, all_scores


def evaluate(
    input_files, 
    variables_file, 
    model_path, 
    output_dir, 
    batch_size, 
    score_function, 
    metrics, 
    k_values,
    model_save_path,
    retriever_type,
    show_progress_bar,
    ):

    assert os.path.exists(variables_file)
    variables = load_variables(variables_file)

    # Normalize metrics
    metrics = get_metrics(metrics)

    retriever = load_retriever(model_path, batch_size, score_function, k_values, retriever_type, show_progress_bar)

    for input_file in input_files:
        assert os.path.exists(input_file)
        df = pd.read_csv(input_file, sep="\t")
        df = df[df["research_data"].notna()]  # drop rows w/o research data
        df = drop_unk_variables(df)  # only evaluate on valid rows
        langs = "-".join(list(set(df["lang"].to_list())))
        if os.path.exists(model_save_path):
            output_subdir = os.path.join(output_dir, langs, os.path.basename(model_save_path))
        else:
            output_subdir = os.path.join(output_dir, langs, model_save_path)
        if not os.path.exists(output_subdir):
            os.makedirs(output_subdir)
        print(f"Input file contains {df.shape[0]} instances.")

        _type = ""
        if "explicit" in input_file:
            _type = "explicit"
        elif "other" in input_file:
            _type = "other"
        elif "train_unreleased" in input_file:
            _type = "train_unreleased"
        elif "test.tsv" in input_file:
            _type = "full_test"
        elif "train" in input_file:
            _type = "train"
        elif "test" in input_file:
            _type = "test"
        else:
            timestamp = datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
            _type = timestamp

        df = drop_unk_variables(df)
        results, scores = eval(df, variables, retriever, metrics)

        mean_scores = get_mean_scores(scores)
        print(mean_scores)

        if output_dir:
                        
            with open(os.path.join(output_subdir, f"{_type}_results.json"), "w") as fp:
                json.dump(results, fp)

            with open(os.path.join(output_subdir, f"{_type}_scores.json"), "w") as fp:
                json.dump(mean_scores, fp)
            
            if _type == "full_test" or retriever_type.split("-")[0] in ["BM25", "BM25", "sparse"]:
                # Submission
                submission = {k:v for rd in results for k,v in results[rd].items()}
                with open(os.path.join(output_subdir, "submission.json"), "w") as fp:
                    json.dump(submission, fp)

            print(f"Saved outputs to: {output_subdir}")

In [None]:
# mode_path: choose from BM25, local paths, or huggingface hub models
# retriever_type: choose from BM25-english, BM25-german, sparse, dense
# If using BM25, make sure that Elasticsearch is running in the background.
# The baseline in the paper used Elasticsearch 7.9.2 (https://www.elastic.co/downloads/past-releases/elasticsearch-7-9-2)

model_path="BM25"
retriever_type="BM25-english"
dataset_names=None
dataset_indicies=None
num_epochs=1
batch_size_pairs=64
batch_size_triplets=32
max_seq_length=128
no_amp=False
warmup_steps=500
subset=None
seeds=[0]
eval_input_files=["./sv-ident/data/train/test.tsv"]
eval_variables_file="./sv-ident/data/train/variables_metadata.json"
eval_output_dir="./results/"
eval_batch_size=8
eval_score_function="cos_sim"
eval_metrics=["map@10,30,50", "p@1,5,10,30", "r@1,5,10,30", "rprec"]
eval_k_values=[1,5,10,20]   
show_progress_bar=False


for seed in seeds:
    set_seeds(seed)  # for reproducibility
    
    if model_path != "BM25" and retriever_type == "dense":
        word_embedding_model = models.Transformer(model_path, max_seq_length=max_seq_length)
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
        model.show_progress_bar = False
        
        model_save_path = f"seed={seed}_"+os.path.basename(model_path).replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    else:
        model = None
        model_save_path = f"seed={seed}_"+os.path.basename(model_path).replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # Evaluate model
    evaluate(
        eval_input_files, 
        eval_variables_file, 
        model if model else model_path, 
        eval_output_dir, 
        eval_batch_size, 
        eval_score_function, 
        eval_metrics, 
        eval_k_values,
        os.path.basename(model_save_path),
        retriever_type,
        show_progress_bar,
    )