<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/ufidon/nlp/blob/main/qair.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
  </td>
  <td>
    <a target="_blank" href="https://kaggle.com/kernels/welcome?src=https://github.com/ufidon/nlp/blob/main/qair.ipynb"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" /></a>
  </td>
</table>
<br>

**Question Answering, Information Retrieval, and Retrieval Augmented Generation**

- 📝 SALP chapter 14

## Overview
- `Question answering (QA) systems` have supported human information needs since the 1960s, advancing over time with innovations like IBM’s Watson, which surpassed human performance on "Jeopardy!" in 2011.
  - They are integrated with search engines and modern large language models (LLMs), enhancing their ability to answer diverse queries.

- `Prompt-based methods` in modern QA use pretrained LLMs to directly answer fact-based questions from stored knowledge.
  - LLMs often produce `hallucinations`, confidently giving incorrect answers, especially in specialized domains, due to poor calibration.
  - `Simple prompting LLMs` limits their ability to access proprietary or real-time data, as they cannot update information after training.
  - `Retrieval-augmented generation (RAG)` overcomes these issues by combining `document retrieval` with LLMs, grounding answers in relevant, curated data and enabling responses based on proprietary information.
  - RAG uses `information retrieval techniques`, from tf-idf to neural models like BERT, to select relevant documents, enhancing answer reliability and context-awareness.

## Information Retrieval
- `Information retrieval (IR)` is the field focused on retrieving media based on user needs,
  -  with IR systems often functioning as `search engines`; 
  -  `ad hoc retrieval` is a core IR task where users enter queries to find ordered sets of relevant documents.

![The architecture of an ad hoc IR system](./images/qa/ir.png)

- Key components of IR include 
  - documents: text units like web pages or articles, 
  - collections: sets of documents, 
  - terms: words or phrases in documents, and 
  - queries: user-expressed needs.

- IR systems commonly use the `vector space` model, 
  - representing documents and queries as vectors based on word counts and ranking them by cosine similarity, 
  - following a bag-of-words approach where word positions are disregarded.

### Term weighting and document scoring
- IR scores a document’s match with a query by assigning term weights to each word, commonly using weighting schemes:
  - [tf-idf](https://scikit-learn.org/1.5/modules/feature_extraction.html): term-frequency — inverse document-frequency
  - [BM25](https://huggingface.co/blog/xhluca/bm25s): Best Match 25

- **Tf-idf** is the product of term frequency (**tf**) and inverse document frequency (**idf**): 
  - **Term frequency (tf)** measures word frequency in a document, using a logarithmic scale to avoid overemphasis on high counts:
    - $tf_{t, d} = 
    \begin{cases} 
      1 + \log_{10}(\text{count}(t, d)) & \text{if count}(t, d) > 0 \\
      0 & \text{otherwise}
   \end{cases}$

  - **Inverse Document Frequency (idf)** measures how informative a term is across documents, `decreasing` for terms appearing in many documents:
    - $idf_t = \log_{10} \dfrac{N}{df_t}$
    where $N$ is the total number of documents, and $df_t$ is the number of documents containing term $t$.
- The **tf-idf** score for a word $t$ in document $d$ is computed by multiplying `tf` and `idf`:
  - $\text{tf-idf}(t, d) = tf_{t, d} \cdot idf_t$

### Document scoring
- Document scoring uses the **cosine similarity** between the query vector $𝐪$ and document vector $𝐝$, computed as:
  - $\text{score}(𝐪, 𝐝) = \cos(𝐪, 𝐝) = \dfrac{𝐪 \cdot 𝐝}{|𝐪||𝐝|}$

  - Cosine similarity can also be viewed as the **dot product of unit vectors** by normalizing both vectors to unit length before calculation:
    - $\text{score}(𝐪, 𝐝) = \cos(𝐪, 𝐝) = \left(\dfrac{𝐪}{|𝐪|}\right) \cdot \left(\dfrac{𝐝}{|𝐝|}\right)$
  - In `tf-idf`:
    - $\text{score}(𝐪, 𝐝) = \sum_{t \in 𝐪} \frac{\text{tf-idf}(t, 𝐪)}{\sqrt{\sum_{q_i \in 𝐪} \text{tf-idf}^2(q_i, 𝐪)}} \cdot \frac{\text{tf-idf}(t, 𝐝)}{\sqrt{\sum_{d_i \in 𝐝} \text{tf-idf}^2(d_i, 𝐝)}}$

- Simplifications and **variants of tf-idf cosine** scoring are used in practice, such as omitting the **idf term** in the document to enhance performance.

  - **BM25** is an advanced variant of tf-idf that introduces two parameters
    - **k** (balancing term frequency and IDF) 
    - **b** (adjusting document length normalization)—to refine scoring.

- **Stop lists** (lists of high-frequency words like "the," "a") were traditionally used to exclude common terms from queries and documents, reducing index size; 
  - modern IR systems now rarely use stop lists due to advancements in IDF weighting and processing efficiency.


### [Inverted Index](https://nlp.stanford.edu/IR-book/html/htmledition/a-first-take-at-building-an-inverted-index-1.html)
- In information retrieval, the goal is to find documents containing query terms, ignoring those with none of the terms.

- An `inverted index` structure, with a dictionary and postings lists, efficiently identifies relevant documents and stores term frequencies and positions.

- The dictionary links terms to postings lists, which provide document IDs and term-related data for fast score computation.

- Alternatives, such as bigram indexing and hashing, can improve efficiency in tasks like finding Wikipedia pages for question answering.

## Information Retrieval with Dense Vectors
- Traditional IR methods like tf-idf and BM25 have the **Vocabulary Mismatch Problem**
  - They require `exact word overlap` between query and document, which limits retrieval when `synonyms` are used.

  - `Dense embeddings`, such as those generated by BERT, address this issue by encoding semantic meaning, rather than relying on exact word matches.

- **Single encoder (a)**: The query and document are fed into a BERT model, with a linear layer on the [CLS] token to predict similarity, allowing context-sensitive matching.
  - $\text{score}(𝐪,𝐝) = \text{softmax}(𝐔(𝐳))$
    - $𝐳 = \text{BERT}(𝐪;[\text{SEP}];𝐝)[\text{CLS}]$
  - Documents are split into short passages (e.g., 100 tokens) to fit within BERT's 512-token limit, and the model is fine-tuned on a relevance dataset for better retrieval accuracy.
  - `Single encoder` uses a `full BERT encoder` for each query-document pair, which is computationally expensive, 
    - as it requires encoding each document in the collection with every new query.

![Two ways to do dense retrieval](./images/qa/qdbert.png)


- **Bi-Encoder (b)** encodes each document only once and precomputes document vectors, enabling `fast` query processing by computing dot products between the query vector and precomputed document vectors.
  - $\text{score}(𝐪,𝐝) = 𝐳_𝐪 ⋅ 𝐳_𝐝$
    - $𝐳_𝐪 = \text{BERT}_Q(𝐪)[\text{CLS}]$
    - $𝐳_𝐝 = \text{BERT}_D(𝐝)[\text{CLS}]$
  - However, it sacrifices some accuracy by not capturing detailed token-level interactions between the query and document.

- Numerous approaches between the full encoder and the bi-encoder balance accuracy and efficiency by 
  - using cheaper ranking methods (e.g., BM25) to initially rank documents 
  - then applying costly BERT scoring to rerank only the top-ranked ones.

- [ColBERT (Contextualized Late Interaction over BERT)](https://huggingface.co/colbert-ir/colbertv2.0) encodes the query and document separately into token-level representations, allowing it to pre-store document representations for faster scoring.

- ![A sketch of the ColBERT algorithm at inference time](./images/qa/colbert.png)

  - It computes relevance by summing the maximum similarity between each query token and the most contextually similar document token, highlighting `token-level contextual similarity`.
  - Query and document tokens are processed with BERT, including special `[Q]` and `[D]` tokens, and then scaled to unit length, optimizing vector size for storage efficiency.
  - It requires end-to-end training, fine-tuning BERT encoders and linear layers using positive and negative document pairs, optimizing relevance scoring through cross-entropy loss.

- Training data for supervised methods like ColBERT often includes labeled positive and negative passages, with semi-supervised approaches or iterative methods used when labeled data is sparse.

- Efficient ranking of dense vectors uses `approximate nearest neighbor search algorithms` (e.g., [Faiss](https://huggingface.co/docs/datasets/en/faiss_es)) to quickly find the most similar document vectors for a query.

## Answering Questions with Retrieval-Augmented Generation
- The retrieval-augmented generation (RAG) approach to question answering involves 
  - retrieving relevant text segments (retriever) 
  - generating an answer based on these documents (reader).

![Retrieval-based question answering](./images/qa/rag.png)

- This two-stage model first uses dense retrievers to find supportive passages and then employs a large language model to generate the answer from the retrieved content, one token at a time.

### Retrieval-Augmented Generation
- RAG uses retrieved passages to help a language model generate answers, addressing limitations of simple conditional generation.
  - simple autoregressive language modeling: 
    - $\displaystyle p(x_1,⋯,x_n)=∏_{i=1}^n p([Q:]; q; [A:]; x_{<i})$ 
- In RAG, a language model is conditioned on both the question and retrieved passages to reduce issues like hallucinations and provide evidence-based answers.
  - Simple prompting can work for basic fact-based questions, but RAG `improves reliability` by using `specific prompts` like "Based on these texts, answer this question."
  - $\displaystyle p(x_1,⋯,x_n)=∏_{i=1}^n p(x_i|R(q);\text{prompt};[Q:]; q; [A:]; x_{<i})$ 
- Effective RAG requires a well-performing retriever, often utilizing a `two-stage process` that ranks retrieved passages to improve accuracy.
  - For complex questions, RAG may use `multi-hop retrieval`, combining initial retrieval results with follow-up searches for more context.
- `Privacy and prompt engineering` are important considerations, especially in combining private and public data, with ongoing research focusing on refining the integration of retrieval and generation stages.

### Question Answering Datasets
- `Question-answering (QA) datasets` support both training and evaluating language models.  They vary by purpose, 
  - targeting natural information-seeking `questions`
  - `probing` system knowledge and reasoning.
- Natural question datasets include 
  - [Natural Questions](https://huggingface.co/datasets/google-research-datasets/natural_questions): based on Google queries, with answers derived from Wikipedia
  - [MS MARCO (Microsoft Machine Reading Comprehension)](https://huggingface.co/datasets/microsoft/ms_marco): Bing queries with human-generated answers and passages.
- Non-English natural question datasets include 
  - [DuReader](https://paperswithcode.com/dataset/dureader): Chinese search engine queries
  - [TyDi QA](https://huggingface.co/datasets/google-research-datasets/tydiqa): questions in 11 languages from Wikipedia passages.
- Probing datasets, such as [MMLU (Massive Multitask Language Understanding)](https://paperswithcode.com/dataset/mmlu), assess knowledge across multiple fields like medicine and computer science, using questions sourced from exams.
- Some datasets augment questions with passages, enabling reading comprehension tasks that require extracting answers from provided texts.

- QA tasks vary: 
  - `open book` QA uses retrieval-augmented generation, 
  - `closed book` QA answers directly from the model without retrieval.
- Answer formats differ, including multiple-choice and freeform, affecting model requirements for generating or selecting responses.

- Prompting styles in QA, impacting model performance on complex questions, can be 
  - zero-shot (only the question) 
  - few-shot (with examples).
- `MMLU` supports both zero-shot and few-shot prompting, making it suitable for testing various model capabilities.

## Evaluating Question Answering
- **Exact Match**: For multiple-choice questions (like in MMLU), QA systems are evaluated by the `percentage of answers that exactly match the correct answer`.

- **Token F1 Score**: For free-text questions (e.g., Natural Questions), token F1 is used to measure `partial overlap between the predicted answer and the reference`, treating them as bags of tokens and averaging scores across all questions.

- **Mean Reciprocal Rank (MRR)**: For systems providing ranked answers, MRR evaluates the rank of the first correct answer, scoring each question based on `the reciprocal of this rank` and averaging scores across questions.
  - If no correct answer is given, the score for that question is zero, with some versions of MRR excluding these zero-score questions from the final calculation.

-🏃 Practice [Question answering from HuggingFace NLP](https://huggingface.co/learn/nlp-course/en/chapter7/7?fw=pt)

In [None]:
# 1. Install required libraries
!pip install datasets evaluate transformers[sentencepiece]
!pip install accelerate
# To run the training on TPU, you will need to uncomment the following line:
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!apt install git-lfs

In [None]:
# 2. [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/)
from datasets import load_dataset
raw_datasets = load_dataset("squad")

In [None]:
# features of the dataset
raw_datasets

In [None]:
# 3. check the first element
print("Context: ", raw_datasets["train"][0]["context"])
print("Question: ", raw_datasets["train"][0]["question"])
print("Answer: ", raw_datasets["train"][0]["answers"])
# In the Answer, The text field is rather obvious, 
# and the answer_start field contains the `starting character index` of each answer in the context.

In [None]:
# there is only one possible answer in each training record
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

In [None]:
# but there may be several answers in each evaluation record
print(raw_datasets["validation"][0]["context"])
print(raw_datasets["validation"][0]["question"])
print(raw_datasets["validation"][0]["answers"])

In [None]:
print(raw_datasets["validation"][2]["context"])
print(raw_datasets["validation"][2]["question"])
print(raw_datasets["validation"][2]["answers"])

In [None]:
# 4. Processing the training data
# 4.1 prepare a [fast tokenizer](https://huggingface.co/docs/transformers/index)
from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
# make sure the tokenizer is really fast
tokenizer.is_fast

In [None]:
context = raw_datasets["train"][0]["context"]
question = raw_datasets["train"][0]["question"]

# 4.2 The tokenizer will properly insert the special tokens to form a sentence like this:
# [CLS] question [SEP] context [SEP]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])
# The labels are the index of the tokens starting and ending the answer, 
# and the model will be tasked to predicted one start and end logit 
# per token in the input.

In [None]:
# 4.2 deal long contexts with a sliding windows
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
)

for ids in inputs["input_ids"]:
    print(tokenizer.decode(ids))

# The example is split into four inputs 
# each contains the question and some part of the context 
# some do not contain the answer
#   the labels will be start_position = end_position = 0 
#   so we predict the [CLS] token
# nonzero (start_position, end_position) if the answer is in the context

In [None]:
# 4.3 overflow_to_sample_mapping and offset mappings
# https://huggingface.co/learn/nlp-course/chapter6/4
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)
inputs.keys()

In [None]:
inputs["overflow_to_sample_mapping"]

In [None]:
# more examples give more useful result
inputs = tokenizer(
    raw_datasets["train"][2:6]["question"],
    raw_datasets["train"][2:6]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The 4 examples gave {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

In [None]:
# 4.4 find the answer
answers = raw_datasets["train"][2:6]["answers"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

start_positions, end_positions

In [None]:
# answer in the context
idx = 0
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

In [None]:
# answer not in the context
idx = 4
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

decoded_example = tokenizer.decode(inputs["input_ids"][idx])
print(f"Theoretical answer: {answer}, decoded example: {decoded_example}")

In [None]:
# 4.5 put all preprocess together
max_length = 384
stride = 128


def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length", #!!!!
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [None]:
# 4.6 apply the preprocessing to the whole training set

train_dataset = raw_datasets["train"].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)
len(raw_datasets["train"]), len(train_dataset)
# the preprocessing added roughly 1,000 samples
# i.e. many contexts are quite long

In [None]:
# 5. Processing the validation data
# preprocessing function
def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

# apply the preprocessing function
validation_dataset = raw_datasets["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
len(raw_datasets["validation"]), len(validation_dataset)
# a couple of hundred samples are added
# i.e. most of the contexts in the validation set are short

In [None]:
# 6. Fine-tuning the model with the Trainer API
# 6.1 a taste on a small validation set
small_eval_set = raw_datasets["validation"].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
import torch
from transformers import AutoModelForQuestionAnswering

# remove columns not needed by the model
eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
    device
)

with torch.no_grad():
    outputs = trained_model(**batch)

In [None]:
# format conversion
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

In [None]:
import collections

example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
    example_to_features[feature["example_id"]].append(idx)

In [None]:
# 6.2 pick the answer with the best logit score
import numpy as np

n_best = 20
max_answer_length = 30
predicted_answers = []

for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set["offset_mapping"][feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                # Skip answers with a length that is either < 0 or > max_answer_length.
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )

    best_answer = max(answers, key=lambda x: x["logit_score"])
    predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

In [None]:
# 6.3 evaluate
import evaluate

metric = evaluate.load("squad")

In [None]:
# theoretical/reference answer format
theoretical_answers = [
    {"id": ex["id"], "answers": ex["answers"]} for ex in small_eval_set
]

In [None]:
# check the predicted answer
print(predicted_answers[0])
print(theoretical_answers[0])

In [None]:
# check the score
metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [None]:
# put all together `compute_metrics`
from tqdm.auto import tqdm


def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [None]:
# evaluate our predictions
compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

In [None]:
# 7. Fine-tuning the model
# 7.1 load a pretrained the model
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

In [None]:
# 7.2 prepare TrainingArguments
from transformers import TrainingArguments

args = TrainingArguments(
    "bert-finetuned-squad",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    push_to_hub=False,
)

In [None]:
# 7.3 prepare the Trainer
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)

In [None]:
# launch the training:
trainer.train()

In [None]:
# evaluate the model
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets["validation"])

In [None]:
# 8. A custom training loop
# 8.1 Preparing everything for training
from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataset.set_format("torch")
validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
validation_set.set_format("torch")

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=8,
)
eval_dataloader = DataLoader(
    validation_set, collate_fn=default_data_collator, batch_size=8
)

In [None]:
# reinstantiate our model, 
# starting from the BERT pretrained model again:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

In [None]:
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=2e-5)

In [None]:
from accelerate import Accelerator

accelerator = Accelerator(fp16=True)
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [None]:
output_dir = "./"

In [None]:
# 8.2 Training loop
from tqdm.auto import tqdm
import torch

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    start_logits = []
    end_logits = []
    accelerator.print("Evaluation!")
    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
        end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())

    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    start_logits = start_logits[: len(validation_dataset)]
    end_logits = end_logits[: len(validation_dataset)]

    metrics = compute_metrics(
        start_logits, end_logits, validation_dataset, raw_datasets["validation"]
    )
    print(f"epoch {epoch}:", metrics)

    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        print(f"Training in progress epoch {epoch}")

In [None]:
from transformers import pipeline

# Load the model checkpoint from the current folder
model_checkpoint = "./"
question_answerer = pipeline("question-answering", model=model_checkpoint)

context = """
🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration
between them. It's straightforward to train your models with one before loading them for inference with the other.
"""
question = "Which deep learning libraries back 🤗 Transformers?"
print(question_answerer(question=question, context=context))