In [1]:
import pandas as pd
from transformers import (
    DPRContextEncoder,
    DPRQuestionEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoderTokenizer,
)
import torch.nn.functional as F
import torch
import numpy as np
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

  return torch._C._cuda_getDeviceCount() > 0


# Test

In [3]:
def test(question_encoder, context_encoder):
    with open("test_data.json", "r") as f:
        test_data = json.load(f)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base"
    )
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
        "facebook/dpr-ctx_encoder-single-nq-base"
    )
    for data in test_data:
        label = next(iter(data))
        pos = data[label]["pos"]
        neg = data[label]["neg"]
        all = pos + neg
        # Tokenize the question and the context
        tokenized_question = question_tokenizer(
            label, return_tensors="pt", padding="max_length", max_length=512, truncation=True
        )
        question_input_ids = tokenized_question["input_ids"]
        question_attention_mask = tokenized_question["attention_mask"]

        tokenized_context = context_tokenizer(
            all, return_tensors="pt", padding="max_length", max_length=512, truncation=True
        )
        context_input_ids = tokenized_context["input_ids"]
        context_attention_mask = tokenized_context["attention_mask"]

        # Encode the question and the context
        question_output = question_encoder(question_input_ids.to(device), question_attention_mask.to(device)).pooler_output
        context_output = context_encoder(context_input_ids.to(device), context_attention_mask.to(device)).pooler_output
        scores = F.cosine_similarity(question_output, context_output)
        _, indices = torch.topk(scores, 5)
        relevant_passages = np.array(all)[indices.cpu().numpy()]
        num_correct = 0
        for p in relevant_passages:
            if p in pos:
                num_correct += 1
        print(f"Accuracy ({label}): {num_correct}/{len(relevant_passages)}")

In [4]:
question_encoder = DPRQuestionEncoder.from_pretrained(
    "facebook/dpr-question_encoder-single-nq-base"
).to(device)
context_encoder = DPRContextEncoder.from_pretrained(
    "facebook/dpr-ctx_encoder-single-nq-base"
).to(device)
test(question_encoder, context_encoder)

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

Accuracy (Java_Developer): 2/5
Accuracy (Web_Developer): 1/5
Accuracy (Software_Developer): 1/5
Accuracy (Python_Developer): 4/5
Accuracy (Front_End_Developer): 3/5
Accuracy (Network_Administrator): 3/5
Accuracy (Database_Administrator): 3/5
Accuracy (Project_manager): 2/5
Accuracy (Systems_Administrator): 2/5
Accuracy (Security_Analyst): 2/5
Accuracy (Java_Developer): 1/5
Accuracy (Web_Developer): 3/5
Accuracy (Software_Developer): 2/5
Accuracy (Python_Developer): 3/5
Accuracy (Front_End_Developer): 2/5
Accuracy (Network_Administrator): 4/5
Accuracy (Database_Administrator): 4/5
Accuracy (Project_manager): 2/5
Accuracy (Systems_Administrator): 3/5
Accuracy (Security_Analyst): 4/5
Accuracy (Java_Developer): 4/5
Accuracy (Web_Developer): 2/5
Accuracy (Software_Developer): 2/5
Accuracy (Python_Developer): 2/5
Accuracy (Front_End_Developer): 1/5
Accuracy (Network_Administrator): 2/5
Accuracy (Database_Administrator): 1/5
Accuracy (Project_manager): 1/5
Accuracy (Systems_Administrator): 3/