# Dense Retrieval for Relation Extraction

## 1. Load Corpus

In [17]:
# load entities
import json

with open('../data/relations.json', 'r') as fin:
    relations = json.load(fin)

print(len(relations), 'relations')

8913 relations


In [18]:
# a subset of questions from https://github.com/askplatypus/wikidata-simplequestions
from beir.datasets.data_loader import GenericDataLoader

data_path = '/ivi/ilps/personal/svakule/spoken_qa/'
dataset = 'WD18_relations/'
split = 'valid'

query_path = data_path + dataset + "%s_original.jsonl" % split  # original text questions
# query_path = data_path + dataset + "%s_wav2vec2-base-960h.jsonl" % split  # questions transcribed from synthethised speech

qrels_path = data_path + dataset + "%s.tsv" % split
corpus_path = data_path + dataset + "relations.jsonl"


corpus, queries, qrels = GenericDataLoader(
    corpus_file=corpus_path, 
    query_file=query_path, 
    qrels_file=qrels_path).load_custom()

## 2. Evaluate with BEIR

In [19]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models

In [20]:
# TAS-B trained on original corrupted WD18
model_name = '/ivi/ilps/personal/svakule/msmarco/output/msmarco-distilbert-base-tas-b-WD18-original-corrupted'
model = DRES(models.SentenceBERT(model_name))
retriever = EvaluateRetrieval(model, score_function="dot")
results1 = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results1, retriever.k_values)
# acc = precision['P@1']
print(ndcg, _map, recall, precision)

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

{'NDCG@1': 0.03978, 'NDCG@3': 0.06112, 'NDCG@5': 0.07255, 'NDCG@10': 0.08704, 'NDCG@100': 0.1335, 'NDCG@1000': 0.17667} {'MAP@1': 0.03978, 'MAP@3': 0.05582, 'MAP@5': 0.0622, 'MAP@10': 0.06808, 'MAP@100': 0.07594, 'MAP@1000': 0.07728} {'Recall@1': 0.03978, 'Recall@3': 0.07646, 'Recall@5': 0.10408, 'Recall@10': 0.14942, 'Recall@100': 0.38623, 'Recall@1000': 0.74196} {'P@1': 0.03978, 'P@3': 0.02549, 'P@5': 0.02082, 'P@10': 0.01494, 'P@100': 0.00386, 'P@1000': 0.00074}


In [21]:
# TAS-B trained on original WD18
model_name = '/ivi/ilps/personal/svakule/msmarco/output/msmarco-distilbert-base-tas-b-WD18'
model = DRES(models.SentenceBERT(model_name))
retriever = EvaluateRetrieval(model, score_function="dot")
results1 = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results1, retriever.k_values)
# acc = precision['P@1']
print(ndcg, _map, recall, precision)

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

{'NDCG@1': 0.04081, 'NDCG@3': 0.06281, 'NDCG@5': 0.07195, 'NDCG@10': 0.08323, 'NDCG@100': 0.12757, 'NDCG@1000': 0.17606} {'MAP@1': 0.04081, 'MAP@3': 0.05719, 'MAP@5': 0.06225, 'MAP@10': 0.06688, 'MAP@100': 0.07436, 'MAP@1000': 0.07588} {'Recall@1': 0.04081, 'Recall@3': 0.07914, 'Recall@5': 0.1014, 'Recall@10': 0.13644, 'Recall@100': 0.36294, 'Recall@1000': 0.76175} {'P@1': 0.04081, 'P@3': 0.02638, 'P@5': 0.02028, 'P@10': 0.01364, 'P@100': 0.00363, 'P@1000': 0.00076}


In [22]:
# TAS-B winner!
model = DRES(models.SentenceBERT("msmarco-distilbert-base-tas-b"))
retriever = EvaluateRetrieval(model, score_function="dot")
results1 = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results1, retriever.k_values)
acc = precision['P@1']
# print(acc)
print(ndcg, _map, recall, precision)

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

{'NDCG@1': 0.29761, 'NDCG@3': 0.38768, 'NDCG@5': 0.40999, 'NDCG@10': 0.43107, 'NDCG@100': 0.47815, 'NDCG@1000': 0.49608} {'MAP@1': 0.29761, 'MAP@3': 0.36535, 'MAP@5': 0.37782, 'MAP@10': 0.38648, 'MAP@100': 0.3953, 'MAP@1000': 0.39594} {'Recall@1': 0.29761, 'Recall@3': 0.45239, 'Recall@5': 0.50618, 'Recall@10': 0.57152, 'Recall@100': 0.80132, 'Recall@1000': 0.94394} {'P@1': 0.29761, 'P@3': 0.1508, 'P@5': 0.10124, 'P@10': 0.05715, 'P@100': 0.00801, 'P@1000': 0.00094}


In [23]:
# DistilBERT v3 cosine
model = DRES(models.SentenceBERT("msmarco-distilbert-base-v3"))
retriever = EvaluateRetrieval(model, score_function="cos_sim")
results = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
precision['P@1']

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.17766

In [24]:
# DistilBERT dot
model = DRES(models.SentenceBERT("msmarco-distilbert-base-dot-prod-v3"))
retriever = EvaluateRetrieval(model, score_function="dot")
results = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
precision['P@1']

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.1488

In [25]:
# DistilBERT v2
model = DRES(models.SentenceBERT("msmarco-distilbert-base-v2"))
retriever = EvaluateRetrieval(model, score_function="cos_sim")
results = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
precision['P@1']

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.20342

In [26]:
# ANCE
model = DRES(models.SentenceBERT("msmarco-roberta-base-ance-fristp"))
retriever = EvaluateRetrieval(model, score_function="dot")
results = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
precision['P@1']

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

Batches:   0%|          | 0/70 [00:00<?, ?it/s]

0.16941

## 3. Error Analysis

In [16]:
import random

top_k = 10
n_times = 1000

for _ in range(n_times):
    query_id, ranking_scores = random.choice(list(results1.items()))
    scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
    correct_id = list(qrels[query_id].keys())[0]
    correct_label = relations[correct_id]
    for rank in range(top_k):
        doc_id = scores_sorted[rank][0]
        if doc_id == correct_id and rank==0:
            break
        if corpus[doc_id].get("text") == correct_label and rank==0:
            break
        else:
            if rank==0:
                print("\nQuery : %s, %s" % (queries[query_id], correct_label))
            print("Rank %d: %s - %s" % (rank+1, doc_id, corpus[doc_id].get("text")))
            if doc_id == correct_id:
                print('Correct!')
#             break

KeyError: 'Q4877206'