<a href="https://colab.research.google.com/github/rahiakela/getting-started-with-google-bert/blob/main/4-part-1-bert-variants/1_extracting_embeddings_with_ALBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Extracting embeddings with ALBERT

One of the challenges with BERT is that it consists of millions of parameters. BERT-base consists of 110 million parameters, which makes it harder to train, and it also has a high inference time. Increasing the model size gives us good results but it puts a limitation on the computational resources. To combat this, ALBERT was introduced. ALBERT is a lite
version of BERT with fewer parameters compared to BERT. It uses the following two techniques to reduce the number of parameters:

- Cross-layer parameter sharing
- Factorized embedding layer parameterization

By using the preceding two techniques, we can reduce the training time and inference time of the BERT model.


###Cross-layer parameter sharing

## Setup

In [None]:
%%capture
!pip install torch==1.4.0
!pip install nlp==0.4.0
!pip install transformers==3.5.1

In [None]:
import torch
import numpy as np

from transformers import BertForQuestionAnswering, BertTokenizer, Trainer, TrainingArguments
from nlp import load_dataset

## Question-answering

In a question-answering task, we are given a question along with a paragraph containing an answer to the question. Our goal is to extract the answer from the paragraph for the given question.

The input to the BERT model will be a question-paragraph pair. That is, we feed a question and a paragraph containing the answer to the question to BERT and it has to extract the answer from the paragraph. So, essentially, BERT has to return the text span that contains the answer from the paragraph. 

Let's understand this with an example – consider the following question-paragraph pair:

```
Question = "What is the immune system?"

Paragraph = "The immune system is a system of many biological structures and processes within an organism that protects against disease. To function properly, an immune system must detect a wide variety of agents, known as pathogens, from viruses to parasitic worms, and distinguish them from
the organism's own healthy tissue."
```

Now, our model has to extract an answer from the paragraph; it essentially has to return the text span containing the answer. So, it should return the following:

```
Answer = "a system of many biological structures and processes within an organism that protects against disease"
```

To do this, our model has to understand the starting and ending index of the text span containing the answer in the given paragraph. For example, take the question, "What is the immune system?" If our model understands that the answer to this question starts from index 4 ("a") and ends at index 21
("disease"), then we can get the answer as shown here:

```
Paragraph = "The immune system is **a system of many system of many biological structures and processes within an organism that protects against disease"** biological structures and processes within an organism that protects against disease. To function properly, an immune system must detect a wide variety of agents, known as pathogens, from viruses to parasitic worms, and distinguish them from the organism's own healthy tissue."
```

If we get the probability of each token (word) in the paragraph of being the starting and ending token (word) of the answer, then we can easily extract the answer, right? Yes, but how we can achieve this? To do this, we use two vectors called the start vector $S$ and the end vector $E$. The values of the start and end vectors will be learned during training.

First, we compute the probability of each token (word) in the paragraph being the starting token of the answer.

To compute this probability, for each token $i$, we compute the dot product between the representation of the token $R_i$ and the start vector $S$. Next, we apply the softmax function to the dot product $S.R_i$ and obtain the probability:

$$ P_i = \frac{e^{S.R_i}}{\sum_j{e^{S.R_j}}} $$

Next, we compute the starting index by selecting the index of the token that has a high probability of being the starting token.

In a very similar fashion, we compute the probability of each token (word) in the paragraph being the ending token of the answer. To compute this probability, for each token $i$, we compute the dot product between the representation of the token $R_i$ and the end vector $E$.

Next, we apply the softmax function to the dot product $E.R_i$ and obtain the probability:

$$ P_i = \frac{e^{E.R_i}}{\sum_j{e^{E.R_j}}} $$

Next, we compute the ending index by selecting the index of the token that has a high probability of being the ending token. Now, we can select the text span that contains the answer using the starting and ending index.

As shown, first, we tokenize the question-paragraph pair and feed
the tokens to the pre-trained BERT model, which returns the embeddings of all the tokens.
As shown, $R_1$ to $R_N$ denotes the embeddings of the tokens in the question and $R^_1$ to $R_M$ denotes the embedding of the tokens in the paragraph.

After computing the embedding, we compute the dot product with the start/end vectors, apply the softmax function, and obtain the probabilities of each token in the paragraph being the start/end word as shown here:

<img src='https://github.com/rahiakela/img-repo/blob/master/getting-started-with-google-bert/fine-tuning-question-answering.png?raw=1' width='800'/>

We can see how we compute the probability of each token in the paragraph being the start/end word. Next, we select the text span containing the answer
using the starting and ending indexes with the highest probability.



## Loading the model

We use the `bert-large-uncased-whole-wordmasking-fine-tuned-squad` model, which is fine-tuned on the **Stanford Question- Answering Dataset (SQUAD)**:

In [None]:
model = BertForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

Next, we download and load the tokenizer.

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

Now that we have downloaded the model and tokenizer, let's preprocess the input.

## Preprocessing the dataset

First, we define the input to BERT, which is the question and paragraph text:

In [None]:
question = "What is the immune system?"
paragraph = "The immune system is a system of many biological structures and processes within an organism that protects against disease. To function properly, an immune system must detect a wide variety of agents, known as pathogens, from viruses to parasitic worms, and distinguish them from the organism's own healthy tissue."

Add a `[CLS]` token to the beginning of the question and an `[SEP]` token to the end of both the question and the paragraph:

In [None]:
question = "[CLS] " + question + "[SEP]"
paragraph = paragraph + "[SEP]"

Now, tokenize the question and paragraph:

In [None]:
question_tokens = tokenizer.tokenize(question)
paragraph_tokens = tokenizer.tokenize(paragraph)

Combine the question and paragraph tokens and convert them to `input_ids`:

In [None]:
tokens = question_tokens + paragraph_tokens
input_ids = tokenizer.convert_tokens_to_ids(tokens)
print(input_ids)

[101, 2054, 2003, 1996, 11311, 2291, 1029, 102, 1996, 11311, 2291, 2003, 1037, 2291, 1997, 2116, 6897, 5090, 1998, 6194, 2306, 2019, 15923, 2008, 18227, 2114, 4295, 1012, 2000, 3853, 7919, 1010, 2019, 11311, 2291, 2442, 11487, 1037, 2898, 3528, 1997, 6074, 1010, 2124, 2004, 26835, 2015, 1010, 2013, 18191, 2000, 26045, 16253, 1010, 1998, 10782, 2068, 2013, 1996, 15923, 1005, 1055, 2219, 7965, 8153, 1012, 102]


Next, we define `segment_ids`. 

Now, `segment_ids` will be 0 for all the tokens of the question and 1 for all the tokens of the paragraph:

In [None]:
segment_ids = [0] * len(question_tokens)
segment_ids += [1] * len(paragraph_tokens)
print(segment_ids)

[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


Now we convert `input_ids` and `segment_ids` to tensors:

In [None]:
input_ids = torch.tensor([input_ids])
segment_ids = torch.tensor([segment_ids])

print(input_ids)
print(segment_ids)

tensor([[  101,  2054,  2003,  1996, 11311,  2291,  1029,   102,  1996, 11311,
          2291,  2003,  1037,  2291,  1997,  2116,  6897,  5090,  1998,  6194,
          2306,  2019, 15923,  2008, 18227,  2114,  4295,  1012,  2000,  3853,
          7919,  1010,  2019, 11311,  2291,  2442, 11487,  1037,  2898,  3528,
          1997,  6074,  1010,  2124,  2004, 26835,  2015,  1010,  2013, 18191,
          2000, 26045, 16253,  1010,  1998, 10782,  2068,  2013,  1996, 15923,
          1005,  1055,  2219,  7965,  8153,  1012,   102]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


Now that we have processed the input, let's feed it to the model and get the result.

## Getting the answer

We feed the `input_ids` and `segment_ids` to the model which return the start score and end score for all of the tokens:

In [None]:
start_scores, end_scores = model(input_ids, token_type_ids=segment_ids)

Now, we select the `start_index` which is the index of the token which has a maximum start score and `end_index` which is the index of the token which has a maximum end score:

In [None]:
start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)

That's it! Now, we print the text span between the start and end indexes as our answer:

In [None]:
print(" ".join(tokens[start_index: end_index + 1]))

a system of many biological structures and processes within an organism that protects against disease
