# Extractive Question Answering with Qdrant

Welcome to a thrilling journey into the realm of AI! In this notebook, we're going to explore an exciting aspect of Natural Language Processing (NLP) - Extractive Question Answering.

Question Answering systems can respond to user queries with precise answers. 'Extractive' means our system will pull the answer directly from a given context, rather than generating new text. It's like having your own personal librarian who knows every book cover to cover and can pull the perfect quote for any question you ask!

To make our 'AI Librarian', we will be using three main components:
1. **Qdrant**: Powers our performant vector search. It's our magic bookshelf that finds the right book.
2. **Retriever Model**: It helps in embedding context passages into numerical representations (vectors) that Qdrant can store and search efficiently.
3. **Reader Model**: Once Qdrant finds the most relevant passages for a question, our reader model goes through these passages to extract the precise answer.

## Install dependencies

Let's get started by installing prerequisite packages:

In [None]:
!pip install -qU datasets==2.12.0 qdrant-client==1.2.0 sentence-transformers==2.2.2 torch==2.0.1

### Import libraries

In [None]:
import torch
from datasets import load_dataset
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm.auto import tqdm
from typing import List

## Load and process dataset

We'll use the [DuoRC dataset](https://huggingface.co/datasets/duorc), containing questions, plots and answers crowd-sourced from Wikipedia and IMDb movie plots.

We generate embeddings for the context passages using the retriever, index them in the Qdrant vector database, and query to retrieve the top k most relevant contexts containing potential answers to our question. We then use the reader model to extract the answers from the returned contexts.

We load the dataset into a pandas dataframe. Keep the title and plot columns, and we drop duplicates.

In [None]:
# load the duorc dataset into a pandas dataframe
df = load_dataset("duorc", "ParaphraseRC", split="train").to_pandas()
df = df[["title", "plot"]]  # select only title and plot column
print(f"Before removing duplicates: {len(df)}")

df = df.drop_duplicates(subset="plot")  # drop rows containing duplicate plot passages, if any
print(f"Unique Plots: {len(df)}")
df.head()

Found cached dataset duorc (/Users/nirantk/.cache/huggingface/datasets/duorc/ParaphraseRC/1.0.0/7a96356b7615d573abcd03a9328292c38348547971989538a771c32089bff199)


Before removing duplicates: 69524
Unique Plots: 5133


Unnamed: 0,title,plot
0,Ghosts of Mars,"Set in the second half of the 22nd century, Ma..."
15,Noriko's Dinner Table,"The film starts on December 12th, 2001 with a ..."
34,Gutterballs,A brutally sadistic rape leads to a series of ...
83,An Innocent Man,Jimmie Rainwood (Tom Selleck) is a respected m...
105,The Sorcerer's Apprentice,"Every hundred years, the evil Morgana (Kelly L..."


## Initialize Qdrant client
The Qdrant collection stores vector representations of our context passages which we can retrieve using another vector (query vector)

In [None]:
client = QdrantClient(":memory:")

## Create collection

Now we create a new collection called `extractive-question-answering` — we can name the collection anything we want.

We specify the metric type as "cosine" and dimension or size as 384 because the retriever we use to generate context embeddings is optimized for cosine similarity and outputs 384-dimension vectors.

In [None]:
collection_name = "extractive-question-answering"

collections = client.get_collections()
print(collections)

# only create collection if it doesn't exist
if collection_name not in [c.name for c in collections.collections]:
    client.recreate_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(
            size=384,
            distance=models.Distance.COSINE,
        ),
    )
collections = client.get_collections()
print(collections)

collections=[]
collections=[CollectionDescription(name='extractive-question-answering')]


## Initialize retriever

Next, we need to initialize our retriever. The retriever will mainly do two things:

- Generate embeddings for all context passages (context vectors/embeddings)
- Generate embeddings for our questions (query vector/embedding)

The retriever will generate embeddings in a way that the questions and context passages containing answers to our questions are nearby in the vector space. We can use cosine similarity to calculate the similarity between the query and context embeddings to find the context passages that contain potential answers to our question.

### Embedding model

We will use a SentenceTransformer model named ``multi-qa-MiniLM-L6-cos-v1`` designed for semantic search and trained on 215M (question, answer) pairs from diverse sources as our retriever. It's also quite competitive on two embedding and retrieval benchmarks: [MTEB](https://github.com/embeddings-benchmark/mteb) and [BEIR](arxiv.org/abs/2104.08663)

In [None]:
# set device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# load the retriever model from huggingface model hub
retriever = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device=device)
retriever

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

## Generate Embeddings -> Store in Qdrant

Next, we need to generate embeddings for the context passages. We will use the `retriever.encode` for that.

When passing the documents to Qdrant, we need an:
1. id (a unique integer value),
2. context embedding, and
3. payload for each document representing context passages in the dataset. The payload is a dictionary containing data relevant to our embeddings, such as the title, plot etc.

In [None]:
%%time

batch_size = 512  # specify batch size according to your RAM and compute, higher batch size = more RAM usage

for index in tqdm(range(0, len(df), batch_size)):
    i_end = min(index + batch_size, len(df))  # find end of batch
    batch = df.iloc[index:i_end]  # extract batch
    emb = retriever.encode(batch["plot"].tolist()).tolist()  # generate embeddings for batch
    meta = batch.to_dict(orient="records")  # get metadata
    ids = list(range(index, i_end))  # create unique IDs

    # upsert to qdrant
    client.upsert(
        collection_name=collection_name,
        points=models.Batch(ids=ids, vectors=emb, payloads=meta),
    )

collection_vector_count = client.get_collection(collection_name=collection_name).vectors_count
print(f"Vector count in collection: {collection_vector_count}")
assert collection_vector_count == len(df)

## Initialize Reader

We use the `bert-large-uncased-whole-word-masking-finetuned-squad` model from the HuggingFace model hub as our reader model. This is finetuned on the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). It is trained to extract an answer from a given context. This special mechanism is why we can use this model to extract answers from our context passages.

This is our (encoder) component which uses the contexts to extract an answer.

In [None]:
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"

# load the reader model into a question-answering pipeline
reader = pipeline("question-answering", model=model_name, tokenizer=model_name)
print(reader.model, reader)

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), ep

Now all the components we need are ready. Let's write some helper functions to execute our queries. The `get_relevant_plot` function retrieves the context embeddings containing answers to our question from the Qdrant collection, and the `extract_answer` function extracts the answers from these context passages.

## Get context

The `get_relevant_plot()` function is your librarian to the vast universe of stories stored in Qdrant.

When you have a question or need a specific story (plot), you tell this guide your question and how many top matches you want. The guide then translates your question into a language Qdrant understands, finds the best matching stories in Qdrant's massive library, and delivers you the titles and contents of these matches.

In [None]:
def get_relevant_plot(question: str, top_k: int) -> List[str]:
    """
    Get the relevant plot for a given question

    Args:
        question (str): What do we want to know?
        top_k (int): Top K results to return

    Returns:
        context (List[str]):
    """
    try:
        encoded_query = retriever.encode(question).tolist()  # generate embeddings for the question

        result = client.search(
            collection_name=collection_name,
            query_vector=encoded_query,
            limit=top_k,
        )  # search qdrant collection for context passage with the answer

        context = [
            [x.payload["title"], x.payload["plot"]] for x in result
        ]  # extract title and payload from result
        return context

    except Exception as e:
        print({e})

## Extracting an answer

Here is how the engine operates:

1. The central part of the function is `extract_answer`. Qdrant processes your question and retrieves all related context.

2. All related context is processed via the `reader`, which looks at each piece of context and extracts an answer that best fits your question.

3. The function sorts all answers by confidence score, with the top score at the front. Each answer has a title in order to provide context.

4. The result is a sorted list of potential answers, their confidence scores and associated titles.

That's it! All you have to do is put in a question, and wait for an ordered list of the best possible answers. The advantage of this engine is that it also tells you where the answer came from and how confident it is about the result.

In [None]:
def extract_answer(question: str, context: List[str]):
    """
    Extract the answer from the context for a given question

    Args:
        question (str): _description_
        context (list[str]): _description_
    """
    results = []
    for c in context:
        # feed the reader the question and contexts to extract answers
        answer = reader(question=question, context=c[1])

        # add the context to answer dict for printing both together, we print only first 500 characters of plot
        answer["title"] = c[0]
        results.append(answer)

    # sort the result based on the score from reader model
    sorted_result = sorted(results, key=lambda x: x["score"], reverse=True)
    for i in range(len(sorted_result)):
        print(f"{i+1}", end=" ")
        print(
            "Answer: ",
            sorted_result[i]["answer"],
            "\n  Title: ",
            sorted_result[i]["title"],
            "\n  score: ",
            sorted_result[i]["score"],
        )


question = "In the movie 3 Idiots, what is the name of the college where the main characters Rancho, Farhan, and Raju study"
context = get_relevant_plot(question, top_k=1)
context

[['Three Idiots',

As we can see, the retriever is working fine and gets us the context passage that contains the answer to our question. Now let's use the reader to extract the exact answer from the context passage.

In [None]:
extract_answer(question, context)

1 Answer:  Imperial College of Engineering 
  Title:  Three Idiots 
  score:  0.9049272537231445


The reader model predicted with 90% accuracy the correct answer as seen from the context passage. Let's run few more queries.

In [None]:
question = "Who hates Harry Potter?"
context = get_relevant_plot(question, top_k=1)
extract_answer(question, context)

1 Answer:  . 
  Title:  Harry Potter and the Half-Blood Prince 
  score:  0.15585105121135712


This might look like a simple question, but it's actually a pretty tough one for our model. The answer is not explicitly mentioned in the context passage, but the model still tries to extract the answer from the context passage.

In [None]:
question = "Who wants to kill Harry Potter?"
context = get_relevant_plot(question, top_k=1)
extract_answer(question, context)

1 Answer:  Lord Voldemort 
  Title:  Harry Potter and the Philosopher's Stone 
  score:  0.9568217992782593


In [None]:
question = "In the movie The Shawshank Redemption, what was the item that Andy Dufresne used to escape from Shawshank State Penitentiary?"
context = get_relevant_plot(question, top_k=1)
extract_answer(question, context)

1 Answer:  rock hammer 
  Title:  The Shawshank Redemption 
  score:  0.8666210770606995


Let's run another question. This time for top 3 context passages from the retriever.

In [None]:
question = "who killed the spy"
context = get_relevant_plot(question, top_k=3)
extract_answer(question, context)

1 Answer:  Soviet agents 
  Title:  Tinker, Tailor, Soldier, Spy 
  score:  0.7920866012573242
2 Answer:  Gila 
  Title:  Our Man Flint 
  score:  0.12037214636802673
3 Answer:  Gabriel's assassins 
  Title:  Live Free or Die Hard 
  score:  0.06259559094905853


### Cleaning up

We delete the collection from Qdrant and close the connection to the database. This is important to do, otherwise the collection will keep running in the background and consume resources. In a production environment, you would not want to do this. Here, we are mentioning this for completeness.

In [None]:
client.delete_collection(collection_name=collection_name)

True